# Fine-tuning de modeles de vision pour la classification.

Dans ce TD, on s'interesse a fine-tuner des modeles de fondation pour classifier des images.

In [1]:
import os

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cuda


### 1. Preparation du dataset

Nous allons maintenant utiliser le jeu de donnees [Intel Image Classification](https://www.kaggle.com/datasets/puneet6060/intel-image-classification) comme exemple dans ce TD. N'hesitez pas a modifier le code ci-dessous pour essayer d'autres jeux de donnees !

On commence par definir les datasets d'entrainement et de validation et les dataloaders associés.



In [None]:
def get_dataloaders(
    url_data_dir: str = './data/chest_xray',
    batch_size: int = 16,
    num_workers: int = 0,
    add_augmentation: bool = False,
    num_images: int = -1,
) -> tuple[DataLoader, DataLoader]:
    transform = []
    if add_augmentation:
        transform.append(transforms.RandomHorizontalFlip(p=0.5))
        transform.append(transforms.RandomVerticalFlip(p=0.5))
        transform.append(transforms.RandomResizedCrop(224, scale=(0.7, 1.0), ratio=(0.75, 1.33)))

    # Taille d'image standard attendue par le model
    transform.append(transforms.Resize((224, 224)))
    transform.append(transforms.ToTensor())

    # Le processor attend des images entre 0 et 255 :
    transform.append(transforms.Lambda(lambda x: x * 255))
    transform.append(transforms.Lambda(lambda x: x.to(torch.uint8)))
    transform = transforms.Compose(transform)
    # Pas besoin de normaliser les données, c'est géré par le processor
    
    train_dataset = ImageFolder(root=url_data_dir + '/train', transform=transform)
    test_dataset = ImageFolder(root=url_data_dir + '/test', transform=transform)
    
    # if num_images > 0:
    #     train_indices = np.random.choice(len(train_dataset), num_images, replace=False)
    #     test_indices = np.random.choice(len(test_dataset), num_images // 4, replace=False)
    #     train_dataset = Subset(train_dataset, train_indices)
    #     test_dataset = Subset(test_dataset, test_indices)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

Le parametre `num_images` permet juste de limiter le nombre d'images utilisees pour l'entrainement et le test, si besoin (pour du debugging par exemple !). Notons que le dataset d'origine contient ~14000 images pour l'entrainement et ~3000 pour le test.

### 2. Classification head

On commence par definir une classification head basique : une unique couche lineaire.

In [4]:
class LinearClassificationHead(nn.Module):
    '''Classification head tres basique.'''
    def __init__(self, in_channels: int, num_classes: int = 2):
        super().__init__()
        self.classifier = nn.Linear(in_channels, num_classes)

    def forward(self, x):
        return self.classifier(x)

### 3. Model class pour le fine-tuning.

A present, on va definir le modele que nous allons finetuner. Le modele consiste en l'enchainement d'un backbone pre-entraine, dont nous freezons les parametres, et d'une classification head dont les parametres sont entrainables.

Pour correctement freezer les parametres du backbone, on peut utiliser la fonction `requires_grad` de PyTorch (`model.eval()` ne fait pas exactement la meme chose) :
```
for param in self.backbone.parameters():
    param.requires_grad = False
```

In [5]:
class ModelLinearProbing(nn.Module):
    def __init__(
        self,
        model_name: str = 'microsoft/resnet-50',
        num_classes: int = 10,
        device: str = 'cuda',
    ):
        super().__init__()

        self.processor = AutoImageProcessor.from_pretrained(model_name, use_fast=False)
        self.backbone = AutoModel.from_pretrained(model_name).to(device)
        self.device = device

        # On ajoute la classification head (ici, une couche linéaire)
        if model_name == 'microsoft/resnet-50':
            self.head = LinearClassificationHead(self.backbone.config.hidden_sizes[-1], num_classes)
        else:
            self.head = LinearClassificationHead(self.backbone.config.hidden_size, num_classes)

        # On freeze le backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

        # On initialise les poids de la classification head
        nn.init.xavier_uniform_(self.head.classifier.weight)
        nn.init.zeros_(self.head.classifier.bias)

        self.to(device)

    def forward(self, x) -> torch.Tensor:
        inputs = self.processor(images=x, return_tensors="pt").to(self.device)
        outputs = self.backbone(**inputs)
        cls_tokens = outputs[1]
        x = self.head(cls_tokens.squeeze())
        return x

On definit ensuite une fonction `train_model` qui effectue le fine-tuning du modele.

In [6]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    lr: float = 1E-4,
    num_epochs: int = 25,
) -> None:
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=1E-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=True)

    for epoch in range(num_epochs):

        # Training loop
        model.train()
        avg_loss_train = 0
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}: Training'):
            images, labels = images.to(model.device), labels.to(model.device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            avg_loss_train += loss.item()
        print(f'Epoch {epoch+1}/{num_epochs}, Training loss: {avg_loss_train / len(train_loader):.4f}')

        scheduler.step(avg_loss_train)

        # Validation loop
        model.eval()
        avg_loss_val = 0
        true_labels, pred_labels = [], []
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs}: Validation'):
                images, labels = images.to(model.device), labels.to(model.device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                avg_loss_val += loss.item()
                _, preds = torch.max(outputs, 1)
                pred_labels.append(preds.cpu().numpy())
                true_labels.append(labels.cpu().numpy())

        pred_labels = np.concatenate(pred_labels)
        true_labels = np.concatenate(true_labels)
        accuracy = accuracy_score(true_labels, pred_labels)
        print(f'Epoch {epoch+1}/{num_epochs}, Validation loss: {avg_loss_val / len(val_loader):.4f}')
        print(f'Epoch {epoch+1}/{num_epochs}, Validation accuracy: {accuracy * 100:.2f}%')

**Exercice** : lancer l'entrainement du modele, et tester divers hyperparametres. Quelle performance arrivez-vous a obtenir ?

In [None]:
model = ModelLinearProbing()
train_loader, val_loader = get_dataloaders(num_workers=16)
lr=1E-4
num_epochs=25

train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    lr=lr,
    num_epochs=num_epochs,
)

