## Helper functions

In [83]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [84]:
import numpy as np

from others.implementations import *
import torch
from target_model import *
from torch import optim


In [85]:
def generate_model(shallow_hidden_channels, deep_hidden_channels):
    return  nn.Sequential(
            nn.Conv2d(3, shallow_hidden_channels, kernel_size=2, stride=2, padding=0, bias=True),
            nn.ReLU(),
            nn.Conv2d(shallow_hidden_channels, deep_hidden_channels, kernel_size=2, stride=2, padding=0, bias=True),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(deep_hidden_channels, shallow_hidden_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(shallow_hidden_channels, 3, kernel_size=3, stride=1, padding=1, bias=True),
            nn.Sigmoid()
            )

def test_and_train(batch_size, hidden_channels, lr, momentum, nesterov, epochs):
    network = generate_model(shallow_hidden_channels=hidden_channels[0], deep_hidden_channels=hidden_channels[1])
    optimizer = optim.SGD(network.parameters(), lr=lr, momentum=momentum, nesterov=nesterov)

    # Instantiate model and replace network and optimizer
    m = Model()
    m.model = network
    m.optimizer = optimizer
    m.batch_size = batch_size

    # Train
    m.train(s1, s2, epochs)
    return compute_psnr(m.predict(t1)/255.0, t2), m

def sample(tensor1, tensor2, k):
    perm = torch.randperm(tensor1.size(0))
    idx = perm[:k]
    return tensor1[idx], tensor2[idx]

## Loading data

In [86]:
path_train = '../data/train_data.pkl'
path_val = '../data/val_data.pkl'
noisy_imgs_1, noisy_imgs_2 = torch.load(path_train)
noisy_imgs_1, noisy_imgs_2 = noisy_imgs_1.float(), noisy_imgs_2.float() / 255.0
test, truth = torch.load(path_val)
test, truth = test.float(), truth.float() / 255.0

## Parameters to tune

In [87]:
# Channels parameter
shallow_hidden_channels = [32]
deep_hidden_channels = [128]
hidden_channels = [(shallow_channel, deep_channel) for shallow_channel in shallow_hidden_channels for deep_channel in deep_hidden_channels if shallow_channel <= deep_channel]

# Optimizer parameters
# lrs = np.logspace(-7, -1, 5)
lrs = [1e-6]
momentums = [0.9]
nesterovs = [True]

# Batch sizes
batch_sizes = [10]

## Training

In [88]:
s1, s2 = sample(noisy_imgs_1, noisy_imgs_2, 5000)
t1, t2 = sample(test, truth, 5000)

In [89]:
epochs = 20

results = dict()
for lr in lrs:
    for momentum in momentums:
        for nesterov in nesterovs:
            for hidden_channel in hidden_channels:
                for batch_size in batch_sizes:
                    description = f'lr{lr}_HiddenChannels{hidden_channel}_Batch{batch_size}_Epochs{epochs}_Momentum{momentum}_Nesterov{nesterov}_Sample1000'
                    print("\nSTARTING TRAINING FOR:", description)
                    error, m = test_and_train(batch_size, hidden_channel, lr, momentum, nesterov, epochs)
                    description = f'lr{lr}_HiddenChannels{hidden_channel}_Batch{batch_size}_Epochs{epochs}_Momentum{momentum}_Nesterov{nesterov}_Sample1000_{error}'
                    results[description] = m
                    print(f'PSNR: {error}')


STARTING TRAINING FOR: lr1e-06_HiddenChannels(32, 128)_Batch10_Epochs20_Momentum0.9_NesterovTrue_Sample1000
EPOCH 0 --- LOSS 0.3729703426361084
EPOCH 1 --- LOSS 0.031113997101783752
EPOCH 2 --- LOSS 0.020727170631289482
EPOCH 3 --- LOSS 0.015968862920999527
EPOCH 4 --- LOSS 0.01271134428679943
EPOCH 5 --- LOSS 0.010362786240875721
EPOCH 6 --- LOSS 0.008750085718929768
EPOCH 7 --- LOSS 0.007626623380929232
EPOCH 8 --- LOSS 0.006809890735894442
EPOCH 9 --- LOSS 0.006190726533532143
EPOCH 10 --- LOSS 0.005704669281840324
EPOCH 11 --- LOSS 0.005312198773026466
EPOCH 12 --- LOSS 0.004987902473658323
EPOCH 13 --- LOSS 0.0047148000448942184
EPOCH 14 --- LOSS 0.004481127019971609
EPOCH 15 --- LOSS 0.004278520587831736
EPOCH 16 --- LOSS 0.0041008200496435165
EPOCH 17 --- LOSS 0.003943456336855888
EPOCH 18 --- LOSS 0.0038029036950320005
EPOCH 19 --- LOSS 0.003676437307149172
FINAL LOSS 0.0035619009286165237
PSNR: 6.448063850402832


## Final results