# Ratio Guidance (Lightweight) - Training Demo

This notebook demonstrates the complete training pipeline for ratio-guided diffusion on MNIST.

**Runtime:** Approximately 30-40 minutes on Colab GPU

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/foubari/ratio-guidance-light/blob/main/ratio_guidance_light_colab.ipynb)

## Setup

In [None]:
# Clone repository
!git clone https://github.com/foubari/ratio-guidance-light.git
%cd ratio-guidance-light

In [None]:
# Install dependencies
!pip install -q torch torchvision tqdm matplotlib numpy Pillow

In [None]:
# Verify installation
!python test_basic.py

## Step 1: Train Diffusion Models

Train two DDPM models - one on standard MNIST, one on 90Â° rotated MNIST.

**Time:** ~10 minutes each on Colab GPU

In [None]:
# Train standard MNIST diffusion model
!python src/train_diffusion.py \
    --dataset standard \
    --epochs 50 \
    --batch_size 128 \
    --lr 1e-4 \
    --device cuda

In [None]:
# Train rotated MNIST diffusion model
!python src/train_diffusion.py \
    --dataset rotated \
    --epochs 50 \
    --batch_size 128 \
    --lr 1e-4 \
    --device cuda

## Step 2: Train Ratio Estimator

Train the density-ratio estimator to enable guided sampling.

**Time:** ~5 minutes on Colab GPU

In [None]:
# Train ratio estimator with discriminator loss
!python src/train_ratio.py \
    --loss_type disc \
    --epochs 30 \
    --batch_size 128 \
    --lr 1e-4 \
    --device cuda

## Step 3: Generate Samples

### Unconditional Sampling

In [None]:
# Generate unconditional samples
!python src/sample.py \
    --dataset standard \
    --num_samples 16 \
    --device cuda

In [None]:
# Display unconditional samples
from PIL import Image
import matplotlib.pyplot as plt

img = Image.open('outputs/unconditional_standard.png')
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.axis('off')
plt.title('Unconditional Samples (Standard MNIST)')
plt.show()

### Guided Sampling

In [None]:
# Generate guided samples with different guidance scales
for scale in [0.5, 1.0, 2.0, 5.0]:
    print(f"\n=== Guidance Scale: {scale} ===")
    !python src/sample.py \
        --dataset standard \
        --num_samples 16 \
        --guided \
        --loss_type disc \
        --guidance_scale {scale} \
        --condition_dataset rotated \
        --device cuda

In [None]:
# Display guided samples at different scales
import glob
from PIL import Image
import matplotlib.pyplot as plt

guided_files = sorted(glob.glob('outputs/guided_*.png'))

fig, axes = plt.subplots(len(guided_files), 1, figsize=(15, 5*len(guided_files)))
if len(guided_files) == 1:
    axes = [axes]

for ax, file in zip(axes, guided_files):
    img = Image.open(file)
    ax.imshow(img)
    ax.axis('off')
    # Extract guidance scale from filename
    scale = file.split('scale')[-1].split('.')[0]
    ax.set_title(f'Guided Samples (scale={scale})', fontsize=14)

plt.tight_layout()
plt.show()

print("\nTop row: Condition (rotated)")
print("Bottom row: Generated (standard)")

## Step 4: Experiment

Try different configurations:

In [None]:
# Train ratio estimator with different loss function (DV)
!python src/train_ratio.py \
    --loss_type dv \
    --epochs 30 \
    --batch_size 128 \
    --device cuda

In [None]:
# Compare DV vs Discriminator
!python src/sample.py --dataset standard --num_samples 16 --guided --loss_type dv --guidance_scale 2.0 --device cuda

# Display comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

img_disc = Image.open('outputs/guided_standard_disc_scale2.0.png')
img_dv = Image.open('outputs/guided_standard_dv_scale2.0.png')

ax1.imshow(img_disc)
ax1.axis('off')
ax1.set_title('Discriminator Loss', fontsize=16)

ax2.imshow(img_dv)
ax2.axis('off')
ax2.set_title('DV Loss', fontsize=16)

plt.tight_layout()
plt.show()

## Analysis

### Expected Behavior

- **Unconditional samples**: Should generate recognizable MNIST digits
- **Guided samples (low scale)**: Weak correspondence with condition
- **Guided samples (medium scale ~2.0)**: Good correspondence
- **Guided samples (high scale >5.0)**: Very strong correspondence, possibly artifacts

### Understanding the Results

The paired images show:
- **Top**: Condition image (rotated MNIST)
- **Bottom**: Generated image (standard MNIST)

With good guidance, the generated digit should match the rotated condition's identity (e.g., if condition is rotated "5", generated should be standard "5").

## Download Results

In [None]:
# Create archive of results
!zip -r results.zip outputs/ checkpoints/

# Download
from google.colab import files
files.download('results.zip')