# Real NVAE Architecture Experiment on CIFAR-10

This notebook implements the training and evaluation pipeline for the **Real NVAE Architecture** on CIFAR-10.

Unlike the previous simplified experiments, this version implements:
- **Deep Hierarchical Latent Space**: Multiple groups of latent variables per scale.
- **Residual Parameterization**: Posterior distributions are learned as residuals to the prior.
- **Cell-Based Architecture**: Encoder and Decoder built from repeating Residual Cells.
- **Separable Convolutions**: Efficient depthwise separable convolutions throughout.

This configuration is much closer to the official paper implementation.

## 1. Google Colab Setup
Mount Drive and clone the repository (Fresh Copy).

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
import sys
import shutil

# --- CONFIGURATION ---
REPO_PATH = '/content/drive/MyDrive/Generative-Modeling-on-CIFAR-10'
REPO_URL = "https://github.com/konstantine25b/Generative-Modeling-on-CIFAR-10.git"

# 1. Delete repo if it already exists (Ensure fresh code)
if os.path.exists(REPO_PATH):
    print(f"Deleting existing repository at {REPO_PATH}...")
    shutil.rmtree(REPO_PATH)

# 2. Clone repository
os.chdir('/content/drive/MyDrive')
print(f"Cloning repository to {REPO_PATH}...")
!git clone {REPO_URL}

# 3. Enter the repository
os.chdir(REPO_PATH)
print(f"Current working directory: {os.getcwd()}")

# 4. Add source code to Python path
sys.path.append(os.path.join(REPO_PATH, 'src'))

## 2. GitHub Configuration (Optional)
Configure this if you want to push results back to the repo.

In [None]:
# GitHub Configuration & Setup
import os

try:
    # 1. Configure Git
    user_name = "konstantine25b"
    mail = "konstantine25b@gmail.com"

    # --- IMPORTANT: PASTE YOUR TOKEN BELOW ---
    my_token = "YOUR_TOKEN_HERE"

    if my_token == "YOUR_TOKEN_HERE":
        print("⚠️ PLEASE UPDATE 'my_token' in the code cell with your actual GitHub token to enable pushing.")

    repo_url = f"https://{my_token}@github.com/konstantine25b/Generative-Modeling-on-CIFAR-10.git"

    !git config --global user.name "{user_name}"
    !git config --global user.email "{mail}"

    # 2. Set Remote URL
    if os.path.isdir(".git") and my_token != "YOUR_TOKEN_HERE":
        !git remote set-url origin "{repo_url}"
        print("Git configured successfully for pushing.")
    else:
        print("Skipping remote setup (either not a git repo or token not set).")

except Exception as e:
    print(f"Error setting up GitHub: {e}")

## 3. Install Dependencies

In [None]:
!pip install -r requirements.txt
!pip install wandb -q

import wandb
wandb.login()

## 4. Setup Experiment

In [None]:
import torch
from src.utils.data_loader import get_cifar10_loaders
from src.vae.train import train_vae
from src.vae.sampling import generate_samples, save_sample_grid
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import wandb

# Configuration for Real NVAE
config = {
    'epochs': 100,              # Training for longer due to increased depth
    'batch_size': 64,           # Reduced batch size to fit deeper model in memory
    'lr': 1e-3,                 # Slightly higher LR for AdamAX/AdamW with deep models often helps, but sticking to 1e-3 is safe
    'hidden_dim': 64,           # Base channel width
    'latent_dim': 20,           # Latent dimension per group
    'num_scales': 3,            # 32x32, 16x16, 8x8
    'warmup_epochs': 10,         # KL Warmup
    'weight_decay': 3e-4,
    'use_wandb': True,
    'run_name': 'nvae_real_arch',
    'model_save_dir': '/content/drive/MyDrive/Generative-Modeling-on-CIFAR-10-Checkpoints/nvae_real',
    'results_dir': 'results/'
}

# Create directories
os.makedirs(config['model_save_dir'], exist_ok=True)
os.makedirs(config['results_dir'], exist_ok=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
print(f"Checkpoints will be saved to: {config['model_save_dir']}")

## 5. Load Data

In [None]:
train_loader, val_loader, test_loader = get_cifar10_loaders(
    data_dir='./data',
    batch_size=config['batch_size']
)

## 6. Start Training

This will train the deep NVAE model. Expect this to take significantly longer per epoch than the simplified version.

In [None]:
# Start Training
train_vae(config, train_loader, val_loader, device)

## 7. Evaluation

After training, we evaluate using Importance Weighted Sampling (IWELBO) for a tight bound on the log-likelihood.

In [None]:
# Load best model
from src.vae.model import NVAE
from src.vae.train import evaluate_with_importance_sampling, evaluate

best_model = NVAE(
    hidden_dim=config['hidden_dim'], 
    latent_dim=config['latent_dim']
).to(device)

checkpoint_path = os.path.join(config['model_save_dir'], 'nvae_best.pth')
checkpoint = torch.load(checkpoint_path, map_location=device)

if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    best_model.load_state_dict(checkpoint['model_state_dict'])
else:
    best_model.load_state_dict(checkpoint)

print(f"Loaded best model from {checkpoint_path}")

# Standard Evaluation
loss, bpd = evaluate(best_model, test_loader, device)
print(f"Standard Test Set Results (ELBO) -> Loss: {loss:.4f} | BPD: {bpd:.4f}")

# Importance Weighted Evaluation
# k=100 provides a good balance between speed and accuracy for debugging
# Paper uses k=1000
iw_loss, iw_bpd = evaluate_with_importance_sampling(best_model, test_loader, device, k=100)
print(f"Importance Weighted Results (k=100) -> Loss: {iw_loss:.4f} | BPD: {iw_bpd:.4f}")

# Log final results to WandB
if config['use_wandb']:
    wandb.log({
        "test/iw_loss": iw_loss,
        "test/iw_bpd": iw_bpd
    })
    print("Logged test results to WandB.")