In [1]:
import os, sys, time
import numpy as np
import h5py
import torch
import torchvision
import torch.nn as nn
import torch.quantization
import matplotlib.pyplot as plt
import wandb
sys.path.insert(1, './FourCastNet/') # insert code repo into path

"""
*******************************
Usage:
Run this script by passing arguments to the command line as follows:

python base_script.py --torch.compile --quantization --num-gpus XX --prediction-length XX -- ensemble_size XX --variable


--torch.compile: Boolean, choose if running in compile mode for speedup
--quantization: Boolean, choose if you'd like to run with quantized linear layer weights
--num-gpus: Int, number of GPUs to use for distributed inference
--prediction-length: Int, number of timesteps for autoregressive loop
--variable: String, variable name you'd like to calculate. Options are:
    variables = ['u10' (10 metre zonal wind speed m s-1),
             'v10' (10 metre meridional wind speed m s-1),
             't2m' (2 metre temperature K),
             'sp' (Surface pressure Pa),
             'msl' (Mean sea level pressure Pa),
             't850' (temperature at the 850 hPa pressure level K),
             'u1000' (zonal wind at 1000 mbar pressure surface m s-1),
             'v1000' (meridional wind at 1000 mbar pressure surface m s-1),
             'z1000' (vertical wind at 1000 mbar pressure surface m s-1),
             'u850' (zonal wind at 850 mbar pressure surface m s-1),
             'v850' (meridional wind at 850 mbar pressure surface m s-1),
             'z850' (vertical wind at 850 mbar pressure surface m s-1),
             'u500' (zonal wind at 500 mbar pressure surface m s-1),
             'v500' (meridional wind at 500 mbar pressure surface m s-1),
             'z500' (vertical wind at 500 mbar pressure surface m s-1),
             't500' (temperature wind at 500 mbar pressure surface K),
             'z50'  (geopotential height at 50 hPa),
             'r500' (relative humidity at 500 mbar pressure surface),
             'r850' (relative humidity at 850 mbar pressure surface),
             'tcwv' (total column water vapor kg m-2)]

*******************************
"""


# you may need to
# !pip install ruamel.yaml einops timm
# (or conda install)

from utils.YParams import YParams
from networks.afnonet import AFNONet

from constants import VARIABLES
from proj_utils import load_model, inference
from quantize import replace_linear_with_target_and_quantize, W8A16LinearLayer, model_size

PLOT_INPUTS = False # to get a sample plot
QDTYPE = torch.int8




In [2]:
COMPILE = False
QUANTIZE = False
distributed = True
prediction_length = 20
ensemble_size = 20
field = 'u10'

base_path = "/pscratch/sd/m/mpowell/hpml/"
# data and model paths
data_path = f"{base_path}ccai_demo/data/FCN_ERA5_data_v0/out_of_sample"
data_file = os.path.join(data_path, "2018.h5")
model_path = f"{base_path}ccai_demo/model_weights/FCN_weights_v0/backbone.ckpt"
global_means_path = f"{base_path}ccai_demo/additional/stats_v0/global_means.npy"
global_stds_path = f"{base_path}ccai_demo/additional/stats_v0/global_stds.npy"
time_means_path = f"{base_path}ccai_demo/additional/stats_v0/time_means.npy"
land_sea_mask_path = f"{base_path}ccai_demo/additional/stats_v0/land_sea_mask.npy"


In [3]:
# default
config_file = "./FourCastNet/config/AFNO.yaml"
config_name = "afno_backbone"
params = YParams(config_file, config_name)
print("Model architecture used = {}".format(params["nettype"]))

