In [1]:
import sys
sys.path.append('../')
import matplotlib.pyplot as plt

from Miniproject_2.model import *

import torch
from torch.nn import functional as F

torch.set_grad_enabled(True);

In [2]:
def conv2d(X, K, stride=1, padding=0, dilation=1):
    N, _, H, W = X.shape
    out_channels, in_channels, h, w = K.shape
    assert X.shape[1] == in_channels
    assert w == h
    
    h_out = int((H + 2*padding - dilation*(h-1)-1)/stride+1)
    w_out = int((W + 2*padding - dilation*(w-1)-1)/stride+1)

    Xprime  = unfold(X, kernel_size=h, padding=padding, dilation=dilation, stride=stride)
    cΠks, L = Xprime.shape[1], Xprime.shape[2]
    Xprime  = torch.transpose(Xprime, 1, 2).reshape(-1, cΠks)

    Kprime  = K.reshape(out_channels, cΠks)
    
    Yprime = Xprime @ Kprime.t()    
    Y = Yprime.reshape(N, L, out_channels).transpose_(1, 2)
    Y = fold(Y, output_size=[h_out, w_out], kernel_size=1, padding=0, dilation=1, stride=1)
    return Y


def conv_transpose2d(Y, K, stride=1, padding=0, dilation=1):
    N, _, H, W = Y.shape
    out_channels, in_channels, h, w = K.shape
    assert Y.shape[1] == out_channels
    assert w == h
    
    h_out = (H-1)*stride - 2*padding + dilation*(h-1) + 1
    w_out = (W-1)*stride - 2*padding + dilation*(w-1) + 1

    Yprime = Y.flatten(-2,-1)
    Yprime = Yprime.transpose_(1, 2).flatten(0,1)
    
    KT = K.flatten(1,-1)
    
    Xprime = Yprime @ KT
    Xprime = Xprime.reshape(N, -1, Xprime.shape[-1]).transpose(1,2)
    X = fold(Xprime, output_size=[h_out,w_out], kernel_size=h, padding=padding, dilation=dilation, stride=stride)
    return X


def augment(input, nzeros, padding=0):
    shape = input.shape
    nold  = shape[-1]
    nnew  = nold + (nold-1)*nzeros
    
    new = torch.zeros(*shape[:2], nnew, nnew)
    new[:,:,::(nzeros+1),::(nzeros+1)] = input
                
    if padding: new = unfold(new,1, padding=padding).reshape(*new.shape[:2],*[new.shape[-1]+2*padding]*2)
    return new


def weight_backward(input, dL_dy, weight, ignored, stride=1):
    dL_df = torch.zeros_like(weight.transpose(0,1))
    dL_dy_aug = augment(dL_dy, nzeros=stride-1, padding=0)
    
    x = input if not ignored else input[:,:,:-ignored, :-ignored]
    
    for mu in range(x.shape[0]):
        for alpha in range(x.shape[1]):
            dLdy = dL_dy_aug[[mu]].transpose(0,1)
            xx   = x[mu,alpha].view(1,1,*x.shape[2:])
            dL_df[alpha] += conv2d(xx, dLdy)[0]

    dL_df.transpose_(0,1)
    return dL_df




class Conv2d():
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        self.in_channels  = in_channels
        self.out_channels = out_channels
        self.kernel_size  = kernel_size
        self.stride   = stride
        self.padding  = padding
        self.dilation = dilation

        self.weight   = torch.Tensor(out_channels, in_channels, kernel_size, kernel_size)
        
    def forward(self, input):
        return conv2d(input, self.weight, stride=self.stride, padding=self.padding, dilation=self.dilation)
    
    __call__ = forward
    
    def backward(self, input, dL_dy):
        dL_dx = conv_transpose2d(dL_dy, self.weight, stride=self.stride, padding=self.padding, dilation=self.dilation)
        
        ignored = int(input.shape[-1]-dL_dx.shape[-1])
        if ignored: dL_dx = pad(dL_dx, (0,ignored,0,ignored))
            #dL_dx = unfold(dL_dx, 1, padding=ignored).reshape(*dL_dx.shape[:2],*[dL_dx.shape[-1]+2*ignored]*2)
            #dL_dx = dL_dx[:,:,ignored:, ignored:]

        dL_df = weight_backward(input, dL_dy, self.weight, ignored, stride=self.stride)
        return dL_dx, dL_df
    
    
