# Part 3: Weather Foundational Model with Prithvi WxC

## Introduction

In this part of the workshop, we will explore the concept of **Weather Foundational Models** and use a state-of-the-art model called **Prithvi WxC** for a downscaling task.

### What is a Weather Foundational Model?

A **Foundational Model** in the context of weather and climate is a large AI model trained on vast amounts of meteorological data (like ERA5 or MERRA-2 reanalysis data). Unlike traditional numerical weather prediction (NWP) models that solve complex physical equations, these AI models learn the underlying physics and patterns of the atmosphere directly from the data. They are designed to be:

*   **General-purpose**: Once pre-trained, they can be fine-tuned for various downstream tasks such as forecasting, downscaling, or gravity wave parameterization.
*   **Efficient**: Inference is typically much faster than running high-resolution physical simulations.

### Prithvi WxC

**Prithvi WxC** is a 2.3 billion parameter Foundational Model developed by IBM and NASA. It is based on a **Vision Transformer (ViT)** architecture with an encoder-decoder structure. It treats weather data as a sequence of tokens, allowing it to capture both local and global interactions in the atmosphere.

### The Task: Downscaling

**Downscaling** is the process of generating high-resolution weather data from low-resolution inputs. This is crucial for local impact assessments where we need detailed information (e.g., city-level temperature) that global models (with coarse grids like 50km or 100km) cannot provide.

In this notebook, we will use a fine-tuned version of Prithvi WxC to downscale **2-meter Temperature (T2M)** from the MERRA-2 dataset, increasing its resolution by **6x**.

Let's get started!

## 1. Setup and Installation

We need to install the `PrithviWxC` and `granitewxc` libraries, along with standard data handling libraries.

In [None]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    print("Running on Google Colab. Installing dependencies...")
    !pip install h5netcdf matplotlib wget pyyaml xarray scipy torch PrithviWxC granitewxc huggingface_hub
else:
    print("Running locally. Skipping dependency installation.")

## 2. Imports and Configuration

Import the necessary modules and set up the computing device (GPU is highly recommended).

In [None]:
import os
import wget
import random
from pathlib import Path

import matplotlib.pyplot as plt
from itertools import product
import numpy as np
import torch
from torch.utils.data import DataLoader
from huggingface_hub import hf_hub_download

from granitewxc.utils.config import get_config
from granitewxc.utils.data import _get_transforms
from granitewxc.datasets.merra2 import Merra2DownscaleDataset
from granitewxc.utils.downscaling_model import get_finetune_model
from PrithviWxC.dataloaders.merra2 import SampleSpec

torch.jit.enable_onednn_fusion(True)
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
torch.manual_seed(42)
np.random.seed(42)

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using device: {device}")

## 3. Load Model Configuration

We use a `config.yaml` file to define the data variables, model parameters, and training settings. This ensures the model is rebuilt exactly as it was during fine-tuning.

In [None]:
config_path = hf_hub_download(repo_id="ibm-granite/granite-geospatial-wxc-downscaling",
                            filename="config.yaml",
                            local_dir="../data")
config = get_config(config_path)

## 4. Download Data and Weights

We will download sample MERRA-2 data for a single day (Jan 1, 2020) and the pre-trained fine-tuned weights for the downscaling model.

In [None]:
config.download_path = '../data'

# Download Model Weights
hf_hub_download(repo_id="ibm-granite/granite-geospatial-wxc-downscaling", filename="pytorch_model.bin", local_dir=config.download_path)

# Download Sample Data (Surface and Vertical levels)
hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M", filename="merra-2/MERRA2_sfc_20200101.nc", local_dir=config.download_path)
hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M", filename="merra-2/MERRA_pres_20200101.nc", local_dir=config.download_path)

# Download Climatology / Scalers
hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M", filename="climatology/anomaly_variance_surface.nc", local_dir=config.download_path)
hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M", filename="climatology/anomaly_variance_vertical.nc", local_dir=config.download_path)
hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M", filename="climatology/musigma_surface.nc", local_dir=config.download_path)
hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M", filename="climatology/musigma_vertical.nc", local_dir=config.download_path)

# Download Climatology Mean Files (for normalization)
for hour in [0, 3, 6, 9, 12, 15, 18, 21]:
    h_str = f"{hour:02d}"
    hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M", filename=f"climatology/climate_surface_doy001_hour{h_str}.nc", local_dir=config.download_path)
    hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-WxC-1.0-2300M", filename=f"climatology/climate_vertical_doy001_hour{h_str}.nc", local_dir=config.download_path)

## 5. Prepare Dataset and Dataloader

The `Merra2DownscaleDataset` handles the data preparation:
1.  **Input**: It applies a "coarsening" transform to the original MERRA-2 data to simulate a low-resolution input.
2.  **Target**: It uses the original high-resolution MERRA-2 data as the ground truth.

