# Sparse Image Reconstruction Training with DDO

Complete training pipeline for conditional image reconstruction:
- **Input**: 10% randomly observed pixels
- **Output**: Full reconstructed 32×32 RGB image
- **Model**: Conditional DDO with sparse context encoding

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))
display(HTML("<style>.output_result { max-width:98% !important; }</style>"))

In [None]:
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Change to parent directory
os.chdir('..')
sys.path.insert(0, '.')

%matplotlib inline

print(f"Working directory: {os.getcwd()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Import after main.py
from utils.sparse_datasets import SparseImageDatasetWrapper, create_sparse_mask_image
from utils.visualize import get_grid_image
from utils.utils import Writer, count_parameters_in_M
from main_sparse_reconstruction import (
    init_model, get_mgrid, create_context_conditioning, get_args
)

# Re-import tqdm
from tqdm.auto import tqdm

## 1. Configuration

In [None]:
# Create configuration
import argparse

args = argparse.Namespace()

# Paths
args.exp_path = './experiments/sparse_recon_notebook'
args.data = './data'
args.seed = 1
args.command_type = 'train'

# Dataset
args.dataset = 'cifar10'
args.train_img_height = 32
args.input_dim = 3
args.coord_dim = 2

# Sparse settings
args.context_ratio = 0.1  # 10% observed
args.query_ratio = 0.1    # 10% query for training

# Model (reduced size)
args.model = 'fnounet2d'
args.ch = 64
args.ch_mult = [1, 2, 2]
args.num_res_blocks = 2
args.modes = 16
args.dropout = 0.1
args.norm = 'group_norm'
args.use_pos = True
args.use_pointwise_op = True

# Diffusion
args.ns_method = 'vp_cosine'
args.timestep_sampler = 'low_discrepancy'
args.disp_method = 'sine'
args.sigma_blur_min = 0.05
args.sigma_blur_max = 0.25
args.gp_type = 'exponential'
args.gp_exponent = 2.0
args.gp_length_scale = 0.05
args.gp_sigma = 1.0

# Training
args.train_batch_size = 64  # Smaller for notebook
args.lr = 0.0002
args.weight_decay = 0.0
args.num_iterations = 10000  # Shorter for demo
args.ema_decay = 0.999
args.optimizer = 'adam'
args.beta1 = 0.9
args.beta2 = 0.999

# Logging
args.print_every = 50
args.save_every = 2000
args.vis_every = 500
args.eval_every = 5000
args.vis_batch_size = 16
args.resume = True

# Sampling
args.num_steps = 250
args.sampler = 'denoise'
args.s_min = 0.0001

# Misc
args.distributed = False
args.global_rank = 0
args.checkpoint_file = 'checkpoint.pt'
args.use_clip = False
args.weight_method = None

# Create directories
os.makedirs(args.exp_path, exist_ok=True)
os.makedirs(os.path.join(args.exp_path, 'samples'), exist_ok=True)

print("Configuration:")
print(f"  Experiment: {args.exp_path}")
print(f"  Context ratio: {args.context_ratio*100:.0f}%")
print(f"  Model: ch={args.ch}, ch_mult={args.ch_mult}")
print(f"  Iterations: {args.num_iterations}")

## 2. Load Dataset

In [None]:
# Load CIFAR-10
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor()])

base_dataset = torchvision.datasets.CIFAR10(
    root=args.data, train=True, download=True, transform=transform
)

# Wrap with sparse dataset
sparse_dataset = SparseImageDatasetWrapper(
    dataset=base_dataset,
    context_ratio=args.context_ratio,
    query_ratio=args.query_ratio,
    mode='train',
    return_full_image=True
)

