In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
from sklearn.model_selection import train_test_split
from tqdm import notebook, tqdm

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchsummary import summary
import torchvision.models as models
from torch.optim import lr_scheduler

from model_kpn import KPN, LossBasic
#from model_baseline import Unet
#from model_gdfn import GDFN

from data import ims, ims_noise

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
N_ims, h, w, color = ims.shape
ims = ims[:N_ims].astype(np.float32).transpose(0,3,1,2)
ims_noise = ims_noise[:N_ims].astype(np.float32).transpose(0,3,1,2)

In [None]:
# training hyperparameters
batch_size = 8
lr = 3e-4
epochs = 50
test_size = 0.1

In [None]:
# train test split
train_X, train_Y = ims_noise, ims
train_X, test_X, train_Y, test_Y = train_test_split(train_X, train_Y, test_size=test_size, random_state=42)

train_X = train_X[:,np.newaxis,...]
test_X = test_X[:,np.newaxis,...]

print('Training X: ', train_X.shape, train_X.dtype, train_X.max(), train_X.min())
print('Training Y: ', train_Y.shape, train_Y.dtype, train_Y.max(), train_Y.min())
print('Testing X: ', test_X.shape, test_X.dtype, test_X.max(), test_X.min())
print('Testing Y: ', test_Y.shape, test_Y.dtype, test_Y.max(), test_Y.min())

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
        
    def __getitem__(self, index):      #这个方法是必须要有的，用于按照索引读取每个元素的具体内容
        return self.X[index], self.Y[index]
    
    def __len__(self):                 #这个函数也必须要写，它返回的是数据集的长度，也就是多少张图片，要和loader的长度作区分
        return len(self.X)
        
train_set = MyDataset(train_X, train_Y)
test_set = MyDataset(test_X, test_Y)

def collate(batch): 
    inputs = torch.FloatTensor([item[0] for item in batch])
    target = torch.FloatTensor([item[1] for item in batch])
    return inputs, target

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=collate)

In [None]:
'''please choose the model from 'kpn', 'unet' and 'gdfn' '''
choice = 'gdfn'

if choice == 'kpn':
    model = KPN(color=False, burst_length=1, blind_est=True, kernel_size=[3], sep_conv=False, 
                     channel_att=False, spatial_att=True, core_bias=False).to(device)
    mode = 1
elif choice == 'unet':
    model = Unet(color=False, blind_est=True, channel_att=False, spatial_att=False, core_bias=False).to(device)
    mode = 2
elif choice == 'gdfn':
    model = GDFN(filter_size = (3,3), color=False, blind_est=True, channel_att=False, spatial_att=False, core_bias=False).to(device)
    mode = 2
else:
    assert
    
print('# model parameters:', sum(param.numel() for param in model.parameters()))

In [None]:
if_load = False
if if_load:
    model.load_state_dict(torch.load(r'./model_weights/kpn_satt.pkl'))
    #model.eval()

In [None]:
'''optimizer and loss function'''
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
#scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)  # step_size == epoch
                                
#loss_func = LossBasic(gradient_L1 = True)
loss_func = nn.L1Loss()

In [None]:
train_losses = []
train_steps = 0

test_losses = []
lrs = []

In [None]:
for epoch in range(epochs):
    # train loss
    model.train()
    total_train_loss = total =0
    progress_bar = notebook.tqdm(train_loader, desc='Training', leave=False) # desc应该是开头的文字提示
    for inputs, target in progress_bar:
        total += 1
        inputs, target = inputs.to(device), target.to(device)

        optimizer.zero_grad()
        if mode == 1:
            outputs,_ = model(inputs, inputs)
        elif mode == 2:
            outputs = model(inputs)
        else:
            assert
        
        loss = loss_func(outputs, target)

        loss.backward() 
        optimizer.step()
        
        total_train_loss += loss.item()
        train_losses.append(loss.item())
        train_steps += 1
        progress_bar.set_description(f'Loss: {loss.item():.5f}')
    
    total_train_loss /= total
    
    # test
    model.eval()
    total_test_loss = total = 0
    for inputs, target in test_loader:
        total += 1
        inputs, target = inputs.to(device), target.to(device)

        if mode == 1:
            outputs,_ = model(inputs, inputs)
        elif mode == 2:
            outputs = model(inputs)
        else:
            assert
        
        test_loss = loss_func(outputs, target).item()
        total_test_loss += test_loss
    total_test_loss /= total
    test_losses.append(total_test_loss)
    
    tqdm.write(f'Epoch #{epoch + 1:3d}\tTrain Loss: {total_train_loss:.5f}\tTest Loss: {total_test_loss:.5f}')
    
    lrs.append(optimizer.param_groups[0]['lr'])
    scheduler.step()