Epoch 1/25: Training: 100%|██████████| 326/326 [00:57<00:00,  5.71it/s]


Epoch 1/25, Training loss: 1.0400


Epoch 1/25: Validation: 100%|██████████| 39/39 [00:08<00:00,  4.57it/s]


Epoch 1/25, Validation loss: 0.7882
Epoch 1/25, Validation accuracy: 62.50%


Epoch 2/25: Training: 100%|██████████| 326/326 [00:57<00:00,  5.67it/s]


Epoch 2/25, Training loss: 0.5404


Epoch 2/25: Validation: 100%|██████████| 39/39 [00:08<00:00,  4.39it/s]


Epoch 2/25, Validation loss: 0.6523
Epoch 2/25, Validation accuracy: 62.82%


Epoch 3/25: Training: 100%|██████████| 326/326 [00:57<00:00,  5.64it/s]


Epoch 3/25, Training loss: 0.4530


Epoch 3/25: Validation: 100%|██████████| 39/39 [00:08<00:00,  4.35it/s]


Epoch 3/25, Validation loss: 0.5926
Epoch 3/25, Validation accuracy: 65.22%


Epoch 4/25: Training: 100%|██████████| 326/326 [00:54<00:00,  5.97it/s]


Epoch 4/25, Training loss: 0.4004


Epoch 4/25: Validation: 100%|██████████| 39/39 [00:08<00:00,  4.80it/s]


Epoch 4/25, Validation loss: 0.5591
Epoch 4/25, Validation accuracy: 69.55%


Epoch 5/25: Training: 100%|██████████| 326/326 [00:53<00:00,  6.14it/s]


Epoch 5/25, Training loss: 0.3650


Epoch 5/25: Validation: 100%|██████████| 39/39 [00:07<00:00,  4.93it/s]


Epoch 5/25, Validation loss: 0.5418
Epoch 5/25, Validation accuracy: 71.63%