class ConvTranspose2d():
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        self.in_channels  = in_channels
        self.out_channels = out_channels
        self.kernel_size  = kernel_size
        self.stride   = stride
        self.padding  = padding
        self.dilation = dilation

        self.weight   = torch.Tensor(in_channels, out_channels, kernel_size, kernel_size)
        
    def forward(self, input):
        return conv_transpose2d(input, self.weight, stride=self.stride,\
                                padding=self.padding, dilation=self.dilation)
    
    __call__ = forward
    
    def backward(self, input, dL_dy):
        p = self.kernel_size-1-self.padding
        z = self.stride-1
        
        eff_input  = augment(input, nzeros=z, padding=p)
        eff_weight = self.weight.flip(2,3).transpose(0,1)
        dL_dx, dL_df = conv_backward(eff_input, dL_dy, eff_weight, stride=1, padding=0, dilation=1)
        
        dL_df = dL_df.flip(2,3).transpose(0,1)
        return dL_dx[:,:,p:-p:z+1, p:-p:z+1], dL_df
    
    
class Sigmoid(Module):
    def __init__(self,*input):
        super(Sigmoid,self).__init__()
        self.input = None
        return 

    def forward(self,input):
        self.input = input
        return torch.sigmoid(input)
    
    __call__ = forward

    def backward(self, dL_dy):
        x = self.input
        dsigma_dx = torch.sigmoid(x)*(1.-torch.sigmoid(x))
        return dL_dy*dsigma_dx

In [3]:
def compare(x,y,decimals=7):
    return torch.all(torch.round(torch.abs(x - y), decimals=decimals)==0.).item()

## Test Forward Conv2d

In [16]:
images = torch.load('../train_data.pkl')
a = images[0][:10].float()
a.shape

torch.Size([10, 3, 32, 32])

In [17]:
f = torch.empty(5,3,3,3)

f[0, 0] = torch.tensor([ [ +0., +0., -1. ], [ +0., +1., +0. ], [ -1., +0., +0. ]])
f[1, 0] = torch.tensor([ [ +1., +1., +1. ], [ +1., +1., +1. ], [ +1., +1., +1. ]])
f[2, 0] = torch.tensor([ [ -1., +0., +1. ], [ -1., +0., +1. ], [ -1., +0., +1. ]])
f[3, 0] = torch.tensor([ [ -1., -1., -1. ], [ +0., +0., +0. ], [ +1., +1., +1. ]])
f[4, 0] = torch.tensor([ [ +0., -1., +0. ], [ -1., +4., -1. ], [ +0., -1., +0. ]])

for j in range(0,5):
    for i in range(1,3):
        f[j,i] = f[j,0]
        
f.requires_grad_();

In [18]:
#convolution with torch.functionals
conv_a = F.conv2d(a, f, stride=1, padding=0, dilation=1)
conv_a.shape

torch.Size([10, 5, 30, 30])

In [19]:
#Implemented convolution
myconv_a = conv2d(a, f, stride=1, padding=0, dilation=1)

In [20]:
torch.all(conv_a==myconv_a)

tensor(True)

## Test transposed convolution

In [21]:
valid_input, valid_target = torch.load('../val_data.pkl')

select=10
x      = (valid_input[:select].float()).requires_grad_()
xtrue  = (valid_target[:select].float())

stride = 2
padding= 1
dilation=2 

In [22]:
f = torch.empty(5,3,3,3)

