In [1]:
from typing import List

import torch as t
from torch import nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from einops import reduce, rearrange

import PIL
from PIL import Image
from pathlib import Path
import json

import utils

In [2]:
class BatchNorm2d(nn.Module):
    running_mean: t.Tensor         # shape: (num_features,)
    running_var: t.Tensor          # shape: (num_features,)
    num_batches_tracked: t.Tensor  # shape: ()

    def __init__(self, num_features: int, eps=1e-05, momentum=0.1):
        '''Like nn.BatchNorm2d with track_running_stats=True and affine=True.

        Name the learnable affine parameters `weight` and `bias` in that order.
        '''
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.weight = nn.Parameter(t.ones(num_features))
        self.bias = nn.Parameter(t.zeros(num_features))
        
        self.register_buffer("running_mean", t.zeros(num_features))
        self.register_buffer("running_var", t.ones(num_features))
        self.register_buffer("num_batches_tracked", t.tensor(0))
        
    def forward(self, x: t.Tensor) -> t.Tensor:
        '''Normalize each channel.

        Compute the variance using `torch.var(x, unbiased=False)`
        Hint: you may also find it helpful to use the argument `keepdim`.

        x: shape (batch, channels, height, width)
        Return: shape (batch, channels, height, width)
        '''
        if self.training:
            self.num_batches_tracked += 1

            mean = t.mean(x, dim=(0, 2, 3), keepdim=True)
            var = t.var(x, dim=(0, 2, 3), unbiased=False, keepdim=True)

            self.running_mean = (1 - self.momentum) * self.running_mean + \
                                self.momentum * mean.squeeze()
            self.running_var = (1 - self.momentum) * self.running_var + \
                                self.momentum * var.squeeze()
        else:
            mean = rearrange(self.running_mean, "c -> 1 c 1 1")
            var = rearrange(self.running_var, "c -> 1 c 1 1")

        weight = rearrange(self.weight, "c -> 1 c 1 1")
        bias = rearrange(self.bias, "c -> 1 c 1 1")
        return ((x - mean) / t.sqrt(var + self.eps)) * weight + bias

    def extra_repr(self) -> str:
        return ", ".join(
            [f"{key}={getattr(self, key)}" for key in ["num_features", "eps", "momentum"]]
        )


# if __name__ == "__main__":
#     utils.test_batchnorm2d_module(BatchNorm2d)
#     utils.test_batchnorm2d_forward(BatchNorm2d)
#     utils.test_batchnorm2d_running_mean(BatchNorm2d)

In [3]:
class AveragePool(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, channels, height, width)
        Return: shape (batch, channels)
        '''
        return reduce(x, 'b c h w -> b c 1 1', 'mean')

# Testing
if __name__ == "__main__":
    for c in range(1, 10):
        x = t.rand(4, c, c+2, c+2)
        t.testing.assert_close(AveragePool()(x), nn.AdaptiveAvgPool2d((1, 1))(x))

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_feats: int, out_feats: int, first_stride=1):
        '''A single residual block with optional downsampling.

        For compatibility with the pretrained model, declare the left side branch first
        a `Sequential`.

        If first_stride is > 1, this means the optional (conv + bn) should be present on the
        right branch. Declare it second using another `Sequential`.
        '''
        super().__init__()

        if first_stride == 1:
            assert in_feats == out_feats, \
                "Invalid ResBlock: if first_stride==1, we require in_feats == out_feats"

        self.left = nn.Sequential(
            nn.Conv2d(
                in_feats, out_feats,
                kernel_size=3, stride=first_stride, padding=1, bias=False
            ),
            nn.BatchNorm2d(out_feats),
            nn.ReLU(),
            nn.Conv2d(out_feats, out_feats, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_feats)
        )

        if first_stride <= 1:
            self.right = nn.Identity()
        else:
            self.right = nn.Sequential(
                nn.Conv2d(in_feats, out_feats, kernel_size=1, stride=first_stride, bias=False),
                nn.BatchNorm2d(out_feats)
            )

        self.relu = nn.ReLU()

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''Compute the forward pass.

        x: shape (batch, in_feats, height, width)

        Return: shape (batch, out_feats, height / stride, width / stride)

        If no downsampling block is present, the addition should just add the left branch's output
        to the input.
        '''
        return self.relu(self.left(x) + self.right(x))

