# Swin Transformer Training Notebook (Kaggle Version)

This notebook trains a **Swin Transformer** model for Face Recognition.
It uses `timm` for easy model creation and is optimized for Kaggle.

**Prerequisite:** Upload your `data_processed` folder as a Kaggle Dataset.

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm.notebook import tqdm
import sys
import matplotlib.pyplot as plt
import seaborn as sns

# ==========================================
# KAGGLE SETUP
# ==========================================
# Install dependencies (Clean & Simple!)
!pip install -q timm albumentations scikit-learn

import timm
print(f"‚úÖ timm version: {timm.__version__}")

# Check GPU
if torch.cuda.is_available():
    print(f"‚úÖ GPU Detected: {torch.cuda.get_device_name(0)}")
else:
    print("‚ö†Ô∏è Warning: No GPU detected. Training will be slow.")

In [None]:
# ==========================================
# INLINE UTILS (Dataset & Augmentations)
# ==========================================
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

class FaceDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if image.dtype != np.uint8:
            image = image.astype(np.uint8)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        else:
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        return image, torch.tensor(label, dtype=torch.long)

# Augmentations
train_transforms = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=15, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), p=0.5),
    A.GaussianBlur(blur_limit=3, p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

test_transforms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [None]:
# ==========================================
# DATA LOADING
# ==========================================
# GANTI 'face-recognition-data' DENGAN NAMA DATASET ANDA DI KAGGLE
DATA_DIR = "/kaggle/input/face-recognition-data/data_processed"
TARGET_SIZE = (224, 224)

def load_dataset_kaggle():
    X = []
    y = []
    label_map = {}
    
    if not os.path.exists(DATA_DIR):
        print(f"‚ùå Error: Path {DATA_DIR} tidak ditemukan.")
        print("Pastikan Anda sudah Add Data di sidebar kanan Kaggle.")
        return [], [], {}

    folders = sorted(os.listdir(DATA_DIR))
    for idx, folder in enumerate(folders):
        label_map[folder] = idx
        folder_path = os.path.join(DATA_DIR, folder)

        for filename in os.listdir(folder_path):
            if filename.lower().endswith((".jpg", ".png", ".jpeg", ".webp")):
                img_path = os.path.join(folder_path, filename)
                img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, TARGET_SIZE)
                X.append(img)
                y.append(idx)

    X = np.array(X)
    y = np.array(y)
    print(f"üì¶ Loaded {len(X)} images from {len(label_map)} classes.")
    return X, y, label_map

def split_dataset(X, y):
    try:
        # Try stratified split first
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, stratify=y_train, random_state=42)
    except ValueError as e:
        print(f"‚ö†Ô∏è Warning: Stratified split failed. Falling back to random split.")
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
        
    return X_train, X_val, X_test, y_train, y_val, y_test

In [None]:
# ==========================================
# CONFIGURATION (HYPERPARAMETERS)
# ==========================================
# üõ†Ô∏è GANTI NILAI DI SINI UNTUK EKSPERIMEN
CONFIG = {
    'BATCH_SIZE': 32,          # Jumlah gambar per batch (turunkan jika OOM)
    'EPOCHS': 20,              # Jumlah epoch training
    'LEARNING_RATE': 1e-4,     # Kecepatan belajar (1e-4 = 0.0001)
    'WEIGHT_DECAY': 0.05,      # Regularisasi untuk mencegah overfitting
    'IMAGE_SIZE': 224,         # Ukuran input gambar (Swin Tiny pakai 224)
    'NUM_WORKERS': 2,          # Jumlah worker dataloader
    'SEED': 42,                # Seed untuk reproducibility
    'CHECKPOINT_DIR': "/kaggle/working/checkpoints",
    'EARLY_STOPPING_PATIENCE': 5 # Stop jika loss tidak turun selama 5 epoch
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(CONFIG['CHECKPOINT_DIR'], exist_ok=True)

print("‚öôÔ∏è Configuration:")
for key, val in CONFIG.items():
    print(f"   {key}: {val}")

def create_swin_model(num_classes):
    print("ü¶¢ Creating Swin Transformer Model...")
    # Using 'swin_tiny_patch4_window7_224' - A great balance of speed and accuracy
    model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=num_classes)
    return model

In [None]:
# ==========================================
# TRAINING HELPERS
# ==========================================
def train_one_epoch(model, loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['EPOCHS']} [Train]")
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_postfix({'loss': running_loss/total, 'acc': 100 * correct / total})
    return running_loss / len(loader), 100 * correct / total

def validate(model, loader, criterion, epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['EPOCHS']} [Val]")
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            pbar.set_postfix({'loss': running_loss/total, 'acc': 100 * correct / total})
    return running_loss / len(loader), 100 * correct / total

