## CBottle Video Conditional Animation (Memory Optimized)

Create conditional weather forecast animations using CBottleVideo with ERA5 data.

**Key optimization:** Process initialization times sequentially instead of in batch to avoid OOM errors.

## Import packages

In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.colors
import numpy as np
import pandas as pd
import os
import gc
from datetime import datetime, timedelta
import cartopy.crs as ccrs

from earth2studio.data import WB2ERA5
from earth2studio.lexicon import WB2Lexicon, CBottleLexicon
from earth2studio.models.dx import CBottleInfill
from earth2studio.models.px import CBottleVideo
from earth2studio.data.utils import fetch_data

## Configuration

In [None]:
# Configuration
os.makedirs("outputs", exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Parameters - CONFIGURE THESE
VIDEO_VARIABLE = "t850"
N_FRAMES = 12  # 0-66 hours in 6-hour steps per initialization
SEED = 42
INTERPOLATE = True  # Mask middle timesteps for interpolation

# Initial date
date_dict = {'year': 2022, 'month': 6, 'day': 1, 'hour': 0, 'minute': 0, 'second': 0}

## Generate Time Array

We'll generate 12 initialization times, but **process them one at a time** to avoid OOM errors.

In [None]:
def generate_datetime_array(year, month, day, hour=0, minute=0, second=0, n_times=12):
    """
    Generate a numpy array of datetime objects spaced 6 hours apart.
    
    Parameters:
    -----------
    year, month, day, hour, minute, second : int
        Starting datetime components
    n_times : int
        Number of timestamps to generate
    
    Returns:
    --------
    numpy.ndarray
        Array of datetime64[ns] objects
    """
    start_time = datetime(year, month, day, hour, minute, second)
    times = [start_time + timedelta(hours=6*i) for i in range(n_times)]
    return np.array(times, dtype="datetime64[ns]")

# Generate initialization times
init_times = generate_datetime_array(
    date_dict['year'], 
    date_dict['month'], 
    date_dict['day'],
    n_times=12
)
print(f"Processing {len(init_times)} initialization times:")
print(init_times)
print(f"\nNote: Processing ONE at a time to avoid OOM errors")

## Determine Available ERA5 Variables

In [None]:
print("Determining available ERA5 variables...")
wb2_vars = set(WB2Lexicon.VOCAB.keys())
cbottle_vars = list(CBottleLexicon.VOCAB.keys())
available_in_era5 = sorted([v for v in cbottle_vars if v in wb2_vars])

print(f"Using {len(available_in_era5)} ERA5 variables for conditioning")

## Load Models

Load models once and reuse for all initialization times.

In [None]:
# Load ERA5 data source
print("Loading data sources and models...")
era5_ds = WB2ERA5()

# Load CBottleInfill model
package_infill = CBottleInfill.load_default_package()
cbottle_infill = CBottleInfill.load_model(
    package_infill,
    input_variables=available_in_era5,
    sampler_steps=18
)
cbottle_infill = cbottle_infill.to(device)
cbottle_infill.set_seed(SEED)
print("✓ CBottleInfill loaded")

## Process Each Initialization Time

**Critical:** We process one initialization time at a time to avoid GPU memory overflow.

For each initialization time:
1. Fetch ERA5 data
2. Run CBottleInfill
3. Generate forecast with CBottleVideo
4. Clean up GPU memory
5. Move to next time

In [None]:
# Storage for all results
all_outputs = []
all_coords = []

# Free CBottleInfill before loading CBottleVideo
del cbottle_infill
torch.cuda.empty_cache()
gc.collect()
print("✓ Freed CBottleInfill from GPU\n")

# Load CBottleVideo once
package_video = CBottleVideo.load_default_package()
cbottle_video = CBottleVideo.load_model(package_video, seed=SEED)
cbottle_video = cbottle_video.to(device)
print("✓ CBottleVideo loaded\n")

# Process each initialization time sequentially
for init_idx, init_time in enumerate(init_times):
    print(f"{'='*60}")
    print(f"Processing initialization time {init_idx+1}/{len(init_times)}")
    print(f"Time: {pd.Timestamp(init_time)}")
    print(f"{'='*60}")
    
    # Convert single time to array
    times_single = np.array([init_time], dtype="datetime64[ns]")
    
    # Fetch ERA5 data for this single time
    print("  Fetching ERA5 data...")
    era5_x, era5_coords = fetch_data(era5_ds, times_single, available_in_era5, device=device)
    print(f"  ERA5 shape: {era5_x.shape}")
    
    # Reload CBottleInfill for this iteration
    cbottle_infill_temp = CBottleInfill.load_model(
        package_infill,
        input_variables=available_in_era5,
        sampler_steps=18
    )
    cbottle_infill_temp = cbottle_infill_temp.to(device)
    cbottle_infill_temp.set_seed(SEED + init_idx)  # Different seed per init time
    
    # Run infilling
    print("  Running CBottleInfill...")
    infilled_x, infilled_coords = cbottle_infill_temp(era5_x, era5_coords)
    print(f"  Infilled shape: {infilled_x.shape}")
    
    # Free infill model and ERA5 data
    del cbottle_infill_temp, era5_x, era5_coords
    torch.cuda.empty_cache()
    
    # Prepare input for CBottleVideo
    if len(infilled_x.shape) == 5:  # [time, lead_time, variable, lat, lon]
        x_cond = infilled_x.unsqueeze(0)  # [batch, time, lead_time, variable, lat, lon]
    else:
        x_cond = infilled_x
    
    print(f"  Conditional input shape: {x_cond.shape}")
    
    # Optional: Apply interpolation masking
    if INTERPOLATE and x_cond.shape[1] > 2:
        x_cond_masked = x_cond.clone()
        x_cond_masked[:, 1:-1, :, :, :, :] = float('nan')
        del x_cond
        x_cond = x_cond_masked
        print("  Applied interpolation masking")
    
    # Setup coordinates
    coords_cond = cbottle_video.input_coords()
    coords_cond["time"] = times_single
    coords_cond["batch"] = np.array([0])
    coords_cond["variable"] = infilled_coords["variable"]
    
    # Run CBottleVideo inference
    print("  Running CBottleVideo...")
    cbottle_video.set_seed(SEED + init_idx)
    iterator = cbottle_video.create_iterator(x_cond, coords_cond)
    
    # Move input to CPU immediately after iterator starts
    del x_cond, infilled_x, infilled_coords
    torch.cuda.empty_cache()
    
    # Collect outputs for this initialization
    init_outputs = []
    init_coords = []
    
    for step, (output, output_coords) in enumerate(iterator):
        lead_time = output_coords["lead_time"][0]
        hours = int(lead_time / np.timedelta64(1, "h"))
        
        # Move to CPU immediately
        init_outputs.append(output.cpu())
        init_coords.append(output_coords)
        
        del output
        if step % 3 == 0:
            torch.cuda.empty_cache()
        
        print(f"    Frame {step}: +{hours}h")
        
        if step >= N_FRAMES - 1:
            break
    
    # Store results
    all_outputs.append(init_outputs)
    all_coords.append(init_coords)
    
    # Cleanup
    torch.cuda.empty_cache()
    gc.collect()
    print(f"  ✓ Completed initialization time {init_idx+1}\n")

# Free CBottleVideo
del cbottle_video
torch.cuda.empty_cache()
gc.collect()
print("\n" + "="*60)
print("✓ All forecasts completed!")
print(f"Generated {len(all_outputs)} initialization times × {len(all_outputs[0])} forecast frames")
print("="*60)

## Create Animation

Now create an animated visualization of all the forecasts.

In [None]:
print(f"Creating animation for variable: {VIDEO_VARIABLE}")

# Find variable index
var_names = all_coords[0][0]["variable"]
try:
    var_idx = np.where(var_names == VIDEO_VARIABLE)[0][0]
except IndexError:
    print(f"Error: Variable '{VIDEO_VARIABLE}' not found!")
    print(f"Available variables: {list(var_names)}")
    raise

# Setup plot
plt.style.use("dark_background")
projection = ccrs.Orthographic(central_longitude=0.0, central_latitude=45.0)
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection=projection)

