In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import timm
import torch.cuda.amp as amp
import random
from torchvision.utils import make_grid
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [None]:
class Config:
    TRAIN_DIR = '/content/drive/MyDrive/Drepa_Vite/Dataset_drépano/train'
    VAL_DIR   = '/content/drive/MyDrive/Drepa_Vite/Dataset_drépano/val'
    TEST_DIR  = '/content/drive/MyDrive/Drepa_Vite/Dataset_drépano/test'

    IMG_SIZE = 224
    NUM_CLASSES = 3
    BATCH_SIZE = 32
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    NUM_EPOCHS = 100
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(weight=self.alpha)

    def forward(self, inputs, targets):
        logpt = -self.ce(inputs, targets)
        pt = torch.exp(logpt)
        loss = ((1 - pt) ** self.gamma) * -logpt
        return loss.mean()

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(Config.IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_test_transforms = transforms.Compose([
    transforms.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(Config.TRAIN_DIR, transform=train_transforms)
val_dataset = datasets.ImageFolder(Config.VAL_DIR, transform=val_test_transforms)
test_dataset = datasets.ImageFolder(Config.TEST_DIR, transform=val_test_transforms)

class_counts = torch.bincount(torch.tensor([label for _, label in train_dataset]))
class_weights = (1. / class_counts.float()).to(Config.DEVICE)
class_weights = class_weights / class_weights.sum()

sample_weights = [class_weights[label] for _, label in train_dataset]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, sampler=sampler, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=2)

## 2.3 Visualisation et Division des données

In [None]:
def count_images_in_folder(folder):
    total = 0
    for label in os.listdir(folder):
        label_folder = os.path.join(folder, label)
        if os.path.isdir(label_folder):
            total += len(os.listdir(label_folder))
    return total

train_count = count_images_in_folder(Config.TRAIN_DIR)
val_count = count_images_in_folder(Config.VAL_DIR)
test_count = count_images_in_folder(Config.TEST_DIR)

print(f"Number of images:")
print(f" - Training: {train_count}")
print(f" - Validation: {val_count}")
print(f" - Test: {test_count}")

def visualize_batch(dataloader, title, class_names):
    images, labels = next(iter(dataloader))
    images = images[:8]
    labels = labels[:8]

    inv_norm = transforms.Normalize(
        mean=[-m/s for m, s in zip([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])],
        std=[1/s for s in [0.229, 0.224, 0.225]]
    )
    images = torch.stack([inv_norm(img) for img in images])

    grid = make_grid(images, nrow=4)
    plt.figure(figsize=(12, 6))
    plt.imshow(grid.permute(1, 2, 0))
    plt.title(title + " | " + ", ".join([class_names[label] for label in labels]))
    plt.axis('off')
    plt.show()

class_names = train_dataset.classes

visualize_batch(train_loader, "Training Images", class_names)
visualize_batch(val_loader, "Validation Images", class_names)
visualize_batch(test_loader, "Test Images", class_names)

In [None]:
# MultiHeadAttention
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.NUM_HEADS
        self.head_dim = config.HIDDEN_DIM // config.NUM_HEADS

        self.qkv = nn.Linear(config.HIDDEN_DIM, 3 * config.HIDDEN_DIM)
        self.proj = nn.Linear(config.HIDDEN_DIM, config.HIDDEN_DIM)
        self.dropout = nn.Dropout(config.DROPOUT)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.head_dim**-0.5
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        return x

In [None]:
# Bloc MLP
class MLPBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config.HIDDEN_DIM, config.MLP_DIM)
        self.fc2 = nn.Linear(config.MLP_DIM, config.HIDDEN_DIM)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(config.DROPOUT)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

3. Module TransformerBlock

In [None]:
# Bloc Transformer
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.norm1 = nn.LayerNorm(config.HIDDEN_DIM)
        self.attn = MultiHeadAttention(config)
        self.norm2 = nn.LayerNorm(config.HIDDEN_DIM)
        self.mlp = MLPBlock(config)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

4. Module HybridModel

In [None]:
class HybridModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Branche CNN
        self.cnn = timm.create_model(
            config.CNN_BACKBONE,
            pretrained=True,
            features_only=True,
            out_indices=[4]
        )

        # Projection CNN
        with torch.no_grad():
            dummy = torch.randn(1, 3, config.IMG_SIZE, config.IMG_SIZE)
            cnn_feat = self.cnn(dummy)[0]
            self.cnn_feat_dim = cnn_feat.shape[1]

        self.cnn_proj = nn.Sequential(
            nn.Conv2d(self.cnn_feat_dim, config.HIDDEN_DIM, 1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )

        # Branche Transformer
        self.patch_embed = nn.Conv2d(
            config.NUM_CHANNELS,
            config.HIDDEN_DIM,
            kernel_size=config.PATCH_SIZE,
            stride=config.PATCH_SIZE
        )

        num_patches = (config.IMG_SIZE // config.PATCH_SIZE) ** 2
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, config.HIDDEN_DIM))
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.HIDDEN_DIM))
        self.dropout = nn.Dropout(config.DROPOUT)

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.NUM_LAYERS)
        ])

        self.norm = nn.LayerNorm(config.HIDDEN_DIM)
        self.head = nn.Linear(config.HIDDEN_DIM, config.NUM_CLASSES)

    def forward(self, x):
        # Extraction des caracteristique des CNN
        cnn_features = self.cnn_proj(self.cnn(x)[0])

        # Passage au Transformer
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)

        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)

        for block in self.transformer_blocks:
            x = block(x)

        x = x[:, 0]

        # Fusion Tranformer + CNN
        x = self.norm(x + cnn_features)
        return self.head(x)

