# NVAE Debug Experiment with WandB & Overfitting Check

This debug notebook tests:
1. **WandB Logging**: Ensures metrics and images are logged correctly.
2. **Overfitting**: Uses the training set as the validation set to verify the model can learn.
3. **Config**: 10 Epochs, Subset of data.

## 1. Google Colab Setup
Mount Drive and clone the repository (Fresh Copy).

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
import sys
import shutil

# --- CONFIGURATION ---
REPO_PATH = '/content/drive/MyDrive/Generative-Modeling-on-CIFAR-10'
REPO_URL = "https://github.com/konstantine25b/Generative-Modeling-on-CIFAR-10.git"

# 1. Delete repo if it already exists (Ensure fresh code)
if os.path.exists(REPO_PATH):
    print(f"Deleting existing repository at {REPO_PATH}...")
    shutil.rmtree(REPO_PATH)

# 2. Clone repository
os.chdir('/content/drive/MyDrive')
print(f"Cloning repository to {REPO_PATH}...")
!git clone {REPO_URL}

# 3. Enter the repository
os.chdir(REPO_PATH)
print(f"Current working directory: {os.getcwd()}")

# 4. Add source code to Python path
sys.path.append(os.path.join(REPO_PATH, 'src'))

## 2. GitHub & WandB Configuration
**Important:** You need your WandB API key here.

In [None]:
# GitHub Configuration & Setup
import os

try:
    # 1. Configure Git
    user_name = "konstantine25b"
    mail = "konstantine25b@gmail.com"

    # --- IMPORTANT: PASTE YOUR TOKEN BELOW ---
    my_token = "YOUR_TOKEN_HERE"

    if my_token == "YOUR_TOKEN_HERE":
        print("⚠️ PLEASE UPDATE 'my_token' in the code cell with your actual GitHub token to enable pushing.")

    repo_url = f"https://{my_token}@github.com/konstantine25b/Generative-Modeling-on-CIFAR-10.git"

    !git config --global user.name "{user_name}"
    !git config --global user.email "{mail}"

    # 2. Set Remote URL
    if os.path.isdir(".git") and my_token != "YOUR_TOKEN_HERE":
        !git remote set-url origin "{repo_url}"
        print("Git configured successfully for pushing.")
    else:
        print("Skipping remote setup (either not a git repo or token not set).")

except Exception as e:
    print(f"Error setting up GitHub: {e}")

## 3. Install Dependencies

In [None]:
!pip install -r requirements.txt
!pip install wandb -q

# Login to WandB
import wandb
wandb.login()

## 4. Setup WandB Test Experiment

In [None]:
import torch
from src.utils.data_loader import get_cifar10_loaders
from src.vae.train import train_vae
from src.vae.sampling import generate_samples, save_sample_grid
import matplotlib.pyplot as plt
import numpy as np
import torchvision

# WandB Test Configuration
config = {
    'epochs': 10,         # Increased to 10
    'batch_size': 64,
    'lr': 1e-3,
    'hidden_dim': 64,
    'latent_dim': 20,
    'num_scales': 2,
    'warmup_epochs': 3,
    'weight_decay': 3e-4,
    'use_wandb': True,    # ENABLED WandB
    'run_name': 'nvae_debug_wandb_test',
    'model_save_dir': 'models/debug_wandb',
    'results_dir': 'results/debug_wandb'
}

# Create directories
os.makedirs(config['model_save_dir'], exist_ok=True)
os.makedirs(config['results_dir'], exist_ok=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 5. Load Data (Train set used as Val set)
We use a subset of the training data for both training AND validation to check for overfitting.

In [None]:
from torch.utils.data import DataLoader, Subset

full_train_loader, _, _ = get_cifar10_loaders(
    data_dir='./data', 
    batch_size=config['batch_size']
)

# Create a small subset of training data (e.g., 2000 images)
def create_subset_loader(original_loader, size=2000):
    dataset = original_loader.dataset
    indices = list(range(size))
    subset = Subset(dataset, indices)
    
    return DataLoader(
        subset, 
        batch_size=original_loader.batch_size, 
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

# Create ONE subset loader
subset_loader = create_subset_loader(full_train_loader, size=5000)

# Use the SAME loader for both Train and Test to check overfitting capability
train_loader = subset_loader
val_loader = subset_loader 

print(f"Using same subset ({len(train_loader.dataset)} images) for Train and Val.")

## 6. Train Model (WandB Enabled)

In [None]:
# Start Training
train_vae(config, train_loader, val_loader, device)

## 7. Generate Samples & Log to WandB

In [None]:
# Load best model
from src.vae.model import NVAE
import torchvision.utils as vutils

model = NVAE(
    hidden_dim=config['hidden_dim'],
    latent_dim=config['latent_dim'],
    num_scales=config['num_scales']
).to(device)

model.load_state_dict(torch.load(os.path.join(config['model_save_dir'], 'nvae_best.pth')))
print("Loaded best model.")

# Generate
samples = generate_samples(model, num_samples=64, temperature=0.8, device=device)

# Visualize locally
plt.figure(figsize=(10, 10))
grid_img = vutils.make_grid(samples, nrow=8, normalize=True)
plt.imshow(grid_img.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title("Generated Samples (WandB Test)")
plt.show()

# Log to WandB if active
if wandb.run is not None:
    wandb.log({
        "final_evaluation/generated_samples_grid": [wandb.Image(grid_img, caption="Final Generated Samples (T=0.8)")]
    })
    print("Logged final samples to WandB.")
    wandb.finish()