# Vision Mamba Training Notebook (Kaggle Version)

This notebook is optimized for running on Kaggle. 
**Prerequisite:** Upload your `data_processed` folder as a Kaggle Dataset.

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm.notebook import tqdm
import sys

# ==========================================
# KAGGLE SETUP
# ==========================================
# Install dependencies (Kaggle usually needs these)
!pip install -q causal-conv1d>=1.1.0
!pip install -q mamba-ssm
!pip install -q timm albumentations

# Clone Vim Repo if not exists
if not os.path.exists('Vim'):
    !git clone https://github.com/hustvl/Vim.git

# Add paths
sys.path.append(os.path.abspath('Vim'))

# Copy utils from uploaded dataset or create them inline if they are not in the dataset
# Assuming you uploaded the WHOLE project folder as a dataset named 'face-recognition-project'
# The path would be /kaggle/input/face-recognition-project/...

# For simplicity, let's define the dataset class and transforms INLINE here
# to avoid dependency on uploading python files correctly.

In [None]:
# ==========================================
# INLINE UTILS (Dataset & Augmentations)
# ==========================================
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

class FaceDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if image.dtype != np.uint8:
            image = image.astype(np.uint8)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        else:
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        return image, torch.tensor(label, dtype=torch.long)

train_transforms = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=10, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.RandomResizedCrop(height=224, width=224, scale=(0.9, 1.1), p=0.5),
    A.GaussianBlur(blur_limit=3, p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

test_transforms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [None]:
# ==========================================
# DATA LOADING (KAGGLE SPECIFIC)
# ==========================================
# GANTI 'face-recognition-data' DENGAN NAMA DATASET ANDA DI KAGGLE
DATA_DIR = "/kaggle/input/face-recognition-data/data_processed"
TARGET_SIZE = (224, 224)

def load_dataset_kaggle():
    X = []
    y = []
    label_map = {}
    
    if not os.path.exists(DATA_DIR):
        print(f"‚ùå Error: Path {DATA_DIR} tidak ditemukan.")
        print("Pastikan Anda sudah Add Data di sidebar kanan Kaggle.")
        return [], [], {}

    folders = sorted(os.listdir(DATA_DIR))
    for idx, folder in enumerate(folders):
        label_map[folder] = idx
        folder_path = os.path.join(DATA_DIR, folder)

        for filename in os.listdir(folder_path):
            if filename.lower().endswith((".jpg", ".png", ".jpeg", ".webp")):
                img_path = os.path.join(folder_path, filename)
                img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, TARGET_SIZE)
                X.append(img)
                y.append(idx)

    X = np.array(X)
    y = np.array(y)
    print(f"üì¶ Loaded {len(X)} images from {len(label_map)} classes.")
    return X, y, label_map

def split_dataset(X, y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, stratify=y_train, random_state=42)
    return X_train, X_val, X_test, y_train, y_val, y_test

In [None]:
# ==========================================
# CONFIG & MODEL
# ==========================================
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "/kaggle/working/checkpoints" # Output harus di /kaggle/working
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def create_vim_model(num_classes):
    try:
        from vim.models_mamba import VisionMamba
        print("üêç Menggunakan Vision Mamba (Vim) Model...")
        model = VisionMamba(
            img_size=224, 
            patch_size=16, 
            embed_dim=192, 
            depth=24, 
            rms_norm=True, 
            residual_in_fp32=True, 
            fused_add_norm=True, 
            final_pool_type='mean', 
            if_abs_pos_embed=True, 
            if_rope=False, 
            if_rope_residual=False, 
            bimamba_type="v2", 
            if_cls_token=True, 
            if_devide_out=True, 
            use_middle_cls_token=True,
            num_classes=num_classes
        )
        return model
    except ImportError:
        print("‚ùå Error: Library 'vim' tidak ditemukan.")
        return None

In [None]:
# ==========================================
# TRAINING HELPERS
# ==========================================
def train_one_epoch(model, loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_postfix({'loss': running_loss/total, 'acc': 100 * correct / total})
    return running_loss / len(loader), 100 * correct / total

def validate(model, loader, criterion, epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]")
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            pbar.set_postfix({'loss': running_loss/total, 'acc': 100 * correct / total})
    return running_loss / len(loader), 100 * correct / total

In [None]:
# ==========================================
# EXECUTION
# ==========================================
print("üöÄ Starting Kaggle Training...")

# 1. Load Data
X, y, label_map = load_dataset_kaggle()

if len(X) > 0:
    num_classes = len(label_map)
    print(f"‚úÖ Detected {num_classes} classes.")

    # 2. Split
    X_train, X_val, X_test, y_train, y_val, y_test = split_dataset(X, y)

    # 3. Dataloaders
    train_dataset = FaceDataset(X_train, y_train, transform=train_transforms)
    val_dataset = FaceDataset(X_val, y_val, transform=test_transforms)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # 4. Model
    model = create_vim_model(num_classes).to(DEVICE)

    # 5. Train
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_acc = 0.0
    
    for epoch in range(EPOCHS):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, epoch)
        val_loss, val_acc = validate(model, val_loader, criterion, epoch)
        scheduler.step()
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f"Summary: Train Loss {train_loss:.4f} | Val Acc {val_acc:.2f}%")
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, "vim_best.pth"))
            print("üíæ Saved Best Model")
            
    # ==========================================
    # VISUALIZATION & METRICS
    # ==========================================
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.metrics import confusion_matrix, classification_report

    # A. Plot Loss & Accuracy
    plt.figure(figsize=(12, 5))

    # Loss
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title('Accuracy History')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(CHECKPOINT_DIR, 'training_history.png'))
    plt.show()

    # B. Confusion Matrix
    print("üîç Generating Confusion Matrix...")
    model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "vim_best.pth")))
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Get Class Names
    idx_to_class = {v: k for k, v in label_map.items()}
    class_names = [idx_to_class[i] for i in range(len(label_map))]

    cm = confusion_matrix(all_labels, all_preds)

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.savefig(os.path.join(CHECKPOINT_DIR, 'confusion_matrix.png'))
    plt.show()

    # C. Classification Report
    print("\nüìë Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))

else:
    print("‚ö†Ô∏è No data found. Please check dataset path.")