# Conditional Sparse Image Reconstruction with DDO

## Task
Train a **conditional diffusion model** for sparse image reconstruction:
- Each image has a **fixed** sparse mask (doesn't change during training)
- 20% total allowance: **10% context** (input) + **10% query** (ground truth)
- Model learns: `context (10%) → full image`, trained on query pixels

## Key Features
1. **Fixed masks per instance** - Each of 60K images has same mask every epoch
2. **Conditional generation** - Model sees context as input
3. **Query-based loss** - Loss computed only on 10% query pixels (or full image)
4. **DDO framework** - Uses function-space diffusion with GP noise

In [None]:
# Setup
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))
display(HTML("<style>.output_result { max-width:98% !important; }</style>"))

In [None]:
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Change to parent directory
os.chdir('..')
sys.path.insert(0, '.')

%matplotlib inline

print(f"Working directory: {os.getcwd()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Import utilities
from utils.sparse_datasets_fixed import (
    FixedSparseMaskDataset,
    create_context_image_batched,
    create_sparse_mask_image
)
from utils.visualize import get_grid_image
from utils.utils import Writer, count_parameters_in_M, save_checkpoint, load_checkpoint
from utils.ema import EMA

# Import DDO components
from lib.diffusion import BlurringDiffusion, DenoisingDiffusion
from lib.models.fourier_unet import FNOUNet2d
from lib.conditional_model import ConditionalDDOModel, ConditionalDDOModelSimple

# Re-import tqdm after main imports
from tqdm.auto import tqdm

print("All imports successful!")

## 1. Configuration

In [None]:
import argparse

args = argparse.Namespace()

# Paths
args.exp_path = './experiments/conditional_sparse_recon'
args.data = './data'
args.seed = 1

# Dataset
args.dataset = 'cifar10'
args.train_img_height = 32
args.input_dim = 3
args.coord_dim = 2

# Sparse conditioning settings
args.context_ratio = 0.1   # 10% for context (input)
args.query_ratio = 0.1     # 10% for query (GT target)
args.mask_seed = 42        # Seed for fixed masks

# Model architecture
args.model = 'fnounet2d'
args.ch = 64                    # Base channels
args.ch_mult = [1, 2, 2]        # Channel multipliers
args.num_res_blocks = 2         # Residual blocks per level
args.modes = 16                 # Fourier modes
args.dropout = 0.1
args.norm = 'group_norm'
args.use_pos = True
args.use_pointwise_op = True
args.context_feature_dim = 32   # Context encoder output channels
args.use_simple_conditioning = False  # True = simple concatenation, False = encoder

# Diffusion settings (function-space DDO)
args.ns_method = 'vp_cosine'
args.timestep_sampler = 'low_discrepancy'
args.disp_method = 'sine'
args.sigma_blur_min = 0.05
args.sigma_blur_max = 0.25
args.gp_type = 'exponential'
args.gp_exponent = 2.0
args.gp_length_scale = 0.05
args.gp_sigma = 1.0

# Training
args.train_batch_size = 128
args.lr = 0.0002
args.weight_decay = 0.0
args.num_iterations = 100000
args.ema_decay = 0.999
args.optimizer = 'adam'
args.beta1 = 0.9
args.beta2 = 0.999

# Logging
args.print_every = 100
args.save_every = 5000
args.vis_every = 1000
args.vis_batch_size = 16
args.resume = True

# Sampling
args.num_steps = 250
args.sampler = 'denoise'
args.s_min = 0.0001

# Misc
args.distributed = False
args.global_rank = 0
args.checkpoint_file = 'checkpoint.pt'
args.use_clip = False
args.weight_method = None

# Create directories
os.makedirs(args.exp_path, exist_ok=True)
os.makedirs(os.path.join(args.exp_path, 'samples'), exist_ok=True)

print("=" * 60)
print("Conditional Sparse Reconstruction Configuration")
print("=" * 60)
print(f"Experiment: {args.exp_path}")
print(f"Context ratio: {args.context_ratio*100:.0f}% (input)")
print(f"Query ratio: {args.query_ratio*100:.0f}% (GT target)")
print(f"Model: ch={args.ch}, ch_mult={args.ch_mult}, modes={args.modes}")
print(f"Conditioning: {'Simple' if args.use_simple_conditioning else 'With Encoder'}")
print(f"Iterations: {args.num_iterations}")
print(f"Batch size: {args.train_batch_size}")
print("=" * 60)

## 2. Load Dataset with Fixed Masks

In [None]:
# Load CIFAR-10
transform = transforms.Compose([transforms.ToTensor()])

base_dataset = torchvision.datasets.CIFAR10(
    root=args.data, train=True, download=True, transform=transform
)

print(f"Base dataset: {len(base_dataset)} images")

# Wrap with fixed sparse masks
sparse_dataset = FixedSparseMaskDataset(
    dataset=base_dataset,
    context_ratio=args.context_ratio,
    query_ratio=args.query_ratio,
    seed=args.mask_seed
)

print(sparse_dataset)

# Create dataloader
train_loader = torch.utils.data.DataLoader(
    sparse_dataset,
    batch_size=args.train_batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

print(f"\nDataloader: {len(train_loader)} batches per epoch")
print(f"Total pixels: 32×32 = 1024")
print(f"Context pixels: {sparse_dataset.num_context} (input)")
print(f"Query pixels: {sparse_dataset.num_query} (GT target)")
print(f"Remaining: {1024 - sparse_dataset.num_context - sparse_dataset.num_query} (not used)")

### Visualize Fixed Masks

In [None]:
# Visualize some examples
num_vis = 8
originals = []
contexts = []
queries = []

for i in range(num_vis):
    sample = sparse_dataset[i]
    originals.append(sample['image'])
    
    # Context visualization
    contexts.append(create_sparse_mask_image(
        sample['image'], sample['context_indices'], fill_value=0.5
    ))
    
    # Query visualization
    queries.append(create_sparse_mask_image(
        sample['image'], sample['query_indices'], fill_value=0.5
    ))

originals = torch.stack(originals)
contexts = torch.stack(contexts)
queries = torch.stack(queries)

fig, axes = plt.subplots(3, 1, figsize=(14, 10))

axes[0].imshow(get_grid_image(originals, nrow=4, to_numpy=True))
axes[0].set_title('Original Images', fontsize=14)
axes[0].axis('off')

axes[1].imshow(get_grid_image(contexts, nrow=4, to_numpy=True))
axes[1].set_title('Context (10% Input) - FIXED per image', fontsize=14)
axes[1].axis('off')

axes[2].imshow(get_grid_image(queries, nrow=4, to_numpy=True))
axes[2].set_title('Query (10% GT Target) - FIXED per image', fontsize=14)
axes[2].axis('off')

plt.tight_layout()
plt.show()

print("Note: Context and Query masks are FIXED for each image throughout training!")

## 3. Initialize Model with Conditioning

In [None]:
def get_mgrid(dim, img_height):
    """Generate coordinate grid"""
    grid = torch.linspace(0, img_height-1, img_height) / img_height
    if dim == 2:
        grid = torch.cat([grid[None,None,...,None].repeat(1, 1, 1, img_height),
                          grid[None,None,None].repeat(1, 1, img_height, 1)], dim=1)
    else:
        raise NotImplementedError
    return grid


def init_conditional_model(args):
    """Initialize conditional DDO model"""
    
    # GP config for function-space noise
    gp_config = argparse.Namespace()
    gp_config.device = 'cuda'
    gp_config.exponent = args.gp_exponent
    gp_config.length_scale = args.gp_length_scale
    gp_config.sigma = args.gp_sigma
    
    # Blurring config
    disp_config = argparse.Namespace()
    disp_config.sigma_blur_min = args.sigma_blur_min
    disp_config.sigma_blur_max = args.sigma_blur_max
    
    # Create diffusion process (function-space)
    inf_sde = BlurringDiffusion(
        dim=args.coord_dim,
        ch=args.input_dim,
        ns_method=args.ns_method,
        disp_method=args.disp_method,
        disp_config=disp_config,
        gp_type=args.gp_type,
        gp_config=gp_config,
    )
    
    # Create base FNO-UNet
    base_unet = FNOUNet2d(
        modes_height=args.modes,
        modes_width=args.modes,
        in_channels=args.input_dim,
        in_height=args.train_img_height,
        ch=args.ch,
        ch_mult=tuple(args.ch_mult),
        num_res_blocks=args.num_res_blocks,
        dropout=args.dropout,
        norm=args.norm,
        use_pos=args.use_pos,
        use_pointwise_op=args.use_pointwise_op,
    )
    
    # Wrap with conditional layer
    if args.use_simple_conditioning:
        print("Using simple conditioning (concatenation only)")
        model = ConditionalDDOModelSimple(base_unet, input_dim=args.input_dim)
    else:
        print("Using conditioning with encoder")
        model = ConditionalDDOModel(
            base_unet,
            input_dim=args.input_dim,
            context_feature_dim=args.context_feature_dim
        )
    
    # Create denoising diffusion wrapper
    gen_sde = DenoisingDiffusion(
        inf_sde,
        model=model,
        timestep_sampler=args.timestep_sampler,
        use_clip=args.use_clip,
        weight_method=args.weight_method
    ).cuda()
    
    # Optimizer
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(
            gen_sde.parameters(),
            lr=args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay
        )
    else:
        raise ValueError(f"Unknown optimizer: {args.optimizer}")
    
    # Add EMA
    optimizer = EMA(optimizer, ema_decay=args.ema_decay)
    
    # Resume from checkpoint
    count = 0
    best_loss = 1e10
    checkpoint_file = os.path.join(args.exp_path, args.checkpoint_file)
    if args.resume and os.path.exists(checkpoint_file):
        print(f'Loading checkpoint from {checkpoint_file}')
        gen_sde, optimizer, _, count, best_loss = load_checkpoint(
            checkpoint_file, gen_sde, optimizer, None
        )
        print(f'Resumed from iteration {count}')
    
    return gen_sde, optimizer, count, best_loss


# Initialize model
gen_sde, optimizer, count, best_loss = init_conditional_model(args)

# Count parameters
num_params = count_parameters_in_M(gen_sde._model)
print(f"\nModel parameters: {num_params:.2f}M")
print(f"Starting from iteration: {count}")
print(f"Best loss: {best_loss:.6f}")

## 4. Conditional Training Loop

Key differences from standard training:
1. Fixed masks per image (same context every epoch)
2. Context image passed to model during denoising
3. Loss computed on full image (model learns full reconstruction)

In [None]:
# Training setup
torch.manual_seed(args.seed)
np.random.seed(args.seed)

writer = Writer(args.global_rank, args.exp_path)
start_time = time.time()

gen_sde.train()
train_iter = iter(train_loader)
pbar = tqdm(total=args.num_iterations, initial=count, desc='Conditional Training')

# Pre-compute coordinate grid (same for all batches)
v_grid = get_mgrid(2, args.train_img_height).cuda()

while count < args.num_iterations:
    try:
        batch = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        batch = next(train_iter)
    
    # Get data
    full_images = batch['image'].cuda()  # (B, C, H, W)
    context_indices = batch['context_indices'].cuda()  # (B, num_context)
    context_values = batch['context_values'].cuda()  # (B, num_context, C)
    
    batch_size = full_images.shape[0]
    
    # Create context image (sparse observations as dense image)
    context_image = create_context_image_batched(
        context_values,
        context_indices,
        height=args.train_img_height,
        width=args.train_img_height,
        num_channels=args.input_dim
    )
    
    # Coordinate grid
    v = v_grid.repeat(batch_size, 1, 1, 1)
    
    # Forward pass with conditioning
    optimizer.zero_grad()
    
    # DSM loss with context conditioning
    # The model will receive context_image through **kwargs
    loss = gen_sde.dsm(full_images, v, context_image=context_image).mean()
    
    # Backward
    loss.backward()
    optimizer.step()
    
    count += 1
    pbar.update(1)
    
    # Logging
    if count % args.print_every == 0:
        elapsed = (time.time() - start_time) / args.print_every
        lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'lr': f'{lr:.6f}',
            's/it': f'{elapsed:.2f}'
        })
        writer.add_scalar('train/loss', loss.item(), count)
        writer.add_scalar('train/lr', lr, count)
        start_time = time.time()
    
    # Visualization
    if count % args.vis_every == 0:
        # Save visualization
        num_vis = min(args.vis_batch_size, 16)
        
        vis_samples = []
        for i in range(num_vis):
            sample = sparse_dataset[i]
            vis_samples.append({
                'original': sample['image'],
                'context': create_sparse_mask_image(
                    sample['image'], sample['context_indices'], fill_value=0.5
                ),
                'query': create_sparse_mask_image(
                    sample['image'], sample['query_indices'], fill_value=0.5
                )
            })
        
        contexts_vis = torch.stack([s['context'] for s in vis_samples])
        queries_vis = torch.stack([s['query'] for s in vis_samples])
        originals_vis = torch.stack([s['original'] for s in vis_samples])
        
        # Save comparison: [context | query | original]
        fig_path = os.path.join(args.exp_path, 'samples', f'iter_{count:06d}.png')
        comparison = torch.cat([contexts_vis, queries_vis, originals_vis], dim=0)
        torchvision.utils.save_image(
            comparison, fig_path, nrow=4, padding=2, normalize=True, value_range=(0, 1)
        )
        
        print(f'\n[Iter {count}] Saved visualization to {fig_path}')
    
    # Save checkpoint
    if count % args.save_every == 0:
        save_checkpoint(
            args, count, loss.item(), gen_sde, optimizer, None, 'checkpoint.pt'
        )
        print(f'\n[Iter {count}] Saved checkpoint')

