In [1]:
import sys
sys.path.append('../')

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import torch
import torch.nn as nn
from torchsummary import summary
import model
from utils import evaluate_psnr

In [3]:
from importlib import reload
reload(model);

In [4]:
if torch.cuda.is_available():
  device = torch.device('cuda')
  print('Using GPU. ✅')
else:
  device = torch.device('cpu')
  print('Using CPU ❌ 😭')

Using CPU ❌ 😭


## Load data

In [5]:
# Load train data
TRAIN_PATH = '../../data/train_data.pkl'
train_noisy_imgs_input, train_noisy_imgs_target = torch.load(TRAIN_PATH)
print('training size: ', train_noisy_imgs_input.shape[0])
training_set_size, num_channels, y_size, x_size = train_noisy_imgs_input.shape

training size:  50000


In [6]:
# Load validation data
VALIDATION_PATH = '../../data/val_data.pkl'
validation_noisy_imgs, validation_clean_imgs = torch.load(VALIDATION_PATH)
print('validation size: ', validation_clean_imgs.shape[0])

validation size:  1000


## Torch implementation

In [7]:
class TorchModel():
    def __init__(self, **kwargs):
        self.model = nn.Sequential(nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 16, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(16, 3, 3, padding=1),
            nn.Sigmoid()
            )
        self.criterion = nn.MSELoss()

        optimizer = kwargs.get('optimizer', 'Adam')
        if(optimizer == 'Adam'):
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-2)
        elif(optimizer == 'SGD'):
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1)
        else:
            raise ValueError('Optimizer not implemented')

    def load_pretrained_model(self):
        ## This loads the parameters saved in bestmodel.pth into the model
        self.model = torch.load('bestmodel.pth')

    def train(self, train_input, train_target, num_epochs, **kwargs):
        # train ̇input: tensor of size (N, C, H, W) containing a noisy version of the images.
        # train target: tensor of size (N, C, H, W) containing another noisy version of the
        # same images, which only differs from the input by their noise.
        train_input = train_input / 255.0
        train_target = train_target / 255.0
        batch_size = kwargs.get('batch_size', 32)
        self.losses = []
        avg_loss = 0
        
        for e in range(num_epochs):
            print('Doing epoch %d'%e)
            for b, (input, target) in enumerate(zip(train_input.split(batch_size),
                                                    train_target.split(batch_size))):
                output = self.model(input)
                loss = self.criterion(output, target)
                avg_loss += loss.item()

                # make step
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                b_freq = 5
                if b % b_freq == 0 and (b+e) > 0:
                    self.losses.append(avg_loss / b_freq)
                    avg_loss = 0
                    b % 50 == 0 and kwargs.get('debug', False) and print(self.losses[-1])

    def predict(self, test_input):
        #:test ̇input: tensor of size (N1, C, H, W) that has to be denoised by the trained
        # or the loaded network.
        #: returns a tensor of the size (N1, C, H, W)
        return self.model(test_input / 255.0) * 255.0

In [8]:
summary(TorchModel().model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             448
              ReLU-2           [-1, 16, 16, 16]               0
            Conv2d-3             [-1, 32, 8, 8]           4,640
              ReLU-4             [-1, 32, 8, 8]               0
          Upsample-5           [-1, 32, 16, 16]               0
            Conv2d-6           [-1, 16, 16, 16]           4,624
              ReLU-7           [-1, 16, 16, 16]               0
          Upsample-8           [-1, 16, 32, 32]               0
            Conv2d-9            [-1, 3, 32, 32]             435
          Sigmoid-10            [-1, 3, 32, 32]               0
Total params: 10,147
Trainable params: 10,147
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.39
Params size (MB): 0.04
Estimated Tot

## Training

In [9]:
if not(os.path.exists('modules')):
    os.mkdir('modules')

In [13]:
def set_model(model_name):
  if (model_name=='Ours_Adam'):
    torch.set_grad_enabled(False)
    net = model.Model(optimizer='Adam')
  elif (model_name=='Ours_SGD'):
    torch.set_grad_enabled(False)
    net = model.Model(optimizer='SGD')
  elif (model_name=='Torch_Adam'):
    torch.set_grad_enabled(True)
    net = TorchModel(optimizer='Adam')
  elif (model_name=='Torch_SGD'):
    torch.set_grad_enabled(True)
    net = TorchModel(optimizer='SGD')
  else:
    raise ValueError('Invalid model name!')

  net.model.to(device)
  net.criterion.to(device)
  return net

In [14]:
# model_name = 'Ours_Adam'
# model_name = 'Ours_SGD'
model_name = 'Torch_Adam'
# model_name = 'Torch_SGD'

In [None]:
nb_epochs = 10

train_start_time_ms = time.perf_counter()
net = set_model(model_name)
net.train(train_noisy_imgs_input, train_noisy_imgs_target, nb_epochs)
train_end_time_ms = time.perf_counter()
training_time = train_end_time_ms - train_start_time_ms
print('Training took: ', training_time, 's')

# Save model 
checkpoint = {'model': net.model.state_dict(),
              'optimizer' : net.optimizer.state_dict(),
              'losses' : net.losses
              }
torch.save(checkpoint, 'models/' + model_name + '_' + '.pth')

## Evaluation

In [26]:
if not(os.path.exists('figures')):
    os.mkdir('figures')

In [30]:
def load_model(model_name, model_path):
    if (model_name=='Ours_Adam'):
        net = model.Model(optimizer='Adam')
    elif (model_name=='Ours_SGD'):
        net = model.Model(optimizer='SGD')
    elif (model_name=='Torch_Adam'):
        net = TorchModel(optimizer='Adam')
    elif (model_name=='Torch_SGD'):
        net = TorchModel(optimizer='SGD')
    else:
        raise ValueError('Invalid model name!')

    checkpoint = torch.load(model_path, map_location=device)
    net.model.load_state_dict(checkpoint['model'])
    net.batch_size = checkpoint['batch_size']
    net.losses = checkpoint['losses']
    return net

In [None]:
model_name = 'Ours_Adam'
# model_name = 'Ours_SGD'
# model_name = 'Torch_Adam'
# model_name = 'Torch_SGD'
net = load_model(model_name, model_path='models/' + model_name + '_' + '.pth')

In [None]:
evaluate_psnr(net, validation_noisy_imgs, validation_clean_imgs)

In [None]:
fig = plt.figure()
xs = np.arange(len(net.losses)) / train_noisy_imgs_input.shape[0] * 32 * 5
plt.plot(xs, net.losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

In [None]:
test_input = validation_noisy_imgs
test_target = validation_clean_imgs

nr = 3
denoised = net.predict(test_input[nr, :, :, :].unsqueeze(dim=0))

fig, ax = plt.subplots(1, 3, figsize=(13,5))
ax[0].imshow(test_input[nr,:,:,:].permute(1,2,0).cpu(), cmap='gray')
ax[0].set_title('Input')
ax[1].imshow(denoised[0,:,:,:].permute(1,2,0).cpu().detach().numpy(), cmap='gray')
ax[1].set_title('Denoised')
ax[2].imshow(test_target[nr, :, :, :].permute(1,2,0).cpu(), cmap='gray')
ax[2].set_title('Target')
plt.show()