In [None]:
# Set paths in config
config.data.data_path_surface = os.path.join(config.download_path,'merra-2')
config.data.data_path_vertical = os.path.join(config.download_path, 'merra-2')
config.data.climatology_path_surface = os.path.join(config.download_path,'climatology')
config.data.climatology_path_vertical = os.path.join(config.download_path,'climatology')

config.model.input_scalers_surface_path = os.path.join(config.download_path,'climatology/musigma_surface.nc')
config.model.input_scalers_vertical_path = os.path.join(config.download_path,'climatology/musigma_vertical.nc')
config.model.output_scalers_surface_path = os.path.join(config.download_path,'climatology/anomaly_variance_surface.nc')
config.model.output_scalers_vertical_path = os.path.join(config.download_path,'climatology/anomaly_variance_vertical.nc')

# Set time range for validation
config.data.val_time_range_start = '2020-01-01T00:00:00'
config.data.val_time_range_end = '2020-01-01T23:59:59'

# Initialize Dataset
dataset = Merra2DownscaleDataset(
    time_range=(config.data.val_time_range_start, config.data.val_time_range_end),
    data_path_surface = config.data.data_path_surface,
    data_path_vertical = config.data.data_path_vertical,
    climatology_path_surface = config.data.climatology_path_surface,
    climatology_path_vertical = config.data.climatology_path_vertical,
    input_surface_vars = config.data.input_surface_vars,
    input_static_surface_vars = config.data.input_static_surface_vars,
    input_vertical_vars = config.data.input_vertical_vars,
    input_levels = config.data.input_levels,
    n_input_timestamps = config.data.n_input_timestamps,
    output_vars=config.data.output_vars,
    transforms=_get_transforms(config),
)

dataloader = DataLoader(dataset, batch_size=1)
print(f"Dataset initialized with {len(dataset)} samples.")

## 6. Initialize Model and Load Weights

We initialize the `ClimateDownscaleFinetuneModel`. This specific architecture adds an upscaling head to the core Prithvi WxC encoder.

In [None]:
model = get_finetune_model(config, logger=None)

weights_path = Path(config.download_path, 'pytorch_model.bin')
model.load_state_dict(torch.load(weights_path, weights_only=False, map_location=device))
model.to(device)
print("Model loaded successfully.")

## 7. Run Inference

We'll run the model on a single sample from our dataloader.

In [None]:
with torch.no_grad():
    model.eval()

    batch = next(iter(dataloader))
    batch = {k: v.to(device) for k, v in batch.items()}
    out = model(batch)

    inputs = batch['x']
    targets = batch['y']
    outputs = out

    inputs = batch['x']
    targets = batch['y']
    outputs = out

print("Inference complete.")
print(f"Input shape: {inputs.shape}")
print(f"Target shape: {targets.shape}")
print(f"Output shape: {outputs.shape}")

## 8. Visualization

Let's compare the **input** (low resolution), **prediction** (downscaled high resolution), and **target** (original high resolution) side-by-side.

In [None]:
var_name = "T2M"
var_name_title = '2M air temperature'

input_vars = [*config.data.input_surface_vars, *product(config.data.input_vertical_vars, config.data.input_levels)]
input_t2m_index = input_vars.index(var_name)

# Extract data for plotting
plot_input = inputs[0, input_t2m_index, :, :].detach().cpu().numpy()
plot_target = targets[0, 0, : ,:].detach().cpu().numpy()
plot_output = outputs[0, 0, :, :].detach().cpu().numpy()
plot_residual = plot_target - plot_output

fig, axes = plt.subplots(1, 4, figsize=(24, 6))

# Input (Low Res)
im0 = axes[0].imshow(plot_input, origin='lower', cmap='plasma')
axes[0].set_title(f'Input (Low Res)\n{plot_input.shape}')
plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

# Prediction (High Res)
im1 = axes[1].imshow(plot_output, origin='lower', cmap='plasma')
axes[1].set_title(f'Prediction (High Res)\n{plot_output.shape}')
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

# Ground Truth (High Res)
im2 = axes[2].imshow(plot_target, origin='lower', cmap='plasma')
axes[2].set_title(f'Ground Truth (High Res)\n{plot_target.shape}')
plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)

# Residual / Error
im3 = axes[3].imshow(plot_residual, origin='lower', cmap='bwr', vmin=-5, vmax=5)
axes[3].set_title(f'Residual (Target - Pred)')
plt.colorbar(im3, ax=axes[3], fraction=0.046, pad=0.04)

plt.suptitle(f"Downscaling Results for {var_name_title}", fontsize=16)
plt.tight_layout()
plt.show()