# Training DDO on CIFAR-10 Dataset

This notebook trains a DDO (Diffusion in Domain/Operator space) model on CIFAR-10 in pixel space.

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))
display(HTML("<style>.output_result { max-width:98% !important; }</style>"))

In [None]:
import os
os.environ["PROJ_DIR"] = '/PATH/TO/REPO/ddo'
os.environ["FID_DIR"] = '/PATH/TO/CACHE/FOLDER/fid-stats'
os.environ["EXP_PATH"] = '/PATH/TO/CACHE/FOLDER/exp'
%cd /PATH/TO/REPO/ddo

In [None]:
! nvidia-smi

## 1. Imports and Setup

In [None]:
import importlib
import sys
import functools
import math
import time
import argparse

import numpy as np
import torch
import torch.nn as nn

import matplotlib
import matplotlib.pyplot as plt

from tqdm.auto import tqdm  # Fixed: use tqdm.auto instead

MYBACKEND = plt.get_backend()
print(MYBACKEND)

%matplotlib inline

In [None]:
from main import *
from utils import datasets
from utils.visualize import get_grid_image
from utils.utils import Writer

# Re-import tqdm correctly in case main.py imports it differently
from tqdm.auto import tqdm

matplotlib.use(MYBACKEND)

## 2. Configuration

In [None]:
args = argparse.Namespace()

# seed
args.seed = 1
args.command_type = 'train'

# i/o paths
args.exp_path = os.path.join(os.getenv('EXP_PATH', './experiments'), 'cifar10_ddo_notebook')
args.data = os.path.join(os.getenv('SLURM_TMPDIR', './data'), 'data')
args.fid_dir = os.getenv('FID_DIR', './fid_stats/cifar10')
args.print_every = 100
args.save_every = 5000
args.ckpt_every = 10000
args.eval_every = 5000
args.vis_every = 1000
args.plot = False
args.resume = True
args.checkpoint_file = 'checkpoint.pt'

# optimization
args.train_batch_size = 128
args.vis_batch_size = 64
args.optimizer = 'adam'
args.lr = 0.0002
args.lr_rampup_kimg = 0
args.lr_scheduler = 'none'
args.ema_decay = 0.999
args.weight_decay = 0.
args.beta1 = 0.9
args.beta2 = 0.999
args.num_iterations = 100000  # Shorter for notebook demo

# dataset - CIFAR-10 is 32x32 RGB
args.train_img_height = 32
args.dataset = 'cifar10'
args.dequantize = False
args.transform = None  # No transform - work in pixel space
args.input_dim = 3  # RGB
args.coord_dim = 2
args.centered = False
args.interpolation = 'bilinear'
args.antialias = False

# model - REDUCED SIZE for CIFAR-10
args.model = 'fnounet2d'
args.modes = 16  # Reduced from 32
args.act = None
args.ch = 64  # Reduced from 128
args.ch_mult = (1,2,2)  # Reduced from (1,2,2,2) - one less level
args.num_res_blocks = 2  # Reduced from 4
args.dropout = 0.1
args.discard_resamp_with_conv = False
args.use_pointwise_op = True
args.use_radial = False
args.use_pos = True
args.norm = 'group_norm'

# diffusion forward process
args.timestep_sampler = 'low_discrepancy'
args.ns_method = 'vp_cosine'
args.disp_method = 'sine'
args.sigma_blur_min = 0.05
args.sigma_blur_max = 0.25

# gaussian process noise
args.gp_type = 'exponential'
args.gp_exponent = 2.0
args.gp_length_scale = 0.05
args.gp_sigma = 1.0
args.gp_modes = None

# sampling
args.num_steps = 250
args.s_min = 1e-4
args.sampler = 'denoise'
args.use_clip = False
args.weight_method = None

# evaluation
args.eval_img_height = 32
args.eval_batch_size = 256
args.eval_use_ema = True
args.eval_fid = True
args.eval_pr = False
args.eval_num_samples = 5000
args.eval_resize_mode = 'tensor'
args.eval_interpolation = 'bilinear'
args.eval_antialias = False
args.eval_cache = False

# distributed training (not used in notebook)
args.num_proc_node = 1
args.num_process_per_node = 1
args.node_rank = 0
args.local_rank = 0
args.global_rank = 0
args.global_size = 1
args.distributed = False
args.master_address = '127.0.0.1'
args.master_port = None

# batch sizes per GPU
args.train_batch_size_per_gpu = args.train_batch_size
args.eval_batch_size_per_gpu = args.eval_batch_size
args.batch_size_per_gpu = args.train_batch_size

# create experiment directory
os.makedirs(args.exp_path, exist_ok=True)

print(f"Experiment path: {args.exp_path}")
print(f"Data path: {args.data}")
print(f"Training for {args.num_iterations} iterations")
print(f"\nModel configuration:")
print(f"  Base channels: {args.ch}")
print(f"  Channel multipliers: {args.ch_mult}")
print(f"  Residual blocks: {args.num_res_blocks}")
print(f"  Fourier modes: {args.modes}")
print(f"  Expected params: ~10-50M (much smaller than before!)")

## 3. Load Data

In [None]:
# Load CIFAR-10 dataset
train_loader, valid_loader, num_classes = datasets.get_loaders_eval(
    dataset=args.dataset,
    root=args.data,
    distributed=args.distributed,
    batch_size=args.train_batch_size_per_gpu,
    centered=args.centered,
    num_workers=4,
)

print(f"Dataset: {args.dataset}")
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of classes: {num_classes}")

# Visualize some training samples
x_sample, _ = next(iter(train_loader))
nrow = 8
plt.figure(figsize=(12, 12))
plt.imshow(get_grid_image(x_sample[:nrow**2], nrow=nrow, pad_value=0, padding=2, to_numpy=True))
plt.title('CIFAR-10 Training Samples')
plt.axis('off')
plt.show()