f[0, 0] = torch.tensor([ [ +0., +0., -1. ], [ +0., +1., +0. ], [ -1., +0., +0. ]])
f[1, 0] = torch.tensor([ [ +1., +1., +1. ], [ +1., +1., +1. ], [ +1., +1., +1. ]])
f[2, 0] = torch.tensor([ [ -1., +0., +1. ], [ -1., +0., +1. ], [ -1., +0., +1. ]])
f[3, 0] = torch.tensor([ [ -1., -1., -1. ], [ +0., +0., +0. ], [ +1., +1., +1. ]])
f[4, 0] = torch.tensor([ [ +0., -1., +0. ], [ -1., +4., -1. ], [ +0., -1., +0. ]])

for j in range(0,5):
    for i in range(1,3):
        f[j,i] = f[j,0]
        
f.requires_grad_();

In [23]:
dd   = F.conv_transpose2d(x, f.transpose(0,1), stride=stride, padding=padding, dilation=dilation)
mydd = conv_transpose2d(x.detach(), f.transpose(0,1), stride=stride, padding=padding, dilation=dilation)

In [24]:
torch.all(mydd==dd)

tensor(True)

## Test derivative MSE loss

In [26]:
def dL(y,ytrue):
    return 2 * (y-ytrue) / y.shape.numel()

In [27]:
valid_input, valid_target = torch.load('../val_data.pkl')

select=10
y     = (valid_input[:select].double()/255.).requires_grad_()
ytrue = (valid_target[:select].double()/255.).requires_grad_()

y.shape

torch.Size([10, 3, 32, 32])

In [28]:
L = F.mse_loss(y, ytrue)
dL_dy = torch.autograd.grad(L, (y))[0]

mydL_dy = dL(y,ytrue)

In [29]:
compare(dL_dy, mydL_dy, decimals=14)

True

## Test backward Conv2d

In [30]:
valid_input, valid_target = torch.load('../val_data.pkl')

select=10
x     = (valid_input[:select].float()).requires_grad_()
xtrue = (valid_target[:select].float())

In [31]:
f = torch.empty(5,3,3,3)

f[0, 0] = torch.tensor([ [ +0., +0., -1. ], [ +0., +1., +0. ], [ -1., +0., +0. ]])
f[1, 0] = torch.tensor([ [ +1., +1., +1. ], [ +1., +1., +1. ], [ +1., +1., +1. ]])
f[2, 0] = torch.tensor([ [ -1., +0., +1. ], [ -1., +0., +1. ], [ -1., +0., +1. ]])
f[3, 0] = torch.tensor([ [ -1., -1., -1. ], [ +0., +0., +0. ], [ +1., +1., +1. ]])
f[4, 0] = torch.tensor([ [ +0., -1., +0. ], [ -1., +4., -1. ], [ +0., -1., +0. ]])

for j in range(0,5):
    for i in range(1,3):
        f[j,i] = f[j,0]
        
f.requires_grad_();

In [32]:
stride = 1
y = F.conv2d(x, f, stride=stride)

with torch.no_grad():
    ytrue = F.conv2d(xtrue, f, stride=stride)

L = F.mse_loss(y,ytrue)
dL_dy, dL_dx, dL_df = torch.autograd.grad(L, (y,x,f))

print("Shapes")
print("x : \t", tuple(x.shape))
print("y : \t", tuple(y.shape))
print("f : \t", tuple(f.shape))
print("dL_dy : ", tuple(dL_dy.shape))
print("dL_dx : ", tuple(dL_dx.shape))
print("dL_df : ", tuple(dL_df.shape))

Shapes
x : 	 (10, 3, 32, 32)
y : 	 (10, 5, 30, 30)
f : 	 (5, 3, 3, 3)
dL_dy :  (10, 5, 30, 30)
dL_dx :  (10, 3, 32, 32)
dL_df :  (5, 3, 3, 3)


In [74]:
myconv = Conv2d(3 ,5, 3, stride=stride, padding=0, dilation=1)
myconv.weight = f

In [75]:
mydL_dx, mydL_df = myconv.backward(x, dL_dy)

print("Shapes")
print("mydL_dx : ",mydL_dx.shape)
print("mydL_df : ",mydL_df.shape)

Shapes
mydL_dx :  torch.Size([10, 3, 32, 32])
mydL_df :  torch.Size([5, 3, 3, 3])


In [76]:
compare(dL_dx, mydL_dx, decimals=6)

