# Prithvi WxC Downscaling with ECCC Data: Model Inference

This notebook is a walkthrough on using a fine-tuned downscaling model for generating inferences

We show how to initalize the model, load the `Prithvi` fine-tuned weights, and use the model for inference

To replicate the results show in this notebook please download the required files from our [Hugging Face](https://huggingface.co/ibm-granite/granite-geospatial-wxc-downscaling) repository

In [1]:
from huggingface_hub import snapshot_download
import os

if not os.path.exists('granite-geospatial-wxc-downscaling'):
    snapshot_download(repo_id='ibm-granite/granite-geospatial-wxc-downscaling', allow_patterns="*", repo_type='model', local_dir='./granite-geospatial-wxc-downscaling')

  from .autonotebook import tqdm as notebook_tqdm


**This notebook is a simple plug-and-play example** 

We provide only **1 data sample**. See `./examples/eccc_downscaling/notebooks/README.md` to download and preprocess the remaining files

---

## Setup

Python >= 3.10 is required

In [2]:
import logging
import warnings
logging.disable(logging.CRITICAL)
warnings.simplefilter(action='ignore', category=FutureWarning)

In [3]:
import torch
import numpy as np
from itertools import product

from granitewxc.utils.config import get_config
from granitewxc.utils.eccc_data import get_dataloaders_eccc
from granitewxc.utils.plot import plot_eccc_results
from granitewxc.models.model import get_finetune_model_UNET, get_finetune_model

Pysteps configuration file found at: /home/simon/miniconda3/envs/fm4a/lib/python3.12/site-packages/pysteps/pystepsrc



Configure the backends, PyTorch states, and random seeds to standardize the RNG for random crops in this example

In [4]:
torch.jit.enable_onednn_fusion(True)
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed(42)
torch.manual_seed(42)
np.random.seed(42)

It is possible to use a cpu or gpu/s to generate inferences. Based on avaiablity of a `cuda:gpu`, we set the device that the model uses

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Configuration File

The model is configured using `YAML` files.

In these files, you specify:
- Paths to the input data  
- Locations of the pretrained weights  

To ensure compatibility with the provided weights during inference, keep the model configuration consistent with the original definitions 

In [6]:
config_path = './granite-geospatial-wxc-downscaling/ECCC/configs/config_UNET.yaml'
config = get_config(config_path)

## Dataloader 

In this example we will use only **1 sample of data**

To download and setup all the remaining data follow the instructions in `./examples/eccc_downscaling/notebooks/README.md`

In [7]:
test_dl = get_dataloaders_eccc(config, test=True)

--> Test samples: 1


## Model Initialization

We provide **2** different model architectures `UNET-like` and `CONV` 

Both architectures include:  
1. **Patch Embedding**: Extracts shallow features from the input data  
2. **Feature Extraction**: Utilizes the Prithvi backbone to extract deeper features  

The key difference is that the UNET-like version incorporates **static high-resolution data** into the model

In this notebook, we use the **UNET-like** version

To switch to the **CONV** model, update the configuration file accordingly and use `get_finetune_model(config)`

In [8]:
model = get_finetune_model_UNET(config)

Creating the model.
--> model has 1,448,690,568 params.


We can now load the pretrained weights

In [9]:
weights_path = config.path_model_weights
weights = torch.load(weights_path, map_location=device)['model']

model.load_state_dict(weights, strict=True)
model.to(device)

OutOfMemoryError: CUDA out of memory. Tried to allocate 100.00 MiB. GPU 0 has a total capacity of 47.50 GiB of which 49.12 MiB is free. Process 1094577 has 540.00 MiB memory in use. Process 1410371 has 540.00 MiB memory in use. Process 3152761 has 34.85 GiB memory in use. Including non-PyTorch memory, this process has 10.98 GiB memory in use. Of the allocated memory 10.55 GiB is allocated by PyTorch, and 13.57 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

### Inference

The model is now ready for inference. We are running an inference for one batch (in this example batch_size=1)

Unlike training, where we used fixed-size random crops, for inference we will use the entire Canadian region, a (1280, 2528) image

In [None]:
with torch.no_grad():
    model.eval()

    batch = next(iter(test_dl))
    batch = {k: v.to(device) for k, v in batch.items()}

    out = model(batch)

    inputs = batch['x']
    targets = batch['y']
    outputs = out

In [None]:
inputs.shape, targets.shape, outputs.shape

### Plotting

We set the variable names and extract the sample information for generating plots

In [None]:
var_name = "UUWE" # UUWE or VVSN
var_unit = "m/s"

output_vars = [*config.data.output_vars]
input_vars = [*config.data.input_surface_vars,
              *product(config.data.vertical_pres_vars, config.data.input_level_pres),
              *product(config.data.vertical_level1_vars, config.data.input_level1),
              *product(config.data.vertical_level2_vars, config.data.input_level2),
              *config.data.other
             ]

coarsening_factor = targets.shape[-1] / inputs.shape[-1]

f"Downscaling '{var_name}' at  by {coarsening_factor}x"

In [None]:
for idx in range(0, len(outputs)):
    var = output_vars.index(var_name)
    input_var = input_vars.index(var_name)
    var_name_tile = var_name

    plot_input = inputs[idx, input_var, :, :].cpu().numpy()
    plot_target = targets[idx, var, : ,:].cpu().numpy()
    plot_pred = outputs[idx, var, :, :].cpu().numpy()
    plot_residual = plot_target - plot_pred

    plot_eccc_results(plot_input, plot_pred, plot_target)