## 4. Initialize Model

In [None]:
# Initialize model, optimizer, and scheduler
gen_sde, gen_sde_optimizer, gen_sde_scheduler, count, best_fid_score = init_model(args)

# Count parameters (use _model not model, since model is a method)
num_params = count_parameters_in_M(gen_sde._model)
print(f"Model parameters: {num_params:.2f}M")
print(f"Starting from iteration: {count}")
print(f"Best FID score: {best_fid_score}")

## 5. Training Loop

In [None]:
# Setup transforms
if args.transform == "center":
    args.forward = forward = to_center
    args.reverse = reverse = to_01_clip
elif args.transform == "sdf":
    args.forward = forward = x_to_image
    args.reverse = reverse = from_sdf_to_01_clip
elif args.transform == "logit":
    args.forward = forward = logit
    args.reverse = reverse = inverse_logit
else:
    args.forward = forward = identity
    args.reverse = reverse = identity

# Writer for tensorboard
writer = Writer(args.global_rank, args.exp_path)

num_iters_per_epoch = len(train_loader)

In [None]:
# Training loop
from tqdm.auto import tqdm  # Re-import here since main.py overwrites it
import torchvision

start_time = time.time()
epoch = count // num_iters_per_epoch

for (x, _) in tqdm(train_loader, desc=f"Epoch {epoch}"):
    if count >= args.num_iterations:
        break
    
    # Model to training mode
    gen_sde.train()
    gen_sde_optimizer.zero_grad()
    
    # Prepare batch
    if args.dequantize:
        x = x * 255 / 256 + torch.rand_like(x) / 256
    x = args.forward(x).cuda()
    
    # Get coordinate grid for function space
    v = get_mgrid(2, x.shape[-1]).repeat(x.shape[0], 1, 1, 1).cuda()
    
    # Compute DSM loss (Denoising Score Matching)
    loss = gen_sde.dsm(x, v).mean()
    
    # Backward pass
    loss.backward()
    gen_sde_optimizer.step()
    
    # Update learning rate
    if gen_sde_scheduler is not None:
        gen_sde_scheduler.step()
    
    count += 1
    
    # Logging
    if count % args.print_every == 0:
        elapsed = (time.time() - start_time) / args.print_every
        lr = gen_sde_optimizer.param_groups[0]['lr']
        print(f"Iter {count:6d} | Loss: {loss.item():.4f} | LR: {lr:.6f} | Time: {elapsed:.2f}s/iter")
        writer.add_scalar('train/loss', loss.item(), count)
        writer.add_scalar('train/lr', lr, count)
        start_time = time.time()
    
    # Visualization
    if count % args.vis_every == 0:
        gen_sde.eval()
        with torch.no_grad():
            # Generate samples
            sample = sample_image(
                gen_sde,
                batch_size=min(64, args.vis_batch_size),
                img_height=args.train_img_height,
                num_steps=args.num_steps,
                transform=None,
                clip=True,
                disable_tqdm=True,
                sampler=args.sampler
            )
            sample = args.reverse(sample)
            nrow = 8
            
            # Save to TensorBoard
            writer.add_image(
                'train/samples',
                get_grid_image(sample[:nrow**2].cpu(), nrow=nrow, pad_value=0, padding=2, to_numpy=False),
                count
            )
            
            # Also save as PNG file
            sample_dir = os.path.join(args.exp_path, 'samples')
            os.makedirs(sample_dir, exist_ok=True)
            torchvision.utils.save_image(
                sample[:nrow**2],
                os.path.join(sample_dir, f'iter_{count:06d}.png'),
                nrow=nrow,
                padding=2,
                normalize=True,
                value_range=(0, 1)
            )
            
        writer.flush()
        print(f"Visualized samples at iter {count}")
        print(f"  - TensorBoard: {args.exp_path}/tensorboard/")
        print(f"  - PNG file: {sample_dir}/iter_{count:06d}.png")
    
    # Save checkpoint
    if count % args.save_every == 0:
        if args.global_rank == 0:
            save_checkpoint(
                os.path.join(args.exp_path, 'checkpoint.pt'),
                gen_sde,
                gen_sde_optimizer,
                gen_sde_scheduler,
                count,
                best_fid_score
            )
            print(f"Saved checkpoint at iter {count}")

print(f"Training completed at iteration {count}")

## 6. Generate Final Samples

In [None]:
# Generate final samples
gen_sde.eval()

# Switch to EMA parameters if available
if args.eval_use_ema and hasattr(gen_sde_optimizer, 'swap_parameters_with_ema'):
    gen_sde_optimizer.swap_parameters_with_ema(store_params_in_ema=True)

num_samples = 64
with torch.no_grad():
    sample = sample_image(
        gen_sde,
        batch_size=num_samples,
        img_height=32,
        num_steps=args.num_steps,
        transform=None,
        clip=True,
        disable_tqdm=False,
        sampler=args.sampler
    )
    sample = args.reverse(sample)

# Switch back to original parameters
if args.eval_use_ema and hasattr(gen_sde_optimizer, 'swap_parameters_with_ema'):
    gen_sde_optimizer.swap_parameters_with_ema(store_params_in_ema=True)

# Visualize
nrow = 8
image = get_grid_image(sample[:nrow**2].cpu(), nrow=nrow, pad_value=0, padding=2, to_numpy=True)
plt.figure(figsize=(12, 12))
plt.imshow(image)
plt.title(f'Generated CIFAR-10 Samples (Iteration {count})')
plt.axis('off')
plt.savefig(os.path.join(args.exp_path, f'samples_iter_{count}.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Generated {num_samples} samples")