## 3.2 Initialisation du Modèle

In [None]:
Config.CNN_BACKBONE = 'resnet50'
Config.HIDDEN_DIM = 256
Config.MLP_DIM = 512
Config.NUM_HEADS = 4
Config.NUM_LAYERS = 4
Config.DROPOUT = 0.1
Config.NUM_CHANNELS = 3
Config.PATCH_SIZE = 16

model = HybridModel(Config).to(Config.DEVICE)

criterion = FocalLoss(alpha=class_weights, gamma=2.0).to(Config.DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
scaler = amp.GradScaler()
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.NUM_EPOCHS)

In [None]:
from torchinfo import summary
from thop import profile

In [None]:
from torchinfo import summary
from thop import profile

device_str = "cuda" if torch.cuda.is_available() else "cpu"
summary(model, input_size=(1, 3, Config.IMG_SIZE, Config.IMG_SIZE), device=device_str)

model.eval()
dummy = torch.randn(1, 3, Config.IMG_SIZE, Config.IMG_SIZE, device=Config.DEVICE)
with torch.no_grad():
    flops, params = profile(model, inputs=(dummy,), verbose=False)

print(f"\nParamètres: {params/1e6:.2f} M")
print(f"FLOPs (forward, bs=1): {flops/1e9:.2f} GFLOPs")
print(f"FLOPs par batch (bs={Config.BATCH_SIZE}): {(flops*Config.BATCH_SIZE)/1e9:.2f} GFLOPs")
print("Note: l'entraînement (forward+backward) ≈ 2–3× ces FLOPs.")

In [None]:
def train_epoch(model, loader, optimizer, criterion, scaler):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, labels in loader:
        images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
        optimizer.zero_grad()
        with amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
    return total_loss / len(loader), 100. * correct / total

In [None]:
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return total_loss / len(loader), 100. * correct / total, all_preds, all_labels

In [None]:
train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []

