Test model

In [None]:
import torch
from model import Model
from others.psnr import compute_psnr
from others.unet import UNet
from others.autoencoder import AutoEncoder
from others.rednet import REDNet
from torch.optim import Adam, SGD
from custom_model import MyModel

In [None]:
path_train = '../data/train_data.pkl'
path_val = '../data/val_data.pkl'
noisy_imgs_1, noisy_imgs_2 = torch.load(path_train, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
noisy_imgs_1, noisy_imgs_2 = noisy_imgs_1.float(), noisy_imgs_2.float()
test, truth = torch.load(path_val, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
test, truth = test.float(), truth.float()

In [None]:
def sample(tensor1, tensor2, k):
    perm = torch.randperm(tensor1.size(0))
    idx = perm[:k]
    return tensor1[idx], tensor2[idx]

## Training and testing

In [None]:
models = ["UNet", "REDNet", "AutoEncoder"]
batch_sizes = [5, 10, 20]
epochs = 10
s1, s2 = sample(noisy_imgs_1, noisy_imgs_2, 1000)
t1, t2 = sample(test, truth, 1000)


In [None]:
def test_and_train(model, batch):
    network = UNet() if model == "UNet" else (REDNet() if model == "REDNet" else AutoEncoder())
    m = MyModel(network, batch)
    m.train(s1, s2, 10)
    return compute_psnr(m.predict(t1), t2), m

In [None]:
results = []
for model in models:
  for batch_size in batch_sizes:    
    error, m = test_and_train(model, batch_size)
    results.append(f'{model}_Layers5_Batch{batch_size}_Epochs10_Sample1000_{error}')
    #torch.save(m.model.state_dict(), f'{model}_Layers10_Batch{batch_size}_Epochs10_Sample1000_{error}.pth')

In [None]:
results

In [8]:
results

['UNet_Layers5_Batch5_Epochs10_Sample1000_24.809358596801758',
 'UNet_Layers5_Batch10_Epochs10_Sample1000_24.92894744873047',
 'UNet_Layers5_Batch20_Epochs10_Sample1000_24.560302734375',
 'REDNet_Layers5_Batch5_Epochs10_Sample1000_24.46434211730957',
 'REDNet_Layers5_Batch10_Epochs10_Sample1000_24.47145652770996',
 'REDNet_Layers5_Batch20_Epochs10_Sample1000_24.19704818725586',
 'AutoEncoder_Layers5_Batch5_Epochs10_Sample1000_21.843154907226562',
 'AutoEncoder_Layers5_Batch10_Epochs10_Sample1000_21.672786712646484',
 'AutoEncoder_Layers5_Batch20_Epochs10_Sample1000_15.181829452514648']