In [None]:
!mkdir Data
!mkdir Models
!mkdir Media
!mkdir Media/SSL
!mkdir Models/SSL

# Imports

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.datasets
from torchsummary import summary
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import f1_score
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset
import random
import itertools

# CNN Class

In [None]:
# Network class definition
class Net(nn.Module):
    def __init__(self, num_classes, size):

        # list of the convolutional layers parameters
        convLayerNumber = 7
        kernels = [9, 7, 7, 5, 5, 3, 3]
        paddings = [1, 1, 1, 1, 1, 1, 1]
        poolingsStride = [2, 0, 2, 0, 2, 0, 2]
        poolingsKernels = [2, 0, 2, 0, 2, 0, 2]
        filters = [8, 16, 16, 32, 64, 64, 128]

        super(Net, self).__init__()
        
        # conv layer
        self.conv1 = nn.Conv2d(3, filters[0], kernel_size=kernels[0], padding=paddings[0])
        self.pool1 = nn.MaxPool2d(kernel_size=poolingsKernels[0], stride=poolingsStride[0])

        self.conv2 = nn.Conv2d(filters[0], filters[1], kernel_size=kernels[1], padding=paddings[1])
        self.conv3 = nn.Conv2d(filters[1], filters[2], kernel_size=kernels[2], padding=paddings[2])
        self.pool2 = nn.MaxPool2d(kernel_size=poolingsKernels[2], stride=poolingsStride[2])

        self.conv4 = nn.Conv2d(filters[2], filters[3], kernel_size=kernels[3], padding=paddings[3])
        self.conv5 = nn.Conv2d(filters[3], filters[4], kernel_size=kernels[4], padding=paddings[4])
        self.pool3 = nn.MaxPool2d(kernel_size=poolingsKernels[4], stride=poolingsStride[4])

        self.conv6 = nn.Conv2d(filters[4], filters[5], kernel_size=kernels[5], padding=paddings[5])
        self.conv7 = nn.Conv2d(filters[5], filters[6], kernel_size=kernels[6], padding=paddings[6])
        self.pool4 = nn.MaxPool2d(kernel_size=poolingsKernels[6], stride=poolingsStride[6])

        # custom code to calculate the size of the first layer
        for i in range(convLayerNumber):
            size = (size - kernels[i] + 2 * paddings[i]) / 1 + 1
            if poolingsKernels[i] != 0:
                size = int((size - poolingsKernels[i]) / poolingsStride[i] + 1)

        fc1_input_size = filters[6] * size * size

        print("First layer size: ", fc1_input_size)

        # dense nn layers
        self.fc1 = nn.Linear(fc1_input_size, 256)
        self.fc2 = nn.Linear(256, num_classes)

        self.dropout = nn.Dropout(0.15)
        
    def forward(self, x):
        # Feature extraction
        x = F.leaky_relu(self.conv1(x))
        x = self.pool1(x)

        x = F.leaky_relu(self.conv2(x))
        x = F.leaky_relu(self.conv3(x))
        x = self.pool2(x)

        x = F.leaky_relu(self.conv4(x))
        x = F.leaky_relu(self.conv5(x))
        x = self.pool3(x)

        x = F.leaky_relu(self.conv6(x))
        x = F.leaky_relu(self.conv7(x))
        x = self.pool4(x)

        # Flatten the tensor for the fully connected layers
        x = x.view(x.size(0), -1)

        # Classification
        x = F.leaky_relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
train_on_gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if train_on_gpu else "cpu")
print("running on: ", device)

# MacOS specific code
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print(x)
else:
    print("MPS device not found.")

# Image size
size = 128

# mean and std values for the dataset calculated
mean = [0.6388, 0.5446, 0.4452]
std = [0.2252, 0.2437, 0.2661]

# network
net = Net(num_classes=251, size=size)

# torch-summary to check the network parameters
summary(net, (3, size, size))

In [None]:
# load net to device
net.to(device)

# for kaggle use: "/kaggle/input/supervised/"
datasetPath = './Data/'
# datasetPath = '/kaggle/input/supervised/'
simpleCNNLossOvertime = []
simpleCNNAccuracyOvertime = []

# data transformations pipeline
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((size, size), antialias = True),
    torchvision.transforms.Normalize(mean=mean, std=std)
])

