### Local kernel check
Open this notebook **from the project root folder** (the folder that contains `data/`, `utils/`, `Vim/`). Then **Run All**. The Drive cell will no-op locally. For a quick sanity run, set `config["epochs"] = 2` in the config cell before running.

# Vision Mamba (Vim) Training on Colab

This notebook trains a Vision Mamba model on CIFAR-100 using Google Colab's free GPU.

## Setup Instructions

1. **Upload your project**: Upload the entire `Vim` folder to Colab (or clone from GitHub)
2. **Run all cells**: Execute cells sequentially
3. **Monitor training**: Check TensorBoard logs or print statements
4. **Download checkpoints**: Save model checkpoints to Google Drive or download locally

## 1. Install Dependencies

In [None]:
!pip install torch torchvision timm einops tqdm matplotlib tensorboard -q

### Optional: Mount Google Drive

Run this if you want to save checkpoints or use data from Drive. Then set `config["output_dir"]` and/or `config["data_dir"]` to paths under `/content/drive/MyDrive/...`.

In [None]:
# Only runs on Colab; skip on local kernel
try:
    from google.colab import drive
    drive.mount("/content/drive")
except Exception:
    print("Not in Colab — skipping Drive mount. (Optional: use local paths for data/checkpoints.)")

In [None]:
# Colab: run this first to clone the repo (skip if you already uploaded the project)
# Local: skip this cell or it will no-op (no /content).
import os
if os.path.exists("/content"):  # Colab
    if not os.path.exists("/content/Vim/data") or not os.path.exists("/content/Vim/Vim"):
        import subprocess
        subprocess.run(["git", "clone", "https://github.com/ns-1456/Vim.git", "/content/Vim"], check=False)
    if os.path.exists("/content/Vim"):
        os.chdir("/content/Vim")
print("Working directory:", os.getcwd())
print("Has data/:", os.path.exists("data"), "| Has Vim/:", os.path.exists("Vim") or os.path.exists("vim"))

## 2. Setup Project Structure

**Option A**: If you uploaded files, make sure the project structure is correct.

**Option B**: If cloning from GitHub, uncomment and modify the cell below.

In [None]:
# If you uploaded the Vim folder to Colab, it may be at /content/Vim.
# If you cloned a repo, %cd into the folder that contains vim/, data/, utils/.

import os
import sys

# Ensure project root is on path (folder containing Vim/, data/, utils/)
def _find_project_root():
    if os.path.exists("data") and (os.path.exists("vim") or os.path.exists("Vim")):
        return os.getcwd()
    for path in ["/content/Vim", "/content", os.getcwd()]:
        data_ok = os.path.exists(os.path.join(path, "data"))
        vim_ok = os.path.exists(os.path.join(path, "vim")) or os.path.exists(os.path.join(path, "Vim"))
        if data_ok and vim_ok:
            return path
    return os.getcwd()

PROJECT_ROOT = _find_project_root()
os.chdir(PROJECT_ROOT)
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
print(f"Project root: {PROJECT_ROOT}")
print(f"Contents: {os.listdir(PROJECT_ROOT)}")

Project root: /content
Contents: ['.config', 'sample_data']


## 3. Import Libraries and Setup

In [None]:
import os
import sys
import random
from pathlib import Path
from typing import Dict

# Ensure project root is on path (run "Setup Project Structure" cell first; this fixes it if run out of order)
def _find_root():
    if os.path.exists("data") and (os.path.exists("vim") or os.path.exists("Vim")):
        return os.getcwd()
    for path in ["/content/Vim", "/content", os.getcwd()]:
        if os.path.exists(os.path.join(path, "data")) and (
            os.path.exists(os.path.join(path, "vim")) or os.path.exists(os.path.join(path, "Vim"))
        ):
            return path
    return os.getcwd()
_root = _find_root()
os.chdir(_root)
if _root not in sys.path:
    sys.path.insert(0, _root)

import numpy as np
import torch
import torch.nn as nn
from torch.cuda import amp
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import matplotlib.pyplot as plt

from data import build_cifar100_dataloaders
from utils import AverageMeter, accuracy
try:
    from Vim import vim_tiny_cifar100
except ImportError:
    from Vim.vim import vim_tiny_cifar100

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")

ModuleNotFoundError: No module named 'data'

## 4. Configuration

