# Toy Example for Diffusion

This notebook demonstrates how to train a diffusion model on toy 2D spiral datasets using the SciREX library.

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from flax import nnx as nn

from scirex.data import create_dataloader
from scirex.diffusion import ScheduleCosine, diffusion_loss, sample_ddim
from scirex.diffusion.helpers import TimeInputMLP
from scirex.training import Trainer

# Set up plotting style
plt.style.use("seaborn-v0_8-muted")
%matplotlib inline

# Part 1: 2D Spiral Diffusion

## 1. Data Generation

We create a 2D spiral dataset and normalize it.

In [None]:
def create_spiral_data(tmin=0, tmax=5 * jnp.pi, n_points: int = 100) -> jnp.ndarray:
    t = jnp.linspace(tmin, tmax, n_points)
    x = t * jnp.cos(t) / tmax
    y = t * jnp.sin(t) / tmax
    data = jnp.stack([x, y], axis=1)
    # Normalize to [-1, 1]
    data = data / jnp.abs(data).max()
    return data


train_data_2d = create_spiral_data(n_points=2000)
plt.figure(figsize=(6, 6))
plt.scatter(train_data_2d[:, 0], train_data_2d[:, 1], alpha=0.6, s=10)
plt.title("2D Spiral Training Data")
plt.axis("equal")
plt.grid(True, alpha=0.3)
plt.show()

## 2. Model Architecture

We use `TimeInputMLP` from `scirex.diffusion.helpers`, which handles time/sigma embeddings internally.

In [None]:
# Initialize TimeInputMLP
# dim=2 for 2D data
# hidden_dims controls the capacity of the MLP
rngs = nn.Rngs(0)
model_2d = TimeInputMLP(dim=2, output_dim=2, hidden_dims=(128, 128, 128), rngs=rngs)

In [None]:
# Create trainer for 2D spiral diffusion model
schedule = ScheduleCosine(N=1000)
optimizer = nn.Optimizer(model_2d, optax.adam(1e-3))

trainer = Trainer(model=model_2d, optimizer=optimizer, loss_fn=diffusion_loss(schedule), rngs=rngs)

## 3. Training and Sampling

We train the model and sample from it.  
**Note:** We set `worker_count=0` in `create_dataloader` to avoid multiprocessing issues in some interactive environments.

In [None]:
train_loader = create_dataloader(
    np.array(train_data_2d),
    batch_size=128,
    shuffle=True,
    seed=0,
    worker_count=0,  # Run in main process to avoid potential errors
)

In [None]:
# Train the diffusion model
trainer.train(train_loader=train_loader, n_epochs=100)

In [None]:
# Sampling
print("Sampling from model...")
sigmas = schedule.sample_sigmas(20)  # 20 steps

sampled_data = sample_ddim(model_2d, sigmas, batchsize=1000, shape=(2,), rng=nn.Rngs(42).sample)

# Plot results
plt.figure(figsize=(6, 6))
plt.scatter(sampled_data[:, 0], sampled_data[:, 1], alpha=0.6, s=10, label="Sampled")
plt.scatter(train_data_2d[:, 0], train_data_2d[:, 1], alpha=0.1, s=10, label="Real")
plt.title("2D Spiral Diffusion Generation")
plt.legend()
plt.axis("equal")
plt.show()