# - Swin Transformer

## 1. Définition du Modèle

In [None]:
import torch
import torch.nn as nn
import timm

class RakutenSwin(nn.Module):
    """Swin Transformer pour classification Rakuten avec Stochastic Depth."""
    
    def __init__(self, model_name='swin_base_patch4_window7_224', num_classes=27, 
                 pretrained=True, drop_path_rate=0.3):
        super(RakutenSwin, self).__init__()
        
        self.backbone = timm.create_model(
            model_name, pretrained=pretrained, num_classes=0, 
            global_pool='avg', drop_path_rate=drop_path_rate
        )
        
        feature_dim = self.backbone.num_features
        self.head = nn.Sequential(
            nn.LayerNorm(feature_dim),
            nn.Dropout(p=0.5),
            nn.Linear(feature_dim, 512),
            nn.GELU(),
            nn.Dropout(p=0.3),
            nn.Linear(512, num_classes)
        )
        
        self.num_classes = num_classes
        self.model_name = model_name
    
    def forward(self, x):
        return self.head(self.backbone(x))

print("Modèle RakutenSwin défini")

## 2. Configuration

In [None]:
import sys
from pathlib import Path
import pandas as pd
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import json

# Chemins
project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / "scripts"))

from src.rakuten_image.datasets import RakutenImageDataset
from load_data import split_data

# Configuration
CONFIG = {
    "data_dir": Path("/workspace/data"),
    "img_dir": Path("/workspace/data/images/image_train"),
    "checkpoint_dir": Path("/workspace/checkpoints/swin_final"),
    "model_name": "swin_base_patch4_window7_224",
    "img_size": 224,
    "batch_size": 32,
    "num_epochs": 30,
    "learning_rate": 5e-5,
    "weight_decay": 0.05,
    "random_state": 42,
    "early_stopping_patience": 5,
    "drop_path_rate": 0.3,
    "mixup_alpha": 0.8,
    "cutmix_alpha": 1.0,
    "label_smoothing": 0.1,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "num_workers": 4,
    "use_amp": True
}

device = torch.device(CONFIG["device"])
print(f"Device: {device}")

## 3. Chargement des Données

In [None]:
# Chargement et division des données
X_dev, X_holdout, y_dev, y_holdout = split_data()

df_dev = X_dev.copy()
df_dev['prdtypecode'] = y_dev
df_holdout = X_holdout.copy()
df_holdout['prdtypecode'] = y_holdout

print(f"Développement: {len(df_dev):,} | Test: {len(df_holdout):,}")

# Encodage des labels
label_encoder = LabelEncoder()
label_encoder.fit(df_dev['prdtypecode'])
df_dev['encoded_label'] = label_encoder.transform(df_dev['prdtypecode'])
df_holdout['encoded_label'] = label_encoder.transform(df_holdout['prdtypecode'])
num_classes = len(label_encoder.classes_)

# Division train/val
train_indices, val_indices = train_test_split(
    df_dev.index, test_size=0.15, random_state=CONFIG["random_state"],
    stratify=df_dev['encoded_label']
)

df_train = df_dev.loc[train_indices].reset_index(drop=True)
df_val = df_dev.loc[val_indices].reset_index(drop=True)
df_holdout = df_holdout.reset_index(drop=True)

print(f"Train: {len(df_train):,} | Val: {len(df_val):,} | Test: {len(df_holdout):,}")

## 4. Préparation des DataLoaders

