In [None]:
import torchvision
import torchvision.transforms as transforms
import torch
from torch import nn
import torchinfo
import matplotlib as plt



In [None]:
dataset = torchvision.datasets.Omniglot(
    root="./data", download=True, transform=torchvision.transforms.ToTensor()
)

image, label = dataset[0]
print(type(image))  # torch.Tensor
print(type(label))  # int
print(image[0].size())
print(label)

In [None]:
image_size = 28

train_set = torchvision.datasets.Omniglot(
    root="./data",
    background=True,
    transform=transforms.Compose(
        [
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize([int(image_size), int(image_size)]),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)
test_set = torchvision.datasets.Omniglot(
    root="./data",
    background=False,
    transform=transforms.Compose(
        [
            # Omniglot images have 1 channel, but our model will expect 3-channel images
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize([int(image_size), int(image_size)]),
            transforms.ToTensor(),
        ]
    ),
    download=True,
)

In [None]:
print(type(train_set))

In [None]:

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample):
        super().__init__()
        if downsample:
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            self.shortcut = nn.Sequential()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, input):
        shortcut = self.shortcut(input)
        input = nn.ReLU()(self.bn1(self.conv1(input)))
        input = nn.ReLU()(self.bn2(self.conv2(input)))
        input = input + shortcut
        return nn.ReLU()(input)

In [None]:
class MiniResNet(nn.Module):
    def __init__(self, in_channels, resblock, outputs=1623):
        super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 8, kernel_size=3, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU()
        )

        self.layer1 = nn.Sequential(
            resblock(8, 8, downsample=False),
            resblock(8, 8, downsample=False)
        )

        self.layer2 = nn.Sequential(
            resblock(8, 16, downsample=True),
            resblock(16, 16, downsample=False)
        )

        self.layer3 = nn.Sequential(
            resblock(16, 32, downsample=True),
            resblock(32, 32, downsample=False)
        )

        self.gap = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = torch.nn.Linear(32, outputs)

    def forward(self, input):
        input = self.layer0(input)
        input = self.layer1(input)
        input = self.layer2(input)
        input = self.layer3(input)
        input = self.gap(input)
        input = torch.flatten(input)
        input = self.fc(input)

        return input

In [None]:


mini_resnet = MiniResNet(1, ResBlock, outputs=1623)
mini_resnet.to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))


In [None]:
batch_size = 16
print(torchinfo.summary(mini_resnet))