In [2]:
import torch
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

In [25]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels=3, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, 64, num_blocks[0])
        self.layer2 = self._make_layer(block, 64, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 128, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 256, 512, num_blocks[3], stride=2)
        self.avg = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_layer = nn.Linear(512, num_classes)

    def _make_layer(self, block, in_channels, out_channels, num_blocks, stride=1):
        layers = []

        for _ in range(num_blocks):
            layer = block(in_channels, out_channels, stride=stride)
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avg(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layer(x)

        return x

In [4]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ConvBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU()

        self.shortcut = nn.Sequential()
        if stride != 1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return self.relu(out)

In [17]:
def ResNet18():
    return ResNet(ConvBlock, [2, 2, 2, 2])