# Sparse CIFAR-10 Reconstruction with DDO

This notebook demonstrates sparse image reconstruction:
- **Task**: Given 10% observed pixels (random locations), reconstruct the full image
- **Training**: Use another 10% pixels as query/target for supervision
- **Model**: Conditional DDO that takes observed pixels as context

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))
display(HTML("<style>.output_result { max-width:98% !important; }</style>"))

In [None]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

# Add parent directory to path
sys.path.insert(0, '..')

from utils import datasets
from utils.sparse_datasets import SparseImageDatasetWrapper, create_sparse_mask_image
from utils.visualize import get_grid_image

%matplotlib inline

## 1. Load and Visualize Sparse CIFAR-10 Dataset

In [None]:
# Load base CIFAR-10 dataset
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
])

base_dataset = torchvision.datasets.CIFAR10(
    root='../data',
    train=True,
    download=True,
    transform=transform
)

print(f"CIFAR-10 dataset loaded: {len(base_dataset)} images")

In [None]:
# Create sparse dataset wrapper
sparse_dataset = SparseImageDatasetWrapper(
    dataset=base_dataset,
    context_ratio=0.1,  # 10% observed pixels
    query_ratio=0.1,    # 10% query pixels for training
    mode='train',
    return_full_image=True  # For visualization
)

print(sparse_dataset)
print(f"\nContext points: {sparse_dataset.num_context} pixels")
print(f"Query points: {sparse_dataset.num_query} pixels")
print(f"Total pixels: {sparse_dataset.num_pixels} pixels")

## 2. Visualize Sparse Sampling

In [None]:
# Get a sample
sample = sparse_dataset[0]

# Original image
original_image = sample['image']

# Create masked visualization (only showing observed pixels)
masked_image = create_sparse_mask_image(
    original_image,
    sample['context_indices'],
    fill_value=0.5  # Gray for unobserved
)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(original_image.permute(1, 2, 0))
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(masked_image.permute(1, 2, 0))
axes[1].set_title(f'Observed Pixels (10% = {sparse_dataset.num_context} pixels)')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"Context coords shape: {sample['context_coords'].shape}")
print(f"Context values shape: {sample['context_values'].shape}")
print(f"Query coords shape: {sample['query_coords'].shape}")
print(f"Query values shape: {sample['query_values'].shape}")

## 3. Visualize Multiple Samples

In [None]:
# Visualize a batch of sparse samples
num_samples = 16
nrow = 4

originals = []
masked = []

for i in range(num_samples):
    sample = sparse_dataset[i]
    originals.append(sample['image'])
    masked.append(create_sparse_mask_image(
        sample['image'],
        sample['context_indices'],
        fill_value=0.5
    ))

originals = torch.stack(originals)
masked = torch.stack(masked)

fig, axes = plt.subplots(2, 1, figsize=(15, 8))

# Original images
grid_orig = get_grid_image(originals, nrow=nrow, pad_value=0, padding=2, to_numpy=True)
axes[0].imshow(grid_orig)
axes[0].set_title('Original Images')
axes[0].axis('off')

# Sparse observations
grid_masked = get_grid_image(masked, nrow=nrow, pad_value=0, padding=2, to_numpy=True)
axes[1].imshow(grid_masked)
axes[1].set_title('Sparse Observations (10% pixels)')
axes[1].axis('off')

plt.tight_layout()
plt.show()

## 4. Training Setup (Demonstration)

This shows how the data would be used in training. The actual training loop would:
1. Take context points as conditional input
2. Predict query point values
3. Compute loss on query points only

In [None]:
# Create data loader
from utils.sparse_datasets import collate_sparse_batch

train_loader = torch.utils.data.DataLoader(
    sparse_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_sparse_batch
)

# Get a batch
batch = next(iter(train_loader))

print("Batch contents:")
for key, val in batch.items():
    if isinstance(val, torch.Tensor):
        print(f"  {key}: {val.shape}")
    else:
        print(f"  {key}: {type(val)}")

print("\nTraining workflow:")
print("  1. Input: context_coords + context_values (observed pixels)")
print("  2. Model predicts: values at query_coords")
print("  3. Loss: MSE between predicted and query_values")
print("  4. At inference: predict all pixels given context")

## 5. Different Mask Patterns

In [None]:
from utils.sparse_datasets import GridMaskGenerator

# Get a sample image
sample_img, _ = base_dataset[0]
C, H, W = sample_img.shape
num_samples = int(H * W * 0.1)

# Generate different mask patterns
masks = {
    'Random': GridMaskGenerator.random_mask(H * W, num_samples),
    'Grid (stride=3)': GridMaskGenerator.grid_mask(H, W, stride=3),
    'Center': GridMaskGenerator.center_mask(H, W, num_samples),
}

# Visualize
fig, axes = plt.subplots(1, len(masks) + 1, figsize=(15, 4))

axes[0].imshow(sample_img.permute(1, 2, 0))
axes[0].set_title('Original')
axes[0].axis('off')

for idx, (name, mask) in enumerate(masks.items()):
    masked = create_sparse_mask_image(sample_img, mask, fill_value=0.5)
    axes[idx + 1].imshow(masked.permute(1, 2, 0))
    axes[idx + 1].set_title(f'{name}\n({len(mask)} pixels)')
    axes[idx + 1].axis('off')

plt.tight_layout()
plt.show()

## Notes on Training

To train a DDO model for sparse reconstruction:

1. **Model Architecture**: Use conditional DDO that takes context points as input
   - Context encoding: Process observed (coord, value) pairs
   - Diffusion: Apply to full image representation
   - Decoding: Generate values at all pixel locations

2. **Training Loss**: Compute on query points only
   ```python
   # Pseudo-code
   pred = model(context_coords, context_values, query_coords)
   loss = mse_loss(pred, query_values)
   ```

3. **Inference**: Predict all unobserved pixels
   ```python
   # Get all pixel coordinates
   all_coords = create_meshgrid(H, W)
   reconstructed = model(context_coords, context_values, all_coords)
   ```

4. **Key Hyperparameters**:
   - `context_ratio=0.1`: 10% observed pixels
   - `query_ratio=0.1`: 10% query pixels for training
   - Random sampling during training (different masks each iteration)
   - Fixed or grid sampling for evaluation

5. **Evaluation Metrics**:
   - PSNR on reconstructed vs. ground truth
   - SSIM for perceptual quality
   - MSE on unobserved regions