# load the dataset
trainSet = torchvision.datasets.ImageFolder(root= datasetPath + 'processedData/processed_train_set', transform=transforms)
testSet = torchvision.datasets.ImageFolder(root= datasetPath + 'processedData/processed_test_set', transform=transforms)
valSet = torchvision.datasets.ImageFolder(root= datasetPath + 'processedData/processed_val_set', transform=transforms)

# data loaders
trainLoader = DataLoader(trainSet, batch_size=64, shuffle=True, num_workers=4)
testLoader = DataLoader(testSet, batch_size=64, shuffle=True, num_workers=4)
valLoader = DataLoader(valSet, batch_size=64, shuffle=True, num_workers=4)

# print the dataset sizes
print("Training:", len(trainLoader.dataset))
print("Validation:", len(valLoader.dataset))
print("Test:", len(testLoader.dataset))

In [None]:
# criterion for classification
criterion = nn.CrossEntropyLoss()

# optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

# learning rate scheduler
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

# number of epochs
epochs = 1

for epoch in range(epochs):
    running_loss = 0.0

    # tqdm to show a progress bar 
    for i, data in tqdm(enumerate(trainLoader, 0), total=len(trainLoader)):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    simpleCNNLossOvertime.append(round(running_loss/len(trainLoader), 2))
    
    # put the network in evaluation mode
    net.eval()

    # check accuracy on val set
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(valLoader, total=len(valLoader)):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    accuracy = round(accuracy, 2)
    simpleCNNAccuracyOvertime.append(accuracy)
    print(f"Epoch {epoch + 1}, loss: {round(running_loss/len(trainLoader), 2)}, accuracy: {accuracy}")

    # save model after each epoch
    torch.save(net.state_dict(), f"Models/{size}-model_{epoch}.pth")
    
    # put the network back in training mode
    net.train()

print('Finished Training')
print(simpleCNNLossOvertime)
print(simpleCNNAccuracyOvertime)

# plot loss and accuracy in separate graphs
plt.plot(simpleCNNLossOvertime)
plt.title('Simple CNN loss')
plt.grid()
plt.savefig('Media/loss.png')
plt.close()

plt.plot(simpleCNNAccuracyOvertime)
plt.title('Simple CNN accuracy')
plt.grid()
plt.savefig('Media/accuracy.png')
plt.close()

print("Starting testing")

# put the network in evaluation mode
net.eval()

# validate the model on the test set
correct = 0
total = 0
y_true = []
y_pred = []
with torch.no_grad():
    for data in tqdm(testLoader, total=len(testLoader)):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        y_true += labels.tolist()
        y_pred += predicted.tolist()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Testings accuracy: {accuracy}")

f1 = f1_score(y_true, y_pred, average='macro')
print(f"F1 score: {f1}")

# Self Supervised Learning Task

In [None]:
# dataset class which extends the Dataset class
class JigsawPuzzleDataset(Dataset):
    def __init__(self, dataset, grid_size=2):
        self.dataset = dataset
        self.grid_size = grid_size
        self.permutations = self._generate_permutations(grid_size)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        image, _ = self.dataset[index]
        image_pieces, correct_order = self._divide_image(image)
        shuffled_pieces, shuffled_order = self._shuffle_pieces(image_pieces, correct_order)

        # Create a new blank image of the correct size
        reconstructed_image = torch.zeros_like(image)

        piece_w, piece_h = image.shape[1] // self.grid_size, image.shape[2] // self.grid_size
        for i in range(self.grid_size):
            for j in range(self.grid_size):
                # Calculate the position of the piece in the reconstructed image
                pos = (j * piece_w, i * piece_h)
                # Paste the piece at the correct position
                reconstructed_image[:, pos[1]:pos[1]+piece_h, pos[0]:pos[0]+piece_w] = shuffled_pieces[i * self.grid_size + j]

        # One-hot encode the permutation index
        permutation_index = self.permutations.index(tuple(shuffled_order))
        label = torch.zeros(len(self.permutations), dtype=torch.float)
        label[permutation_index] = 1.0

        return reconstructed_image, label

    def _divide_image(self, image):
        pieces = []

        piece_w, piece_h = image.shape[1] // self.grid_size, image.shape[2] // self.grid_size
        for i in range(self.grid_size):
            for j in range(self.grid_size):
                piece = image[:, i*piece_h:(i+1)*piece_h, j*piece_w:(j+1)*piece_w]
                pieces.append(piece)

        correct_order = list(range(self.grid_size ** 2))
        return pieces, correct_order

    def _shuffle_pieces(self, pieces):
        shuffled_order = random.choice(self.permutations)
        shuffled_pieces = [pieces[i] for i in shuffled_order]
        return shuffled_pieces, shuffled_order

    def _generate_permutations(self, grid_size):
        indices = list(range(grid_size ** 2))
        permutations = list(itertools.permutations(indices))
        return permutations