Epoch 6/25: Training: 100%|██████████| 326/326 [00:52<00:00,  6.18it/s]


Epoch 6/25, Training loss: 0.3346


Epoch 6/25: Validation: 100%|██████████| 39/39 [00:08<00:00,  4.69it/s]


Epoch 6/25, Validation loss: 0.5128
Epoch 6/25, Validation accuracy: 72.92%


Epoch 7/25: Training: 100%|██████████| 326/326 [00:51<00:00,  6.38it/s]


Epoch 7/25, Training loss: 0.3134


Epoch 7/25: Validation: 100%|██████████| 39/39 [00:08<00:00,  4.72it/s]


Epoch 7/25, Validation loss: 0.4924
Epoch 7/25, Validation accuracy: 76.12%


Epoch 8/25: Training: 100%|██████████| 326/326 [00:57<00:00,  5.67it/s]


Epoch 8/25, Training loss: 0.2955


Epoch 8/25: Validation: 100%|██████████| 39/39 [00:08<00:00,  4.63it/s]


Epoch 8/25, Validation loss: 0.4912
Epoch 8/25, Validation accuracy: 75.64%


Epoch 9/25: Training: 100%|██████████| 326/326 [00:50<00:00,  6.49it/s]


Epoch 9/25, Training loss: 0.2814


Epoch 9/25: Validation: 100%|██████████| 39/39 [00:07<00:00,  4.97it/s]


Epoch 9/25, Validation loss: 0.4616
Epoch 9/25, Validation accuracy: 79.01%


Epoch 10/25: Training: 100%|██████████| 326/326 [00:53<00:00,  6.14it/s]


Epoch 10/25, Training loss: 0.2748


Epoch 10/25: Validation: 100%|██████████| 39/39 [00:07<00:00,  5.39it/s]


Epoch 10/25, Validation loss: 0.4854
Epoch 10/25, Validation accuracy: 77.56%


Epoch 11/25: Training: 100%|██████████| 326/326 [00:53<00:00,  6.07it/s]


Epoch 11/25, Training loss: 0.2603


Epoch 11/25: Validation: 100%|██████████| 39/39 [00:08<00:00,  4.86it/s]


Epoch 11/25, Validation loss: 0.4735
Epoch 11/25, Validation accuracy: 79.17%


Epoch 12/25: Training: 100%|██████████| 326/326 [00:50<00:00,  6.42it/s]


Epoch 12/25, Training loss: 0.2530


Epoch 12/25: Validation: 100%|██████████| 39/39 [00:08<00:00,  4.76it/s]


Epoch 12/25, Validation loss: 0.4651
Epoch 12/25, Validation accuracy: 79.17%


Epoch 13/25: Training: 100%|██████████| 326/326 [00:52<00:00,  6.22it/s]


Epoch 13/25, Training loss: 0.2485


Epoch 13/25: Validation: 100%|██████████| 39/39 [00:07<00:00,  5.48it/s]


Epoch 13/25, Validation loss: 0.4804
Epoch 13/25, Validation accuracy: 78.04%


Epoch 14/25: Training: 100%|██████████| 326/326 [00:50<00:00,  6.48it/s]


Epoch 14/25, Training loss: 0.2397


Epoch 14/25: Validation:  13%|█▎        | 5/39 [00:02<00:11,  3.06it/s]

A noter que le code peut etre simplifie en utilisant `pytorch-lightning`. La model class peut ainsi etre reecrite de la sorte :

In [None]:
import pytorch_lightning as pl
from torchmetrics.functional import accuracy


