In [None]:
import os, sys, time
import numpy as np
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
from matplotlib.ticker import FuncFormatter
import xarray
from mpi4py import MPI

In [None]:
# These paths will need to be altered to fit the current environment

# data and model paths
data_path = "/home/jovyan/shared/fourcastnet/FourCastNet/data/FCN_ERA5_data_v0/out_of_sample"
d_2 = '/home/jovyan/ccai_demo/data/FCN_ERA5_data_v0/out_of_sample'
data_file = os.path.join(data_path, "2018.h5")
data_file2 = os.path.join(d_2, "2018.h5")

#model_path = "/home/jovyan/new_checkpoints/backbone1.ckpt"
model_path = "/home/jovyan/shared/fourcastnet/FourCastNet/model_weights/FCN_weights_v0/backbone.ckpt"
global_means_path = "/home/jovyan/shared/fourcastnet/FourCastNet/additional/stats_v0/global_means.npy"
global_stds_path = "/home/jovyan/shared/fourcastnet/FourCastNet/additional/stats_v0/global_stds.npy"
time_means_path = "/home/jovyan/shared/fourcastnet/FourCastNet/additional/stats_v0/time_means.npy"

In [None]:
'''
The ordering of atmospheric variables along the channel dimension is as follows:
'''
variables = ['u10',
             'v10',
             't2m',
             'sp',
             'msl',
             't850',
             'u1000',
             'v1000',
             'z1000',
             'u850',
             'v850',
             'z850',
             'u500',
             'v500',
             'z500',
             't500',
             'z50' ,
             'r500',
             'r850',
             'tcwv']

In [None]:

from FourCastNet.utils.YParams import YParams

In [None]:
config_file = "/home/jovyan/FourCastNet/config/AFNO.yaml"
config_name = "afno_backbone"
params = YParams(config_file, config_name)
print("Model architecture used = {}".format(params["nettype"]))

In [None]:
# import model
from FourCastNet.networks.afnonet import AFNONet
from collections import OrderedDict

def load_model(model, params, checkpoint_file):
    ''' helper function to load model weights '''
    checkpoint_fname = checkpoint_file
    checkpoint = torch.load(checkpoint_fname)
    try:
        ''' FourCastNet is trained with distributed data parallel
            (DDP) which prepends 'module' to all keys. Non-DDP
            models need to strip this prefix '''
        new_state_dict = OrderedDict()
        for key, val in checkpoint['model_state'].items():
            name = key[7:]
            if name != 'ged':
                new_state_dict[name] = val
        model.load_state_dict(new_state_dict)
    except:
        model.load_state_dict(checkpoint['model_state'])
    model.eval() # set to inference mode
    return model

device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

# in and out channels: FourCastNet uses 20 input channels corresponding to 20 prognostic variables
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
params['N_in_channels'] = len(in_channels)
params['N_out_channels'] = len(out_channels)
params.means = np.load(global_means_path)[0, out_channels] # for normalizing data with precomputed train stats
params.stds = np.load(global_stds_path)[0, out_channels]
params.time_means = np.load(time_means_path)[0, out_channels]

# load the model
if params.nettype == 'afno':
    model = AFNONet(params).to(device)  # AFNO model
else:
    raise Exception("not implemented")
# load saved model weights
model = load_model(model, params, model_path)
model = model.to(device)

In [None]:
# move normalization tensors to gpu
# load time means: represents climatology
img_shape_x = 720
img_shape_y = 1440

# means and stds over training data
means = params.means
stds = params.stds

# load climatological means
time_means = params.time_means # temporal mean (for every pixel)
m = torch.as_tensor((time_means - means)/stds)[:, 0:img_shape_x]
m = torch.unsqueeze(m, 0)
# these are needed to compute ACC and RMSE metrics
m = m.to(device, dtype=torch.float)
std = torch.as_tensor(stds[:,0,0]).to(device, dtype=torch.float)

print("Shape of time means = {}".format(m.shape))
print("Shape of std = {}".format(std.shape))

