In [1]:
import sys
sys.path.append('../')
import matplotlib.pyplot as plt

from Miniproject_2.model import *

import torch
from torch.nn import functional as F

torch.set_grad_enabled(True);

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

valid_input, valid_target = torch.load('../val_data.pkl',map_location=device)#validation set (noise-clean)
train_input, train_target = torch.load('../train_data.pkl',map_location=device) #test set (noise-noise)

num_samples = 1000
valid_input=torch.narrow(valid_input,0,0,num_samples)
valid_target=torch.narrow(valid_target,0,0,num_samples)
train_input=torch.narrow(train_input,0,0,num_samples)
train_target=torch.narrow(train_target,0,0,num_samples)

# plt.imshow(valid_input[164].permute(1,2,0))

print("Vector shape: ",train_input.shape)

Vector shape:  torch.Size([1000, 3, 32, 32])


In [3]:
def compare(x,y,decimals=7):
    return torch.all(torch.round(torch.abs(x - y), decimals=decimals)==0.).item()

## Test MSE loss derivative

In [4]:
# class MSELoss(Module):
#     def __init__(self, *input):
#         super(MSELoss,self).__init__()
#         self.input = None
#         self.reference = None
    
#     def forward(self,input,reference):
#         self.input     = input
#         self.reference = reference
#         n              = input.size().numel()
#         output         = ((input-reference)**2).sum()/n
#         return output

#     __call__ = forward

#     def backward(self):
#         return 2*(self.input - self.reference)/self.input.size().numel()



In [5]:
# Test MSE loss
idx = 164
len = 1
y=(valid_input[idx].float()/255.).requires_grad_()
y_true=(valid_target[idx].float()/255.).requires_grad_()

# fig,ax = plt.subplots(1,2, figsize=(8,8))
# ax[0].imshow(input.permute(1,2,0))
# ax[1].imshow(target.permute(1,2,0))
L = MSELoss()
out = L.forward(y,y_true)

L_true = F.mse_loss(y,y_true, reduction='mean')
print('Output: ', L_true, '\nTrue output: ', out)


Output:  tensor(0.0110, grad_fn=<MseLossBackward0>) 
True output:  0.011031925678253174


In [6]:
dL_dy = torch.autograd.grad(L_true,(y))[0]
dLdy = L.backward()

print(compare(dL_dy, dLdy, decimals=3))
print(dLdy.shape)

True
torch.Size([3, 32, 32])


In [7]:
# plt.imshow(out[0].permute(1,2,0)[:,:,::2])

## Test Sequential

In [8]:
f = torch.empty(5,3,3,3)

f[0, 0] = torch.tensor([ [ +0., +0., -1. ], [ +0., +1., +0. ], [ -1., +0., +0. ]])
f[1, 0] = torch.tensor([ [ +1., +1., +1. ], [ +1., +1., +1. ], [ +1., +1., +1. ]])
f[2, 0] = torch.tensor([ [ -1., +0., +1. ], [ -1., +0., +1. ], [ -1., +0., +1. ]])
f[3, 0] = torch.tensor([ [ -1., -1., -1. ], [ +0., +0., +0. ], [ +1., +1., +1. ]])
f[4, 0] = torch.tensor([ [ +0., -1., +0. ], [ -1., +4., -1. ], [ +0., -1., +0. ]])

for j in range(0,5):
    for i in range(1,3):
        f[j,i] = f[j,0]

ff = f.transpose(0,1)
ff.requires_grad_();

In [9]:
stride = 2
kernel_size = 2

conv1 = Conv2d(3,5, kernel_size, stride=stride, padding=0, dilation=1)
# conv1.weight=f
relu1 = Relu()
conv2 = Conv2d(5,5,kernel_size, stride=stride, padding=0, dilation=1)
relu2 = Relu()
tconv3 = ConvTranspose2d(5,5, kernel_size, stride=stride, padding=0, dilation=1)
relu3 = Relu()
tconv4 = ConvTranspose2d(5,3, kernel_size, stride=stride, padding=0, dilation=1)
sig4 = Sigmoid()

net = Sequential(conv1,  
                relu1,  
                conv2, 
                relu2, 
                tconv3,
                relu3, 
                tconv4,
                sig4
                )

In [10]:
idx = 164
len = 1
a=valid_input[[idx]].float()/255.
target=valid_target[idx].float()/255.
# plt.imshow(a[0].permute(1,2,0))

In [11]:
out=net(a)
# plt.imshow(out[0].permute(1,2,0)[:,:,:])
# plt.colorbar()

In [12]:
loss = MSELoss()
l_val = loss(out, target)
print(l_val)

net.backward(loss.backward())

0.1814565807580948
<Miniproject_2.model.Conv2d object at 0x7fb69cde6940>


AssertionError: 

## Test ```self.parameters```

In [None]:
stride = 2
kernel_size = 2

conv1 = Conv2d(3,5, kernel_size, stride=stride, padding=0, dilation=1)

In [None]:
net.param()


[[tensor([[[[-2.0634,  0.9640],
            [ 0.2477,  3.2047]],
  
           [[-0.9109,  0.2193],
            [-1.1209, -0.6504]],
  
           [[-0.6069,  0.8905],
            [ 0.3814,  0.1239]]],
  
  
          [[[ 1.1259,  1.1637],
            [ 1.7956, -2.3122]],
  
           [[ 0.5978, -2.4755],
            [-0.6630, -0.8699]],
  
           [[-0.6026, -0.5721],
            [-0.2921, -0.3357]]],
  
  
          [[[ 0.9216, -1.4277],
            [ 2.2568,  0.2003]],
  
           [[ 2.3050,  2.1170],
            [-0.2490, -1.2836]],
  
           [[ 0.8843, -0.5604],
            [-1.0172, -1.8599]]],
  
  
          [[[ 0.5274, -0.2988],
            [ 1.0247, -0.8511]],
  
           [[ 0.3414, -0.1291],
            [ 0.2659,  0.7151]],
  
           [[-1.6160,  0.5138],
            [ 1.5900,  1.9041]]],
  
  
          [[[-0.2093,  0.3122],
            [-2.0011, -1.5988]],
  
           [[-0.8707,  1.8550],
            [ 0.0707, -0.4577]],
  
           [[-0.2838, -0.5247],
