In [2]:
import numpy as np
import xarray as xr
import torch
from spherical_unet.models.spherical_unet.unet_model import SphericalUNet
from pathlib import Path
import re
import time

In [3]:
input_file = '/Users/kramea/Documents/AIBEDO_dir/data_aibedo/compress.isosph.CESM2.historical.r1i1p1f1.Input.Exp8_fixed.nc'

In [4]:
output_file = '/Users/kramea/Documents/AIBEDO_dir/data_aibedo/compress.isosph.CESM2.historical.r1i1p1f1.Output.nc'

In [5]:
inDS = xr.open_dataset(input_file)
outDS = xr.open_dataset(output_file)
n_pixels = len(inDS.ncells)

In [6]:
in_vars = [ 'crelSurf_pre', 'crel_pre', 'cresSurf_pre', 'cres_pre', 'netTOAcs_pre', 'lsMask', 'netSurfcs_pre']
out_vars = ['tas_pre', 'psl_pre', 'pr_pre']

In [7]:
weights_file = torch.load('/Users/kramea/Documents/AIBEDO_dir/data_aibedo/sunet_state_6.pt', 
                          map_location=torch.device('cpu'))
weights_file = {key.replace("module.", ""): value for key, value in weights_file.items()}

In [8]:
unet = SphericalUNet('icosahedron', n_pixels, 6, 'combinatorial', 3, 7, 3)

In [9]:
modelfilename = Path(output_file).stem
p = re.compile('compress.isosph.(.*).historical.r1i1p1f1.Output')
modelname = p.findall(modelfilename)[0]

In [None]:
data_all = []
for var in in_vars:
    temp_data = np.reshape(np.concatenate(inDS[var].data, axis=0), [-1, n_pixels, 1])
    data_all.append(temp_data)
dataset_in = np.concatenate(data_all, axis=2)

In [None]:
data_all = []
for var in out_vars:
    temp_data = np.reshape(np.concatenate(outDS[var].data, axis=0), [-1, n_pixels, 1])
    data_all.append(temp_data)
dataset_out = np.concatenate(data_all, axis=2)

In [None]:
before = time.perf_counter()
unet.load_state_dict(weights_file, strict=False)
after = time.perf_counter()
print("time taken to load model weights", after - before)
unet.eval()

In [None]:
# timesteps to predict
inPredict = dataset_in[100:120]
outPredict = dataset_out[100:120]

In [None]:
before = time.perf_counter()
preds = unet(torch.Tensor(inPredict)) # Change this to the desired number of timesteps
after = time.perf_counter()
print("time taken to perform prediction", after - before)

In [None]:
#Predictions
pred_numpy = preds.detach().cpu().numpy()

In [None]:
np.save(("/Users/kramea/Documents/AIBEDO_dir/data_aibedo/"+ modelname + "_predictions.npy"), pred_numpy)
np.save(("/Users/kramea/Documents/AIBEDO_dir/data_aibedo/" + modelname + "_groundtruth.npy"), outPredict)