In [None]:
# define metrics from definitions
def lat(j: torch.Tensor, num_lat: int) -> torch.Tensor:
    return 90. - j * 180./float(num_lat-1)

def latitude_weighting_factor(j: torch.Tensor, num_lat: int, s: torch.Tensor) -> torch.Tensor:
    return num_lat * torch.cos(3.1416/180. * lat(j, num_lat))/s

def weighted_rmse_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    #takes in arrays of size [n, c, h, w]  and returns latitude-weighted rmse for each channel
    num_lat = pred.shape[2]
    lat_t = torch.arange(start=0, end=num_lat, device=pred.device)
    s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat)))
    weight = torch.reshape(latitude_weighting_factor(lat_t, num_lat, s), (1, 1, -1, 1))
    result = torch.sqrt(torch.mean(weight * (pred - target)**2., dim=(-1,-2)))
    return result

def weighted_acc_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    #takes in arrays of size [n, c, h, w]  and returns latitude-weighted acc for each channel
    num_lat = pred.shape[2]
    lat_t = torch.arange(start=0, end=num_lat, device=pred.device)
    s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat)))
    weight = torch.reshape(latitude_weighting_factor(lat_t, num_lat, s), (1, 1, -1, 1))
    result = torch.sum(weight * pred * target, dim=(-1,-2)) / torch.sqrt(torch.sum(weight * pred * pred, dim=(-1,-2)) * torch.sum(weight * target *
    target, dim=(-1,-2)))
    return result

In [None]:
#BSD 3-Clause License
#
#Copyright (c) 2022, FourCastNet authors
#All rights reserved.
#
#Redistribution and use in source and binary forms, with or without
#modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
#THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
#FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
#DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
#SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
#CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#The code was authored by the following people:
#
#Jaideep Pathak - NVIDIA Corporation
#Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory
#Peter Harrington - NERSC, Lawrence Berkeley National Laboratory
#Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory 
#Ashesh Chattopadhyay - Rice University 
#Morteza Mardani - NVIDIA Corporation 
#Thorsten Kurth - NVIDIA Corporation 
#David Hall - NVIDIA Corporation 
#Zongyi Li - California Institute of Technology, NVIDIA Corporation 
#Kamyar Azizzadenesheli - Purdue University 
#Pedram Hassanzadeh - Rice University 
#Karthik Kashinath - NVIDIA Corporation 
#Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation


# Instructions: 
# Set Nimgtot correctly

import h5py
from mpi4py import MPI
import numpy as np
import time
from netCDF4 import Dataset as DS
import os

def print_hdf5_contents(name, obj):
        print(name)

def writetofile(src, dest, channel_idx, varslist, src_idx=0, frmt='nc'):
    if os.path.isfile(src):
        batch = 2**4
        rank = MPI.COMM_WORLD.rank
        Nproc = MPI.COMM_WORLD.size
        Nimgtot = 4#src_shape[0]

        Nimg = Nimgtot//Nproc
        base = rank*Nimg
        end = (rank+1)*Nimg if rank<Nproc - 1 else Nimgtot
        idx = base

        for variable_name in varslist[:3]:

            if frmt == 'nc':
                fsrc = DS(src, 'r', format="NETCDF4").variables[variable_name]
                print(fsrc.shape[0])
            elif frmt == 'h5':
                fsrc = h5py.File(src, 'r')[varslist[0]]
            print("fsrc shape", fsrc.shape)
            fdest = h5py.File(dest, 'a', driver='mpio', comm=MPI.COMM_WORLD)
            

            start = time.time()
            while idx<end:
                if end - idx < batch:
                    if len(fsrc.shape) == 4:
                        ims = fsrc[idx:end,src_idx]
                    else:
                        ims = fsrc[idx:end]
                    print(ims.shape)
                    if 'fields' not in fdest:
                        fdest.create_dataset('fields', 
                                             (Nimgtot, 20, 721, 1440),
                                             dtype=fsrc.dtype)
                    fdest['fields'][idx:end, channel_idx, :, :] = ims
                    break
                else:
                    if len(fsrc.shape) == 4:
                        ims = fsrc[idx:idx+batch,src_idx]
                    else:
                        ims = fsrc[idx:idx+batch]
                    #ims = fsrc[idx:idx+batch]
                    print("ims shape", ims.shape)
                    if 'fields' not in fdest:
                        fdest.create_dataset('fields', 
                                             (Nimgtot, 20, 721, 1440),
                                             dtype=fsrc.dtype)
                    fdest['fields'][idx:idx+batch, channel_idx, :, :] = ims
                    idx+=batch
                    ttot = time.time() - start
                    eta = (end - base)/((idx - base)/ttot)
                    hrs = eta//3600
                    mins = (eta - 3600*hrs)//60
                    secs = (eta - 3600*hrs - 60*mins)

            ttot = time.time() - start
            hrs = ttot//3600
            mins = (ttot - 3600*hrs)//60
            secs = (ttot - 3600*hrs - 60*mins)
            channel_idx += 1 
