# üöÄ Google Colab Setup

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ogautier1980/sandbox-ml/blob/main/cours/12_vision_avancee/12_demo_vision_transformers.ipynb)

**Si vous ex√©cutez ce notebook sur Google Colab**, ex√©cutez la cellule suivante pour installer les d√©pendances.

In [None]:
# Installation des d√©pendances (Google Colab uniquement)import sysIN_COLAB = 'google.colab' in sys.modulesif IN_COLAB:    print('üì¶ Installation des packages...')        # Packages ML de base    !pip install -q numpy pandas matplotlib seaborn scikit-learn        # D√©tection du chapitre et installation des d√©pendances sp√©cifiques    notebook_name = '12_demo_vision_transformers.ipynb'  # Sera remplac√© automatiquement        # Ch 06-08 : Deep Learning    if any(x in notebook_name for x in ['06_', '07_', '08_']):        !pip install -q torch torchvision torchaudio        # Ch 08 : NLP    if '08_' in notebook_name:        !pip install -q transformers datasets tokenizers        if 'rag' in notebook_name:            !pip install -q sentence-transformers faiss-cpu rank-bm25        # Ch 09 : Reinforcement Learning    if '09_' in notebook_name:        !pip install -q gymnasium[classic-control]        # Ch 04 : Boosting    if '04_' in notebook_name and 'boosting' in notebook_name:        !pip install -q xgboost lightgbm catboost        # Ch 05 : Clustering avanc√©    if '05_' in notebook_name:        !pip install -q umap-learn        # Ch 11 : S√©ries temporelles    if '11_' in notebook_name:        !pip install -q statsmodels prophet        # Ch 12 : Vision avanc√©e    if '12_' in notebook_name:        !pip install -q ultralytics timm segmentation-models-pytorch        # Ch 13 : Recommandation    if '13_' in notebook_name:        !pip install -q scikit-surprise implicit        # Ch 14 : MLOps    if '14_' in notebook_name:        !pip install -q mlflow fastapi pydantic        print('‚úÖ Installation termin√©e !')else:    print('‚ÑπÔ∏è  Environnement local d√©tect√©, les packages sont d√©j√† install√©s.')

# Chapitre 13 - Vision Transformers (ViT) et CLIP

Ce notebook explore les **Vision Transformers** et les mod√®les **vision-langage** (CLIP).

## Objectifs
- Comprendre l'architecture Vision Transformer (ViT)
- Utiliser ViT avec timm (PyTorch Image Models)
- Fine-tuner ViT sur CIFAR-10
- Comparer ViT vs CNN (ResNet)
- Visualiser attention maps
- D√©couvrir CLIP pour zero-shot classification

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

# timm (PyTorch Image Models)
try:
    import timm
    TIMM_AVAILABLE = True
    print(f"timm version: {timm.__version__}")
except ImportError:
    print("‚ö†Ô∏è timm not installed. Install with: pip install timm")
    TIMM_AVAILABLE = False

# CLIP (optionnel)
try:
    import clip
    CLIP_AVAILABLE = True
except ImportError:
    print("‚ö†Ô∏è CLIP not installed. Install with: pip install git+https://github.com/openai/CLIP.git")
    CLIP_AVAILABLE = False

print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 1. Architecture Vision Transformer (ViT)

Impl√©mentation simplifi√©e de ViT pour comprendre les m√©canismes.

In [None]:
class PatchEmbedding(nn.Module):
    """D√©coupe l'image en patches et les projette en embeddings."""
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Convolution pour d√©couper en patches
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
    
    def forward(self, x):
        # x: [B, C, H, W]
        x = self.projection(x)  # [B, embed_dim, H/P, W/P]
        x = x.flatten(2)  # [B, embed_dim, n_patches]
        x = x.transpose(1, 2)  # [B, n_patches, embed_dim]
        return x


