In [None]:
%matplotlib inline

import time
import random

import numpy as np

from matplotlib import pyplot

import torch
from torch.utils.data import DataLoader, random_split

from utils.torchutils import ImageListDataset, UNet

In [None]:
#https://github.com/majedelhelou/denoising_datasets
data_path = 'denoising_datasets-main/Set14'

In [None]:
images_dataset = ImageListDataset(data_path, std=0.5, size=256, mode='L')

In [None]:
fold = 4

len_trainset = (fold-1)*len(images_dataset)//fold
len_testset = len(images_dataset) - len_trainset

train_dataset, test_dataset = random_split(
    images_dataset, 
    [len_trainset, len_testset]
)

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=True
)

In [None]:
#net = ConvNet()
net = UNet([1,64,128], n_classes=1, double_conv=True)

In [None]:
criterion = torch.nn.MSELoss()
#criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(net.parameters())

In [None]:
print('Doing one iteration through train set as a sanity check.')

tic = time.perf_counter()
net.train_test_epoch(train_dataloader, optimizer, criterion, grad_enabled=False)
toc = time.perf_counter()

print(f'Took {toc-tic : 0.4f}s for 1 epoch of {len(train_dataset)} images.')

In [None]:
net.fit(train_dataloader, test_dataloader, optimizer, criterion, n_epochs=20, patience=20)

In [None]:
input, target = random.choice(train_dataset)
net.eval()

with torch.no_grad():
    output = net(input[None,:])[0,:]

In [None]:
pyplot.figure(figsize=(15,15))

pyplot.subplot(2,2,1)
pyplot.imshow(torch.permute(target,(1,2,0)),cmap='gray')
pyplot.colorbar()
pyplot.title('target')

pyplot.subplot(2,2,2)
pyplot.imshow(torch.permute(input,(1,2,0)),cmap='gray')
pyplot.colorbar()
pyplot.title(f'input (loss ={criterion(input, target).item() : 0.4f})')

pyplot.subplot(2,2,3)
pyplot.imshow(torch.permute(abs(target-output),(1,2,0)),cmap='gray')
pyplot.colorbar()
pyplot.title('|target-predicted|')

pyplot.subplot(2,2,4)
pyplot.imshow(torch.permute(output,(1,2,0)),cmap='gray')
pyplot.colorbar()
pyplot.title(f'predicted (loss ={criterion(output, target).item() : 0.4f})')

pyplot.show()