In [5]:
class BlockGroup(nn.Module):
    def __init__(self, n_blocks: int, in_feats: int, out_feats: int, first_stride=1):
        '''An n_blocks-long sequence of ResidualBlock where only the first block uses
        the provided stride.
        '''
        super().__init__()

        self.model = nn.Sequential(
            ResidualBlock(in_feats, out_feats, first_stride),
            *[ResidualBlock(out_feats, out_feats, 1) for _ in range(n_blocks-1)]
        )

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''Compute the forward pass.
        x: shape (batch, in_feats, height, width)

        Return: shape (batch, out_feats, height / first_stride, width / first_stride)
        '''
        return self.model(x)

In [6]:
class ResNet34(nn.Module):
    def __init__(
        self,
        n_blocks_per_group=[3, 4, 6, 3],
        out_features_per_group=[64, 128, 256, 512],
        first_strides_per_group=[1, 2, 2, 2],
        n_classes=1000,
    ):
        super().__init__()

        assert (
            len(n_blocks_per_group) == len(out_features_per_group) == len(first_strides_per_group)
        ), "BlockGroup params need to properly defined."

        in_feat = 64
        in_features_per_group = [in_feat] + out_features_per_group[:-1]
        zipped_params = zip(
            n_blocks_per_group,
            in_features_per_group,
            out_features_per_group,
            first_strides_per_group
        )

        self.model = nn.Sequential(
            nn.Conv2d(3, in_feat, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(in_feat),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            *[BlockGroup(*params) for params in zipped_params],
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(out_features_per_group[-1], n_classes)
        )

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, channels, height, width)

        Return: shape (batch, n_classes)
        '''
        return self.model(x)

In [7]:
my_resnet = ResNet34().eval()
pt_resnet = torchvision.models.resnet34(weights="DEFAULT").eval()
# Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to
# /home/jmsdao/.cache/torch/hub/checkpoints/resnet34-b627a593.pth

In [8]:
def copy_weights(
        myresnet: ResNet34,
        pretrained_resnet: torchvision.models.resnet.ResNet
    ) -> ResNet34:
    '''Copy over the weights of `pretrained_resnet` to your resnet.'''

    mydict = myresnet.state_dict().items()
    pretraineddict = pretrained_resnet.state_dict().items()

    # Check the number of params/buffers is correct
    assert len(mydict) == len(pretraineddict), \
        "Number of layers is wrong. Have you done the prev step correctly?"

    # Initialise an empty dictionary to store the correct key-value pairs
    state_dict_to_load = {}

    for (mykey, myvalue), (pretrainedkey, pretrainedvalue) in zip(mydict, pretraineddict):
        state_dict_to_load[mykey] = pretrainedvalue

    myresnet.load_state_dict(state_dict_to_load)

    return myresnet

my_resnet = copy_weights(my_resnet, pt_resnet)

In [9]:
IMAGE_FILENAMES = [
    "chimpanzee.jpg",
    "golden_retriever.jpg",
    "platypus.jpg",
    "frogs.jpg",
    "fireworks.jpg",
    "astronaut.jpg",
    "iguana.jpg",
    "volcano.jpg",
    "goofy.jpg",
    "dragonfly.jpg",
]

IMAGE_FOLDER = Path("./resnet_inputs")

images = [Image.open(IMAGE_FOLDER / filename) for filename in IMAGE_FILENAMES]

In [10]:
def prepare_data(images: List[Image.Image]) -> t.Tensor:
    '''
    Return: shape (batch=len(images), num_channels=3, height=224, width=224)
    '''
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return t.stack(tuple(map(tf, images)), dim=0)

prepared_images = prepare_data(images)

In [11]:
# def predict(model, images):
#     logits = model(images)
#     return logits.argmax(dim=1)

# my_preds = predict(my_resnet, prepared_images)
# pt_preds = predict(pt_resnet, prepared_images)

my_logits = my_resnet(prepared_images)
pt_logits = pt_resnet(prepared_images)

In [12]:
with open("imagenet_labels.json") as f:
    imagenet_labels = list(json.load(f).values())

In [15]:
t.testing.assert_close(my_logits, pt_logits)

In [13]:
k = 3
for i, filename in enumerate(IMAGE_FILENAMES):
    print('\nImage filename:', filename)

    print(f'  my_resnet top {k}:')
    for value, index in zip(*my_logits[i].topk(k)):
        print(f'    ({value:.2f}) {imagenet_labels[index]}')

    print(f'  pt_resnet top {k}:')
    for value, index in zip(*pt_logits[i].topk(k)):
        print(f'    ({value:.2f}) {imagenet_labels[index]}')


Image filename: chimpanzee.jpg
  my_resnet top 3:
    (19.29) chimpanzee, chimp, Pan troglodytes
    (15.70) siamang, Hylobates syndactylus, Symphalangus syndactylus
    (12.54) gorilla, Gorilla gorilla
  pt_resnet top 3:
    (19.29) chimpanzee, chimp, Pan troglodytes
    (15.70) siamang, Hylobates syndactylus, Symphalangus syndactylus
    (12.54) gorilla, Gorilla gorilla

Image filename: golden_retriever.jpg
  my_resnet top 3:
    (12.12) golden retriever
    (8.31) Newfoundland, Newfoundland dog
    (8.30) Pekinese, Pekingese, Peke
  pt_resnet top 3:
    (12.12) golden retriever
    (8.31) Newfoundland, Newfoundland dog
    (8.30) Pekinese, Pekingese, Peke

Image filename: platypus.jpg
  my_resnet top 3:
    (18.12) platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
    (11.81) electric ray, crampfish, numbfish, torpedo
    (11.17) stingray
  pt_resnet top 3:
    (18.12) platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynch