In [5]:
import torch as t
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import PIL
from PIL import Image
import json
from pathlib import Path
from typing import Union, Tuple, Callable, Optional
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import utils
import torch.nn as nn
from einops import rearrange, repeat

In [6]:
class BatchNorm2d(nn.Module):
    running_mean: t.Tensor         # shape: (num_features,)
    running_var: t.Tensor          # shape: (num_features,)
    num_batches_tracked: t.Tensor  # shape: ()

    def __init__(self, num_features: int, eps=1e-05, momentum=0.1):
        '''Like nn.BatchNorm2d with track_running_stats=True and affine=True.

        Name the learnable affine parameters `weight` and `bias` in that order.
        '''
        super().__init__()
        self.weight = t.nn.parameter.Parameter(t.ones(num_features))
        self.bias = t.nn.parameter.Parameter(t.zeros(num_features))
        self.eps=eps
        self.momentum=momentum
        self.num_features = num_features
        self.register_buffer("running_mean", t.zeros(num_features))
        self.register_buffer("running_var", t.ones(num_features))
        self.register_buffer("num_batches_tracked", t.tensor(0))

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''Normalize each channel.

        x.shape = (batch channel height width)


        Compute the variance using `torch.var(x, unbiased=False)`
        Hint: you may also find it helpful to use the argument `keepdim`.

        x: shape (batch, channels, height, width)
        Return: shape (batch, channels, height, width)
        '''
        if self.train:
            mean = t.mean(x, dim=(0, 2, 3), keepdim=True) # reducing keeps channel dim. result is shape 1 everywhere except 'channels'
            var = t.var(x, dim=(0, 2, 3), unbiased=False, keepdim=True)
            self.running_mean = (1 - self.momentum) * mean.squeeze() + self.momentum * self.running_mean
            self.running_var = (1 - self.momentum) * var.squeeze() + self.momentum * self.running_var
            self.num_batches_tracked += 1
        else:
            mean = rearrange(self.running_mean, "channels -> 1 channels 1 1")
            var = rearrange(self.running_var, "channels -> 1 channels 1 1")

        weight = rearrange(self.weight, "channels -> 1 channels 1 1")
        bias = rearrange(self.bias, "channels -> 1 channels 1 1")

        return ((x - mean) / t.sqrt(var + self.eps)) * weight + bias

    def extra_repr(self) -> str:
        pass
arr = t.randn(1,2,3,3)
bn = BatchNorm2d(2)
tbn = t.nn.BatchNorm2d(2)

MAIN = True
    
if MAIN:
    utils.test_batchnorm2d_module(BatchNorm2d)
    utils.test_batchnorm2d_forward(BatchNorm2d)
    utils.test_batchnorm2d_running_mean(BatchNorm2d)


All tests in `test_batchnorm2d_module` passed!
All tests in `test_batchnorm2d_forward` passed!
All tests in `test_batchnorm2d_running_mean` passed!


In [7]:
class AveragePool(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, channels, height, width)
        Return: shape (batch, channels)
        '''
        return x.mean(dim=(2,3))

In [34]:
from torch.nn import Conv2d, ReLU, Sequential


class ResidualBlock(nn.Module):
    def __init__(self, in_feats: int, out_feats: int, first_stride=1):
        '''A single residual block with optional downsampling.

        For compatibility with the pretrained model, declare the left side branch first using a `Sequential`.

        If first_stride is > 1, this means the optional (conv + bn) should be present on the right branch. Declare it second using another `Sequential`.
        '''
        super().__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.first_stride = first_stride
        self.left = Sequential(
            Conv2d(self.in_feats, self.out_feats, kernel_size=3, stride=self.first_stride, bias=False),
            BatchNorm2d(self.out_feats),
            ReLU(),
            Conv2d(self.out_feats, self.out_feats, kernel_size=3, padding=1, bias=False),
            BatchNorm2d(self.out_feats)
        )
        self.right = nn.Identity() if self.first_stride == 1 else Sequential(
            Conv2d(self.in_feats, self.out_feats, kernel_size=1, bias=False),
            BatchNorm2d(self.out_feats)
        )

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''Compute the forward pass.

        x: shape (batch, in_feats, height, width)

        Return: shape (batch, out_feats, height / first_stride, width / first_stride)
        '''
        return t.relu(self.left(x) + self.right(x))



In [28]:
class BlockGroup(nn.Module):
    def __init__(self, n_blocks: int, in_feats: int, out_feats: int, first_stride=1):
        '''A single residual block with optional downsampling.

        For compatibility with the pretrained model, declare the left side branch first using a `Sequential`.

        If first_stride is > 1, this means the optional (conv + bn) should be present on the right branch. Declare it second using another `Sequential`.
        '''
        super().__init__()
        first = [ResidualBlock(in_feats, out_feats, first_stride)]
        rest = [ResidualBlock(out_feats, out_feats)] * (n_blocks - 1)
        self.model = nn.Sequential(*(first + rest))

        

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''Compute the forward pass.

        x: shape (batch, in_feats, height, width)

        Return: shape (batch, out_feats, height / first_stride, width / first_stride)
        '''
        return self.model(x)

In [35]:

import torch.nn as nn

