# ClimaX Prediction Notebook

This notebook demonstrates how to run precdiction via ClimaX model at NCI computing facility.

In [None]:
import os
from typing import Optional
from typing import Any
import numpy as np
import math
import random
import numpy as np
import click
import xarray as xr
from tqdm import tqdm
import glob
import warnings
from typing import List

import torch
import torchdata.datapipes as dp
from pytorch_lightning import LightningDataModule
from pytorch_lightning import LightningModule
from torch.utils.data import DataLoader, IterableDataset
from torchvision.transforms import transforms
from pytorch_lightning.cli import LightningCLI

from climax.arch import ClimaX
from climax.pretrain.datamodule import collate_fn
from climax.pretrain.datamodule import MultiSourceDataModule
from climax.pretrain.module import PretrainModule

from climax.pretrain.dataset import (
	Forecast,
	IndividualForecastDataIter,
	NpyReader,
	ShuffleIterableDataset,
)
from climax.utils.lr_scheduler import LinearWarmupCosineAnnealingLR
from climax.utils.metrics import (
	lat_weighted_acc,
	lat_weighted_mse,
	lat_weighted_mse_val,
	lat_weighted_rmse,
)
from climax.utils.pos_embed import interpolate_pos_embed
from climax.global_forecast.datamodule import GlobalForecastDataModule
from climax.global_forecast.module import GlobalForecastModule

from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from functools import lru_cache

import torch.nn as nn
from timm.models.vision_transformer import Block, PatchEmbed, trunc_normal_

from climax.utils.pos_embed import (
	get_1d_sincos_pos_embed_from_grid,
	get_2d_sincos_pos_embed,
)
from climax.utils.data_utils import DEFAULT_PRESSURE_LEVELS, NAME_TO_VAR
import holoviews as hv
import matplotlib
from matplotlib import pyplot as plt
from IPython.display import Image
from IPython.core.display import HTML 

In [None]:
def RMSELoss(yhat,y):
    return torch.sqrt(torch.mean((yhat-y)**2))

In [None]:
import os
os.chdir("YOUR_OWN_WORKDIR")

# Climax model
Paper link: https://arxiv.org/pdf/2301.10343.pdf

In [None]:
Image(filename = "/g/data/dk92/apps/climax/Clim-1.png", width=700, height=750)

# WeatherBench data
Paper link: https://arxiv.org/pdf/2002.00469.pdf

In [None]:
Image(filename = "/g/data/dk92/apps/climax/weather-1.png", width=700, height=750)

# Data location.
WeatherBench data are re-processed for ClimaX training, testing and prediction. They are put under the location `/g/data/dk92/apps/climax/weatherbench/5.625deg_npz` 

In [None]:
! ls -lh  /g/data/dk92/apps/climax/weatherbench/5.625deg_npz

# PBS Script - Model fine-tune

The fine-tuning code may take very long time. You can find example PBS job scripts to run the fine-tuning training in /g/data/dk92/apps/climax/0.2.3/examples. For more details, please visit https://opus.nci.org.au/x/vgJDDQ

# Prediction

The Climax code does not have a prediction module. We have added a prediction code to the NCI ClimaX environment. This code is different from the code available in the original repository and includes a prediction script, configuration file, and data loading module. 

In [None]:
import os
os.environ['OUTPUT_DIR']='climax_train_global_output'
!python ${CLIMAX_ROOT}/src/climax/global_forecast/predict.py \
    --config "${CLIMAX_ROOT}/configs/prediction.yaml" \
    --trainer.num_nodes=1 \
    --trainer.strategy=ddp --trainer.devices=1 \
    --data.root_dir='/g/data/dk92/apps/climax/weatherbench/5.625deg_npz' \
    --data.predict_range=72 \
    --data.out_variables=['geopotential_500','temperature_850','2m_temperature'] \
    --data.batch_size=16 \
    --data.num_workers=1 \
    --model.pretrained_path='/g/data/dk92/apps/climax/weatherbench/ClimaX-5.625deg.ckpt' 

The above prediction script will produce a file named "pred.pt" under the directory specified under "OUTPUT_DIR" environment.

In [None]:
import os
import torch

with open(os.environ['OUTPUT_DIR']+'/pred.pt','rb') as f:
    pred = torch.load(f) 