for epoch in range(Config.NUM_EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, scaler)
    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)

    print(f"Epoch {epoch+1}/{Config.NUM_EPOCHS} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

epochs = range(1, Config.NUM_EPOCHS + 1)

plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Over Time')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(epochs, train_accuracies, label='Train Accuracy')
plt.plot(epochs, val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy Over Time')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

Evaluation sur les données d'entrainement

In [None]:
# Bonne
full_evaluation(model, train_loader, criterion, name="Entraînement")

In [None]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, matthews_corrcoef, roc_auc_score, classification_report
)
from sklearn.preprocessing import label_binarize
import numpy as np

def full_evaluation(model, loader, criterion, name=""):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    total_loss = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_probs.extend(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    mcc = matthews_corrcoef(all_labels, all_preds)

    cm = confusion_matrix(all_labels, all_preds)
    sensitivity = np.mean(np.diag(cm) / np.maximum(1, np.sum(cm, axis=1)))
    specificity = np.mean([
        (np.sum(cm) - (cm[i, :].sum() + cm[:, i].sum() - cm[i, i])) /
        max(1, (np.sum(cm) - cm[:, i].sum()))
        for i in range(len(cm))
    ])

    try:
        one_hot_labels = label_binarize(all_labels, classes=np.unique(all_labels))
        auc = roc_auc_score(one_hot_labels, all_probs, average='macro', multi_class='ovo')
    except Exception as e:
        print(f" Erreur AUC ({name}) :", e)
        auc = float('nan')

    print(f"\n Évaluation sur {name} :")
    print(f"Loss        : {total_loss / len(loader):.4f}")
    print(f"Accuracy    : {acc:.4f}")
    print(f"Precision   : {precision:.4f}")
    print(f"Recall      : {recall:.4f}")
    print(f"F1-Score    : {f1:.4f}")
    print(f"Sensitivity : {sensitivity:.4f}")
    print(f"Specificity : {specificity:.4f}")
    print(f"MCC         : {mcc:.4f}")
    print(f"AUC         : {auc:.4f}")

    target_names = ["Autres", "Falciformes", "Normales"]
    print("\n Rapport de classification :\n")
    print(classification_report(all_labels, all_preds, target_names=target_names, digits=2))

In [None]:
full_evaluation(model, train_loader, criterion, name="Entraînement")
full_evaluation(model, val_loader, criterion, name="Validation")
full_evaluation(model, test_loader, criterion, name="Test")

In [None]:
full_evaluation(model, train_loader, criterion, name="Train")
full_evaluation(model, val_loader, criterion, name="Validation")
full_evaluation(model, test_loader, criterion, name="Test")

In [None]:
torch.save(model.state_dict(), "hybrid_model_weights.pth")


In [None]:
import torch.nn.functional as F

y_true = []
y_pred_proba = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(Config.DEVICE)
        labels = labels.to(Config.DEVICE)
        outputs = model(images)
        probs = F.softmax(outputs, dim=1)

        y_true.extend(labels.cpu().numpy())
        y_pred_proba.extend(probs.cpu().numpy())

In [None]:
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
from itertools import cycle

y_true = np.array(y_true)
y_pred_proba = np.array(y_pred_proba)
n_classes = y_pred_proba.shape[1]

y_test_bin = label_binarize(y_true, classes=np.arange(n_classes))

fpr, tpr, roc_auc = {}, {}, {}

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_pred_proba[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

fpr["micro"], tpr["micro"], _ = roc_curve(y_test_bin.ravel(), y_pred_proba.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

plt.figure(figsize=(8, 6))
colors = cycle(['#1f77b4', '#d62728', '#2ca02c'])
target_names = ['Normal', 'Sickle', 'Other']

for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=2,
             label=f'{target_names[i]} (AUC = {roc_auc[i]:.2f})')

plt.plot(fpr["micro"], tpr["micro"],
         color='gray', linestyle='--', linewidth=2.5,
         label=f'Micro-average (AUC = {roc_auc["micro"]:.2f})')

plt.plot([0, 1], [0, 1], 'k--', lw=1)

plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.xlabel('False Positive Rate (1 - Specificity)', fontsize=12)
plt.ylabel('True Positive Rate (Sensitivity)', fontsize=12)
plt.title('ROC Curves per Class with Micro-Average', fontsize=14)
plt.legend(loc='lower right', fontsize=11)
plt.grid(True, linestyle='--', linewidth=0.5)
plt.tight_layout()
plt.show()

In [None]:
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import numpy as np

TARGET_NAMES = ["Autres", "Falciformes", "Normales"]

def plot_pr_curve(model, loader, name="Test"):
    model.eval()
    all_labels, all_probs = [], []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)

            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)

    classes = np.unique(all_labels)
    y_true = label_binarize(all_labels, classes=classes)
    n_classes = y_true.shape[1]

    if len(TARGET_NAMES) == n_classes:
        class_names = TARGET_NAMES
    else:
        class_names = [f"Classe {c}" for c in classes]

    ap_per_class = []
    plt.figure(figsize=(7,6))

    for i in range(n_classes):
        precision, recall, _ = precision_recall_curve(y_true[:, i], all_probs[:, i])
        ap = average_precision_score(y_true[:, i], all_probs[:, i])
        ap_per_class.append(ap)
        plt.plot(recall, precision, lw=2, label=f"{class_names[i]} (AP = {ap:.3f})")

    precision_micro, recall_micro, _ = precision_recall_curve(y_true.ravel(), all_probs.ravel())
    ap_micro = average_precision_score(y_true, all_probs, average="micro")
    plt.plot(recall_micro, precision_micro, linestyle="--", lw=2,
             label=f"Micro-average (AP = {ap_micro:.3f})")

    ap_macro = float(np.mean(ap_per_class))

    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision–Recall Curve - {name} (macro AP = {ap_macro:.3f})")
    plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
    plt.grid(True, linestyle="--", alpha=0.6)
    plt.legend(loc="lower left")
    plt.show()

plot_pr_curve(model, test_loader, name="Test")


Test 1

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model.eval()
        self.target_layer = target_layer

        self.activations = None
        self.gradients = None

        self.fwd_hook = target_layer.register_forward_hook(self._save_activations)
        self.bwd_hook = target_layer.register_full_backward_hook(self._save_gradients)

    def _save_activations(self, module, input, output):
        self.activations = output.detach()

    def _save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def remove_hooks(self):
        self.fwd_hook.remove()
        self.bwd_hook.remove()

    @torch.no_grad()
    def _normalize(self, cam):
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)
        return cam

    def __call__(self, images, targets=None, device=None):

        device = device or next(self.model.parameters()).device
        images = images.to(device)

        outputs = self.model(images)  
        probs = torch.softmax(outputs, dim=1)
        preds = outputs.argmax(dim=1)

        if targets is None:
            targets = preds
        else:
            targets = torch.as_tensor(targets, device=device)

        self.model.zero_grad(set_to_none=True)
        selected = outputs.gather(1, targets.view(-1,1)).sum()
        selected.backward()

        A = self.activations           
        dA = self.gradients            
        B, C, H, W = A.shape

        weights = dA.view(B, C, -1).mean(dim=2).view(B, C, 1, 1)  
        cam = (weights * A).sum(dim=1)                           
        cam = F.relu(cam)

        heatmaps = torch.stack([self._normalize(cam[i]) for i in range(B)], dim=0)  # (B,H,W)
        return heatmaps.cpu(), preds.cpu(), probs.detach().cpu()


