In [1]:
import torch
from torch import nn
import sys
sys.path.insert(0, "..")
from helper import train, test
from data import get_dataloader
import torch.nn.functional as F
from time import time

In [2]:
device = "cuda" if torch.cuda.is_available() else "mps"

# ResNet
![ResNet](https://www.researchgate.net/publication/349646156/figure/fig4/AS:995806349897731@1614430143429/The-architecture-of-ResNet-50-vd-a-Stem-block-b-Stage1-Block1-c-Stage1-Block2.png)


![ResNEt](ResNet.png)
*Implemented a scalable version so can do ResNet-18, ResNet-50 and ResNet-152

In [3]:
class SmallBlock(nn.Module):

    def __init__(self, in_channels, out_channels) -> None:
        super(SmallBlock, self).__init__()
        self.stride = 1
        # implement a block with two layers and a residual connection
        if in_channels != out_channels:
            self.stride = 2
            self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=self.stride)
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
        )   
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        if (self.stride == 2):
            x = self.conv1x1(x)
        output = self.block(x)
        return self.relu(output + x)

class BottleneckBlock(nn.Module):

    def __init__(self, first_channel, in_channels, out_channels, reduce=False) -> None:
        super(BottleneckBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        stride = 2 if reduce else 1
        
        # implement a block with three layers and a residual connection
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=first_channel, out_channels=in_channels, kernel_size=1, stride=stride),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channels)
        )   
        self.conv1x1 = nn.Conv2d(first_channel, out_channels, kernel_size=1, stride=stride)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        output = self.block(x)
        x = self.conv1x1(x)
        return self.relu(output + x)


class ResNet(nn.Module):
    
    def __init__(self, bottleneck = True ,layers = [4, 4, 4, 4]):
        
        super(ResNet, self).__init__()

        if not bottleneck:
            resnet = [nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)]
            in_channel = 64
            out_channel = 64
            for layer in layers:
                for n in range(layer):
                    resnet.append(SmallBlock(in_channel, out_channel))
                out_channel = in_channel * 2
            
            resnet.append(nn.AvgPool2d(kernel_size=3, stride=1))
            self.resnet = nn.Sequential(*resnet)

            self.classifier = nn.Sequential(
                nn.Linear(in_features=2048, out_features=2048),
                nn.ReLU(),
                nn.Linear(in_features=2048, out_features=1024),
                nn.ReLU(),
                nn.Linear(in_features=1024, out_features=10)
            )
        else:
            # layers greater than 34
            resnet = [nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)]
            reduce=False
            in_channel = 64
            first_channel = 64
            out_channel = 256
            for layer in layers:
                for n in range(layer):
                    resnet.append(BottleneckBlock(first_channel ,in_channel, out_channel, reduce=reduce))
                    first_channel = out_channel
                    reduce=False
                reduce=True
                in_channel = in_channel * 2
                out_channel = out_channel * 2
            
            resnet.append(nn.AvgPool2d(kernel_size=3, stride=1))
            self.resnet = nn.Sequential(*resnet)

            self.classifier = nn.Sequential(
                nn.Linear(in_features=8192, out_features=4096),
                nn.ReLU(),
                nn.Linear(in_features=4096, out_features=2048),
                nn.ReLU(),
                nn.Linear(in_features=2048, out_features=10)
            )


    def forward(self, x):
        
        x = self.resnet(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        return logits
        


In [4]:
# get the data
train_dl, test_dl = get_dataloader("cifar", batch_size=128)
# Training the model
model = ResNet(bottleneck=True ,layers=[2,2,2,2]).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
 
epochs = 10

for t in range(epochs):
    start = time()
    print(f"Epoch {t+1}\n---")
    train(train_dl, model, loss_fn, optimizer, device)
    test(test_dl, model, loss_fn, device)
    print(f"Total time taken: {(time()-start):>0.1f} seconds")
    


Files already downloaded and verified
Files already downloaded and verified
ResNet(
  (resnet): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BottleneckBlock(
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
        (6): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv1x1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (2): BottleneckBlock(
      (block): Sequential(
        (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
       

KeyboardInterrupt: 