geopotential_500_y = pred[0][1][:,0]
temperature_850_y  = pred[0][1][:,1]
temperature_y      = pred[0][1][:,2]

geopotential_500_pred = pred[0][2][:,0]
temperature_850_pred  = pred[0][2][:,1]
temperature_pred      = pred[0][2][:,2]
#print (geopotential_500_y.shape)
for i in range (1,16):
    geopotential_500_y = torch.cat((geopotential_500_y, pred[i][1][:,0]), 0)
    temperature_850_y  = torch.cat((temperature_850_y,  pred[i][1][:,1]), 0)
    temperature_y      = torch.cat((temperature_y,      pred[i][1][:,2]), 0)
    
    geopotential_500_pred = torch.cat((geopotential_500_pred, pred[i][2][:,0]), 0)
    temperature_850_pred  = torch.cat((temperature_850_pred,  pred[i][2][:,1]), 0)
    temperature_pred      = torch.cat((temperature_pred,      pred[i][2][:,2]), 0)
print (geopotential_500_y.shape, temperature_850_y.shape, temperature_y.shape)
print (geopotential_500_pred.shape, temperature_850_pred.shape, temperature_pred.shape)

In [None]:
%%time 
%matplotlib inline 
import matplotlib.colors as colors
import matplotlib.image as mpimg
from matplotlib import rcParams
import matplotlib.animation as animation
from IPython.display import HTML
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams['animation.embed_limit'] = 2**128    
plt.rcParams["figure.figsize"] = [15,12]
#plt.rcParams["figure.autolayout"] = True 
plt.subplots_adjust(bottom=0, right=1, top=1, left=0)
print (geopotential_500_y.shape)

fig = plt.figure()
#fig.set_figheight(10)
#fig.set_figwidth(10)
ax1 = fig.add_subplot(321)
ax4 = fig.add_subplot(322)
ax2 = fig.add_subplot(323)
ax5 = fig.add_subplot(324)
ax3 = fig.add_subplot(325)
ax6 = fig.add_subplot(326)

ax1.title.set_text('geopotential_500')
ax2.title.set_text('temperature_850')
ax3.title.set_text('2m_temperature')
ax4.title.set_text("Prediction, RMSE: " + str(round ( RMSELoss(geopotential_500_pred[0],geopotential_500_y [0]).item(),3) ) )
ax5.title.set_text("Prediction, RMSE: " + str(round ( RMSELoss(temperature_850_pred[0], temperature_850_y[0]).item(),3) ) )
ax6.title.set_text("Prediction, RMSE: " + str(round ( RMSELoss(temperature_pred[0],     temperature_y [0]).item(),3) ) )

color_map ='Spectral_r'
inter = 'nearest'
#mi = np.min (geopotential_500_y  )
#mx = np.max (geopotential_500_y  )
imart1 = ax1.imshow(geopotential_500_y [0])
imart2 = ax2.imshow(temperature_850_y  [0])
imart3 = ax3.imshow(temperature_y      [0])
imart4 = ax4.imshow(geopotential_500_pred [0])
imart5 = ax5.imshow(temperature_850_pred  [0])
imart6 = ax6.imshow(temperature_pred      [0])
imart  = [imart1, imart2, imart3, imart4, imart5, imart6 ] 

def init():
    return imart 

def update(val):
    imart[0].set_data(geopotential_500_y [val])
    imart[1].set_data(temperature_850_y [val])
    imart[2].set_data(temperature_y [val])
    imart[3].set_data(geopotential_500_pred [val])
    imart[4].set_data(temperature_850_pred [val])
    imart[5].set_data(temperature_pred [val])
    ax4.title.set_text("Prediction, RMSE: " + str(round ( RMSELoss(geopotential_500_pred[val],geopotential_500_y [val]).item(),3) ) )
    ax5.title.set_text("Prediction, RMSE: " + str(round ( RMSELoss(temperature_850_pred[val], temperature_850_y[val]).item(),3) ) )
    ax6.title.set_text("Prediction, RMSE: " + str(round ( RMSELoss(temperature_pred[val],     temperature_y [val]).item(),3) ) )

    return imart
ani = animation.FuncAnimation(fig, update, frames=len(geopotential_500_y), 
                              init_func=init,
                              interval=3, blit=True)
plt.close()


In [None]:
%%time 
# Note that it will take some time to generate the animation.
from matplotlib import rc
rc('animation', html='jshtml')
ani