# Training Diffusion Priors on Weather/Ocean Datasets

This notebook demonstrates how to train an unconditional diffusion model on weather reanalysis data using the `diffusers` library. We'll use NCEP reanalysis 2-meter air temperature data as an example, but the approach generalizes to other weather/ocean variables.

## Overview

1. **Data Loading**: Download and preprocess NCEP reanalysis data
2. **Dataset Preparation**: Create PyTorch datasets with proper normalization
3. **Model Setup**: Configure UNet2D model and noise scheduler
4. **Training**: Train the diffusion model to learn the data distribution
5. **Sampling**: Generate new samples from the trained model
6. **Visualization**: Examine training progress and generated samples


## 1. Installation and Imports


In [None]:
# Install required packages if needed
# Uncomment the lines below if running in a fresh environment
# !pip install torch torchvision diffusers xarray netcdf4 matplotlib numpy tqdm


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import xarray as xr
from tqdm import tqdm

from diffusers import UNet2DModel, HeunDiscreteScheduler

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


## 2. Download and Load Weather Data

We'll use NCEP reanalysis daily 2-meter air temperature data. This is publicly available from NOAA PSL.

**Note**: You can modify the region selection (`lat`/`lon` slices) to focus on your area of interest (e.g., US, Europe).


In [None]:
# Download NCEP reanalysis data (example years)
# Modify years as needed for your use case
data_dir = "data"
os.makedirs(data_dir, exist_ok=True)

years = ["2019", "2020", "2021", "2022", "2023", "2024"]
base_url = "https://downloads.psl.noaa.gov/Datasets/ncep.reanalysis/Dailies/surface_gauss/"

for year in years:
    filename = f"air.2m.gauss.{year}.nc"
    filepath = os.path.join(data_dir, filename)
    
    if not os.path.exists(filepath):
        print(f"Downloading {filename}...")
        os.system(f"wget -q {base_url}{filename} -O {filepath}")
    else:
        print(f"{filename} already exists, skipping download.")


In [None]:
# Load and subset data
# Example: US region (adjust lat/lon for your region of interest)
# Format: lat=slice(lat_max, lat_min), lon=slice(lon_min, lon_max)
# Note: Longitude is typically 0-360 for NCEP data

# US region example
lat_slice = slice(50, 25)  # Roughly 25N to 50N
lon_slice = slice(235, 295)  # Roughly 235E (125W) to 295E (65W), i.e., 125W to 65W

# Load training data (multiple years)
train_files = [os.path.join(data_dir, f"air.2m.gauss.{year}.nc") for year in ["2021", "2022", "2023", "2024"]]
ds_train = xr.open_mfdataset(train_files).sel(lat=lat_slice, lon=lon_slice)

# Load validation and test data
ds_valid = xr.open_dataset(os.path.join(data_dir, "air.2m.gauss.2020.nc")).sel(lat=lat_slice, lon=lon_slice)
ds_test = xr.open_dataset(os.path.join(data_dir, "air.2m.gauss.2019.nc")).sel(lat=lat_slice, lon=lon_slice)

print(f"Training data shape: {ds_train.air.shape}")
print(f"Spatial dimensions: {ds_train.air.isel(time=0).shape}")
print(f"\nCoordinate ranges:")
print(f"  Latitude: {float(ds_train.lat.min()):.2f} to {float(ds_train.lat.max()):.2f}")
print(f"  Longitude: {float(ds_train.lon.min()):.2f} to {float(ds_train.lon.max()):.2f}")


In [None]:
# Visualize mean temperature field
fig, ax = plt.subplots(figsize=(10, 6))
ds_train.mean(dim='time').air.plot(ax=ax, cmap='coolwarm')
ax.set_title("Mean 2-meter Air Temperature (Training Period)")
plt.tight_layout()
plt.show()


## 3. Dataset and DataLoader Setup

We'll create a PyTorch Dataset that:
1. Extracts daily snapshots from the xarray dataset
2. Normalizes the data to [0, 1] range (or [-1, 1] for some models)
3. Reshapes to the format expected by UNet2D: `[C, H, W]`