class ResNet34(nn.Module):
    def __init__(
        self,
        n_blocks_per_group=[3, 4, 6, 3],
        out_features_per_group=[64, 128, 256, 512],
        strides_per_group=[1, 2, 2, 2],
        n_classes=1000,
    ):
        super().__init__()
        self.n_blocks_per_group = n_blocks_per_group
        self.out_features_per_group = out_features_per_group
        self.strides_per_group = strides_per_group
        self.n_classes = n_classes

        first_in_features = 64
        in_features_per_group = [first_in_features] + out_features_per_group[:-1]

        blocks = [BlockGroup(n_blocks, in_feats, out_feats, first_stride) for n_blocks, in_feats, out_feats, first_stride in zip(
            n_blocks_per_group, in_features_per_group, out_features_per_group, strides_per_group
        )]

        self.model = Sequential(
            Conv2d(3, first_in_features, kernel_size=7, stride=2, padding=3, bias=False),
            BatchNorm2d(first_in_features),
            ReLU(),
            nn.MaxPool2d(3, 2),
            *blocks,
            AveragePool(),
            nn.Flatten(),
            nn.Linear(out_features_per_group[-1], 1000),
        )


    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, channels, height, width)

        Return: shape (batch, n_classes)
        '''
        return self.model(x)

pretrained = torchvision.models.resnet34(weights="DEFAULT")
myresnet = ResNet34()
# myresnet = copy_weights(myresnet, pretrained)

mydict = myresnet.state_dict()
pretraineddict = pretrained.state_dict()

import utils

utils.compare_my_resnet_to_pytorch(myresnet)

Unnamed: 0,their name,their shape,your name,your shape
0,conv1.weight,"(64, 3, 7, 7)",model.0.weight,"(64, 3, 7, 7)"
1,bn1.weight,"(64,)",model.1.weight,"(64,)"
2,bn1.bias,"(64,)",model.1.bias,"(64,)"
3,bn1.running_mean,"(64,)",model.1.running_mean,"(64,)"
4,bn1.running_var,"(64,)",model.1.running_var,"(64,)"
5,bn1.num_batches_tracked,(),model.1.num_batches_tracked,()
6,layer1.0.conv1.weight,"(64, 64, 3, 3)",model.4.model.0.left.0.weight,"(64, 64, 3, 3)"
7,layer1.0.bn1.weight,"(64,)",model.4.model.0.left.1.weight,"(64,)"
8,layer1.0.bn1.bias,"(64,)",model.4.model.0.left.1.bias,"(64,)"
9,layer1.0.bn1.running_mean,"(64,)",model.4.model.0.left.1.running_mean,"(64,)"


In [38]:
def copy_weights(myresnet: ResNet34, pretrained_resnet: torchvision.models.resnet.ResNet) -> ResNet34:
    '''Copy over the weights of `pretrained_resnet` to your resnet.'''

    mydict = myresnet.state_dict()
    pretraineddict = pretrained_resnet.state_dict()

    # Check the number of params/buffers is correct
    assert len(mydict) == len(pretraineddict), "Number of layers is wrong. Have you done the prev step correctly?"

    # Initialise an empty dictionary to store the correct key-value pairs
    state_dict_to_load = {}

    for (mykey, myvalue), (pretrainedkey, pretrainedvalue) in zip(mydict.items(), pretraineddict.items()):
        state_dict_to_load[mykey] = pretrainedvalue

    myresnet.load_state_dict(state_dict_to_load)

    return myresnet

pretrained = torchvision.models.resnet34(weights="DEFAULT")
myresnet = ResNet34()
myresnet = copy_weights(myresnet, pretrained)

In [12]:
print(myresnet)
print(pretrained)

ResNet34(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): BatchNorm2d()
    (2): ReLU()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): BlockGroup(
      (model): Sequential(
        (0): ResidualBlock(
          (left): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
            (1): BatchNorm2d()
            (2): ReLU()
            (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (4): BatchNorm2d()
          )
          (right): Identity()
        )
        (1): ResidualBlock(
          (left): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
            (1): BatchNorm2d()
            (2): ReLU()
            (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (4): BatchNorm2d()
          )
          (right): Identity()
        )
        (2): ResidualBlock(
  

Unnamed: 0,their name,their shape,your name,your shape
0,conv1.weight,"(64, 3, 7, 7)",model.0.weight,"(64, 3, 7, 7)"
1,bn1.weight,"(64,)",model.0.bias,"(64,)"
2,bn1.bias,"(64,)",model.1.weight,"(64,)"
3,bn1.running_mean,"(64,)",model.1.bias,"(64,)"
4,bn1.running_var,"(64,)",model.1.running_mean,"(64,)"
5,bn1.num_batches_tracked,(),model.1.running_var,"(64,)"
6,layer1.0.conv1.weight,"(64, 64, 3, 3)",model.1.num_batches_tracked,()
7,layer1.0.bn1.weight,"(64,)",model.4.model.0.left.0.weight,"(64, 64, 3, 3)"
8,layer1.0.bn1.bias,"(64,)",model.4.model.0.left.0.bias,"(64,)"
9,layer1.0.bn1.running_mean,"(64,)",model.4.model.0.left.1.weight,"(64,)"