In [None]:
def show_gradcam(images, heatmaps, true_labels=None, pred_labels=None, class_names=None, cols=4):
 images.detach().cpu()
    if imgs.shape[1] == 1:  
        imgs = imgs.repeat(1,3,1,1)

    B = imgs.shape[0]
    rows = int(np.ceil(B / cols))
    plt.figure(figsize=(4*cols, 4*rows))

    for i in range(B):
        plt.subplot(rows, cols, i+1)
        img = np.transpose(imgs[i].numpy(), (1,2,0))
        hm = heatmaps[i].numpy()

        plt.imshow(img, interpolation='bilinear')
        plt.imshow(hm, cmap='jet', alpha=0.4, interpolation='bilinear')  
        title = ""
        if pred_labels is not None and class_names is not None:
            title += f"Pred: {class_names[int(pred_labels[i])]}"
        if true_labels is not None and class_names is not None:
            title += f"\nTrue: {class_names[int(true_labels[i])]}"
        plt.title(title.strip(), fontsize=10)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
def show_gradcam_with_original(images, heatmaps, true_labels=None, pred_labels=None,
                               class_names=None, scores=None, score_name="p", cols=4):

    imgs = images.detach().cpu()
    if imgs.shape[1] == 1:  
        imgs = imgs.repeat(1, 3, 1, 1)

    B = imgs.shape[0]
    rows = int(np.ceil(B / cols))
    plt.figure(figsize=(8*cols, 4*rows))  

    for i in range(B):
        img = np.transpose(imgs[i].numpy(), (1, 2, 0))
        hm = heatmaps[i].numpy()

        plt.subplot(rows, cols*2, 2*i+1)
        plt.imshow(img, interpolation='bilinear')
        title_parts = []
        if true_labels is not None and class_names is not None:
            title_parts.append(f"True: {class_names[int(true_labels[i])]}")
        plt.title(" | ".join(title_parts), fontsize=9)
        plt.axis("off")

        plt.subplot(rows, cols*2, 2*i+2)
        plt.imshow(img, interpolation='bilinear')
        plt.imshow(hm, cmap='jet', alpha=0.4, interpolation='bilinear')

        title_parts = []
        if pred_labels is not None and class_names is not None:
            title_parts.append(f"Pred: {class_names[int(pred_labels[i])]}")
        if scores is not None:
            title_parts.append(f"{score_name}={float(scores[i]):.3f}")
        plt.title(" | ".join(title_parts), fontsize=9)
        plt.axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
CLASS_NAMES = ["Autres", "Falciformes", "Normales"]  
device = Config.DEVICE

target_layer = model.cnn.layer4[-1]   

cam = GradCAM(model.to(device).eval(), target_layer=target_layer)

batch_images, batch_labels = next(iter(test_loader))
batch_images = batch_images.to(device)
batch_labels = batch_labels.to(device)

heatmaps, preds, probs = cam(batch_images, targets=batch_labels, device=device)

show_gradcam(batch_images, heatmaps, true_labels=batch_labels.cpu(),
             pred_labels=preds, class_names=CLASS_NAMES, cols=4)

cam.remove_hooks()
