In [27]:
# model_training.ipynb
# Entraînement et validation d'un modèle Swin Transformer pour radiographies multi-label

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import timm
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
import os
import torch 

In [3]:
# --- Chargement des données ---
X_train = np.load('../data/processed/X_train.npy')
X_test = np.load('../data/processed/X_test.npy')
y_train = np.load('../data/processed/y_train.npy')
y_test = np.load('../data/processed/y_test.npy')

print(f"Données chargées : {X_train.shape}, {X_test.shape}")


Données chargées : (4005, 224, 224, 3), (994, 224, 224, 3)


In [4]:
# --- Conversion en tenseurs PyTorch ---
X_train = torch.tensor(X_train).permute(0, 3, 1, 2)  # (N, C, H, W)
X_test = torch.tensor(X_test).permute(0, 3, 1, 2)
y_train = torch.tensor(y_train).float()
y_test = torch.tensor(y_test).float()

# --- Dataloader ---
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [5]:
# --- Modèle Swin Transformer ---
num_classes = y_train.shape[1]

model = timm.create_model(
    'swin_tiny_patch4_window7_224',
    pretrained=True,
    num_classes=num_classes
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (layers): Sequential(
    (0): SwinTransformerStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path1): Identity()
          (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELU(approximate='none')
            (drop1): 

In [6]:
# --- Fonction de perte et optimiseur ---
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# --- Entraînement ---
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    for X_batch, y_batch in tqdm(train_loader, desc=f"Époque {epoch+1}/{num_epochs}"):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)


Époque 1/5: 100%|██████████| 251/251 [34:49<00:00,  8.32s/it]
Époque 2/5: 100%|██████████| 251/251 [25:06<00:00,  6.00s/it]
Époque 3/5: 100%|██████████| 251/251 [33:30<00:00,  8.01s/it]
Époque 4/5: 100%|██████████| 251/251 [32:11<00:00,  7.70s/it]
Époque 5/5: 100%|██████████| 251/251 [25:14<00:00,  6.03s/it]


In [24]:
# --- Validation ---
model.eval()
val_loss = 0.0
all_preds, all_targets = [], []

with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            val_loss += loss.item()

            preds = torch.sigmoid(outputs).cpu().numpy()
            all_preds.append(preds)
            all_targets.append(y_batch.cpu().numpy())

avg_val_loss = val_loss / len(val_loader)
all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)

try:
        auc = roc_auc_score(all_targets, all_preds, average='macro')
except:
        auc = 0.0

print(f"Époque [{epoch+1}/{num_epochs}]  "
          f"Perte entraînement: {avg_train_loss:.4f}  "
          f"Perte validation: {avg_val_loss:.4f}  "
          f"AUC: {auc:.4f}")



Époque [5/5]  Perte entraînement: 0.1776  Perte validation: 0.2090  AUC: nan




In [28]:
model_path = '../models/swin_transformer_radiography.pth'

os.makedirs(os.path.dirname(model_path), exist_ok=True)

# Sauvegarde du modèle
torch.save(model.state_dict(), model_path)
print("Modèle sauvegardé ")

Modèle sauvegardé 