class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention."""
    
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.attn_dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, N, C = x.shape
        
        # Compute Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, num_heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Scaled Dot-Product Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale  # [B, num_heads, N, N]
        attn = attn.softmax(dim=-1)
        attn = self.attn_dropout(attn)
        
        # Apply attention to values
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_dropout(x)
        
        return x, attn


class MLP(nn.Module):
    """MLP avec GELU activation."""
    
    def __init__(self, embed_dim=768, hidden_dim=3072, dropout=0.0):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """Transformer Encoder Block."""
    
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout)
    
    def forward(self, x):
        # Attention with residual
        attn_output, attn_weights = self.attn(self.norm1(x))
        x = x + attn_output
        
        # MLP with residual
        x = x + self.mlp(self.norm2(x))
        
        return x, attn_weights


class VisionTransformer(nn.Module):
    """Vision Transformer (ViT) Architecture."""
    
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, dropout=0.0):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches
        
        # [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Positional embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.pos_dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # [B, n_patches, embed_dim]
        
        # Add [CLS] token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # [B, n_patches+1, embed_dim]
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.pos_dropout(x)
        
        # Transformer blocks
        attn_weights = []
        for block in self.blocks:
            x, attn = block(x)
            attn_weights.append(attn)
        
        # Classification head (use [CLS] token)
        x = self.norm(x[:, 0])  # [B, embed_dim]
        x = self.head(x)  # [B, num_classes]
        
        return x, attn_weights


# Tester architecture
model = VisionTransformer(
    img_size=224,
    patch_size=16,
    num_classes=10,
    embed_dim=384,
    depth=6,
    num_heads=6
)

x = torch.randn(2, 3, 224, 224)
y, attn = model(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Number of attention layers: {len(attn)}")
print(f"Attention shape: {attn[0].shape}")
print(f"\nNombre de param√®tres: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

## 2. Vision Transformers avec timm

Utiliser timm pour des mod√®les ViT pr√©-entra√Æn√©s.

In [None]:
if TIMM_AVAILABLE:
    # Lister tous les mod√®les ViT disponibles
    vit_models = timm.list_models('vit*', pretrained=True)
    print(f"Nombre de mod√®les ViT pr√©-entra√Æn√©s: {len(vit_models)}")
    print(f"\nExemples:")
    for model_name in vit_models[:10]:
        print(f"  - {model_name}")
else:
    print("timm not available")

In [None]:
if TIMM_AVAILABLE:
    # Charger ViT-Base/16 pr√©-entra√Æn√© sur ImageNet
    model_vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
    model_vit.to(device)
    
    print(f"Mod√®le charg√©: vit_base_patch16_224")
    print(f"Param√®tres: {sum(p.numel() for p in model_vit.parameters()) / 1e6:.2f}M")
    
    # Comparer avec ResNet-50
    model_resnet = timm.create_model('resnet50', pretrained=True, num_classes=10)
    model_resnet.to(device)
    
    print(f"\nMod√®le charg√©: resnet50")
    print(f"Param√®tres: {sum(p.numel() for p in model_resnet.parameters()) / 1e6:.2f}M")
    
    # Info mod√®le
    print(f"\nViT-Base configuration:")
    print(f"  Patch size: 16x16")
    print(f"  Embed dim: 768")
    print(f"  Depth: 12 layers")
    print(f"  Num heads: 12")
    print(f"  Number of patches: {(224//16)**2} = 196")
else:
    print("timm not available")

## 3. Dataset CIFAR-10

Pr√©parer CIFAR-10 pour l'entra√Ænement.

In [None]:
# Transformations
transform_train = transforms.Compose([
    transforms.Resize(224),  # ViT attend 224x224
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=16),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Charger CIFAR-10
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck']

print(f"Train dataset: {len(train_dataset)} images")
print(f"Test dataset: {len(test_dataset)} images")
print(f"Classes: {classes}")

In [None]:
# Visualiser √©chantillons
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i in range(10):
    img, label = train_dataset[i]
    
    # D√©normaliser
    img_display = img.permute(1, 2, 0).numpy()
    img_display = img_display * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_display = np.clip(img_display, 0, 1)
    
    ax = axes[i // 5, i % 5]
    ax.imshow(img_display)
    ax.set_title(classes[label])
    ax.axis('off')

plt.tight_layout()
plt.show()

## 4. Fine-Tuning ViT sur CIFAR-10

In [None]:
if TIMM_AVAILABLE:
    def train_epoch(model, loader, criterion, optimizer, device):
        """Entra√Æne le mod√®le sur une epoch."""
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(loader, desc='Training')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            # Forward
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Metrics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 
                             'acc': f'{100.*correct/total:.2f}%'})
        
        return running_loss / len(loader), 100. * correct / total
    
    
    def validate(model, loader, criterion, device):
        """Valide le mod√®le."""
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in tqdm(loader, desc='Validation'):
                images, labels = images.to(device), labels.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        return running_loss / len(loader), 100. * correct / total
    
    
    # Entra√Æner ViT (fine-tuning)
    print("\n" + "="*60)
    print("Fine-tuning ViT-Base sur CIFAR-10")
    print("="*60 + "\n")
    
    # Geler l'encoder, entra√Æner seulement la t√™te
    for param in model_vit.parameters():
        param.requires_grad = False
    for param in model_vit.head.parameters():
        param.requires_grad = True
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model_vit.head.parameters(), lr=1e-3)
    
    num_epochs = 3  # Rapide pour d√©mo
    
    history_vit = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        train_loss, train_acc = train_epoch(model_vit, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate(model_vit, test_loader, criterion, device)
        
        history_vit['train_loss'].append(train_loss)
        history_vit['train_acc'].append(train_acc)
        history_vit['val_loss'].append(val_loss)
        history_vit['val_acc'].append(val_acc)
        
        print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")
    
    print("\n" + "="*60)
    print(f"Fine-tuning termin√© ! Accuracy finale: {history_vit['val_acc'][-1]:.2f}%")
    print("="*60)
else:
    print("timm not available")

## 5. Comparaison ViT vs ResNet

In [None]:
if TIMM_AVAILABLE:
    # Entra√Æner ResNet pour comparaison
    print("\nFine-tuning ResNet-50 sur CIFAR-10")
    
    for param in model_resnet.parameters():
        param.requires_grad = False
    for param in model_resnet.fc.parameters():
        param.requires_grad = True
    
    optimizer_resnet = torch.optim.AdamW(model_resnet.fc.parameters(), lr=1e-3)
    
    history_resnet = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        train_loss, train_acc = train_epoch(model_resnet, train_loader, criterion, optimizer_resnet, device)
        val_loss, val_acc = validate(model_resnet, test_loader, criterion, device)
        
        history_resnet['train_loss'].append(train_loss)
        history_resnet['train_acc'].append(train_acc)
        history_resnet['val_loss'].append(val_loss)
        history_resnet['val_acc'].append(val_acc)
        
        print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%")
    
    print(f"\nResNet accuracy finale: {history_resnet['val_acc'][-1]:.2f}%")

In [None]:
if TIMM_AVAILABLE:
    # Comparaison graphique
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    epochs = range(1, num_epochs + 1)
    
    # Loss
    axes[0].plot(epochs, history_vit['train_loss'], 'b-', marker='o', label='ViT Train')
    axes[0].plot(epochs, history_vit['val_loss'], 'b--', marker='o', label='ViT Val')
    axes[0].plot(epochs, history_resnet['train_loss'], 'r-', marker='s', label='ResNet Train')
    axes[0].plot(epochs, history_resnet['val_loss'], 'r--', marker='s', label='ResNet Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss Comparison')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[1].plot(epochs, history_vit['train_acc'], 'b-', marker='o', label='ViT Train')
    axes[1].plot(epochs, history_vit['val_acc'], 'b--', marker='o', label='ViT Val')
    axes[1].plot(epochs, history_resnet['train_acc'], 'r-', marker='s', label='ResNet Train')
    axes[1].plot(epochs, history_resnet['val_acc'], 'r--', marker='s', label='ResNet Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Accuracy Comparison')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Tableau comparatif
    print("\n" + "="*70)
    print(f"{'Mod√®le':<20} {'Params (M)':<15} {'Val Accuracy':<20} {'Val Loss'}")
    print("="*70)
    print(f"{'ViT-Base':<20} {sum(p.numel() for p in model_vit.parameters())/1e6:<15.2f} "
          f"{history_vit['val_acc'][-1]:<20.2f} {history_vit['val_loss'][-1]:.4f}")
    print(f"{'ResNet-50':<20} {sum(p.numel() for p in model_resnet.parameters())/1e6:<15.2f} "
          f"{history_resnet['val_acc'][-1]:<20.2f} {history_resnet['val_loss'][-1]:.4f}")
    print("="*70)

## 6. Visualisation Attention Maps

Visualiser ce que ViT "regarde" dans l'image.

In [None]:
def visualize_attention(model, image, device, patch_size=16):
    """Visualise l'attention du dernier layer."""
    model.eval()
    
    # Hook pour capturer attention
    attentions = []
    
    def hook_fn(module, input, output):
        # output[1] contient les attention weights pour timm
        attentions.append(output[1])
    
    # Registrer hook sur le dernier block d'attention
    # Note: Cette partie d√©pend de l'impl√©mentation exacte de timm
    # Pour une d√©mo, on utilise notre impl√©mentation custom
    
    # Alternative: Utiliser notre VisionTransformer custom
    with torch.no_grad():
        if hasattr(model, 'blocks'):  # Notre impl√©mentation
            logits, attn_weights = model(image.unsqueeze(0).to(device))
            return attn_weights[-1][0]  # Dernier layer, premier batch
        else:
            print("Attention visualization non support√©e pour ce mod√®le timm")
            return None