In [None]:
# ==========================================
# EXECUTION
# ==========================================
print("üöÄ Starting Swin Transformer Training...")

# 1. Load Data
X, y, label_map = load_dataset_kaggle()

if len(X) > 0:
    num_classes = len(label_map)
    print(f"‚úÖ Detected {num_classes} classes.")

    # 2. Split
    X_train, X_val, X_test, y_train, y_val, y_test = split_dataset(X, y)

    # 3. Dataloaders
    train_dataset = FaceDataset(X_train, y_train, transform=train_transforms)
    val_dataset = FaceDataset(X_val, y_val, transform=test_transforms)
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=CONFIG['NUM_WORKERS'])
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=CONFIG['NUM_WORKERS'])

    # 4. Model
    model = create_swin_model(num_classes).to(DEVICE)

    # 5. Train
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG['LEARNING_RATE'], weight_decay=CONFIG['WEIGHT_DECAY'])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['EPOCHS'])

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_acc = 0.0
    best_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(CONFIG['EPOCHS']):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, epoch)
        val_loss, val_acc = validate(model, val_loader, criterion, epoch)
        scheduler.step()
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Summary: Train Loss {train_loss:.4f} | Val Acc {val_acc:.2f}%")
        
        # Save Best Model (Based on Accuracy)
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), os.path.join(CONFIG['CHECKPOINT_DIR'], "swin_best.pth"))
            print("üíæ Saved Best Model (Best Accuracy)")
            
        # Early Stopping Logic (Based on Loss)
        if val_loss < best_loss:
            best_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"‚ö†Ô∏è Early Stopping Counter: {patience_counter}/{CONFIG['EARLY_STOPPING_PATIENCE']}")
            
        if patience_counter >= CONFIG['EARLY_STOPPING_PATIENCE']:
            print("üõë Early Stopping Triggered! Validation loss stopped improving.")
            break
            
    # ==========================================
    # VISUALIZATION & METRICS
    # ==========================================
    # A. Plot Loss & Accuracy
    plt.figure(figsize=(12, 5))

    # Loss
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title('Accuracy History')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['CHECKPOINT_DIR'], 'training_history.png'))
    plt.show()

    # B. Top-K Accuracy (Better for many classes)
    print("üîç Calculating Top-1 and Top-5 Accuracy...")
    model.load_state_dict(torch.load(os.path.join(CONFIG['CHECKPOINT_DIR'], "swin_best.pth")))
    model.eval()

    correct_1 = 0
    correct_5 = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            
            # Top-1
            _, pred_1 = outputs.topk(1, 1, True, True)
            pred_1 = pred_1.t()
            correct_1 += pred_1.eq(labels.view(1, -1).expand_as(pred_1)).sum().item()
            
            # Top-5
            max_k = min(5, num_classes) # Handle if classes < 5
            _, pred_5 = outputs.topk(max_k, 1, True, True)
            pred_5 = pred_5.t()
            correct_5 += pred_5.eq(labels.view(1, -1).expand_as(pred_5)).sum().item()
            
            total += labels.size(0)

    acc_1 = 100 * correct_1 / total
    acc_5 = 100 * correct_5 / total
    print(f"\nüèÜ Final Test Results:")
    print(f"   Top-1 Accuracy: {acc_1:.2f}%")
    print(f"   Top-5 Accuracy: {acc_5:.2f}%")

    # C. Visualize Predictions (5 Random Images)
    print("\nüñºÔ∏è Visualizing Predictions...")
    
    # Get a batch
    dataiter = iter(val_loader)
    images, labels = next(dataiter)
    images, labels = images.to(DEVICE), labels.to(DEVICE)
    
    # Predict
    outputs = model(images)
    _, preds = torch.max(outputs, 1)
    
    # Select 5 random indices
    indices = np.random.choice(len(images), 5, replace=False)
    
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    
    # Inverse Normalize for display
    mean = torch.tensor([0.485, 0.456, 0.406]).to(DEVICE).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).to(DEVICE).view(3, 1, 1)
    
    idx_to_class = {v: k for k, v in label_map.items()}

    for i, idx in enumerate(indices):
        img = images[idx]
        # Un-normalize
        img = img * std + mean
        img = torch.clamp(img, 0, 1)
        img = img.permute(1, 2, 0).cpu().numpy()
        
        true_label = idx_to_class[labels[idx].item()]
        pred_label = idx_to_class[preds[idx].item()]
        
        color = 'green' if true_label == pred_label else 'red'
        
        axes[i].imshow(img)
        axes[i].set_title(f"True: {true_label}\nPred: {pred_label}", color=color)
        axes[i].axis('off')
        
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['CHECKPOINT_DIR'], 'prediction_samples.png'))
    plt.show()

else:
    print("‚ö†Ô∏è No data found. Please check dataset path.")