In [None]:
class WeatherDataset(Dataset):
    """
    PyTorch Dataset for weather reanalysis data.
    
    Args:
        data: numpy array of shape [time, lat, lon] or xarray DataArray
        normalize: tuple of (min, max) for normalization, or None to skip
    """
    def __init__(self, data, normalize=None):
        if isinstance(data, xr.DataArray):
            data = data.values
        
        self.data = data.astype(np.float32)
        
        if normalize is None:
            # Compute normalization from data
            self.min = self.data.min()
            self.max = self.data.max()
        else:
            self.min, self.max = normalize
        
        # Normalize to [0, 1]
        self.data_norm = (self.data - self.min) / (self.max - self.min)
        
        print(f"Dataset shape: {self.data_norm.shape}")
        print(f"Value range: [{self.data_norm.min():.3f}, {self.data_norm.max():.3f}] (normalized)")
        print(f"Original range: [{self.min:.2f}, {self.max:.2f}] (original units)")
    
    def __len__(self):
        return len(self.data_norm)
    
    def __getitem__(self, idx):
        # Get single timestep: shape [lat, lon] -> [1, lat, lon] (add channel dim)
        x = torch.from_numpy(self.data_norm[idx]).unsqueeze(0)
        return x.float()


In [None]:
# Create datasets
train_data = ds_train.air.values
valid_data = ds_valid.air.values
test_data = ds_test.air.values

# Normalize using training data statistics
data_min, data_max = train_data.min(), train_data.max()

train_dataset = WeatherDataset(train_data, normalize=(data_min, data_max))
valid_dataset = WeatherDataset(valid_data, normalize=(data_min, data_max))
test_dataset = WeatherDataset(test_data, normalize=(data_min, data_max))


In [None]:
# Create DataLoaders
batch_size = 16

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=torch.cuda.is_available()
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=torch.cuda.is_available()
)

# Check a batch
sample_batch = next(iter(train_loader))
print(f"Batch shape: {sample_batch.shape}")  # Should be [batch_size, 1, H, W]


In [None]:
# Visualize a batch of training samples
import torchvision

fig, ax = plt.subplots(figsize=(12, 6))
grid = torchvision.utils.make_grid(sample_batch[:8], nrow=4, padding=2)
ax.imshow(grid[0].cpu().numpy(), cmap='coolwarm')
ax.set_title("Sample Training Batch (8 random timesteps)")
ax.axis('off')
plt.tight_layout()
plt.show()


## 4. Model and Scheduler Setup

We'll use a UNet2D model from diffusers with an EDM-style (Elucidating the Design Space of Diffusion-Based Generative Models) scheduler. The model learns to predict noise at different timesteps.


In [None]:
# Determine image size from data
_, _, img_h, img_w = sample_batch.shape
img_size = max(img_h, img_w)  # UNet2D expects square or uses max dimension
print(f"Image size: {img_size} (height={img_h}, width={img_w})")


In [None]:
# Initialize noise scheduler (EDM-style with Heun discretization)
num_train_timesteps = 1000
noise_scheduler = HeunDiscreteScheduler(num_train_timesteps=num_train_timesteps)

print(f"Scheduler: {type(noise_scheduler).__name__}")
print(f"Number of training timesteps: {num_train_timesteps}")


In [None]:
# Visualize the noising process
fig, axs = plt.subplots(1, 4, figsize=(16, 4))

sample = sample_batch[0:1].to(device)  # [1, 1, H, W]

# Map to [-1, 1] for visualization (EDM schedulers typically work in this range)
sample = sample * 2.0 - 1.0

# Show original
axs[0].imshow(sample[0, 0].cpu().numpy(), cmap='coolwarm', vmin=-1, vmax=1)
axs[0].set_title('Original (t=0)')
axs[0].axis('off')