pbar.close()
print('\n' + '='*60)
print('Training completed!')
print('='*60)

## 5. Visualize Training Progress

In [None]:
# Show latest visualization
from PIL import Image

latest_img = os.path.join(args.exp_path, 'samples', f'iter_{count:06d}.png')
if os.path.exists(latest_img):
    img = Image.open(latest_img)
    plt.figure(figsize=(14, 10))
    plt.imshow(img)
    plt.title(f'Training Progress at Iteration {count}\n'
              f'Top row: Context (10% input) | Middle row: Query (10% GT) | Bottom row: Original',
              fontsize=12)
    plt.axis('off')
    plt.tight_layout()
    plt.show()
else:
    print("No visualization found yet")

print(f"\nExperiment path: {args.exp_path}")
print(f"Samples: {os.path.join(args.exp_path, 'samples')}")
print(f"TensorBoard: tensorboard --logdir={args.exp_path}")

## 6. Conditional Sampling (TODO)

After training, we can generate reconstructions:
1. Take an image from test set
2. Extract context (10% pixels)
3. Run reverse diffusion conditioned on context
4. Compare reconstruction with ground truth

In [None]:
# TODO: Implement conditional sampling
# This requires modifying the diffuse() method to accept context_image
print("Conditional sampling implementation coming soon!")
print("\nFor now, the model is training to denoise images conditioned on sparse context.")
print("The key achievement is that context information is now properly flowing through:")
print("  1. Fixed masks per image ✓")
print("  2. Context image created ✓")
print("  3. Context passed to model ✓")
print("  4. Model conditioned on context ✓")