In [None]:
# define new network class, that extends the previous one
class NetWithJigSawPrediction(Net):
    def __init__(self, num_classes, size):
        super(NetWithJigSawPrediction, self).__init__(num_classes, size)
        
        # 24 corresponds to the number of permutations of the image pieces
        self.rotation_fc = nn.Linear(self.fc1.in_features, 24)

    def forward(self, x, predictJigSaw=False):
        # Feature extraction
        x = F.leaky_relu(self.conv1(x))
        x = self.pool1(x)

        x = F.leaky_relu(self.conv2(x))
        x = F.leaky_relu(self.conv3(x))
        x = self.pool2(x)

        x = F.leaky_relu(self.conv4(x))
        x = F.leaky_relu(self.conv5(x))
        x = self.pool3(x)

        x = F.leaky_relu(self.conv6(x))
        x = F.leaky_relu(self.conv7(x))
        x = self.pool4(x)

        # Flatten the tensor for the fully connected layers
        x = x.view(x.size(0), -1)

        if predictJigSaw:
            return self.rotation_fc(x)
        else:
            x = F.leaky_relu(self.fc1(x))
            x = self.dropout(x)
            return self.fc2(x)

In [None]:
# as before use the same mean and std values
mean = [0.6388, 0.5446, 0.4452]
std = [0.2252, 0.2437, 0.2661]

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((128, 128), antialias= True),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=mean, std=std)
])

# load the dataset
trainSet = torchvision.datasets.ImageFolder(root= datasetPath + 'processedData/processed_train_set', transform=transforms)
testSet = torchvision.datasets.ImageFolder(root= datasetPath + 'processedData/processed_test_set', transform=transforms)
valSet = torchvision.datasets.ImageFolder(root= datasetPath + 'processedData/processed_val_set', transform=transforms)

jigsaw_dataset = JigsawPuzzleDataset(trainSet)
jigsaw_datasetTest = JigsawPuzzleDataset(testSet)
jigsaw_datasetVal = JigsawPuzzleDataset(valSet)

batch_size = 64
jigsaw_loader = DataLoader(jigsaw_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
jigsaw_loaderTest = DataLoader(jigsaw_datasetTest, batch_size=batch_size, shuffle=True, num_workers=4)
jigsaw_loaderVal = DataLoader(jigsaw_datasetVal, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
# Initialize the SSL model
ssl_model = NetWithJigSawPrediction(num_classes=251, size=128)
ssl_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ssl_model.parameters(), lr=0.001)

# learning rate scheduler
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

num_epochs = 1

SSLLossOvertime = []
SSLAccuracyOvertime = []

ssl_model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in tqdm(jigsaw_loader, total=len(jigsaw_loader)):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = ssl_model(images, predictJigSaw=True)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    scheduler.step()

    SSLLossOvertime.append(round(running_loss / len(jigsaw_loader), 2))

    # put the model in evaluation mode
    ssl_model.eval()

    # check accuracy on val set
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(jigsaw_loaderVal, total=len(jigsaw_loaderVal)):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = ssl_model(images, predictJigSaw=True)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            labels = torch.argmax(labels, dim=1)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    accuracy = round(accuracy, 2)
    SSLAccuracyOvertime.append(accuracy)
    print(f"Epoch {epoch + 1}, loss: {round(running_loss / len(jigsaw_loader), 2)}, accuracy: {accuracy}")

    # put the model back in training mode
    ssl_model.train()

    # save the model after each epoch
    torch.save(ssl_model.state_dict(), f"ssl_model_epoch_{epoch + 1}.pth")

print("Finished SSL Training")
print(SSLLossOvertime)
print(SSLAccuracyOvertime)

# plot loss and accuracy in separate graphs
plt.plot(SSLLossOvertime)
plt.title('SSL Loss')
plt.grid()
plt.savefig('Media/SSLLoss.png')
plt.close()

plt.plot(SSLAccuracyOvertime)
plt.title('SSL Accuracy')
plt.grid()
plt.savefig('Media/SSLAccuracy.png')
plt.close()

# start testing
print("Starting testing")

# put the model in evaluation mode
ssl_model.eval()

# validate the model on the test set
correct = 0
total = 0
with torch.no_grad():
    for data in tqdm(jigsaw_loaderTest, total=len(jigsaw_loaderTest)):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = ssl_model(images, predictJigSaw=True)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        labels = torch.argmax(labels, dim=1)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Testings accuracy: {accuracy}")

print("train the CNN with the new weights")

classification_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((128, 128)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=mean, std=std)
])