True

In [77]:
compare(dL_df, mydL_df, decimals=3)

True

## Test transposed convolution

In [21]:
dd   = F.conv_transpose2d(x, f.transpose(0,1), stride=2, padding=0)
mydd =   conv_transpose2d(x, f.transpose(0,1), stride=2, padding=0)

In [22]:
compare(dd,mydd,decimals=3)

True

## Test backward Transposed Conv2d

In [23]:
valid_input, valid_target = torch.load('../val_data.pkl')

select=10
x     = (valid_input[:select].float()).requires_grad_()
xtrue = (valid_target[:select].float())

In [24]:
f = torch.empty(5,3,3,3)

f[0, 0] = torch.tensor([ [ +0., +0., -1. ], [ +0., +1., +0. ], [ -1., +0., +0. ]])
f[1, 0] = torch.tensor([ [ +1., +1., +1. ], [ +1., +1., +1. ], [ +1., +1., +1. ]])
f[2, 0] = torch.tensor([ [ -1., +0., +1. ], [ -1., +0., +1. ], [ -1., +0., +1. ]])
f[3, 0] = torch.tensor([ [ -1., -1., -1. ], [ +0., +0., +0. ], [ +1., +1., +1. ]])
f[4, 0] = torch.tensor([ [ +0., -1., +0. ], [ -1., +4., -1. ], [ +0., -1., +0. ]])

for j in range(0,5):
    for i in range(1,3):
        f[j,i] = f[j,0]

ff = f.transpose(0,1)
ff.requires_grad_();

In [25]:
stride = 2
y = F.conv_transpose2d(x, ff, stride=stride)

with torch.no_grad():
    ytrue = F.conv_transpose2d(xtrue, ff, stride=stride)

L = F.mse_loss(y,ytrue)
dL_dy, dL_dx, dL_df = torch.autograd.grad(L, (y,x,ff))

print("Shapes")
print("x : \t", tuple(x.shape))
print("y : \t", tuple(y.shape))
print("f : \t", tuple(f.shape))
print("dL_dy : ", tuple(dL_dy.shape))
print("dL_dx : ", tuple(dL_dx.shape))
print("dL_df : ", tuple(dL_df.shape))

Shapes
x : 	 (10, 3, 32, 32)
y : 	 (10, 5, 65, 65)
f : 	 (5, 3, 3, 3)
dL_dy :  (10, 5, 65, 65)
dL_dx :  (10, 3, 32, 32)
dL_df :  (3, 5, 3, 3)


In [26]:
mytconv = ConvTranspose2d(3 ,5, 3, stride=stride, padding=0, dilation=1)
mytconv.weight = ff

In [27]:
mydL_dx, mydL_df = mytconv.backward(x, dL_dy)

print("Shapes")
print("mydL_dx : ",mydL_dx.shape)
print("mydL_df : ",mydL_df.shape)

Shapes
mydL_dx :  torch.Size([10, 3, 32, 32])
mydL_df :  torch.Size([3, 5, 3, 3])


In [28]:
compare(dL_dx, mydL_dx, decimals=6)

True

In [29]:
compare(dL_df, mydL_df, decimals=3)

True

## Test Sigmoid backward

In [30]:
select= 10
x     = (valid_input[:select].float()).requires_grad_()
xtrue = (valid_target[:select].float())

y     = torch.sigmoid(x)
ytrue = torch.sigmoid(xtrue)

L = F.mse_loss(y, ytrue)
dL_dy, dL_dx = torch.autograd.grad(L, (y,x))


print("Shapes")
print("x : \t", tuple(x.shape))
print("y : \t", tuple(y.shape))
print("dL_dy : ", tuple(dL_dy.shape))
print("dL_dx : ", tuple(dL_dx.shape))

Shapes
x : 	 (10, 3, 32, 32)
y : 	 (10, 3, 32, 32)
dL_dy :  (10, 3, 32, 32)
dL_dx :  (10, 3, 32, 32)


In [31]:
mysig   = Sigmoid()
myy     = mysig(x)
mydL_dx = mysig.backward(dL_dy)

