# Import & CUDA

In [None]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split, Subset
from collections import Counter
import torch.nn as nn
import torch.nn.functional as F

In [None]:
torch.manual_seed(132) # Fix the seed
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(torch.cuda.get_device_name())
else:
    raise PermissionError("CUDA is not usable.")

# Dataset

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_ds_simclr = datasets.ImageFolder(
    root='./train',
    transform=transform
)

print(f"Train dataset size for SimCLR : {len(train_ds_simclr)}")

In [None]:
class_counts = Counter(train_ds_simclr.targets)
for class_idx, count in class_counts.items():
    print(f"Classe '{train_ds_simclr.classes[class_idx]}' : {count} images")

In [None]:
def get_limited_indices(dataset, max_samples_per_class=50):
    class_counts = {cls_idx: 0 for cls_idx in range(len(dataset.classes))}
    limited_indices = []

    for idx, (_, label) in enumerate(dataset):
        if class_counts[label] < max_samples_per_class:
            limited_indices.append(idx)
            class_counts[label] += 1

        # Stop if enough samples
        if all(count >= max_samples_per_class for count in class_counts.values()):
            break

    return limited_indices

limited_indices = get_limited_indices(train_ds_simclr, max_samples_per_class=50)

train_classifier_ds = Subset(train_ds_simclr, limited_indices)

train_classifier_loader = DataLoader(
    train_classifier_ds,
    batch_size=64,
    shuffle=True
)

print(f"Size of the dataset to train the classifier : {len(train_classifier_ds)}")

In [None]:
val_ds_classifier = datasets.ImageFolder(
    root='./val',
    transform=transform
)

val_classifier_loader = DataLoader(
    val_ds_classifier,
    batch_size=64,
    shuffle=False
)

test_ds_classifier = datasets.ImageFolder(
    root='./test',
    transform=transform
)

test_classifier_loader = DataLoader(
    test_ds_classifier,
    batch_size=64,
    shuffle=False
)

# CNN

In [None]:
# Model Basic CNN

class BasicCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        
        self.fc1 = nn.Linear(64*64*16,128)
        self.fc2 = nn.Linear(128, num_classes)
        
        self.dp = nn.Dropout(p=0.3)
        
    
    def forward(self, x):
        x = self.conv1(x) # (3, 256, 256) -> (16, 256, 256)     1 + (H-K+2*P)/S
        x = self.relu(x) 
        x = self.pool(x) # (16, 256, 256) -> (16, 128, 128)
        x = self.dp(x)
        
        x = self.conv2(x) # (16, 128, 128) -> (32, 128, 128)    
        x = self.relu(x) 
        x = self.pool(x) # (32, 128, 128) -> (32, 64, 64)
        x = self.dp(x)
        
        x = self.flatten(x) # (64, 32, 32) -> (64*32*32) same as x.view(x.size(0), -1)  
        x = self.fc1(x) # (64*32*32) -> (128)
        x = self.dp(x)
        x = self.fc2(x) # (128) -> (num_classes)
        x = self.dp(x)
        return x

In [None]:
# Model CNN1

class CNN1(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64*32*32, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.dp = nn.Dropout(p=0.3)
        
    
    def forward(self, x):
        x = self.conv1(x) # (3, 256, 256) -> (16, 256, 256)     1 + (H-K+2*P)/S
        x = self.relu(x) 
        x = self.pool(x) # (16, 256, 256) -> (16, 128, 128)
        x = self.dp(x)
        
        x = self.conv2(x) # (16, 128, 128) -> (32, 128, 128)    
        x = self.relu(x) 
        x = self.pool(x) # (32, 128, 128) -> (32, 64, 64)
        x = self.dp(x)
        
        x = self.conv3(x) # (32, 64, 64) -> (64, 64, 64)  
        x = self.relu(x) 
        x = self.pool(x) # (64, 32, 32) -> (64, 32, 32)
        x = self.dp(x)
        
        x = self.flatten(x) # (64, 32, 32) -> (64*32*32) same as x.view(x.size(0), -1)  
        x = self.fc1(x) # (64*32*32) -> (128)
        x = self.dp(x)
        x = self.fc2(x)
        x = self.dp(x)
        return x

In [None]:
lr = 1e-3
epochs = 20
num_classes = len(train_ds_simclr.classes)


model = BasicCNN(num_classes).to(device)
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for image, label in train_classifier_loader:
        image, label = image.to(device), label.to(device)

        # Forward pass
        pred = model(image)
        loss = criterion(pred, label)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Train accuracy
        train_loss += loss.item() * image.size(0)
        _, predicted = torch.max(pred, dim=1) 
        train_correct += (predicted == label).sum().item()
        train_total += label.size(0)

    train_accuracy = train_correct / train_total
    train_loss = train_loss / train_total

    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for image, label in val_classifier_loader:
            image, label = image.to(device), label.to(device)

            # Forward pass
            pred = model(image)
            loss = criterion(pred, label)

            # Val accuracy
            val_loss += loss.item() * image.size(0)
            _, predicted = torch.max(pred, dim=1) 
            val_correct += (predicted == label).sum().item()
            val_total += label.size(0)

    val_accuracy = val_correct / val_total
    val_loss = val_loss / val_total

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")


In [None]:
# torch.save(model.state_dict(), "CNN.pth")
# print("Saved model parameters.")

In [None]:
model.eval()

test_loss = 0.0
test_correct = 0
test_total = 0

with torch.no_grad(): 
    for image, label in test_classifier_loader:
        image, label = image.to(device), label.to(device)

        # Forward pass
        pred = model(image)
        loss = criterion(pred, label)

        # Test accuracy
        test_loss += loss.item() * image.size(0)
        _, predicted = torch.max(pred, dim=1) 
        test_correct += (predicted == label).sum().item()
        test_total += label.size(0)

test_accuracy = test_correct / test_total
test_loss = test_loss / test_total

print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")


# SimCLR

In [None]:
# Data augmentation for SimCLR
transform_simclr = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(size=128, scale=(0.2, 1.0)),
            transforms.ToTensor()
])