# Exemple avec notre VisionTransformer custom
model_custom = VisionTransformer(
    img_size=224, patch_size=16, num_classes=10,
    embed_dim=384, depth=6, num_heads=6
).to(device)

# Charger une image
img, label = test_dataset[0]

# Obtenir attention
attn = visualize_attention(model_custom, img, device)

if attn is not None:
    print(f"Attention shape: {attn.shape}")
    print(f"[num_heads, num_patches+1, num_patches+1]")
    
    # Visualiser attention du [CLS] token
    n_heads = attn.shape[0]
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    # Image originale
    img_display = img.permute(1, 2, 0).cpu().numpy()
    img_display = img_display * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_display = np.clip(img_display, 0, 1)
    
    for i in range(6):
        if i < n_heads:
            # Attention du [CLS] token (premi√®re row, skip [CLS] lui-m√™me)
            attn_map = attn[i, 0, 1:].cpu().numpy()
            
            # Reshape en grille de patches
            n_patches = int(np.sqrt(len(attn_map)))
            attn_map = attn_map.reshape(n_patches, n_patches)
            
            # Interpoler √† la taille de l'image
            attn_map_resized = cv2.resize(attn_map, (224, 224))
            
            # Overlay
            axes[i].imshow(img_display)
            axes[i].imshow(attn_map_resized, cmap='hot', alpha=0.6)
            axes[i].set_title(f'Head {i+1}')
            axes[i].axis('off')
        else:
            axes[i].axis('off')
    
    plt.suptitle(f'Attention Maps - Class: {classes[label]}')
    plt.tight_layout()
    plt.show()