dest = '/home/jovyan/processed.h5'

src = '/home/jovyan/out1.nc'
#u10 v10 t2m
writetofile(src, dest, 0, ['u10'])
writetofile(src, dest, 1, ['v10'])
writetofile(src, dest, 2, ['t2m'])

#sp mslp
# writetofile(src, dest, 3, ['sp'])
# writetofile(src, dest, 4, ['msl'])

#t850
# writetofile(src, dest, 5, ['t'], 1)

#uvz1000
# writetofile(src, dest, 6, ['u'], 0)
# writetofile(src, dest, 7, ['v'], 0)
# writetofile(src, dest, 8, ['z'], 0)

#uvz850
# writetofile(src, dest, 9, ['u'], 1)
# writetofile(src, dest, 10, ['v'], 1)
# writetofile(src, dest, 11, ['z'], 1)

#uvz 500
# writetofile(src, dest, 12, ['u'], 2)
# writetofile(src, dest, 13, ['v'], 2)
# writetofile(src, dest, 14, ['z'], 2)

# #t500
# writetofile(src, dest, 15, ['t'], 2)

# #z50
# writetofile(src, dest, 16, ['z'], 3) 

# #r500 
# writetofile(src, dest, 17, ['r'], 2)

# #r850
# writetofile(src, dest, 18, ['r'], 1)
#######
#tcwv
# writetofile(src, dest, 19, ['tcwv'])

#sst
#src = '/project/projectdirs/dasrepo/ERA5/oct_2021_19_31_sfc.nc'
#writetofile(src, dest, 20, ['sst'])




In [None]:
import h5py
dest = '/home/jovyan/processed.h5'
src = '/home/jovyan/out1.nc'


datat = h5py.File(dest, 'r')['fields'][0:, range(0,20), 0:720]

