In [1]:
import random
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F

from matplotlib import pyplot as plt

In [2]:
EPS = 1e-6

In [3]:
class Color:
    PURPLE = '\033[95m'
    CYAN = '\033[96m'
    DARKCYAN = '\033[36m'
    BLUE = '\033[94m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    RED = '\033[91m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    END = '\033[0m'

In [4]:
def zero_pad(X, pad):
    assert len(X.shape) == 4, "X must be with shape [batch, h, w, c]"
    
    X_pad = np.pad(X, (
        (0, 0), 
        (pad, pad), 
        (pad, pad), 
        (0, 0)
    ), mode="constant", constant_values=(0, 0))

    return X_pad


class Conv2D:
    def __init__(self, prev_channels, next_channels, kernel=3, stride=1, pad=0, dtype=float):
        self._cache = None
        self.stride = stride
        self.pad = pad
        self.dtype = dtype 
        
        self.kernel = np.random.randn(kernel, kernel, prev_channels, next_channels).astype(self.dtype)
        self.b = np.zeros(next_channels, dtype=self.dtype)
    
    def forward(self, A_prev):
        # (m, n_H_prev, n_W_prev, n_C_prev) -> (m, n_H, n_W, n_C) by (k, k, n_C_prev, n_C)

        m, n_H_prev, n_W_prev, n_C_prev = A_prev.shape
        k, k, n_C_prev, n_C = self.kernel.shape

        n_H = (n_H_prev + 2*self.pad - k) // self.stride + 1
        n_W = (n_W_prev + 2*self.pad - k) // self.stride + 1

        output = np.zeros((m, n_H, n_W, n_C), dtype=A_prev.dtype)
        A_prev_pad = zero_pad(A_prev, pad=self.pad)
        
        for i in range(m):
            a_prev_pad = A_prev_pad[i]

            for h in range(n_H):
                h_start = h * self.stride 
                h_end = h_start + k

                for w in range(n_W):
                    w_start = w * self.stride
                    w_end = w_start + k

                    for c in range(n_C):
                        a_slice = a_prev_pad[h_start:h_end, w_start:w_end, :]
                        current_kernel = self.kernel[:, :, :, c]
                        current_bias = self.b[c]
                       
                        output[i, h, w, c] = np.sum(a_slice * current_kernel) + current_bias
        
        self.cache = A_prev
        return output
    
    def backward(self, dZ):
        A_prev = self.cache
        
        m, n_H_prev, n_W_prev, n_C_prev = A_prev.shape
        k, k, n_C_prev, n_C = self.kernel.shape
        m, n_H, n_W, n_C = dZ.shape 
        
        dA_prev = np.zeros((m, n_H_prev, n_W_prev, n_C_prev), dtype=self.dtype)
        dW = np.zeros((k, k, n_C_prev, n_C), dtype=self.dtype)
        db = np.zeros(n_C, dtype=self.dtype)
        
        A_prev_pad = zero_pad(A_prev, self.pad)
        dA_prev_pad = zero_pad(dA_prev, self.pad)
        
        for i in range(m):
            a_prev_pad = A_prev_pad[i]
            da_prev_pad = dA_prev_pad[i]
            
            for h in range(n_H):
                h_start = h * self.stride 
                h_end = h_start + k
                
                for w in range(n_W):
                    w_start = w * self.stride
                    w_end = w_start + k
                    
                    for c in range(n_C):
                        a_slice = a_prev_pad[h_start:h_end, w_start:w_end, :]
                        
                        da_prev_pad[h_start:h_end, w_start:w_end, :] += self.kernel[:, :, :, c] * dZ[i, h, w, c]
                        dW[:, :, :, c] += a_slice * dZ[i, h, w, c]
                        db[c] += dZ[i, h, w, c]
            if self.pad == 0:
                dA_prev[i, :, :, :] = da_prev_pad[:, :, :]
            else:
                dA_prev[i, :, :, :] = da_prev_pad[self.pad:-self.pad, self.pad:-self.pad, :]
            
        assert dA_prev.shape == (m, n_H_prev, n_W_prev, n_C_prev)
            
        return dA_prev, dW, db

In [5]:
def upsample(X, stride=2, pad=0):
    assert len(X.shape) == 4, "X should be with shape [batch, h, w, c]"
    
    m, n_H_prev, n_W_prev, c = X.shape
    
    n_H = stride * (n_H_prev - 1) + 1
    n_W = stride * (n_W_prev - 1) + 1
    
    upsampled = np.zeros((m, n_H, n_W, c), dtype=X.dtype)
    
    for h in range(n_H_prev):
        for w in range(n_W_prev):
            upsampled[:, h*stride, w*stride, :] = X[:, h, w, :]
    
    upsampled_pad = zero_pad(upsampled, pad=pad)
    
    return upsampled_pad


def downsample(X, stride, pad):
    if pad == 0:
        pad = None 

    res = X[:, pad:-pad:stride, pad:-pad:stride, :]
    return res


class Conv2DTranspose:
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0):
        self.cache = None 
        self.padding = padding
        self.stride = stride
        
        self.conv = Conv2D(in_channels, out_channels, kernel=kernel_size, stride=1, pad=0)
        
    def forward(self, A_prev):
        m, n_H_prev, n_W_prev, n_C_prev = A_prev.shape 
        k, k, n_C_prev, n_C = self.kernel.shape
        
        n_H = self.stride * (n_H_prev - 1) + k - 2 * self.padding
        n_W = self.stride * (n_W_prev - 1) + k - 2 * self.padding
        
        implicit_padding = k - self.padding - 1
        assert implicit_padding >= 0
        
        A_prev_upsampled = upsample(A_prev, stride=self.stride, pad=implicit_padding)
        out = self.conv.forward(A_prev_upsampled)
        
        assert out.shape == (m, n_H, n_W, n_C)
        
        self.cache = A_prev
        return out
    
    def backward(self, dZ):
        A_prev = self.cache 
        m, n_H_prev, n_W_prev, n_C_prev = A_prev.shape
        k, k, n_C_prev, n_C = self.kernel.shape
        m, n_H, n_W, n_C = dZ.shape
        
        dA_upsampled, dW, db = self.conv.backward(dZ)
        
        implicit_padding = k - self.padding - 1
        dA = downsample(dA_upsampled, self.stride, pad=implicit_padding)
        
        return dA, dW, db
    
    @property 
    def kernel(self):
        return self.conv.kernel
    
    @kernel.setter
    def kernel(self, weights):
        self.conv.kernel = weights
        
    @property 
    def b(self):
        return self.conv.b
    
    @kernel.setter
    def b(self, biases):
        self.conv.b = biases