# Create dataloader
train_loader = torch.utils.data.DataLoader(
    sparse_dataset,
    batch_size=args.train_batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

print(f"Dataset: {len(sparse_dataset)} images")
print(f"Context pixels: {sparse_dataset.num_context}")
print(f"Query pixels: {sparse_dataset.num_query}")
print(f"Batches per epoch: {len(train_loader)}")

In [None]:
# Visualize some examples
num_vis = 8
originals = []
masked = []

for i in range(num_vis):
    sample = sparse_dataset[i]
    originals.append(sample['image'])
    masked.append(create_sparse_mask_image(
        sample['image'], sample['context_indices'], fill_value=0.5
    ))

originals = torch.stack(originals)
masked = torch.stack(masked)

fig, axes = plt.subplots(2, 1, figsize=(12, 6))

axes[0].imshow(get_grid_image(originals, nrow=4, to_numpy=True))
axes[0].set_title('Original Images')
axes[0].axis('off')

axes[1].imshow(get_grid_image(masked, nrow=4, to_numpy=True))
axes[1].set_title('10% Observed Pixels (Sparse Input)')
axes[1].axis('off')

plt.tight_layout()
plt.show()

## 3. Initialize Model

In [None]:
# Initialize model
gen_sde, optimizer, count, best_loss = init_model(args)

# Count parameters
num_params = count_parameters_in_M(gen_sde._model)
print(f"Model parameters: {num_params:.2f}M")
print(f"Starting from iteration: {count}")
print(f"Best loss: {best_loss:.4f}")

## 4. Training Loop

In [None]:
# Training
torch.manual_seed(args.seed)
np.random.seed(args.seed)

writer = Writer(args.global_rank, args.exp_path)
start_time = time.time()

gen_sde.train()
train_iter = iter(train_loader)
pbar = tqdm(total=args.num_iterations, initial=count, desc='Training')

while count < args.num_iterations:
    try:
        batch = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        batch = next(train_iter)

    # Get data
    full_images = batch['image'].cuda()
    context_coords = batch['context_coords'].cuda()
    context_values = batch['context_values'].cuda()

    # Get coordinate grid
    v = get_mgrid(2, args.train_img_height).repeat(full_images.shape[0], 1, 1, 1).cuda()

    # Forward pass
    optimizer.zero_grad()
    loss = gen_sde.dsm(full_images, v).mean()

    # Backward
    loss.backward()
    optimizer.step()

    count += 1
    pbar.update(1)

    # Logging
    if count % args.print_every == 0:
        elapsed = (time.time() - start_time) / args.print_every
        lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{lr:.6f}'})
        writer.add_scalar('train/loss', loss.item(), count)
        start_time = time.time()

    # Visualization
    if count % args.vis_every == 0:
        # Save checkpoint
        from utils.utils import save_checkpoint
        
        save_checkpoint(
            args, count, loss.item(), gen_sde, optimizer, None, 'checkpoint.pt'
        )
        
        # Visualize
        num_vis = 8
        originals_vis = []
        contexts_vis = []

        for i in range(num_vis):
            sample = sparse_dataset[i]
            originals_vis.append(sample['image'])
            contexts_vis.append(create_sparse_mask_image(
                sample['image'], sample['context_indices'], fill_value=0.5
            ))

        originals_vis = torch.stack(originals_vis)
        contexts_vis = torch.stack(contexts_vis)

        # Save visualization
        fig_path = os.path.join(args.exp_path, 'samples', f'iter_{count:06d}.png')
        comparison = torch.cat([contexts_vis, originals_vis], dim=0)
        torchvision.utils.save_image(
            comparison, fig_path, nrow=4, padding=2, normalize=True
        )

pbar.close()
print('Training completed!')

## 5. Results

In [None]:
# Show latest reconstruction
from PIL import Image

latest_img = os.path.join(args.exp_path, 'samples', f'iter_{count:06d}.png')
if os.path.exists(latest_img):
    img = Image.open(latest_img)
    plt.figure(figsize=(12, 6))
    plt.imshow(img)
    plt.title(f'Results at iteration {count}')
    plt.axis('off')
    plt.show()
else:
    print("No visualization found yet")

print(f"\nExperiment path: {args.exp_path}")
print(f"Samples: {os.path.join(args.exp_path, 'samples')}")
print(f"TensorBoard: tensorboard --logdir={args.exp_path}")