class LightningLinearProbing(pl.LightningModule):
    def __init__(self, model_name: str, num_classes: int = 10, lr: float = 1E-4):
        super().__init__()
        self.save_hyperparameters()

        self.processor = AutoImageProcessor.from_pretrained(model_name, use_fast=False)
        self.backbone = AutoModel.from_pretrained(model_name)
        if model_name == 'microsoft/resnet-50':
            self.head = LinearClassificationHead(self.backbone.config.hidden_sizes[-1], num_classes)
        else:
            self.head = LinearClassificationHead(self.backbone.config.hidden_size, num_classes)

        # Freeze the backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Initialize the classification head
        nn.init.xavier_uniform_(self.head.weight)
        nn.init.zeros_(self.head.bias)

    def forward(self, x):
        inputs = self.processor(images=x, return_tensors="pt").to(self.device)
        outputs = self.backbone(**inputs)
        cls_tokens = outputs[1]
        return self.head(cls_tokens.squeeze())

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        _, preds = torch.max(outputs, 1)
        acc = accuracy(preds, labels)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.head.parameters(), lr=self.hparams.lr,
            betas=(0.9, 0.999), weight_decay=1E-4
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.2, patience=5, verbose=True
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

L'entrainement du modele se fait alors de la sorte :

In [None]:
# On initialise le Lightning model
lightning_model = LightningLinearProbing(model_name='facebook/dinov2-base', num_classes=6, lr=1E-4)

# Puis le trainer
trainer = pl.Trainer(max_epochs=50, gpus=1 if torch.cuda.is_available() else 0)

# Et enfin on lance l'entrainement
trainer.fit(lightning_model, train_loader, test_loader)

**Exercice** : comparer aux performances en kNN classification. Le fine-tuning ameliore-t-il les performances ?

### 4. Complexifier la tache

**Exercice** : modifier la classe `LinearClassificationHead` pour ajouter des couches supplémentaires, et tester les performances en finetuning.

In [None]:
class LinearClassificationHead(nn.Module):
    '''Une classification head plus complexe avec 3 couches lineaires.'''
    def __init__(
        self,
        in_channels: int,
        embed_dim: int = 512,
        num_classes: int = 2,
        dropout: float = 0.2
    ):
        super().__init__()

        self.classifier = nn.Sequential(
            nn.Linear(in_channels, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, num_classes),
        )

    def forward(self, x):
        return self.classifier(x)

### 5. Images medicales : au-dela du RGB

A present, on va tester les performances d'un modele de vision sur un dataset d'images medicales.

On va utiliser le dataset [Chest X-Ray Pneumonia](https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia) pour illustrer le cas de la classification binaire.

On commence par definir les datasets d'entrainement et de validation et les dataloaders associés.

In [None]:
train_loader, test_loader = get_dataloaders(
    url_data_dir='./data/chest_xray/',
    batch_size=32,
    num_images=-1,
    num_workers=2,
)

On affiche quelques exemples d'images :

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(3, 3, figsize=(10, 10))
for i, ax in enumerate(axs.flatten()):
    ax.imshow(train_loader.dataset[i][0].permute(1, 2, 0))
    ax.set_title(train_loader.dataset[i][1], fontsize=14)
plt.tight_layout()
plt.show()

Il est important de noter que les images d'origine sont en grayscale (une seule channel), mais que `torchvision.datasets.ImageFolder` les convertit automatiquement en 3 canaux (RGB) en repliquant simplement la couche de grayscale. En pratique, cela se fait via la librairie `PIL` (Python Imaging Library) avec le code suivant : `PIL.Image.open(img_path).convert('RGB')`. Dans d'autres cas (par exemple quand le dataset n'est pas formatte comme il faut pour utiliser ImageFolder), il peut etre necessaire de le faire soi-meme. Ci-dessous un code ecrit de zero pour definir son propre dataset :

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class ChestXRayDataset(Dataset):
    def __init__(self, root_dir: str, transform: transforms.Compose = None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.image_paths = []
        self.labels = []

        for label, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # replique la single channel en RGB
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

**Exercice** : Calculer l'accuracy de classification avec kNN et en finetuning sur un modele de vision pre-entraine sur des images naturelles.

Des modeles pre-entraines sur des images medicales existent et sont meme disponibles sur HuggingFace. Tentons a present d'utiliser le modele `https://huggingface.co/microsoft/rad-dino`. Quelle performance arrive-t-on a obtenir  en kNN et en finetuning ?