## 7. CLIP - Zero-Shot Classification

Utiliser CLIP pour classifier sans entra√Ænement.

In [None]:
if CLIP_AVAILABLE:
    # Charger CLIP
    model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
    
    print("CLIP ViT-B/32 charg√©")
    print(f"Param√®tres: {sum(p.numel() for p in model_clip.parameters()) / 1e6:.2f}M")
else:
    print("CLIP non disponible")

In [None]:
if CLIP_AVAILABLE:
    # Zero-shot sur CIFAR-10
    text_prompts = [f"a photo of a {c}" for c in classes]
    text_tokens = clip.tokenize(text_prompts).to(device)
    
    # Tester sur √©chantillons
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        # Encoder les prompts texte une seule fois
        text_features = model_clip.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        
        for i in range(10):
            # Image CIFAR-10 (32x32)
            img_pil, label = test_dataset.dataset[i]  # Image PIL originale
            
            # Pr√©processer pour CLIP
            img_clip = preprocess_clip(img_pil).unsqueeze(0).to(device)
            
            # Encoder image
            image_features = model_clip.encode_image(img_clip)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            
            # Calculer similarit√©s
            similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            values, indices = similarity[0].topk(3)
            
            # Pr√©diction
            pred_idx = indices[0].item()
            pred_class = classes[pred_idx]
            true_class = classes[label]
            
            if pred_idx == label:
                correct += 1
            total += 1
            
            # Afficher
            axes[i].imshow(np.array(img_pil))
            axes[i].set_title(f'True: {true_class}\nPred: {pred_class} ({values[0].item():.1f}%)',
                             color='green' if pred_idx == label else 'red')
            axes[i].axis('off')
            
            # Afficher top-3
            print(f"\nImage {i+1} - True: {true_class}")
            for rank, (value, index) in enumerate(zip(values, indices)):
                print(f"  {rank+1}. {classes[index]:12s} {value.item():5.1f}%")
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n{'='*60}")
    print(f"CLIP Zero-Shot Accuracy: {100.*correct/total:.2f}% ({correct}/{total})")
    print(f"{'='*60}")
else:
    print("CLIP non disponible")

## R√©sum√©

Dans ce notebook, nous avons explor√© :

1. **Architecture ViT** :
   - Patch embedding (d√©coupage en patches)
   - Positional encoding
   - Multi-head self-attention
   - Transformer encoder
   - Classification via [CLS] token

2. **timm (PyTorch Image Models)** :
   - ViT-Base pr√©-entra√Æn√© sur ImageNet
   - Fine-tuning sur CIFAR-10
   - Comparaison avec ResNet-50

3. **Attention Maps** :
   - Visualisation de ce que ViT "regarde"
   - Interpr√©tabilit√© des multi-heads

4. **CLIP** :
   - Vision-language model
   - Zero-shot classification
   - Pas besoin d'entra√Ænement pour nouvelles classes

### Points Cl√©s
- **ViT** : Transformers adapt√©s √† la vision via patches
- **Self-attention** : capture d√©pendances globales (vs locales CNN)
- **Donn√©es** : ViT n√©cessite beaucoup de donn√©es (ImageNet-21k, JFT-300M)
- **CLIP** : aligne images et textes pour zero-shot
- **Attention maps** : visualisation plus interpr√©table que CNN

### Trade-offs ViT vs CNN
- **ViT** : Meilleure pr√©cision avec beaucoup de donn√©es, d√©pendances globales
- **CNN** : Meilleur avec peu de donn√©es, inductive bias (localit√©, translation)

### Prochaines √âtapes
- Explorer Swin Transformer (attention locale)
- Tester DeiT (Data-efficient ViT)
- Appliquer CLIP √† vos propres t√¢ches