In [None]:
import os, sys, time
import numpy as np
import h5py
import torch
import torch.distributed as dist

import matplotlib.pyplot as plt
sys.path.insert(1, './FourCastNet/') # insert code repo into path

# you may need to
# !pip install ruamel.yaml einops timm
# (or conda install)

from utils.YParams import YParams
from networks.afnonet import AFNONet

from constants import VARIABLES
from proj_utils import load_model, inference, lat, latitude_weighting_factor, weighted_rmse_channels

PLOT_INPUTS = False # to get a sample plot
COMPILE = True # to use torch.compile()

# DO THIS WITHIN YOUR SCRATCH AND SET PATH
# wget https://portal.nersc.gov/project/m4134/ccai_demo.tar
# tar -xvf ccai_demo.tar
# rm ccai_demo.tar

base_path = "/pscratch/sd/m/mpowell/hpml/"

# data and model paths
data_path = f"{base_path}ccai_demo/data/FCN_ERA5_data_v0/out_of_sample"
data_file = os.path.join(data_path, "2018.h5")
model_path = f"{base_path}ccai_demo/model_weights/FCN_weights_v0/backbone.ckpt"
global_means_path = f"{base_path}ccai_demo/additional/stats_v0/global_means.npy"
global_stds_path = f"{base_path}ccai_demo/additional/stats_v0/global_stds.npy"
time_means_path = f"{base_path}ccai_demo/additional/stats_v0/time_means.npy"
land_sea_mask_path = f"{base_path}ccai_demo/additional/stats_v0/land_sea_mask.npy"

# default
config_file = "./FourCastNet/config/AFNO.yaml"
config_name = "afno_backbone"
params = YParams(config_file, config_name)
print("Model architecture used = {}".format(params["nettype"]))

if PLOT_INPUTS:
    sample_data = h5py.File(data_file, 'r')['fields']
    print('Total data shape:', sample_data.shape)
    timestep_idx = 0
    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 10))
    for i, varname in enumerate(['u10', 't2m', 'z500', 'tcwv']):
        cm = 'bwr' if varname == 'u10' or varname == 'z500' else 'viridis'
        varidx = VARIABLES.index(varname)
        ax[i//2][i%2].imshow(sample_data[timestep_idx, varidx], cmap=cm)
        ax[i//2][i%2].set_title(varname)
    fig.tight_layout()

# import model
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

# in and out channels: FourCastNet uses 20 input channels corresponding to 20 prognostic variables
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
params['N_in_channels'] = len(in_channels)
params['N_out_channels'] = len(out_channels)
params.means = np.load(global_means_path)[0, out_channels] # for normalizing data with precomputed train stats
params.stds = np.load(global_stds_path)[0, out_channels]
params.time_means = np.load(time_means_path)[0, out_channels]

# load the model
if params.nettype == 'afno':
    model = AFNONet(params).to(device)  # AFNO model
else:
    raise Exception("not implemented")
# load saved model weights
model = load_model(model, params, model_path)
model = model.to(device)

In [None]:
# move normalization tensors to gpu
# load time means: represents climatology
img_shape_x = 720
img_shape_y = 1440

# means and stds over training data
means = params.means
stds = params.stds

# load climatological means
time_means = params.time_means # temporal mean (for every pixel)
m = torch.as_tensor((time_means - means)/stds)[:, 0:img_shape_x]
m = torch.unsqueeze(m, 0)
# these are needed to compute ACC and RMSE metrics
m = m.to(device, dtype=torch.float)
std = torch.as_tensor(stds[:,0,0]).to(device, dtype=torch.float)

print("Shape of time means = {}".format(m.shape))
print("Shape of std = {}".format(std.shape))

# setup data for inference
dt = 1 # time step (x 6 hours)
ic = 0 # start the inference from here
prediction_length = 20 # number of steps (x 6 hours)

# which field to track for visualization
field = 'u10'
idx_vis = VARIABLES.index(field) # also prints out metrics for this field

# get prediction length slice from the data
print('Loading inference data')
print('Inference data from {}'.format(data_file))
data = h5py.File(data_file, 'r')['fields'][ic:(ic+prediction_length*dt):dt,in_channels,0:img_shape_x]
print(data.shape)
print("Shape of data = {}".format(data.shape))

data = (data - means)/stds # standardize the data

In [None]:
data.shape

In [None]:
data = data[np.newaxis, :, :, :]
ensemble_size = 2

# replicate to create an ensemble and add a small perturbation (e.g., 1e-3 scaling factor)
ensemble_init = data.repeat(ensemble_size, axis = 0)
ensemble_init = torch.tensor(ensemble_init, device=device, dtype=torch.float)

epsilon = 1e-3  # perturbation magnitude
ensemble_init += epsilon * torch.randn_like(ensemble_init.clone().detach())

# Set the prediction length (number of autoregressive steps)
prediction_length = 20  # as before

In [None]:
def weighted_rmse_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    #takes in arrays of size [n, c, h, w]  and returns latitude-weighted rmse for each channel
    num_lat = pred.shape[2]
    lat_t = torch.arange(start=0, end=num_lat, device=pred.device)
    s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat)))
    weight = torch.reshape(latitude_weighting_factor(lat_t, num_lat, s), (1, 1, -1, 1))
    result = torch.sqrt(torch.mean(weight * (pred - target)**2., dim=(-1,-2)))
    return result

