# Vision Mamba Training Notebook

This notebook allows you to train the Vision Mamba (Vim) model for Face Recognition step-by-step.

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm.notebook import tqdm
import sys

# Tambahkan path lokal agar bisa import utils
sys.path.append(os.path.abspath('..')) # Asumsi notebook ada di folder 'notebooks'

from utils.dataloader import load_dataset, split_dataset
from utils.dataset import FaceDataset
from utils.augmentations import train_transforms, test_transforms

In [None]:
# ==========================================
# KONFIGURASI
# ==========================================
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "../checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Device: {DEVICE}")

In [None]:
# ==========================================
# SETUP MODEL (VISION MAMBA)
# ==========================================
def create_vim_model(num_classes):
    """
    Membuat model Vision Mamba.
    Asumsi: Repo 'Vim' sudah di-clone di Colab.
    """
    try:
        # Import dari repo Vim yang di-clone
        # Pastikan folder 'Vim' ada di path atau di-clone
        from vim.models_mamba import VisionMamba
        
        print("üêç Menggunakan Vision Mamba (Vim) Model...")
        
        # Konfigurasi standar Vim-Tiny atau Vim-Small (sesuaikan dengan VRAM)
        model = VisionMamba(
            img_size=224, 
            patch_size=16, 
            embed_dim=192,  # Vim-Tiny
            depth=24, 
            rms_norm=True, 
            residual_in_fp32=True, 
            fused_add_norm=True, 
            final_pool_type='mean', 
            if_abs_pos_embed=True, 
            if_rope=False, 
            if_rope_residual=False, 
            bimamba_type="v2", 
            if_cls_token=True, 
            if_devide_out=True, 
            use_middle_cls_token=True,
            num_classes=num_classes
        )
        return model
    except ImportError:
        print("‚ùå Error: Library 'vim' tidak ditemukan.")
        print("Pastikan Anda sudah clone repo Vim: 'git clone https://github.com/hustvl/Vim.git'")
        print("Dan install requirements-nya.")
        # sys.exit(1) # Tidak exit di notebook

In [None]:
# ==========================================
# TRAINING LOOP
# ==========================================
def train_one_epoch(model, loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        pbar.set_postfix({'loss': running_loss/total, 'acc': 100 * correct / total})
        
    return running_loss / len(loader), 100 * correct / total

def validate(model, loader, criterion, epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]")
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            pbar.set_postfix({'loss': running_loss/total, 'acc': 100 * correct / total})
            
    return running_loss / len(loader), 100 * correct / total

In [None]:
# 1. Load Data
print("üìÇ Loading Dataset...")
# Perlu sesuaikan path DATA_DIR di dataloader.py jika dijalankan dari notebooks/
# Atau kita set working directory ke root project dulu
if os.path.basename(os.getcwd()) == 'notebooks':
    os.chdir('..')
    print(f"Changed working directory to: {os.getcwd()}")

X, y, label_map = load_dataset()
num_classes = len(label_map)
print(f"‚úÖ Detected {num_classes} classes.")

In [None]:
# 2. Split Data
X_train, X_val, X_test, y_train, y_val, y_test = split_dataset(X, y)

# 3. Create Datasets & Dataloaders
train_dataset = FaceDataset(X_train, y_train, transform=train_transforms)
val_dataset = FaceDataset(X_val, y_val, transform=test_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [None]:
# 4. Initialize Model
model = create_vim_model(num_classes).to(DEVICE)

# 5. Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
# 6. Training Loop
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_acc = 0.0

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, epoch)
    val_loss, val_acc = validate(model, val_loader, criterion, epoch)
    
    scheduler.step()
    
    # Store history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"üìä Epoch {epoch+1} Summary:")
    print(f"   Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"   Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    
    # Save Best Model
    if val_acc > best_acc:
        best_acc = val_acc
        save_path = os.path.join(CHECKPOINT_DIR, "vim_best.pth")
        torch.save(model.state_dict(), save_path)
        print(f"üíæ Model saved to {save_path}")
        
print("üéâ Training Completed!")

In [None]:
# ==========================================
# VISUALIZATION & METRICS
# ==========================================
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

# A. Plot Loss & Accuracy
plt.figure(figsize=(12, 5))

# Loss
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.title('Loss History')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['val_acc'], label='Val Acc')
plt.title('Accuracy History')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_DIR, 'training_history.png'))
plt.show()

# B. Confusion Matrix
print("üîç Generating Confusion Matrix...")
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "vim_best.pth")))
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
# Get Class Names
# Reverse label_map: {0: 'Name', 1: 'Name'}
idx_to_class = {v: k for k, v in label_map.items()}
class_names = [idx_to_class[i] for i in range(len(label_map))]

cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.savefig(os.path.join(CHECKPOINT_DIR, 'confusion_matrix.png'))
plt.show()

# C. Classification Report
print("\nüìë Classification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))