# Inference through makani model package

## General

While we generally suggest using Makani's `inferencer.py` module or Earth2Studio for inference, it can be useful to use stand-along model packages. To support this, makani intrduces a model package format, which we showcase here.

## Setting up the model package

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"]="expandable_segments:True"
import torch
import numpy as np

from makani.models.model_package import LocalPackage, load_model_package

In [None]:
# device that we want to use
device = torch.device("cuda:0")

# directory where the model package resides
model_package_dir = "/runs/fcn3_sc2_edim45_layers10_finetune_2013-2016_8step_centered_4member/flexible"

model_package = load_model_package(LocalPackage(model_package_dir)).to(device)

In [None]:
variables = model_package.params.channel_names
timestep = 6

## load data from local HDF5 file

In [None]:
import h5py as h5
import json

era5_2018_file = h5.File("/out_of_sample/2018.h5", "r")
era5_2018_data = era5_2018_file["fields"]

# get the channel names from the desciption file
era5_2018_desc_file = open("/metadata/data.json")
era5_metadata = json.load(era5_2018_desc_file)
era5_2018_desc_file.close()

In [None]:
era5_channels = era5_metadata["coords"]["channel"]
era5_dhours = era5_metadata["dhours"]

In [None]:
from datetime import datetime, timedelta, timezone

iic = 0
ic_time = datetime.fromisoformat("2018-01-01T00:00:00+00:00") + timedelta(hours=iic * era5_dhours)

ich = [era5_channels.index(c) for c in variables]

## do the inference

In [None]:
autoreg_steps = 18

inpt = torch.as_tensor(era5_2018_data[iic, ich]).to(device)
time = ic_time

# normalize the input now to avoid jumping back and forthabs
inpt = (inpt - model_package.in_bias)/model_package.in_scale

with torch.no_grad():
    with torch.inference_mode():
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            
            pred = inpt.clone()

            for idt in range(autoreg_steps):
                pred = model_package(pred, time, normalized_data=True, replace_state=True)
                time += timedelta(hours=timestep)

pred = pred * model_package.out_scale + model_package.out_bias

In [None]:
import matplotlib.pyplot as plt
from torch_harmonics.plotting import plot_sphere

plt_channel = "u10m"

ground_truth = era5_2018_data[iic+autoreg_steps, era5_channels.index(plt_channel)]
prediction = pred.cpu().detach().numpy()[0, variables.index(plt_channel)]

vmax = np.abs(ground_truth).max()
vmin = -vmax

fig = plt.figure(figsize=(10,8))
plot_sphere(prediction, title = f"FCN3 prediction at {time}", vmin=vmin, vmax=vmax, fig=fig)
fig = plt.figure(figsize=(10,8))
plot_sphere(ground_truth, title = f"ERA5 ground truth at {time}", vmin=vmin, vmax=vmax, fig=fig)