def weighted_acc_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    #takes in arrays of size [n, c, h, w]  and returns latitude-weighted acc for each channel
    num_lat = pred.shape[2]
    lat_t = torch.arange(start=0, end=num_lat, device=pred.device)
    s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat)))
    weight = torch.reshape(latitude_weighting_factor(lat_t, num_lat, s), (1, 1, -1, 1))
    result = torch.sum(weight * pred * target, dim=(-1,-2)) / torch.sqrt(torch.sum(weight * pred * pred, dim=(-1,-2)) * torch.sum(weight * target *
    target, dim=(-1,-2)))
    return result

In [None]:
def setup_distributed():
    """Initialize the distributed environment."""
    if 'RANK' not in os.environ:
        os.environ['RANK'] = '0'
    if 'WORLD_SIZE' not in os.environ:
        os.environ['WORLD_SIZE'] = '2'
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    dist.init_process_group(backend='nccl', init_method='env://')
    torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
    
    return int(os.environ['RANK']), int(os.environ['WORLD_SIZE'])

def cleanup():
    """Clean up the distributed environment."""
    dist.destroy_process_group()

In [None]:
# Set up distributed environment
rank, world_size = setup_distributed()

In [None]:
# Share model across processes
model = torch.nn.parallel.DistributedDataParallel(model)

In [None]:
# Divide ensemble members among processes
local_ensemble_size = (ensemble_size + world_size - 1) // world_size
start_idx = rank * local_ensemble_size
end_idx = min(start_idx + local_ensemble_size, ensemble_size)
local_ensemble_range = range(start_idx, end_idx)

In [None]:
# Process local ensemble members
local_results = []
for i in local_ensemble_range:
    data_slice = ensemble_init[i]
    idx = idx_vis

    with torch.no_grad():
        dummy_input = torch.randn(1, data_slice.shape[1], img_shape_x, img_shape_y).to(device)
        _ = model(dummy_input)

    # Create memory for the different stats
    n_out_channels = params['N_out_channels']
    acc = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
    rmse = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)

    # To conserve GPU mem, only save one channel (can be changed if sufficient GPU mem or move to CPU)
    targets = torch.zeros((prediction_length, 1, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
    predictions = torch.zeros((prediction_length, 1, img_shape_x, img_shape_y)).to(device, dtype=torch.float)

    total_time = 0
    with torch.no_grad():
        for j in range(data_slice.shape[0]):
            iter_start = time.time()
            if j == 0:
                first = data_slice[0:1]
                future = data_slice[1:2]
                pred = first
                tar = first
                # Predict
                future_pred = model(first)
            else:
                if j < prediction_length - 1:
                    future = data_slice[j+1:j+2]
                future_pred = model(future_pred)  # Autoregressive step

            if j < prediction_length - 1:
                predictions[j+1, 0] = future_pred[0, idx]
                targets[j+1, 0] = future[0, idx]
            rmse[j] = weighted_rmse_channels(pred, tar) * std
            acc[j] = weighted_acc_channels(pred-m, tar-m)
            iter_time = time.time() - iter_start
            
            if rank == 0:  # Only main process logs
                print(f'Ensemble {i}, Predicted timestep {j} of {prediction_length}. {field} RMS Error: {rmse[j, idx]}, ACC: {acc[j, idx]}')

            pred = future_pred
            tar = future
            total_time += iter_time
            
    avg_time = total_time / prediction_length
    if rank == 0:
        print(f'Ensemble {i}, Total inference time: {total_time:.2f}s, Average time per step: {avg_time:.2f}s')

    # Save local results
    local_results.append({
        'ensemble_idx': i,
        'acc': acc.cpu(),
        'rmse': rmse.cpu(),
        'total_time': total_time,
        'avg_time': avg_time
    })

# Gather results from all processes
all_results = [None for _ in range(world_size)]
dist.all_gather_object(all_results, local_results)

# Flatten the list of results
if rank == 0:  # Only main process handles combined results
    combined_results = []
    for process_results in all_results:
        combined_results.extend(process_results)
    
    # Sort by ensemble index
    combined_results.sort(key=lambda x: x['ensemble_idx'])
    
    # Log combined metrics
    for result in combined_results:
        i = result['ensemble_idx']
        acc = result['acc']
        rmse = result['rmse']
        predictions = result['predictions'] 
        targets = result['targets']
        total_time = result['total_time']
        avg_time = result['avg_time']
        
# Clean up
cleanup()