# Get data range for colormap (from first init time)
data_min = min(
    all_outputs[0][i][0, 0, 0, var_idx].min() for i in range(len(all_outputs[0]))
)
data_max = max(
    all_outputs[0][i][0, 0, 0, var_idx].max() for i in range(len(all_outputs[0]))
)
norm = matplotlib.colors.Normalize(vmin=data_min, vmax=data_max)

# First frame
data_first = all_outputs[0][0][0, 0, 0, var_idx].numpy()
img = ax.pcolormesh(
    all_coords[0][0]["lon"],
    all_coords[0][0]["lat"],
    data_first,
    transform=ccrs.PlateCarree(),
    cmap="viridis",
    norm=norm,
)
ax.coastlines()
ax.gridlines()
plt.colorbar(
    img, ax=ax, orientation="horizontal", shrink=0.5, pad=0.05, label=VIDEO_VARIABLE
)

# Initial title
lead_time = all_coords[0][0]["lead_time"][0]
hours = int(lead_time / np.timedelta64(1, "h"))
time_str = pd.Timestamp(
    all_coords[0][0]["time"][0] + lead_time
).strftime("%Y-%m-%d %H:%M")
title = ax.set_title(
    f"Conditional: {VIDEO_VARIABLE} Init 1/{len(all_outputs)} +{hours:03d}h ({time_str})"
)
fig.tight_layout()


def update(global_frame):
    """Update animation frame"""
    # Calculate which initialization and forecast frame
    init_idx = global_frame // N_FRAMES
    frame_idx = global_frame % N_FRAMES
    
    data = all_outputs[init_idx][frame_idx][0, 0, 0, var_idx].numpy()
    img.set_array(data.ravel())
    
    lead_time = all_coords[init_idx][frame_idx]["lead_time"][0]
    hours = int(lead_time / np.timedelta64(1, "h"))
    time_str = pd.Timestamp(
        all_coords[init_idx][frame_idx]["time"][0] + lead_time
    ).strftime("%Y-%m-%d %H:%M")
    title.set_text(
        f"Conditional: {VIDEO_VARIABLE} Init {init_idx+1}/{len(all_outputs)} +{hours:03d}h ({time_str})"
    )
    return [img, title]


# Create and save animation
total_frames = len(all_outputs) * N_FRAMES
print(f"Rendering {total_frames} total frames...")
anim = animation.FuncAnimation(
    fig, update, frames=total_frames, interval=500, blit=True
)

output_file = f"outputs/conditional_video_{VIDEO_VARIABLE}_fixed.mp4"
writer = animation.FFMpegWriter(fps=2)
print(f"Saving to {output_file}...")
anim.save(output_file, writer=writer, dpi=100)
plt.close()

print(f"\n✓ Animation saved to {output_file}")
print("Memory-efficient execution complete!")

## Memory Usage Summary

**Original approach (OOM):**
- Loaded all 12 initialization times at once
- Shape: `[1, 12, 1, 45, 721, 1440]`
- Generated 12 × 12 = 144 frames simultaneously
- Memory: ~25+ GB just for outputs

**Fixed approach (Success):**
- Process one initialization time at a time
- Shape per iteration: `[1, 1, 1, 45, 721, 1440]`
- Generate 12 frames per iteration
- Memory: ~2-3 GB peak per iteration
- Clean up GPU memory between iterations