<a href="https://colab.research.google.com/github/ckkissane/resnet-34/blob/main/resnet_34.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import torch
from torch.nn import Module, Parameter
import math
import numpy as np

Implement building blocks

In [15]:
# implement Conv2d
def force_pair(v):
    return v if isinstance(v, tuple) else (v, v)

def conv2d(x, weights, stride=1, padding=0):
    sH, sW = force_pair(stride)
    pH, pW = force_pair(padding)
    B, iC, iH, iW = x.shape
    oC, _, kH, kW = weights.shape
    oH = (iH + 2*pH - kH) // sH + 1
    oW = (iW + 2*pW - kW) // sW + 1

    padded_x = torch.nn.functional.pad(x, [pW, pW, pH, pH])

    conv_size = (B, iC, oH, oW, kH, kW)
    bs, cs, hs, ws = padded_x.stride()
    conv_stride = (bs, cs, hs*sH, ws*sW, hs, ws)
    strided_x = torch.as_strided(padded_x, size=conv_size, stride=conv_stride)

    return torch.einsum('bcxyij,ocij->boxy', strided_x, weights)

class Conv2d(Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        kernel_size = force_pair(kernel_size)
        self.stride = force_pair(stride)
        self.padding = force_pair(padding)

        weight_size = (out_channels, in_channels, *kernel_size)
        fan_in = np.prod(weight_size[1:])
        self.weight = Parameter(torch.randn(weight_size) * math.sqrt(2 / fan_in))
    
    def forward(self, x):
        return conv2d(
            x,
            self.weight,
            stride=self.stride,
            padding=self.padding
        )

In [16]:
# implement BatchNorm2d
class BatchNorm2d(Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        self.weight = Parameter(torch.ones(num_features))
        self.bias = Parameter(torch.zeros(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0))

    def forward(self, x):
        ids = (0, 2, 3)
        if self.training:
            mean = x.mean(ids)
            var = x.var(ids, unbiased=False)
            a = self.momentum
            self.running_mean.data = (1 - a) * self.running_mean.data + a * mean
            self.running_var.data = (1 - a) * self.running_var.data + a * var
            self.num_batches_tracked.data += 1
        else:
            mean = self.running_mean
            var = self.running_var

        rs = lambda u : u.reshape(1, -1, 1, 1)
        return rs(self.weight) * (x - rs(mean)) / torch.sqrt(rs(var) + self.eps) + rs(self.bias)

In [17]:
# implement ReLU
def relu(tensor):
    tensor[tensor < 0] = 0
    return tensor

class ReLU(Module):
    def forward(self, x):
        return relu(x)

In [18]:
# implement MaxPool2d
def maxpool2d(x, kernel_size, stride=None, padding=0):
    if stride is None:
        stride = kernel_size
    B, iC, iH, iW = x.shape
    kH, kW = force_pair(kernel_size)
    sH, sW = force_pair(stride)
    pH, pW = force_pair(padding)
    oH = (iH + 2*pH - kH) // sH + 1
    oW = (iW + 2*pW - kW) // sW + 1

    padded_x = torch.functional.F.pad(x, [pW, pW, pH, pH], value=-float('inf'))

    conv_size = (B, iC, oH, oW, kH, kW)
    bs, cs, hs, ws = padded_x.stride()
    conv_stride = (bs, cs, hs*sH, ws*sW, hs, ws)
    strided_x = torch.as_strided(padded_x, size=conv_size, stride=conv_stride)

    return strided_x.amax((-2, -1))

class MaxPool2d(Module):
    def __init__(self, kernel_size, stride=None, padding=1):
        super().__init__()
        if stride is None:
            stride = kernel_size
        self.kernel_size = force_pair(kernel_size)
        self.stride = force_pair(stride)
        self.padding = force_pair(padding)
    
    def forward(self, x):
        return maxpool2d(
            x,
            self.kernel_size,
            stride=self.stride,
            padding=self.padding
        )

In [19]:
# implement AdaptiveAvgPool2d
def avg_pool2d(x, kernel_size, stride, padding=0):
    B, iC, iH, iW = x.shape
    kH, kW = force_pair(kernel_size)
    sH, sW = force_pair(stride)
    pH, pW = force_pair(padding)
    oH = (iH + 2*pH - kH) // sH + 1
    oW = (iW + 2*pW - kW) // sW + 1

    padded_x = torch.functional.F.pad(x, [pW, pW, pH, pH])

    conv_size = (B, iC, oH, oW, kH, kW)
    bs, cs, hs, ws = padded_x.stride()
    conv_stride = (bs, cs, hs*sH, ws*sW, hs, ws)
    strided_x = torch.as_strided(padded_x, size=conv_size, stride=conv_stride)

    return strided_x.mean((-2, -1))

def adaptive_avg_pool2d(x, output_size):
    input_size = torch.tensor(x.size()[-2:])
    output_size = torch.tensor(output_size)
    stride = input_size // output_size
    kernel_size = input_size - (output_size - 1) * stride
    return avg_pool2d(
        x, 
        kernel_size=tuple(kernel_size), 
        stride=tuple(stride)
    )

class AdaptiveAvgPool2d(Module):
    def __init__(self, output_size):
        super().__init__()
        self.output_size = output_size
    
    def forward(self, x):
        return adaptive_avg_pool2d(x, self.output_size)

In [20]:
# implement Flatten
class Flatten(Module):
    def __init__(self, start_dim=1, end_dim=-1):
        super().__init__()
        self.start_dim = start_dim
        self.end_dim = end_dim

    def forward(self, x):
        return x.flatten(self.start_dim, self.end_dim)

In [21]:
# implement Linear
class Linear(Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        weight_bound = 1 / np.sqrt(in_features)
        self.weight = Parameter(torch.FloatTensor(out_features, in_features).uniform_(-weight_bound, weight_bound))
        if bias:
            bias_bound = 1 / np.sqrt(in_features)
            self.bias = Parameter(torch.FloatTensor(out_features).uniform_(-bias_bound, bias_bound)) 
        else:
            self.bias = None

    def forward(self, x):
        x = torch.einsum('...j,kj->...k', x, self.weight)
        if self.bias is not None:
            x += self.bias
        return x

In [22]:
# implement Sequential
class Sequential(Module):
    def __init__(self, *args):
        super(Sequential, self).__init__()
        for idx, module in enumerate(args):
            self.add_module(str(idx), module)
    
    def __iter__(self):
        return iter(self._modules.values())
    
    def forward(self, x):
        for module in self:
            x = module(x)
        return x

Implement ResNet34

In [23]:
# implement ResidualBlock
class ResidualBlock(Module):
    def __init__(self, in_feats, out_feats, stride=1):
        super().__init__()
        self.net = Sequential(
            Conv2d(in_feats, out_feats, kernel_size=3, stride=stride, padding=1),
            BatchNorm2d(out_feats),
            ReLU(),
            Conv2d(out_feats, out_feats, kernel_size=3, padding=1),
            BatchNorm2d(out_feats),
        )
        self.downsample = Sequential(
            Conv2d(in_feats, out_feats, kernel_size=1, stride=stride),
            BatchNorm2d(out_feats)
        ) if stride != 1 else None

    def forward(self, x):
        y_out = self.net(x)
        x_out = x if self.downsample is None else self.downsample(x)
        out = relu(x_out + y_out)
        return out

In [24]:
# implement ResNet34 model
class ResNet34(Module):
    def __init__(self, n_outs=1000, n_blocks_per_n_feats=[3, 4, 6, 3]):
        super().__init__()
        in_feats0 = 64
        self.in_layers = Sequential(
            Conv2d(3, in_feats0, kernel_size=7, stride=2, padding=3),
            BatchNorm2d(in_feats0),
            ReLU(),
            MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        all_out_feats = [64, 128, 256, 512]
        all_in_feats = [in_feats0] + all_out_feats[:-1]
        strides = [1, 2, 2, 2]
        self.residual_layers = Sequential(
            *(
                Sequential(
                    ResidualBlock(in_feats, out_feats, stride),
                    *(ResidualBlock(out_feats, out_feats) for _ in range(num_blocks - 1))
                ) for in_feats, out_feats, stride, num_blocks in zip(all_in_feats, all_out_feats, strides, n_blocks_per_n_feats)
            )
        )

        self.out_layers = Sequential(
            AdaptiveAvgPool2d((1, 1)),
            Flatten(),
            Linear(in_features=512, out_features=n_outs)
        )

    def forward(self, x):
        x = self.in_layers(x)
        x = self.residual_layers(x)
        x = self.out_layers(x)
        return x

Training

In [25]:
# train model on CIFAR10 training data
import torchvision
from tqdm import tqdm
from torch import optim

device = "cuda" if torch.cuda.is_available() else "cpu"

train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.PILToTensor(), 
    torchvision.transforms.ConvertImageDtype(torch.float)
])

cifar10_train = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=train_transforms
)
trainloader = torch.utils.data.DataLoader(cifar10_train, batch_size=128, shuffle=True)

model = ResNet34(n_outs=10).to(device).train()
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

epochs = 20
for epoch in range(epochs):
    for i, (x, y) in enumerate(tqdm(trainloader)):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_hat = model(x)
        loss = loss_fn(y_hat, y)
        loss.backward()
        optimizer.step()
        if i % 500 == 0:
            print(f"epoch {epoch}, loss is {loss}")

Files already downloaded and verified


  import sys
  
  1%|          | 2/391 [00:00<01:22,  4.69it/s]

epoch 0, loss is 2.7241697311401367


100%|██████████| 391/391 [01:00<00:00,  6.47it/s]
  1%|          | 2/391 [00:00<01:05,  5.98it/s]

epoch 1, loss is 1.1102850437164307


100%|██████████| 391/391 [01:00<00:00,  6.42it/s]
  1%|          | 2/391 [00:00<01:06,  5.87it/s]

epoch 2, loss is 0.7511013150215149


100%|██████████| 391/391 [01:01<00:00,  6.39it/s]
  1%|          | 2/391 [00:00<01:07,  5.74it/s]

epoch 3, loss is 0.5957945585250854


100%|██████████| 391/391 [01:01<00:00,  6.39it/s]
  1%|          | 2/391 [00:00<01:05,  5.90it/s]

epoch 4, loss is 0.5541190505027771


100%|██████████| 391/391 [01:01<00:00,  6.39it/s]
  1%|          | 2/391 [00:00<01:07,  5.75it/s]

epoch 5, loss is 0.504300594329834


100%|██████████| 391/391 [01:00<00:00,  6.44it/s]
  1%|          | 2/391 [00:00<01:08,  5.66it/s]

epoch 6, loss is 0.4518352150917053


100%|██████████| 391/391 [01:01<00:00,  6.40it/s]
  1%|          | 2/391 [00:00<01:05,  5.90it/s]

epoch 7, loss is 0.3907870948314667


100%|██████████| 391/391 [01:00<00:00,  6.45it/s]
  1%|          | 2/391 [00:00<01:06,  5.84it/s]

epoch 8, loss is 0.2827386260032654


100%|██████████| 391/391 [01:00<00:00,  6.46it/s]
  1%|          | 2/391 [00:00<01:07,  5.76it/s]

epoch 9, loss is 0.2609924376010895


100%|██████████| 391/391 [01:01<00:00,  6.41it/s]
  1%|          | 2/391 [00:00<01:08,  5.67it/s]

epoch 10, loss is 0.208006352186203


100%|██████████| 391/391 [01:00<00:00,  6.45it/s]
  1%|          | 2/391 [00:00<01:07,  5.78it/s]

epoch 11, loss is 0.11383470892906189


100%|██████████| 391/391 [01:01<00:00,  6.40it/s]
  1%|          | 2/391 [00:00<01:06,  5.89it/s]

epoch 12, loss is 0.06702566146850586


100%|██████████| 391/391 [01:00<00:00,  6.45it/s]
  1%|          | 2/391 [00:00<01:07,  5.78it/s]

epoch 13, loss is 0.09093046188354492


100%|██████████| 391/391 [01:00<00:00,  6.45it/s]
  1%|          | 2/391 [00:00<01:05,  5.96it/s]

epoch 14, loss is 0.09256793558597565


100%|██████████| 391/391 [01:00<00:00,  6.41it/s]
  1%|          | 2/391 [00:00<01:07,  5.77it/s]

epoch 15, loss is 0.11451936513185501


100%|██████████| 391/391 [01:00<00:00,  6.42it/s]
  1%|          | 2/391 [00:00<01:06,  5.87it/s]

epoch 16, loss is 0.047571007162332535


100%|██████████| 391/391 [01:01<00:00,  6.41it/s]
  1%|          | 2/391 [00:00<01:06,  5.82it/s]

epoch 17, loss is 0.028243839740753174


100%|██████████| 391/391 [01:00<00:00,  6.46it/s]
  1%|          | 2/391 [00:00<01:05,  5.91it/s]

epoch 18, loss is 0.03648705407977104


100%|██████████| 391/391 [01:00<00:00,  6.48it/s]
  1%|          | 2/391 [00:00<01:07,  5.74it/s]

epoch 19, loss is 0.045246873050928116


100%|██████████| 391/391 [01:01<00:00,  6.39it/s]


Inference

In [26]:
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.PILToTensor(),
    torchvision.transforms.ConvertImageDtype(torch.float)
])
                                                  
cifar_test = torchvision.datasets.CIFAR10(
    "./data",
    transform=test_transforms,
    download=True,
    train=False
)
testloader = torch.utils.data.DataLoader(cifar_test, batch_size=128, shuffle=False, num_workers=2)

correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy on 10,000 test images: ', 100*(correct/total), '%')

Files already downloaded and verified


  import sys
  


Accuracy on 10,000 test images:  74.77000000000001 %