# Tests

## Test zero_pad

In [6]:
m = 30
h = 40
w = 40
c = 3
pad = 1

X_batch = np.random.rand(m, h, w, c)
shape1 = X_batch.shape
X_pad = zero_pad(X_batch, pad=pad)
shape2 = X_pad.shape

print(shape1, shape2)
print(
    shape1[0] == shape2[0], 
    shape1[1]+pad*2 == shape2[1], 
    shape1[2]+2*pad == shape2[2], 
    shape1[3] == shape2[3]
)

(30, 40, 40, 3) (30, 42, 42, 3)
True True True True


In [7]:
# b = np.random.rand(1, 1, 1)
b = 1
b = np.array(b)
b.shape == (1, 1, 1) or b.shape == ()

True

## Test window_mmult

In [8]:
# input_slice = np.array([
#     [1, 0, 0],
#     [1, 0, 0],
#     [1, 0, 0]
# ], dtype=np.float32)

# kernel = np.array([
#     [-1, 0, 1],
#     [-1, 0, 1],
#     [-1, 0, 1]
# ], dtype=np.float32)

# # b = 1
# b = np.array([[[1]]], dtype=np.float32)

# kernel = kernel[..., np.newaxis]
# input_slice = input_slice[..., np.newaxis]

# print(kernel.shape)
# print(input_slice.shape)
# print(b.shape)

# conv2d = Conv2D(1, 1)

# z = conv2d._window_mult(input_slice, kernel, b)
# print(z)

## Conv2D forward 

In [9]:
m = 16
n_H_prev = 33
n_W_prev = 33
n_C_prev = 8
n_C = 16

kernel = 3
pad = 1
stride = 1


X_batch = np.random.rand(m, n_H_prev, n_W_prev, n_C_prev)
conv2d = Conv2D(n_C_prev, n_C, kernel=kernel, pad=pad, stride=stride)
out = conv2d.forward(X_batch)
print(out.shape)

del m 
del n_H_prev
del n_W_prev
del n_C_prev
del n_C
del kernel
del pad 
del stride

(16, 33, 33, 16)


## Comparing forward with torch

In [10]:
def compare_conv(
            m, 
            n_H_prev, 
            n_W_prev, 
            n_C_prev, 
            n_C, 
            kernel,
            pad, 
            stride
        ):
    
    input_data = np.random.rand(m, n_H_prev, n_W_prev, n_C_prev)
    
    # my realization 
    my_conv = Conv2D(n_C_prev, n_C, kernel=kernel, stride=stride, pad=pad)
    my_out = my_conv.forward(input_data)
    
    # torch_realizatoin 
    torch_input = torch.permute(torch.from_numpy(input_data), [0, 3, 1, 2]) # (m, n_C_prev, n_H_prev, n_W_prev)
    torch_weights = torch.permute(torch.from_numpy(my_conv.kernel), [3, 2, 0, 1]) # [k, k, channels, filters] -> [filters, channels, k, k]
    torch_biases = torch.from_numpy(my_conv.b)
    torch_out_ = F.conv2d(torch_input, torch_weights, bias=torch_biases, padding=pad, stride=stride)
    
    torch_out = torch.permute(torch_out_, [0, 2, 3, 1]).numpy() 
    
    mse = np.power(my_out - torch_out, 2).mean()
    
    return mse

res = compare_conv(
    m = 16,
    n_H_prev = 32,
    n_W_prev = 32,
    n_C_prev = 8,
    n_C = 16,

    kernel = 3,
    pad = 1,
    stride = 1,
)
print(res)

if res < EPS:
    print(Color.GREEN + Color.BOLD + "OK")
else:
    print(Color.RED + Color.BOLD + "Error")
    
del res

1.5816505494095304e-30
[92m[1mOK


## Conv2d backward

In [11]:
kernel = 3
pad = 1
stride = 1

m = 16
n_H_prev = 32
n_W_prev = 32
n_C_prev = 8
n_C = 16

n_H = (n_H_prev + 2*pad - kernel) // stride + 1
n_W = (n_W_prev + 2*pad - kernel) // stride + 1

input_data = np.random.rand(m, n_H_prev, n_W_prev, n_C_prev)
dZ = np.random.rand(m, n_H, n_W, n_C)

my_conv = Conv2D(n_C_prev, n_C, kernel=kernel, stride=stride, pad=pad)
conv_out = my_conv.forward(input_data)

dA, dW, db = my_conv.backward(dZ)
print(dA.shape)
print(dW.shape)
print(db.shape)


del kernel
del pad
del stride

del m
del n_H_prev
del n_W_prev
del n_C_prev
del n_C

(16, 32, 32, 8)
(3, 3, 8, 16)
(16,)


## Comparsion backward with torch

In [12]:
def compare_conv_back(
            kernel,
            pad,
            stride,
            m,
            n_H_prev,
            n_W_prev,
            n_C_prev,
            n_C,
        ):
    n_H = (n_H_prev + 2*pad - kernel) // stride + 1
    n_W = (n_W_prev + 2*pad - kernel) // stride + 1
    
    input_data = np.random.rand(m, n_H_prev, n_W_prev, n_C_prev)
    dZ = np.random.rand(m, n_H, n_W, n_C)
    
    weights = np.random.randn(kernel, kernel, n_C_prev, n_C)
    biases = np.zeros(n_C, dtype=float)
    
    # my realization 
    my_conv = Conv2D(n_C_prev, n_C, kernel=kernel, stride=stride, pad=pad)
    my_conv.kernel = weights.copy()
    my_conv.b = biases.copy()
    
    conv_out = my_conv.forward(input_data)

    dA, dW, db = my_conv.backward(dZ)
    
    # torch realization
    torch_input = torch.permute(torch.from_numpy(input_data), [0, 3, 1, 2])
    torch_dZ = torch.permute(torch.from_numpy(dZ), [0, 3, 1, 2])
    torch_weights = torch.permute(torch.from_numpy(weights.copy()), [3, 2, 0, 1])
    torch_biases = torch.from_numpy(biases.copy())
    
    torch_input.requires_grad = True
    torch_weights.requires_grad = True
    torch_biases.requires_grad = True
    torch_dZ.requires_grad = False
    
    torch_out = F.conv2d(torch_input, torch_weights, bias=torch_biases, padding=pad, stride=stride)
    # loss = torch.sum(torch_out * torch_dZ)
    
    # loss.backward()
    
    torch_out.backward(gradient=torch_dZ)
    
    dA_torch = torch_input.grad
    dW_torch = torch_weights.grad
    db_torch = torch_biases.grad
    
    dA_torch = torch.permute(dA_torch, [0, 2, 3, 1]).numpy() # [m, channels, h, w] -> [m, h, w, channels]
    dW_torch = torch.permute(dW_torch, [2, 3, 1, 0]).numpy() # [filters, channels, k, k] -> [k, k, channels, filters] 
    db_torch = db_torch.numpy()
    
    assert dA.shape == dA_torch.shape and dW.shape == dW_torch.shape and db.shape == db_torch.shape
    
    mse = lambda a, b: np.power(a - b, 2).mean()
    mse_dA = mse(dA, dA_torch)
    mse_dW = mse(dW, dW_torch)
    mse_db = mse(db, db_torch)
    
    return mse_dA, mse_dW, mse_db

In [13]:
np.random.seed(42)
random.seed(42)

mse_dA, mse_dW, mse_db = compare_conv_back(
    kernel = 3,
    pad = 1,
    stride = 1,

    m = 16,
    n_H_prev = 32,
    n_W_prev = 32,
    n_C_prev = 8,
    n_C = 16
)

if mse_dA < EPS and mse_dW < EPS and mse_db < EPS:
    print(Color.GREEN + Color.BOLD + "OK")
else:
    print(Color.RED + Color.BOLD + "ERROR")

[92m[1mOK


## Conv2D Transpose

In [14]:
arr = np.array([
    [2, 5],
    [4, 13]
])
kernel = np.array([
    [3, 1, 2],
    [2, 0, 0],
    [5, 4, 7]
])

batch = arr[np.newaxis, ..., np.newaxis]
weights = kernel[..., np.newaxis, np.newaxis]

convt = Conv2DTranspose(1, 1, kernel_size=3, stride=2, padding=1)
convt.kernel = weights

out = convt.forward(batch)
print(out.squeeze())


del convt
del out 
del weights 
del batch
del kernel 
del arr

 # [  0   4   0]
 # [ 18 127  57]
 # [  0   8   0]

[[  0   4   0]
 [ 18 127  57]
 [  0   8   0]]


## Conv2D Comparsion with Torch

In [15]:
def compare_convt(
            m,
            n_H_prev,
            n_W_prev,
            n_C_prev,
            n_C,
            kernel,
            pad,
            stride
        ):
    input_data = np.random.rand(m, n_H_prev, n_W_prev, n_C_prev)
    weights = np.random.rand(kernel, kernel, n_C_prev, n_C)
    biases = np.zeros(n_C)
    
    # my realizatoin 
    my_convt = Conv2DTranspose(n_C_prev, n_C, kernel_size=kernel, stride=stride, padding=pad)
    my_convt.kernel = weights
    my_convt.b = biases
    
    my_out = my_convt.forward(input_data)
    
    # torch realization 
    torch_input = torch.permute(torch.from_numpy(input_data), [0, 3, 1, 2]) # (m, n_C_prev, n_H_prev, n_W_prev)
    torch_weights = torch.permute(torch.from_numpy(weights), [2, 3, 0, 1]) # [k, k, channels, filters] -> [channels, filters, k, k]
    torch_biases = torch.from_numpy(biases)
    torch_out_ = F.conv_transpose2d(torch_input, torch_weights, bias=torch_biases, padding=pad, stride=stride)
    
    torch_out = torch.permute(torch_out_, [0, 2, 3, 1]).numpy() 
    
    mse = np.power(my_out - torch_out, 2).mean()
    return mse


np.random.seed(1)
random.seed(1)

res = compare_convt(
    m = 1,
    n_H_prev = 2,
    n_W_prev = 2,
    n_C_prev = 1,
    n_C = 1,
    kernel = 3,
    pad = 0,
    stride = 1
)
print(res)

0.05124965867449682


## Conv2D Transpose backward 

In [16]:
kernel = 3
pad = 1
stride = 1

m = 4
n_H_prev = 16
n_W_prev = 16
n_C_prev = 32
n_C = 8

n_H = stride * (n_H_prev - 1) + kernel - 2 * pad
n_W = stride * (n_W_prev - 1) + kernel - 2 * pad

input_data = np.random.rand(m, n_H_prev, n_W_prev, n_C_prev)
dZ = np.random.rand(m, n_H, n_W, n_C)

convt = Conv2DTranspose(n_C_prev, n_C, kernel_size=kernel, stride=stride, padding=pad)
fout = convt.forward(input_data)
dA, dW, dz = convt.backward(dZ)

print(dA.shape)
print(dW.shape)
print(db.shape)

(4, 16, 16, 32)
(3, 3, 32, 8)
(16,)


## Comparsion Conv2D transpose with torch

In [17]:
def compare_convt_back(
            kernel,
            pad,
            stride,
            m,
            n_H_prev,
            n_W_prev,
            n_C_prev,
            n_C,
        ):
    n_H = stride * (n_H_prev - 1) + kernel - 2 * pad
    n_W = stride * (n_W_prev - 1) + kernel - 2 * pad
    
    input_data = np.random.rand(m, n_H_prev, n_W_prev, n_C_prev)
    dZ = np.random.randn(m, n_H, n_W, n_C)
    
    weights = np.random.randn(kernel, kernel, n_C_prev, n_C)
    biases = np.random.rand(n_C)
    
    # my realization 
    convt = Conv2DTranspose(n_C_prev, n_C, kernel_size=kernel, stride=stride, padding=pad)
    convt.kernel = weights.copy()
    convt.biases = biases.copy()
    
    my_out = convt.forward(input_data)
    dA, dW, db = convt.backward(dZ)
    
    # torch realization 
    torch_input = torch.permute(torch.from_numpy(input_data), [0, 3, 1, 2]) # (batch, h, w, in_channels) -> (batch, in_channels, h, w)
    torch_dZ = torch.permute(torch.from_numpy(dZ), [0, 3, 1, 2]) # (batch, h, w, out_channels) -> (batch, out_channels, h, w)
    torch_weights = torch.permute(torch.from_numpy(weights.copy()), [2, 3, 0, 1]) # (k, k, in_channels, out_channels) -> (in_channels, out_channels, k, k)
    torch_biases = torch.from_numpy(biases.copy())
    
    torch_input.requires_grad = True
    torch_weights.requires_grad = True
    torch_biases.requires_grad = True
    torch_dZ.requires_grad = False
    
    torch_out = F.conv_transpose2d(torch_input, torch_weights, bias=torch_biases, padding=pad, stride=stride)
    torch_out.backward(gradient=torch_dZ)
    
    dA_torch = torch_input.grad
    dW_torch = torch_weights.grad
    db_torch = torch_biases.grad
    
    dA_torch = torch.permute(dA_torch, [0, 2, 3, 1]).numpy() # (batch, in_channels, h, w) -> (batch, h, w, in_channels)
    dW_torch = torch.permute(dW_torch, [2, 3, 0, 1]).numpy() # (in_channels, out_channels, k, k) -> (k, k, in_channels, out_channels) 
    db_torch = db_torch.numpy()
    
    assert dA.shape == dA_torch.shape and dW.shape == dW_torch.shape and db.shape == db_torch.shape
    
    mse = lambda a, b: np.power(a - b, 2).mean()
    mse_dA = mse(dA, dA_torch)
    mse_dW = mse(dW, dW_torch)
    mse_db = mse(db, db_torch)
    
    return mse_dA, mse_dW, mse_db
    
    

mse_dA, mse_dW, mse_db = compare_convt_back(
    m = 1,
    n_H_prev = 16,
    n_W_prev = 16,
    n_C_prev = 7,
    n_C = 5,
    kernel = 3,
    pad = 1,
    stride = 2
)

print(mse_dA)
print(mse_dW)
print(mse_db)

90.80346068783682
37.201668270385746
4.13994203059987e-28


## Experiments