# Add noise at different timesteps
for i, timestep in enumerate([250, 500, 999]):
    noise = torch.randn_like(sample)
    timesteps = torch.tensor([timestep], device=device)
    noisy = noise_scheduler.add_noise(sample, noise, timesteps)
    
    axs[i+1].imshow(noisy[0, 0].cpu().numpy(), cmap='coolwarm', vmin=-1, vmax=1)
    axs[i+1].set_title(f'Noisy (t={timestep})')
    axs[i+1].axis('off')

plt.tight_layout()
plt.show()


In [None]:
# Initialize UNet2D model
# Using a configuration similar to EDM models
model = UNet2DModel(
    sample_size=img_size,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(64, 128, 256, 512),
    down_block_types=(
        "DownBlock2D",
        "AttnDownBlock2D",  # Attention layers help with spatial relationships
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
    ),
    time_embedding_type="positional",  # or "fourier" for EDM-style
).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model initialized with {num_params:,} parameters")
print(f"Model architecture: {type(model).__name__}")


## 5. Training Loop

The training objective is to predict the noise that was added to the data:
$$L = \\|\\epsilon - \\epsilon_\\theta(x_t, t)\\|^2$$

where $x_t$ is the noisy input at timestep $t$, $\\epsilon$ is the true noise, and $\\epsilon_\\theta$ is the model prediction.


In [None]:
# Training configuration
num_epochs = 20
learning_rate = 1e-4
weight_decay = 1e-2

# Loss function
loss_fn = nn.MSELoss()

# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)

# Learning rate scheduler (cosine annealing)
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_training_steps,
    eta_min=1e-6
)

print(f"Training for {num_epochs} epochs")
print(f"Steps per epoch: {len(train_loader)}")
print(f"Total training steps: {num_training_steps}")


In [None]:
# Training loop with validation
train_losses = []
val_losses = []

best_val_loss = float('inf')
checkpoint_dir = 'checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

for epoch in range(num_epochs):
    # Training phase
    model.train()
    epoch_train_losses = []
    
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for batch in train_pbar:
        x = batch.to(device)  # [B, 1, H, W]
        
        # Map to [-1, 1] for EDM-style training
        x = x * 2.0 - 1.0
        
        # Sample random timesteps
        timesteps = torch.randint(
            0,
            noise_scheduler.num_train_timesteps,
            (x.shape[0],),
            device=device
        ).long()
        
        # Add noise
        noise = torch.randn_like(x)
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
        
        # Predict noise
        pred_noise = model(noisy_x, timesteps).sample
        
        # Compute loss
        loss = loss_fn(pred_noise, noise)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        optimizer.step()
        lr_scheduler.step()
        
        epoch_train_losses.append(loss.item())
        train_pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    avg_train_loss = np.mean(epoch_train_losses)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    epoch_val_losses = []
    
    with torch.no_grad():
        val_pbar = tqdm(valid_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Valid]")
        for batch in val_pbar:
            x = batch.to(device)
            x = x * 2.0 - 1.0
            
            timesteps = torch.randint(
                0,
                noise_scheduler.num_train_timesteps,
                (x.shape[0],),
                device=device
            ).long()
            
            noise = torch.randn_like(x)
            noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
            pred_noise = model(noisy_x, timesteps).sample
            loss = loss_fn(pred_noise, noise)
            
            epoch_val_losses.append(loss.item())
            val_pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    avg_val_loss = np.mean(epoch_val_losses)
    val_losses.append(avg_val_loss)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{num_epochs}:")
    print(f"  Train Loss: {avg_train_loss:.6f}")
    print(f"  Val Loss: {avg_val_loss:.6f}")
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_loss,
            'train_losses': train_losses,
            'val_losses': val_losses,
        }, checkpoint_path)
        print(f"  âœ“ Saved best model (val_loss={avg_val_loss:.6f})")
    
    print('-' * 50)


In [None]:
# Plot training curves
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(train_losses, label='Train Loss', marker='o')
ax.plot(val_losses, label='Validation Loss', marker='s')
ax.set_xlabel('Epoch')
ax.set_ylabel('MSE Loss')
ax.set_title('Training Progress')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


## 6. Sampling from the Trained Model

Now we'll generate new samples by starting from pure noise and iteratively denoising using the trained model.