classification_dataset = torchvision.datasets.ImageFolder(root=datasetPath + "processedData/processed_train_set",
                                              transform=classification_transform)
classification_datasetTest = torchvision.datasets.ImageFolder(root=datasetPath + "processedData/processed_test_set",
                                                  transform=classification_transform)
classification_datasetVal = torchvision.datasets.ImageFolder(root=datasetPath +  "processedData/processed_val_set",
                                                 transform=classification_transform)

classification_loader = DataLoader(classification_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
classification_loaderTest = DataLoader(classification_datasetTest, batch_size=batch_size, shuffle=True,
                                       num_workers=4)
classification_loaderVal = DataLoader(classification_datasetVal, batch_size=batch_size, shuffle=True, num_workers=4)

# Fine-tune the SSL model for classification
ssl_model.fc2 = nn.Linear(ssl_model.fc2.in_features, 251)  # Update the final layer for 251 classes
ssl_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ssl_model.parameters(), lr=0.001)

# learning rate scheduler
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

num_epochs = 10
SSLCNNLossOvertime = []
SSLCNNAccuracyOvertime = []

ssl_model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in tqdm(classification_loader, total=len(classification_loader)):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = ssl_model(images, predictJigSaw=False)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    scheduler.step()

    SSLCNNLossOvertime.append(round(running_loss / len(classification_loader), 2))

    # put the model in evaluation mode
    ssl_model.eval()

    # check accuracy on val set
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(classification_loaderVal, total=len(classification_loaderVal)):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = ssl_model(images, predictJigSaw=False)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    accuracy = round(accuracy, 2)
    SSLCNNAccuracyOvertime.append(accuracy)
    print(f"Epoch {epoch + 1}, loss: {round(running_loss / len(classification_loader), 2)}, accuracy: {accuracy}")

    # put the model back in training mode
    ssl_model.train()

    # save the model after each epoch
    torch.save(ssl_model.state_dict(), f"ssl_model_classification_epoch_{epoch + 1}.pth")

print("Finished Classification Training")
print(SSLCNNLossOvertime)
print(SSLCNNAccuracyOvertime)

# plot loss and accuracy in separate graphs
plt.figure(figsize=(10, 5))
plt.plot(SSLCNNLossOvertime)
plt.title('SSL CNN Loss')
plt.grid()
plt.savefig('Media/classification_loss.png')
plt.close()

plt.plot(SSLCNNAccuracyOvertime)
plt.title('SSL CNN Accuracy')
plt.grid()
plt.savefig('Media/classification_accuracy.png')
plt.close()

# start testing
print("Starting testing")

# put the model in evaluation mode
ssl_model.eval()

# validate the model on the test set
correct = 0
total = 0
y_true = []
y_pred = []

with torch.no_grad():
    for data in tqdm(classification_loaderTest, total=len(classification_loaderTest)):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = ssl_model(images, predictJigSaw=False)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        y_true += labels.tolist()
        y_pred += predicted.tolist()
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Testings accuracy: {accuracy}")

f1 = f1_score(y_true, y_pred, average='macro')
print(f"F1 score: {f1}")

In [None]:
# plotting compared loss and accuracies of the models
plt.figure(figsize=(10, 5))
plt.plot(SSLCNNLossOvertime)
plt.plot(simpleCNNLossOvertime)
plt.title('CNN Loss Compared')
plt.legend(['SSL CNN', 'CNN'])
plt.grid()
plt.savefig('./Media/lossCompared.png')
plt.close()

plt.figure(figsize=(10, 5))
plt.plot(SSLCNNAccuracyOvertime)
plt.plot(simpleCNNAccuracyOvertime)
plt.title('CNN Accuracy Compared')
plt.legend(['SSL CNN', 'CNN'])
plt.grid()
plt.savefig('./Media/accuracyCompared.png')
plt.close()