In [None]:
# Training configuration (tune for Colab: smaller batch_size if OOM)
config = {
    "data_dir": "./data",  # CIFAR-100 will download here
    "epochs": 200,         # Reduce (e.g. 10) for a quick test run
    "batch_size": 128,     # Reduce to 64 if GPU OOM
    "lr": 1e-3,
    "weight_decay": 0.05,
    "warmup_epochs": 5,
    "img_size": 32,
    "num_workers": 2,      # Colab-friendly
    "seed": 42,
    "grad_clip": 1.0,
    "log_dir": "./runs/vim_cifar100",
    "output_dir": "./checkpoints",
}

# Optional: use Google Drive for data and checkpoints (run Drive mount cell first)
# config["data_dir"] = "/content/drive/MyDrive/vim_data"
# config["output_dir"] = "/content/drive/MyDrive/vim_checkpoints"

# Create output directories
Path(config["output_dir"]).mkdir(parents=True, exist_ok=True)
Path(config["log_dir"]).mkdir(parents=True, exist_ok=True)

print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

## 5. Set Random Seed

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(config["seed"])

## 6. Load Data

In [None]:
print("Loading CIFAR-100 dataset...")
train_loader, val_loader = build_cifar100_dataloaders(
    data_dir=config["data_dir"],
    img_size=config["img_size"],
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Batch size: {config['batch_size']}")

## 7. Create Model

In [None]:
model = vim_tiny_cifar100(img_size=config["img_size"])
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params / 1e6:.2f}M")
print(f"Trainable parameters: {trainable_params / 1e6:.2f}M")

# Test forward pass
with torch.no_grad():
    dummy_input = torch.randn(1, 3, config["img_size"], config["img_size"]).to(device)
    dummy_output = model(dummy_input)
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {dummy_output.shape}")

## 8. Setup Training Components

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Optimizer
optimizer = AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

# Scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=config["epochs"] - config["warmup_epochs"])

# Mixed precision scaler
scaler = amp.GradScaler()

# TensorBoard writer
writer = SummaryWriter(log_dir=config["log_dir"])

print("Training components initialized.")

## 9. Training Functions

In [None]:
def train_one_epoch(
    model, criterion, optimizer, scaler, dataloader, device, epoch
) -> Dict[str, float]:
    model.train()
    loss_meter = AverageMeter("loss")
    acc1_meter = AverageMeter("acc1")

    pbar = tqdm(dataloader, desc=f"Epoch {epoch} [train]")
    for images, targets in pbar:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config["grad_clip"])
        scaler.step(optimizer)
        scaler.update()

        acc1, = accuracy(outputs.detach(), targets, topk=(1,))
        loss_meter.update(loss.item(), images.size(0))
        acc1_meter.update(acc1.item(), images.size(0))

        pbar.set_postfix(loss=loss_meter.avg, acc1=acc1_meter.avg)

    writer.add_scalar("train/loss", loss_meter.avg, epoch)
    writer.add_scalar("train/acc1", acc1_meter.avg, epoch)

    return {"loss": loss_meter.avg, "acc1": acc1_meter.avg}


def validate(model, criterion, dataloader, device, epoch) -> Dict[str, float]:
    model.eval()
    loss_meter = AverageMeter("loss")
    acc1_meter = AverageMeter("acc1")
    acc5_meter = AverageMeter("acc5")

    with torch.no_grad():
        pbar = tqdm(dataloader, desc=f"Epoch {epoch} [val]")
        for images, targets in pbar:
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, targets)

            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))

            loss_meter.update(loss.item(), images.size(0))
            acc1_meter.update(acc1.item(), images.size(0))
            acc5_meter.update(acc5.item(), images.size(0))

            pbar.set_postfix(loss=loss_meter.avg, acc1=acc1_meter.avg, acc5=acc5_meter.avg)

    writer.add_scalar("val/loss", loss_meter.avg, epoch)
    writer.add_scalar("val/acc1", acc1_meter.avg, epoch)
    writer.add_scalar("val/acc5", acc5_meter.avg, epoch)

    return {"loss": loss_meter.avg, "acc1": acc1_meter.avg, "acc5": acc5_meter.avg}

## 10. Training Loop

In [None]:
best_acc1 = 0.0
epochs_no_improve = 0
patience = 20

train_history = {"loss": [], "acc1": []}
val_history = {"loss": [], "acc1": [], "acc5": []}

print("Starting training...")
print(f"Total epochs: {config['epochs']}")
print("-" * 60)

