In [None]:
import os
from torchvision import transforms, datasets
import glob
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torch import nn
import torch.optim as optim

# Define the transformations for your dataset
data_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# The pre-trained weights of the ResNet-50 model are optimized for the original input size (224x224)
# The values used for normalization in the transforms.Normalize function are the mean and standard deviation of the RGB channels for the ImageNet dataset. Since ResNet-50 was pre-trained on ImageNet, using these values helps keep the input distribution consistent with the original training data, which can improve model performance.

# Create datasets
train_dir = 'dataset/train'
unlabeled_dir = 'dataset/unlabelled'

train_dataset = datasets.ImageFolder(train_dir, transform=data_transforms)

batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)


# TODO Check if the unlabeled_loader works correctly.
class UnlabeledDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = glob.glob(os.path.join(root_dir, '*.*'))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_path

unlabeled_dir = 'dataset/unlabelled'
unlabeled_dataset = UnlabeledDataset(unlabeled_dir, transform=data_transforms)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [4]:
# We split the dataset into a train_subset and a test_substet.
# TODO We should implement k-fold cross validation, for better precision, as the test_subset is very small here. However, it will take more time to compute.
from torch.utils.data import random_split

train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size

train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=4)

In [5]:
from torchvision.models.resnet import ResNet50_Weights

model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

In [6]:
num_classes = 10
num_ftrs = model.fc.in_features

# We replace the last layer by a nn.Linear which is going to be trained
model.fc = nn.Linear(num_ftrs, num_classes)

In [7]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') #Metal framework for Apple Silicon support
print(device)

model.to(device)

mps


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

In [9]:
def train_model(model, criterion, optimizer, train_loader, val_loader, device, num_epochs=25):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0

        # Training step
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.float() / len(train_loader.dataset)

        # Validation step
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0

        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            with torch.set_grad_enabled(False):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

            val_running_loss += loss.item() * inputs.size(0)
            val_running_corrects += torch.sum(preds == labels.data)

        val_epoch_loss = val_running_loss / len(val_loader.dataset)
        val_epoch_acc = val_running_corrects.float() / len(val_loader.dataset)

        print(f'Epoch {epoch}/{num_epochs - 1}, Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.4f}')

In [47]:
train_model(model, criterion, optimizer, train_loader, val_loader, device)

Epoch 0/24, Train Loss: 1.8565, Train Acc: 0.5000, Val Loss: 2.0437, Val Acc: 0.2667
Epoch 1/24, Train Loss: 1.4715, Train Acc: 0.7333, Val Loss: 1.7821, Val Acc: 0.5000
Epoch 2/24, Train Loss: 1.2036, Train Acc: 0.8583, Val Loss: 1.6177, Val Acc: 0.5333
Epoch 3/24, Train Loss: 0.9632, Train Acc: 0.9417, Val Loss: 1.6022, Val Acc: 0.4333
Epoch 4/24, Train Loss: 0.8231, Train Acc: 0.9167, Val Loss: 1.5902, Val Acc: 0.4333
Epoch 5/24, Train Loss: 0.6871, Train Acc: 0.9333, Val Loss: 1.4621, Val Acc: 0.5667
Epoch 6/24, Train Loss: 0.5704, Train Acc: 0.9750, Val Loss: 1.3820, Val Acc: 0.6000
Epoch 7/24, Train Loss: 0.5250, Train Acc: 0.9500, Val Loss: 1.4581, Val Acc: 0.5000
Epoch 8/24, Train Loss: 0.3912, Train Acc: 0.9833, Val Loss: 1.4240, Val Acc: 0.5667
Epoch 9/24, Train Loss: 0.3673, Train Acc: 1.0000, Val Loss: 1.3660, Val Acc: 0.6000
Epoch 10/24, Train Loss: 0.3216, Train Acc: 0.9833, Val Loss: 1.3585, Val Acc: 0.6333
Epoch 11/24, Train Loss: 0.2687, Train Acc: 1.0000, Val Loss: 1.