In [None]:
# 发现用tensorboard summary会让速度变得很慢很慢
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = './logs/'

# draw train info
plt.figure(figsize=(15,5))
plt.subplot(131)
plt.plot(range(train_steps), np.log(train_losses))
plt.xlabel('steps')
plt.ylabel('value in logarithm')
plt.title('training loss')

plt.subplot(132)
plt.plot(range(epochs), np.log(test_losses))
plt.xlabel('epoch')
plt.ylabel('value in logarithm')
plt.title('test loss')

plt.subplot(133)
plt.plot(range(epochs), lrs)
plt.xlabel('epoch')
plt.ylabel('value')
plt.title('learning rate')

plt.savefig(log_dir+'kpn_'+current_time+'.png')
plt.show()

In [None]:
# draw test images
model.eval()
test_X, test_Y = next(iter(test_loader))
test_X, test_Y = test_X.to(device), test_Y.to(device)

if mode == 1:
    pred_Y,_ = model(test_X, test_X)
elif mode == 2:
    pred_Y = model(test_X)
else:
    assert

plt.figure(figsize = (15,5*batch_size))
i = 1
for test_x, test_y, pred_y in zip(test_X, test_Y, pred_Y):
    plt.subplot(batch_size,3,i)
    plt.imshow(test_x.cpu().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1
    
    plt.subplot(batch_size,3,i)
    plt.imshow(test_y.cpu().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1
    
    plt.subplot(batch_size,3,i)
    plt.imshow(pred_y.cpu().detach().numpy().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1
plt.savefig('./results/images/kpn_'+current_time+'.png')
plt.show()

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def error(x1, x2, mode='mse'):
    if mode == 'mse':
        return np.mean(np.square(x1-x2))
    elif mode == 'mae':
        return np.mean(np.abs(x1-x2))
    return

In [None]:
test_X = []
test_Y = []
pred_Y = []
for inputs, target in test_loader:
    test_X.append(inputs.numpy())
    test_Y.append(target.numpy())
    
    inputs, target = inputs.to(device), target.to(device)
    
    if mode == 1:
        outputs,_ = model(inputs, inputs)
    elif mode == 2:
        outputs = model(inputs)
    else:
        assert
    
    pred_Y.append(outputs.cpu().detach().numpy())

test_X = np.concatenate(test_X, axis=0)
test_Y = np.concatenate(test_Y, axis=0)
pred_Y = np.concatenate(pred_Y, axis=0)

print('Evaluation of ground truth and noised images:')
print('psnr:{:.3f}\tssmi:{:.3f}\tmse:{:.3f}'.format(psnr(test_X.squeeze(), test_Y.squeeze(), data_range=1), 
                                        ssim(test_X.squeeze(), test_Y.squeeze(), data_range=1),
                                        error(test_X.squeeze(), test_Y.squeeze())))

print('\nEvaluation of recovered images and noised images:')
print('psnr:{:.3f}\tssmi:{:.3f}\tmse:{:.3f}'.format(psnr(pred_Y, test_Y, data_range=1), 
                                        ssim(pred_Y.squeeze(), test_Y.squeeze(), data_range=1),
                                        error(pred_Y.squeeze(), test_Y.squeeze())))

print('\nGround Truth:')
print('max:{:.3f}\tmin:{:.3f}\tmean:{:.3f}'.format(test_Y.max(), test_Y.min(), test_Y.mean()))

print('\nNoised images:')
print('max:{:.3f}\tmin:{:.3f}\tmean:{:.3f}'.format(test_X.max(), test_X.min(), test_X.mean()))

print('\nRecoverd images:')
print('max:{:.3f}\tmin:{:.3f}\tmean:{:.3f}'.format(pred_Y.max(), pred_Y.min(), pred_Y.mean()))

In [None]:
import os

root = r'./model_weights'
if not os.path.exists(root):
    os.makedirs(root)

torch.save(model.state_dict(), root+'/kpn.pkl')