for epoch in range(1, config["epochs"] + 1):
    # Learning rate warmup
    if epoch <= config["warmup_epochs"]:
        warmup_factor = epoch / float(max(1, config["warmup_epochs"]))
        for param_group in optimizer.param_groups:
            param_group["lr"] = config["lr"] * warmup_factor
    else:
        scheduler.step()

    # Train
    train_stats = train_one_epoch(
        model, criterion, optimizer, scaler, train_loader, device, epoch
    )
    
    # Validate
    val_stats = validate(model, criterion, val_loader, device, epoch)

    # Save history
    train_history["loss"].append(train_stats["loss"])
    train_history["acc1"].append(train_stats["acc1"])
    val_history["loss"].append(val_stats["loss"])
    val_history["acc1"].append(val_stats["acc1"])
    val_history["acc5"].append(val_stats["acc5"])

    # Checkpointing
    is_best = val_stats["acc1"] > best_acc1
    if is_best:
        best_acc1 = val_stats["acc1"]
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    ckpt = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scaler_state": scaler.state_dict(),
        "best_acc1": best_acc1,
        "config": config,
    }
    
    # Save checkpoint
    ckpt_path = Path(config["output_dir"]) / f"checkpoint_{epoch:03d}.pth"
    torch.save(ckpt, ckpt_path)
    
    if is_best:
        best_path = Path(config["output_dir"]) / "checkpoint_best.pth"
        torch.save(ckpt, best_path)
        print(f"\n✓ New best model! Val Acc1: {best_acc1:.2f}%")

    # Print epoch summary
    print(f"\nEpoch {epoch}/{config['epochs']}:")
    print(f"  Train Loss: {train_stats['loss']:.4f}, Train Acc1: {train_stats['acc1']:.2f}%")
    print(f"  Val Loss: {val_stats['loss']:.4f}, Val Acc1: {val_stats['acc1']:.2f}%, Val Acc5: {val_stats['acc5']:.2f}%")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")
    print("-" * 60)

    # Early stopping
    if epochs_no_improve >= patience:
        print(f"\nEarly stopping triggered after {epoch} epochs (no improvement for {patience} epochs).")
        break

writer.close()
print("\nTraining completed!")

## 11. Plot Training Curves

In [None]:
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_history["loss"], label="Train")
plt.plot(val_history["loss"], label="Val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(train_history["acc1"], label="Train@1")
plt.plot(val_history["acc1"], label="Val@1")
plt.plot(val_history["acc5"], label="Val@5")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.title("Training and Validation Accuracy")
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("training_curves.png", dpi=150)
plt.show()

print(f"Best validation accuracy: {best_acc1:.2f}%")

## 12. TensorBoard in Colab

Run the cell below to launch TensorBoard. Use the link shown to view training curves.

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./runs --port 6006

## 13. Save to Google Drive (Optional)

Run the cell below to mount Drive and copy checkpoints (or set config paths earlier).

In [None]:
# from google.colab import drive
# import shutil
# 
# drive.mount('/content/drive')
# 
# # Copy checkpoints to Drive
# drive_path = '/content/drive/MyDrive/vim_checkpoints'
# shutil.copytree(config['output_dir'], drive_path, dirs_exist_ok=True)
# print(f"Checkpoints saved to {drive_path}")

## 13. Load Best Model and Evaluate (Optional)

In [None]:
# Load best checkpoint and run final evaluation (no TensorBoard logging)
best_ckpt_path = Path(config["output_dir"]) / "checkpoint_best.pth"
if best_ckpt_path.exists():
    ckpt = torch.load(best_ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state"])
    print(f"Loaded best model from epoch {ckpt['epoch']} with Acc1: {ckpt['best_acc1']:.2f}%")

    model.eval()
    loss_meter = AverageMeter("loss")
    acc1_meter = AverageMeter("acc1")
    acc5_meter = AverageMeter("acc5")
    with torch.no_grad():
        for images, targets in tqdm(val_loader, desc="Evaluating"):
            images, targets = images.to(device), targets.to(device)
            outputs = model(images)
            loss_meter.update(criterion(outputs, targets).item(), images.size(0))
            a1, a5 = accuracy(outputs, targets, topk=(1, 5))
            acc1_meter.update(a1.item(), images.size(0))
            acc5_meter.update(a5.item(), images.size(0))
    print(f"\nFinal validation: Loss {loss_meter.avg:.4f}, Top-1 {acc1_meter.avg:.2f}%, Top-5 {acc5_meter.avg:.2f}%")
else:
    print("No checkpoint_best.pth found. Train first or check output_dir.")

---
**Quick test**: Set `config["epochs"] = 10` in the config cell to run a short sanity check before full training.