# autoregressive inference helper
def inference(data_slice, model, prediction_length, idx):
    # create memory for the different stats
    n_out_channels = params['N_out_channels']
    acc = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
    rmse = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)

    # to conserve GPU mem, only save one channel (can be changed if sufficient GPU mem or move to CPU)
    targets = torch.zeros((prediction_length, 1, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
    predictions = torch.zeros((prediction_length, 1, img_shape_x, img_shape_y)).to(device, dtype=torch.float)


    with torch.no_grad():
        for i in range(prediction_length):
            if i == 0:
                print(data_slice[0:1].shape)
                first = data_slice[0:1]
                predictions[0,0] = first[0,idx]
                # predict
                future_pred = model(first)
            else:
                future_pred = model(future_pred) # autoregressive step

            if i < prediction_length - 1:
                predictions[i+1,0] = future_pred[0,idx]

            # compute metrics using the ground truth ERA5 data as "true" predictions
            # rmse[i] = weighted_rmse_channels(pred, tar) * std
            # acc[i] = weighted_acc_channels(pred-m, tar-m)
            # print('Predicted timestep {} of {}. {} RMS Error: {}, ACC: {}'.format(i, prediction_length, field, rmse[i,idx], acc[i,idx]))

            pred = future_pred

    # copy to cpu for plotting/vis
    acc_cpu = acc.cpu().numpy()
    rmse_cpu = rmse.cpu().numpy()
    predictions_cpu = predictions.cpu().numpy()
    targets_cpu = targets.cpu().numpy()

    print(acc_cpu, rmse_cpu, predictions_cpu, targets_cpu)

    return acc_cpu, rmse_cpu, predictions_cpu, targets_cpu

In [None]:
###### run inference
datat_standardized = (datat - means)/stds # standardize the data
datat_standardized = torch.as_tensor(datat_standardized).to(device, dtype=torch.float) # move to gpu for inference
acc_cpu, rmse_cpu, predictions_t, targets_t = inference(datat_standardized,
                                                            model, 20, idx=5)
# print(means[16])
# for i in range(20):
#     print(datat_standardized[0][i])
#     print(torch.min(datat_standardized[0][i]))
#     print(torch.max(datat_standardized[0][i]))
#     print('------------ ', i, ' -------------')
# print(predictions_t[0])
# for i in range(20):
#     is_homogeneous = torch.all(datat_standardized[0][i] == datat_standardized[0][i][0])
#     print(is_homogeneous)

import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np

# Create subplots with the Robinson projection centered on the Pacific (central_longitude=180)
central_longitude = 180
projection = ccrs.Robinson(central_longitude=central_longitude)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 5), subplot_kw={'projection': projection})
t = 4  # at 2x6 hours lead time

# Define the extent of the map (in degrees)
extent = (-180, 180, -90, 90)

# Define the color limits
vmin, vmax = -2, 1

# Plot the prediction data
ax.set_global()
im1 = ax.imshow(np.roll(predictions_t[t, 0], shift=predictions_t.shape[-1]//2, axis=-1), 
                   transform=ccrs.PlateCarree(central_longitude=0), 
                   cmap="bwr", extent=extent, origin='upper')#, vmin=vmin, vmax=vmax)
ax.coastlines()
# ax[0].add_feature(cfeature.BORDERS)
ax.set_title("FourCastNet prediction - temperature - Oct 30 2024")


# Add colorbar
fig.colorbar(im1, ax=ax, orientation='horizontal', fraction=0.046, pad=0.08)

for i in np.arange(2):
    gl = ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)
    gl.top_labels = False
    gl.right_labels = False

fig.tight_layout()
plt.show()


In [None]:
###### run inference
datat_standardized = (datat - means)/stds # standardize the data
datat_standardized = torch.as_tensor(datat_standardized).to(device, dtype=torch.float) # move to gpu for inference
acc_cpu, rmse_cpu, predictions, targets = inference(datat_standardized,
                                                            model, 20, idx=9)

# Create subplots with the Robinson projection centered on the Pacific (central_longitude=180)
central_longitude = 180
projection = ccrs.Robinson(central_longitude=central_longitude)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 5), subplot_kw={'projection': projection})
t = 4  # at 2x6 hours lead time

# Define the extent of the map (in degrees)
extent = (-180, 180, -90, 90)

# Define the color limits
vmin, vmax = -2, 1

# Plot the prediction data
ax.set_global()
im1 = ax.imshow(np.roll(predictions[t, 0], shift=predictions.shape[-1]//2, axis=-1), 
                   transform=ccrs.PlateCarree(central_longitude=0), 
                   cmap="jet", extent=extent, origin='upper')#, vmin=vmin, vmax=vmax)
ax.coastlines()
# ax[0].add_feature(cfeature.BORDERS)
ax.set_title("FourCastNet prediction - u850 - Oct 30 2024")


# Add colorbar
fig.colorbar(im1, ax=ax, orientation='horizontal', fraction=0.046, pad=0.08)

for i in np.arange(2):
    gl = ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)
    gl.top_labels = False
    gl.right_labels = False

fig.tight_layout()
plt.show()
