Initial CNN archtecture

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay

In [12]:
# initial CNN model placeholder
class simple_cnn(nn.Module):
    def __init__(self):
        super().init()
        self.c1 = nn.Conv2d(1, 32, kernel_Size=3, stride=1, padding=1)
        self.p1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.c2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.p2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.c3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.p3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.act = nn.ReLU()
        self.avg = nn.AdaptiveAvgPool2d((1, 1))



    def forward(self, x):
        x = self.act(self.c1(x))
        x = self.p1(x)
        x = self.act(self.c2(x))
        x = self.p2(x)  
        x = self.act(self.c3(x))
        x = self.p3(x)
        x = self.avg(x)
        return x.view(x.size(0), -1)  # Flatten the output
        



In [None]:
# Data augmentation placeholder 
class contrastive_transforms:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.RandomResizedCrop(32),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))   # need to normalise custom dataset to have mean 0 and std 1
        ])

    def __call__(self, x):   # calls the funtion when an instance of the class is called
        return self.transform(x), self.transform(x)

In [13]:
# Initial SimCLR model placeholder
class SimCLR(nn.Module):
    def __init__(self, base_encoder, out_dim=64):
        super().__init__()
        self.encoder = base_encoder
        self.projector = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_dim)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        return z