if PLOT_INPUTS:
    sample_data = h5py.File(data_file, 'r')['fields']
    print('Total data shape:', sample_data.shape)
    timestep_idx = 0
    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15, 10))
    for i, varname in enumerate(['u10', 't2m', 'z500', 'tcwv']):
        cm = 'bwr' if varname == 'u10' or varname == 'z500' else 'viridis'
        varidx = VARIABLES.index(varname)
        ax[i//2][i%2].imshow(sample_data[timestep_idx, varidx], cmap=cm)
        ax[i//2][i%2].set_title(varname)
    fig.tight_layout()

# import model
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

# in and out channels: FourCastNet uses 20 input channels corresponding to 20 prognostic variables
in_channels = np.array(params.in_channels)
out_channels = np.array(params.out_channels)
params['N_in_channels'] = len(in_channels)
params['N_out_channels'] = len(out_channels)
params.means = np.load(global_means_path)[0, out_channels] # for normalizing data with precomputed train stats
params.stds = np.load(global_stds_path)[0, out_channels]
params.time_means = np.load(time_means_path)[0, out_channels]

# load the model
if params.nettype == 'afno':
    model = AFNONet(params).to(device)  # AFNO model
else:
    raise Exception("not implemented")
# load saved model weights
model = load_model(model, params, model_path)
model = model.to(device)

Model architecture used = afno
Load time: 1.3829610908869654 seconds


In [4]:
if QUANTIZE:
    param_size, buffer_size = model_size(model)
    init_size = param_size + buffer_size
    print(f"Initial model size: {(init_size) / (1024 ** 2):.2f} MB, {param_size / (1024 ** 2):.2f} MB (parameters), {buffer_size /(1024 ** 2):.2f} MB (buffers)")
    print(QDTYPE)
    replace_linear_with_target_and_quantize(model, W8A16LinearLayer, QDTYPE)
    param_size, buffer_size = model_size(model)
    final_size = param_size + buffer_size
    print(f"Final model size: {(final_size) / (1024 ** 2):.2f} MB, {param_size / (1024 ** 2):.2f} MB (parameters), {buffer_size /(1024 ** 2):.2f} MB (buffers)")
    wandb.log({"model_size_reduction":final_size/init_size}) 

if COMPILE:
    model = torch.compile(model, backend = 'inductor')


In [5]:
# move normalization tensors to gpu
# load time means: represents climatology
img_shape_x = 720
img_shape_y = 1440

# means and stds over training data
means = params.means
stds = params.stds

# load climatological means
time_means = params.time_means # temporal mean (for every pixel)
m = torch.as_tensor((time_means - means)/stds)[:, 0:img_shape_x]
m = torch.unsqueeze(m, 0)
# these are needed to compute ACC and RMSE metrics
m = m.to(device, dtype=torch.float)
std = torch.as_tensor(stds[:,0,0]).to(device, dtype=torch.float)

print("Shape of time means = {}".format(m.shape))
print("Shape of std = {}".format(std.shape))

# setup data for inference
dt = 1 # time step (x 6 hours)
ic = 0 # start the inference from here

idx_vis = VARIABLES.index(field) # also prints out metrics for this field

# get prediction length slice from the data
print('Loading inference data')
print('Inference data from {}'.format(data_file))
data = h5py.File(data_file, 'r')['fields'][ic:(ic+prediction_length*dt):dt,in_channels,0:img_shape_x]
print(data.shape)
print("Shape of data = {}".format(data.shape))


# Announce variable name:
print('Running inference for variable '.format(field))

Shape of time means = torch.Size([1, 20, 720, 1440])
Shape of std = torch.Size([20])
Loading inference data
Inference data from /pscratch/sd/m/mpowell/hpml/ccai_demo/data/FCN_ERA5_data_v0/out_of_sample/2018.h5
(20, 20, 720, 1440)
Shape of data = (20, 20, 720, 1440)
Running inference for variable 


In [6]:
from accelerate import Accelerator
from accelerate.utils import gather_object
from tqdm import tqdm
import wandb
import torch
import time

from proj_utils import lat, latitude_weighting_factor, weighted_rmse_channels

# A modified version of the inference script from project_utils.py
# Adapted for distributed learning across GPUs

In [7]:
def weighted_rmse_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    #takes in arrays of size [n, c, h, w]  and returns latitude-weighted rmse for each channel
    num_lat = pred.shape[2]
    lat_t = torch.arange(start=0, end=num_lat, device=pred.device)
    s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat)))
    weight = torch.reshape(latitude_weighting_factor(lat_t, num_lat, s), (1, 1, -1, 1))
    result = torch.sqrt(torch.mean(weight * (pred - target)**2., dim=(-1,-2)))
    return result