print("Shapes")
print("myy :  \t ", tuple(myy.shape))
print("mydL_dx :" , tuple(mydL_dx.shape))

Shapes
myy :  	  (10, 3, 32, 32)
mydL_dx : (10, 3, 32, 32)


In [32]:
compare(dL_dx, mydL_dx, decimals=7)

True

## Test TransverseConv2d + Sigmoid backward

In [33]:
#DEFINE FILTER
f = torch.empty(5,3,3,3)

f[0, 0] = torch.tensor([ [ +0., +0., -1. ], [ +0., +1., +0. ], [ -1., +0., +0. ]])
f[1, 0] = torch.tensor([ [ +1., +1., +1. ], [ +1., +1., +1. ], [ +1., +1., +1. ]])
f[2, 0] = torch.tensor([ [ -1., +0., +1. ], [ -1., +0., +1. ], [ -1., +0., +1. ]])
f[3, 0] = torch.tensor([ [ -1., -1., -1. ], [ +0., +0., +0. ], [ +1., +1., +1. ]])
f[4, 0] = torch.tensor([ [ +0., -1., +0. ], [ -1., +4., -1. ], [ +0., -1., +0. ]])

for j in range(0,5):
    for i in range(1,3):
        f[j,i] = f[j,0]
        
ff = f.transpose(0,1)
ff.requires_grad_()

stride = 2

In [34]:
#DEFINE DERIVATIVES WITH AUTOGRAD
select= 10
x0    = (valid_input[:select].float()).requires_grad_()
xtrue = (valid_target[:select].float())

y0 = F.conv_transpose2d(x0, ff, stride=stride)
y1 = torch.sigmoid(y0)


with torch.no_grad():
    ytrue = F.conv_transpose2d(xtrue, ff, stride=stride)
    ytrue = torch.sigmoid(ytrue)
    
L = F.mse_loss(y1,ytrue)

dL_dy1, dL_dy0, dL_dx0, dL_df = torch.autograd.grad(L, (y1, y0, x0, ff))

print("Shapes")
print("x0 : \t", tuple(x0.shape))
print("y0=x1 : ", tuple(y0.shape))
print("y1 : \t", tuple(y1.shape))
print("f  : \t", tuple(ff.shape),"\n")
print("dL_dy1 :", tuple(dL_dy0.shape))
print("dL_dy0 :", tuple(dL_dy1.shape))
print("dL_dx0 :", tuple(dL_dx0.shape))
print("dL_df  :", tuple(dL_df.shape))

Shapes
x0 : 	 (10, 3, 32, 32)
y0=x1 :  (10, 5, 65, 65)
y1 : 	 (10, 5, 65, 65)
f  : 	 (3, 5, 3, 3) 

dL_dy1 : (10, 5, 65, 65)
dL_dy0 : (10, 5, 65, 65)
dL_dx0 : (10, 3, 32, 32)
dL_df  : (3, 5, 3, 3)


In [35]:
#DEFINE DERIVATIVES WO AUTOGRAD
mytconv = ConvTranspose2d(3 ,5, 3, stride=stride, padding=0, dilation=1)
mytconv.weight = ff

mysig = Sigmoid()

myy0 = mytconv(x0)
myy1 = mysig(myy0)

mydL_dy0 = mysig.backward(dL_dy1)
mydL_dx0, mydL_df = mytconv.backward(x0, mydL_dy0)

print("Shapes")
print("mydL_dy0 :", tuple(mydL_dy0.shape))
print("mydL_dx0 :", tuple(mydL_dx0.shape))
print("mydL_df  :", tuple(mydL_df.shape))

Shapes
mydL_dy0 : (10, 5, 65, 65)
mydL_dx0 : (10, 3, 32, 32)
mydL_df  : (3, 5, 3, 3)


In [36]:
compare(dL_dx0, mydL_dx0, decimals=4)

True

In [37]:
compare(dL_dy0, mydL_dy0, decimals=4)

True

In [38]:
compare(dL_df, mydL_df, decimals=4)

True

In [9]:
import torch.nn