In [None]:
# Load best model
checkpoint = torch.load(os.path.join(checkpoint_dir, 'best_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Loaded model from epoch {checkpoint['epoch']+1} with val_loss={checkpoint['val_loss']:.6f}")


In [None]:
# Generate samples
num_inference_steps = 50  # Number of denoising steps (can be reduced for faster sampling)
num_samples = 4

# Set timesteps for inference
noise_scheduler.set_timesteps(num_inference_steps, device=device)

# Start from random noise
sample_shape = (num_samples, 1, img_h, img_w)
samples = torch.randn(sample_shape, device=device)

# Denoising loop
with torch.no_grad():
    for t in tqdm(noise_scheduler.timesteps, desc="Sampling"):
        # Expand timestep to batch size
        timestep = t.expand(samples.shape[0])
        
        # Predict noise
        noise_pred = model(samples, timestep).sample
        
        # Step scheduler
        samples = noise_scheduler.step(noise_pred, t, samples).prev_sample

# Map back to [0, 1] and then to original scale
samples = (samples + 1.0) / 2.0  # [-1, 1] -> [0, 1]
samples = samples * (data_max - data_min) + data_min  # Denormalize

print(f"Generated {num_samples} samples")
print(f"Sample value range: [{samples.min():.2f}, {samples.max():.2f}]")


In [None]:
# Visualize generated samples
fig, axs = plt.subplots(2, 2, figsize=(12, 12))
axs = axs.flatten()

for i in range(num_samples):
    im = axs[i].imshow(samples[i, 0].cpu().numpy(), cmap='coolwarm')
    axs[i].set_title(f'Generated Sample {i+1}')
    axs[i].axis('off')
    plt.colorbar(im, ax=axs[i], fraction=0.046)

plt.tight_layout()
plt.show()


In [None]:
# Compare with real data
fig, axs = plt.subplots(2, 4, figsize=(16, 8))

# Real samples from test set
real_samples = test_dataset[:num_samples]
real_samples_tensor = torch.stack(real_samples) * (data_max - data_min) + data_min

for i in range(num_samples):
    # Generated
    im1 = axs[0, i].imshow(samples[i, 0].cpu().numpy(), cmap='coolwarm')
    axs[0, i].set_title(f'Generated {i+1}')
    axs[0, i].axis('off')
    plt.colorbar(im1, ax=axs[0, i], fraction=0.046)
    
    # Real
    im2 = axs[1, i].imshow(real_samples_tensor[i, 0].cpu().numpy(), cmap='coolwarm')
    axs[1, i].set_title(f'Real {i+1}')
    axs[1, i].axis('off')
    plt.colorbar(im2, ax=axs[1, i], fraction=0.046)

plt.suptitle('Generated vs Real Samples', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()


## 7. Next Steps

Congratulations! You've trained a diffusion model on weather data. Here are some next steps:

1. **Generative Data Assimilation**: Use the trained model with observations to improve initial conditions. See the `02_generative_data_assimilation.ipynb` notebook.

2. **Experiment with Different Regions**: Train models on different regions (Europe, global, etc.) or different variables (precipitation, winds, etc.).

3. **Architecture Improvements**: Try different UNet configurations, add attention mechanisms, or experiment with Diffusion Transformers (DiT).

4. **Distillation**: Use scheduler-only distillation to reduce inference timesteps for faster sampling.

5. **Evaluation Metrics**: Add quantitative evaluation (e.g., distribution statistics, spatial correlation, spectral properties).

6. **High-Resolution Models**: Extend to higher resolutions using latent diffusion models or tile-based approaches.

## Using deepassimilate Library

You can also use the `deepassimilate` library's training function for a more streamlined workflow:

```python
import deepassimilate as da

cfg = da.UncondTrainConfig(
    architecture="edm_unet_2d",
    scheduler="heun_edm",
    img_size=64,
    channels=1,
    num_epochs=20,
    batch_size=16,
    lr=1e-4,
    device="cuda"
)

model, scheduler, distilled_steps = da.train_unconditional(
    dataset=train_dataset,
    cfg=cfg
)
```