def weighted_acc_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    #takes in arrays of size [n, c, h, w]  and returns latitude-weighted acc for each channel
    num_lat = pred.shape[2]
    lat_t = torch.arange(start=0, end=num_lat, device=pred.device)
    s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat)))
    weight = torch.reshape(latitude_weighting_factor(lat_t, num_lat, s), (1, 1, -1, 1))
    result = torch.sum(weight * pred * target, dim=(-1,-2)) / torch.sqrt(torch.sum(weight * pred * pred, dim=(-1,-2)) * torch.sum(weight * target *
    target, dim=(-1,-2)))
    return result

In [8]:
hold = data[np.newaxis, :, :, :]
ensemble_init = np.tile(hold, (ensemble_size, 1, 1, 1, 1))

epsilon = 1e-8
random_values = np.random.uniform(0, 10, ensemble_size)

for i in range(ensemble_size):
    ensemble_init[i, :, :, :] *= epsilon * random_values[i]

In [9]:
ensemble_init.shape

(20, 20, 20, 720, 1440)

In [10]:
idx = idx_vis

accelerator = Accelerator()
device = accelerator.device
# Prepare model with Accelerator
model = accelerator.prepare(model)
# Distribute ensemble indices

print(accelerator.num_processes)

1


In [12]:
data_slice.shape[1]

20

In [13]:

local_ensemble_size = (ensemble_size + accelerator.num_processes - 1) // accelerator.num_processes
start_idx = accelerator.process_index * local_ensemble_size
end_idx = min(start_idx + local_ensemble_size, ensemble_size)

ens_idx_results = []
for ens in range(start_idx, end_idx):
    
    data_slice = ensemble_init[ens] 
    data_slice = torch.tensor(data_slice, device=device, dtype=torch.float)
    print('Data slice shape:')
    print(data_slice.shape)

    # create memory for the different stats
    n_out_channels = params['N_out_channels']
    acc = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)
    rmse = torch.zeros((prediction_length, n_out_channels)).to(device, dtype=torch.float)

    # to conserve GPU mem, only save one channel (can be changed if sufficient GPU mem or move to CPU)
    targets = torch.zeros((prediction_length, 1, img_shape_x, img_shape_y)).to(device, dtype=torch.float)
    predictions = torch.zeros((prediction_length, 1, img_shape_x, img_shape_y)).to(device, dtype=torch.float)

    total_time = 0
    with torch.no_grad():
        for i in range(data_slice.shape[0]):
            iter_start = time.perf_counter()
            if i == 0:
                first = data_slice[0:1]
                future = data_slice[1:2]
                pred = first
                tar = first
                # also save out predictions for visualizing channel index idx
                targets[0,0] = first[0,idx]
                predictions[0,0] = first[0,idx]
                # predict
                future_pred = model(first)
            else:
                if i < prediction_length - 1:
                    future = data_slice[i+1:i+2]
                future_pred = model(future_pred) # autoregressive step

            if i < prediction_length - 1:
                predictions[i+1,0] = future_pred[0,idx]
                targets[i+1,0] = future[0,idx]

            # compute metrics using the ground truth ERA5 data as "true" predictions
            rmse[i] = weighted_rmse_channels(pred, tar) * std
            acc[i] = weighted_acc_channels(pred-m, tar-m)
            iter_end = time.perf_counter()
            iter_time = iter_end - iter_start
            
            if accelerator.is_main_process: # Only write to wandb if we're in the main process
                print('Predicted timestep {} of {}. {} RMS Error: {}, ACC: {}'.format(i, prediction_length, field, rmse[i,idx], acc[i,idx]))
            
            pred = future_pred
            tar = future
            total_time += iter_time

    if accelerator.is_main_process: # Only write to wandb if we're in the main process
        print(f'Total inference time: {total_time:.2f}s, Average time per step: {total_time/prediction_length:.2f}s')
    
    # copy to cpu for plotting and visualization
    ens_idx_results.append({
        "acc": acc.cpu().numpy,
        "rmse": rmse.cpu().numpy,
        "total_inference_time": total_time,
        "avg_time": total_time/prediction_length,
        "ensemble_idx": ens,
    })

    #Gather results across processes
    all_results = gather_object(ens_idx_results)


Data slice shape:
torch.Size([20, 20, 720, 1440])


OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU 0 has a total capacity of 39.38 GiB of which 45.81 MiB is free. Process 1227695 has 35.33 GiB memory in use. Including non-PyTorch memory, this process has 3.98 GiB memory in use. Of the allocated memory 3.35 GiB is allocated by PyTorch, and 133.43 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)