In [1]:
import numpy as np
import xarray as xr
import torch
from spherical_unet.models.spherical_unet.unet_model import SphericalUNet
from pathlib import Path
import re
import time
from torch.utils.data import DataLoader

In [2]:
input_file = '/Users/kalairamea/Documents/DARPA-AIBEDO/aibedo_local/compress.isosph.CESM2-WACCM-FV2.historical.r1i1p1f1.Input.Exp8_fixed.nc'

In [3]:
output_file = '/Users/kalairamea/Documents/DARPA-AIBEDO/aibedo_local/compress.isosph.CESM2-WACCM-FV2.historical.r1i1p1f1.Output.nc'

In [4]:
inDS = xr.open_dataset(input_file)
outDS = xr.open_dataset(output_file)
n_pixels = len(inDS.ncells)

In [5]:
in_vars = [ 'crelSurf_pre', 'crel_pre', 'cresSurf_pre', 'cres_pre', 'netTOAcs_pre', 'lsMask', 'netSurfcs_pre']
out_vars = ['tas_pre', 'psl_pre', 'pr_pre']

In [6]:
weights_file = torch.load('/Users/kalairamea/Documents/DARPA-AIBEDO/aibedo_local/sunet_state_6.pt', 
                          map_location=torch.device('cpu'))
weights_file = {key.replace("module.", ""): value for key, value in weights_file.items()}

In [7]:
unet = SphericalUNet('icosahedron', n_pixels, 6, 'combinatorial', 3, 7, 3)

In [8]:
modelfilename = Path(output_file).stem
p = re.compile('compress.isosph.(.*).historical.r1i1p1f1.Output')
modelname = p.findall(modelfilename)[0]

In [9]:
data_all = []
for var in in_vars:
    temp_data = np.reshape(np.concatenate(inDS[var].data, axis=0), [-1, n_pixels, 1])
    data_all.append(temp_data)
dataset_in = np.concatenate(data_all, axis=2)

In [10]:
data_all = []
for var in out_vars:
    temp_data = np.reshape(np.concatenate(outDS[var].data, axis=0), [-1, n_pixels, 1])
    data_all.append(temp_data)
dataset_out = np.concatenate(data_all, axis=2)

In [11]:
before = time.perf_counter()
unet.load_state_dict(weights_file, strict=False)
after = time.perf_counter()
print("time taken to load model weights", after - before)
unet.eval()

time taken to load model weights 0.06885841700000128


SphericalUNet(
  (encoder): Encoder(
    (pooling): IcosahedronPool()
    (enc_l5): SphericalChebBN2(
      (spherical_cheb_bn_1): SphericalChebBN(
        (spherical_cheb): SphericalChebConv(
          (chebconv): ChebConv()
        )
        (batchnorm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (spherical_cheb_bn_2): SphericalChebBN(
        (spherical_cheb): SphericalChebConv(
          (chebconv): ChebConv()
        )
        (batchnorm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (enc_l4): SphericalChebBNPool(
      (pooling): IcosahedronPool()
      (spherical_cheb_bn): SphericalChebBN(
        (spherical_cheb): SphericalChebConv(
          (chebconv): ChebConv()
        )
        (batchnorm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (enc_l3): SphericalChebBNPool(
      (pooling): IcosahedronPool()
      (spherical_c

In [12]:
# timesteps to predict
inPredict = dataset_in[0:1]
groundTruth = dataset_out[0:1]

In [13]:
before = time.perf_counter()
preds = unet(torch.Tensor(inPredict)) # Change this to the desired number of timesteps
after = time.perf_counter()
print("time taken to perform prediction", after - before)

time taken to perform prediction 0.4900247919999998


In [14]:
#Predictions
pred_numpy = preds.detach().cpu().numpy()