In [None]:
# Transformations
train_transform = transforms.Compose([
    transforms.Resize((CONFIG["img_size"], CONFIG["img_size"])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((CONFIG["img_size"], CONFIG["img_size"])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Datasets
train_dataset = RakutenImageDataset(df_train, CONFIG["img_dir"], train_transform, "encoded_label")
val_dataset = RakutenImageDataset(df_val, CONFIG["img_dir"], val_transform, "encoded_label")
test_dataset = RakutenImageDataset(df_holdout, CONFIG["img_dir"], val_transform, "encoded_label")

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True,
                          num_workers=CONFIG["num_workers"], pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False,
                        num_workers=CONFIG["num_workers"], pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False,
                         num_workers=CONFIG["num_workers"], pin_memory=True)

print(f"Batches - Train: {len(train_loader)} | Val: {len(val_loader)} | Test: {len(test_loader)}")

## 5. Initialisation du Modèle et Optimiseur

In [None]:
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy

# Modèle
model = RakutenSwin(
    model_name=CONFIG["model_name"],
    num_classes=num_classes,
    pretrained=True,
    drop_path_rate=CONFIG["drop_path_rate"]
).to(device)

# Mixup/CutMix
mixup_fn = Mixup(
    mixup_alpha=CONFIG["mixup_alpha"], cutmix_alpha=CONFIG["cutmix_alpha"],
    prob=1.0, switch_prob=0.5, mode='batch',
    label_smoothing=CONFIG["label_smoothing"], num_classes=num_classes
)

criterion_train = SoftTargetCrossEntropy()
criterion_val = nn.CrossEntropyLoss()

# Optimiseur et Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"],
                              weight_decay=CONFIG["weight_decay"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG["num_epochs"],
                                                        eta_min=1e-6)
scaler = torch.amp.GradScaler('cuda') if CONFIG["use_amp"] else None

CONFIG["checkpoint_dir"].mkdir(parents=True, exist_ok=True)
print("Modèle et optimiseur initialisés")

## 6. Entraînement

In [None]:
best_val_acc = 0.0
patience_counter = 0
history = {"train_loss": [], "val_loss": [], "val_acc": [], "val_f1": []}

for epoch in range(CONFIG["num_epochs"]):
    print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
    
    # Entraînement
    model.train()
    train_loss = 0.0
    
    for images, labels in tqdm(train_loader, desc="Entraînement"):
        images, labels = images.to(device), labels.to(device)
        images, labels = mixup_fn(images, labels)
        
        optimizer.zero_grad()
        
        if CONFIG["use_amp"]:
            with torch.amp.autocast(device_type="cuda"):
                outputs = model(images)
                loss = criterion_train(outputs, labels)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion_train(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
        
        train_loss += loss.item()
    
    avg_train_loss = train_loss / len(train_loader)
    
    # Validation
    model.eval()
    val_loss = 0.0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            
            if CONFIG["use_amp"]:
                with torch.amp.autocast(device_type="cuda"):
                    outputs = model(images)
                    loss = criterion_val(outputs, labels)
            else:
                outputs = model(images)
                loss = criterion_val(outputs, labels)
            
            val_loss += loss.item()
            all_preds.extend(torch.argmax(outputs, dim=-1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100.0 * accuracy_score(all_labels, all_preds)
    val_f1 = f1_score(all_labels, all_preds, average='weighted')
    
    # Sauvegarde
    history["train_loss"].append(avg_train_loss)
    history["val_loss"].append(avg_val_loss)
    history["val_acc"].append(val_accuracy)
    history["val_f1"].append(val_f1)
    
    print(f"Loss - Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
    print(f"Val Acc: {val_accuracy:.2f}% | F1: {val_f1:.4f}")
    
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        patience_counter = 0
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'val_acc': val_accuracy,
            'val_f1': val_f1
        }, CONFIG["checkpoint_dir"] / "best_model.pth")
        print(f"Meilleur modèle sauvegardé!")
    else:
        patience_counter += 1
        if patience_counter >= CONFIG["early_stopping_patience"]:
            print(f"Arrêt précoce après {epoch + 1} epochs")
            break
    
    scheduler.step()

print(f"\nEntraînement terminé. Meilleure Val Acc: {best_val_acc:.2f}%")

## 7. Évaluation Finale

In [None]:
# Chargement du meilleur modèle
checkpoint = torch.load(CONFIG["checkpoint_dir"] / "best_model.pth", weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Évaluation sur test set
all_preds, all_labels = [], []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Test"):
        images = images.to(device)
        
        if CONFIG["use_amp"]:
            with torch.amp.autocast(device_type="cuda"):
                outputs = model(images)
        else:
            outputs = model(images)
        
        all_preds.extend(torch.argmax(outputs, dim=-1).cpu().numpy())
        all_labels.extend(labels.numpy())

test_acc = 100.0 * accuracy_score(all_labels, all_preds)
test_f1 = f1_score(all_labels, all_preds, average='weighted')

print(f"\nRésultats Test - Acc: {test_acc:.2f}% | F1: {test_f1:.4f}")
print("\nRapport de classification:")
print(classification_report(all_labels, all_preds, digits=4))

# Sauvegarde des résultats
results = {
    "best_epoch": int(checkpoint['epoch']),
    "val_acc": float(checkpoint['val_acc']),
    "test_acc": float(test_acc),
    "test_f1": float(test_f1),
    "num_classes": int(num_classes)
}

with open(CONFIG["checkpoint_dir"] / "results.json", "w") as f:
    json.dump(results, f, indent=2)

print("Résultats sauvegardés")