# Self-supervised learning amb Autoencoders

In [1]:
import os
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

## Definim variables globals

In [None]:
data_dir = "./dataset_bin/"
save_dir = "./reconstructions"
os.makedirs(save_dir, exist_ok=True)

batch_size = 32
num_epochs = 100
learning_rate = 1e-4
image_size = (128, 128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Cream el dataset amb imatges rotades

Cream un dataset personalitzat que carrega imatges des d'una carpeta i les retorna juntament amb una versió rotada de la mateixa imatge. Per fer-ho utilitzem la classe `ImageFolder` de `torchvision.datasets` per carregar les imatges i apliquem una rotació aleatòria a cada imatge durant la recuperació de l'element del dataset.

In [None]:
class RotatedImageDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None):
        self.dataset = datasets.ImageFolder(root=root, transform=None)
        self.transform = transform

    def __getitem__(self, index):
        img, _ = self.dataset[index]

        angle = random.choice((90, 180, 270))
        target = img.copy()

        target, inp = self.transform(img), self.transform(target)

        if angle != 0:
            inp = transforms.functional.rotate(target, angle)


        return inp, target

    def __len__(self):
        return len(self.dataset)

input_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
])


### Dividim el dataset en train i test

El *dataset* complet es divideix en conjunts d'entrenament i de prova utilitzant `train_test_split` de `sklearn.model_selection`. Una vegada dividit, es creen subconjunts utilitzant `torch.utils.data.Subset` i es carreguen en *DataLoaders* per a l'entrenament.

In [None]:
whole_dataset = RotatedImageDataset(
    root=data_dir,
    transform=input_transform,
)


idx_datasets = np.arange(len(whole_dataset))
train_dss, test_dss = train_test_split(idx_datasets)

train_ds = torch.utils.data.Subset(whole_dataset, train_dss)
test_ds = torch.utils.data.Subset(whole_dataset, test_dss)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

## Definició de l'Autoencoder

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [2]:
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

## Entrenament de l'Autoencoder

In [None]:
for epoch in tqdm(range(num_epochs)):
    model.train()
    total_loss = 0
    for inputs, targets in tqdm(train_loader, leave=False):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)

    if (epoch + 1) % 5 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
        save_image(outputs[:8], f"{save_dir}/recon_epoch_{epoch+1}.png")
        save_image(inputs[:8], f"{save_dir}/inputs_epoch_{epoch+1}.png")
        save_image(targets[:8], f"{save_dir}/targets_epoch_{epoch+1}.png")

print("Training complete! Reconstructed images saved in", save_dir)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [5/100], Loss: 0.0619


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [10/100], Loss: 0.0549


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [15/100], Loss: 0.0533


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [20/100], Loss: 0.0493


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [25/100], Loss: 0.0470


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [30/100], Loss: 0.0427


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [35/100], Loss: 0.0408


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [40/100], Loss: 0.0374


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [45/100], Loss: 0.0346


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [50/100], Loss: 0.0334


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [55/100], Loss: 0.0310


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [60/100], Loss: 0.0282


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [65/100], Loss: 0.0269


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [70/100], Loss: 0.0251


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [75/100], Loss: 0.0241


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [80/100], Loss: 0.0225


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [85/100], Loss: 0.0219


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [90/100], Loss: 0.0209


  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

## Entrenam el classificador utilitzant transfer learning
Utilitzam l'encoder de l'autoencoder entrenat com a extractor de característiques per a una tasca de classificació d'imatges. Congelam els pesos de l'encoder i afegim un nou model de classificació.

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [5]:
model.decoder = nn.Sequential(
    nn.Flatten(),
    torch.nn.Linear(8*8*512, 512),
    nn.ReLU(),
    torch.nn.Linear(512, 128),
    nn.ReLU(),
    torch.nn.Linear(128, 4),
) 

model = model.to(device)

In [None]:
train_dataset_cls = datasets.ImageFolder(root=data_dir, transform=input_transform)
train_dataset_cls = torch.utils.data.Subset(train_dataset_cls, train_dss)
train_dataloader_cls = DataLoader(train_dataset_cls, batch_size=batch_size, shuffle=True)

### Entrenament del classificador

In [6]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
for epoch in tqdm(range(5), desc="Èpoques"):
    model.train()
    total_loss = 0
    for inputs, targets in tqdm(train_dataloader_cls, desc=f"Batches {epoch}", leave=False):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [1/100], Loss: 1.1368


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [2/100], Loss: 0.5437


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [3/100], Loss: 0.2955


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [4/100], Loss: 0.1870


  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [5/100], Loss: 0.0918


In [7]:

test_dataset_cls = torch.utils.data.Subset(datasets.ImageFolder(root=data_dir, transform=input_transform), test_dss)
test_dataloader_cls = DataLoader(train_dataset_cls, batch_size=batch_size, shuffle=True)

total_loss = 0
total_acc = 0
model = model.eval()
for inputs, targets in tqdm(test_dataloader_cls):
    inputs, targets = inputs.to(device), targets.to(device)

    outputs = model(inputs)
    loss = criterion(outputs, targets)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    total_acc += accuracy_score(torch.argmax(outputs, dim=1).cpu().detach().numpy(), targets.cpu().detach().numpy())
    

avg_loss = total_loss / len(train_loader)
print(avg_loss, (total_acc / len(train_loader)))

  0%|          | 0/9 [00:00<?, ?it/s]

0.30754276778962875 0.8848824786324786


In [8]:
torch.argmax(outputs, dim=1), targets

(tensor([3, 0, 3, 1, 2, 3, 3, 0, 0, 3, 1, 1, 3, 1, 0, 2, 3, 2, 0, 2, 3, 3, 3, 3,
         1, 0], device='cuda:0'),
 tensor([3, 0, 1, 1, 2, 0, 3, 0, 0, 3, 1, 1, 0, 1, 0, 2, 1, 2, 0, 2, 3, 3, 3, 1,
         1, 0], device='cuda:0'))