In [1]:
%load_ext autoreload
%autoreload 2

from torch import empty, cat, arange
from torch.nn.functional import fold, unfold
from torch import nn


from model import *
import pickle


In [66]:
# loading image data
noisy_imgs_1, noisy_imgs_2 = torch.load('../../data/train_data.pkl')
noisy_imgs, clean_imgs = torch.load('../../data/val_data.pkl')

train_input, train_target = noisy_imgs_1.float()/255.0, noisy_imgs_2.float()/255.0 
test_input, test_target = noisy_imgs.float()/255.0, clean_imgs.float()/255.0

In [67]:
# take batch of size 10
input_batch = train_input[:100,:,:,:]
target_batch = train_target[:10,:,:,:]
input_batch.shape

torch.Size([100, 3, 32, 32])

In [68]:
# create some tensors for simple tests
x = torch.randn((1, 3, 4, 4))

# y = torch.ones(x.shape)
y = torch.ones((1, 3, 4, 4))

### Testing our Conv2d Layer 

Forward pass

In [69]:
input_batch.shape, target_batch.shape

(torch.Size([100, 3, 32, 32]), torch.Size([10, 3, 32, 32]))

In [70]:
# testing our conv2d layer
my_conv = Conv2d(3, 4, 2, stride=2)
output = my_conv.forward(input_batch)
output.shape

torch.Size([100, 4, 16, 16])

In [71]:
torch.testing.assert_allclose(my_conv.forward(input_batch), torch.nn.functional.conv2d(input_batch, my_conv.weight, my_conv.bias, stride=2))

In [113]:
x = torch.randn(1, 3, 32, 32)
conv = Conv2d(3, 3, 3)
torch.testing.assert_allclose(conv.forward(x), torch.nn.functional.conv2d(x, conv.weight, conv.bias))

Backward pass

In [72]:
back = my_conv.backward(output)
back.shape

torch.Size([100, 3, 32, 32])

### Testing our Upsampling (TransposeConv2d) layer

Forward pass

In [75]:
x = input_batch[:1,:,:,:]
x.shape

torch.Size([1, 3, 32, 32])

In [83]:
my_t_conv = Upsampling(3, 5, 2, stride=2)
upsampled = my_t_conv.forward(x)
upsampled.shape

torch.Size([1, 5, 64, 64])


torch.Size([1, 5, 64, 64])

In [77]:
torch.testing.assert_allclose(my_t_conv.forward(x), torch.nn.functional.conv_transpose2d(x, my_t_conv.weight, my_t_conv.bias, stride=2))

torch.Size([20, 1024])
torch.Size([5, 64, 64])


Backward pass

In [14]:
back_t = my_t_conv.backward(upsampled)
back_t.shape

torch.Size([10, 3, 32, 32])

### Sequential model testing

In [16]:
model = Sequential(Conv2d(3,4,2), Upsampling(4,3,2, stride=2))
forward_pass = model.forward(x)
backward_pass = model.backward(forward_pass)
forward_pass.shape, backward_pass.shape

(torch.Size([1, 3, 6, 6]), torch.Size([1, 3, 4, 4]))

### Testing model

In [101]:
train_input_batch = train_input[:1000, :, :, :]
train_target_batch = train_target[:1000, :, :, :]
train_input_batch.shape, train_target_batch.shape

(torch.Size([1000, 3, 32, 32]), torch.Size([1000, 3, 32, 32]))

Train the model 

In [105]:
model = Model()
model.train(train_input_batch, train_target_batch, 50)

0 iteration: loss=0.7427013516426086
1 iteration: loss=0.7255274653434753
2 iteration: loss=0.7243703007698059
3 iteration: loss=0.7242486476898193
4 iteration: loss=0.7241997122764587
5 iteration: loss=0.7241531014442444
6 iteration: loss=0.7241024971008301
7 iteration: loss=0.7240540981292725
8 iteration: loss=0.7240176796913147
9 iteration: loss=0.7239911556243896
10 iteration: loss=0.7239688634872437
11 iteration: loss=0.7239482402801514
12 iteration: loss=0.72392737865448
13 iteration: loss=0.7238973379135132
14 iteration: loss=0.7238321900367737
15 iteration: loss=0.7237774729728699
16 iteration: loss=0.7237519025802612
17 iteration: loss=0.7237323522567749
18 iteration: loss=0.7237142324447632
19 iteration: loss=0.7236966490745544
20 iteration: loss=0.723679780960083
21 iteration: loss=0.7236634492874146
22 iteration: loss=0.7236475944519043
23 iteration: loss=0.7236322164535522
24 iteration: loss=0.7236173748970032
25 iteration: loss=0.7236026525497437
26 iteration: loss=0.7235

Prediction and psnr score

In [106]:
def psnr(denoised,ground_truth):
    mse=torch.mean((denoised-ground_truth)**2)
    return  - 10 * torch.log10(((denoised-ground_truth) ** 2).mean((1,2,3))).mean()

In [107]:
prediction = model.predict(test_input)
psnr(prediction, test_target)

tensor(12.6920)

Save the model

In [121]:
pickle.dump(model.model.param(), open('bestmodel.pkl', 'wb'))