# SimCLR dataset
class SimCLRDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, idx):
        img, label = self.dataset[idx]


        if isinstance(img, torch.Tensor):
            img = transforms.ToPILImage()(img)

        img1 = self.transform(img)  # Augmented view 1
        img2 = self.transform(img)  # Augmented view 2
        # img1 and img2 are a positive pair and all the other view are negative pairs for each image

        return img1, img2, label

    def __len__(self):
        return len(self.dataset)


simclr_train_ds = SimCLRDataset(train_ds_simclr, transform_simclr)

In [None]:
class EncoderImproved(nn.Module):
    def __init__(self):
        super(EncoderImproved, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), 
            nn.Dropout(p=0.3),
        ) 

        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), 
            nn.Dropout(p=0.3),
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),  
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  
            nn.Dropout(p=0.3),
        ) 

        self.global_avg_pool = nn.AdaptiveAvgPool2d((4, 4)) 

    def forward(self, x):
        x = self.layer1(x) # (64, 64, 64)
        x = self.layer2(x) # (128, 32, 32)
        x = self.layer3(x) # (256, 16, 16)
        x = self.global_avg_pool(x) # (256, 4, 4)
        x = x.view(x.size(0), -1)  # Flatten (256 * 4 * 4)
        return x


In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=256*4*4, output_dim=64):
        super(ProjectionHead, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        return self.mlp(x)

In [None]:
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)
        similarity_matrix = torch.mm(z_i, z_j.T)
        labels = torch.arange(z_i.size(0)).to(z_i.device)
        loss = F.cross_entropy(similarity_matrix / self.temperature, labels)
        return loss


In [None]:
encoder = EncoderImproved().to(device)
projection_head = ProjectionHead().to(device)
criterion = NTXentLoss(temperature=0.5)
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(projection_head.parameters()), lr=1e-3)

for epoch in range(5):
    encoder.train()
    projection_head.train()
    total_loss = 0
    for img1, img2, _ in DataLoader(simclr_train_ds, batch_size=128, shuffle=True):
        img1, img2 = img1.to(device), img2.to(device)
        h1, h2 = encoder(img1), encoder(img2)
        z1, z2 = projection_head(h1), projection_head(h2)
        loss = criterion(z1, z2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(simclr_train_ds)}")

In [None]:
torch.save(encoder.state_dict(), "encoder.pth")
torch.save(projection_head.state_dict(), "projection_head.pth")
torch.save(optimizer.state_dict(), "optimizer.pth")

In [None]:
class Classifier(nn.Module):
    def __init__(self, encoder):
        super(Classifier, self).__init__()
        self.encoder = encoder
        self.mlp = nn.Sequential(
            nn.Linear(256*4*4, 64),
            nn.ReLU(),
            nn.Linear(64, 25)
        )

    def forward(self, x):
            x = self.encoder(x)
            x = self.mlp(x)
            return x

In [None]:
classifier = Classifier(encoder).to(device)
for param in classifier.encoder.parameters():
    param.requires_grad = False # Freeze the parameters of the encoder that we already trained

In [None]:
optimizer = torch.optim.Adam(classifier.mlp.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

def train_classifier(model, train_loader, val_loader, criterion, optimizer, device, epochs=10):
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0

        # Training loop
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Loss
            total_loss += loss.item()
            predictions = torch.argmax(outputs, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

        # Training accuracy
        train_accuracy = correct / total * 100

        # Validation loop
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for val_images, val_labels in val_loader:
                val_images, val_labels = val_images.to(device), val_labels.to(device)

                val_outputs = model(val_images)
                val_predictions = torch.argmax(val_outputs, dim=1)
                val_correct += (val_predictions == val_labels).sum().item()
                val_total += val_labels.size(0)

        val_accuracy = val_correct / val_total * 100

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}, "
              f"Train Accuracy: {train_accuracy:.2f}%, Validation Accuracy: {val_accuracy:.2f}%")

train_classifier(classifier, train_classifier_loader, val_classifier_loader, criterion, optimizer, device, epochs=20)

In [None]:
classifier.eval()

test_loss = 0.0
test_correct = 0
test_total = 0

with torch.no_grad():  
    for image, label in test_classifier_loader:
        image, label = image.to(device), label.to(device)

        # Forward pass
        pred = classifier(image)
        loss = criterion(pred, label)

        # Test accuracy
        test_loss += loss.item() * image.size(0)
        _, predicted = torch.max(pred, dim=1) 
        test_correct += (predicted == label).sum().item()
        test_total += label.size(0)

test_accuracy = test_correct / test_total
test_loss = test_loss / test_total

print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")