### Load saved model to predict and save output

In [1]:
import yaml
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import s2sml.torch_s2s_dataset as torch_s2s_dataset
from s2sml.load_model import load_model

In [2]:
def device_assignment_using_trials():
    """
    Assign the device to use
    """
    is_cuda = torch.cuda.is_available()
    
    if is_cuda:
        
        which_gpu = np.argmin([
            torch.cuda.memory_reserved(0),
            torch.cuda.memory_reserved(1),
            torch.cuda.memory_reserved(2),
            torch.cuda.memory_reserved(3)
        ])
        
        torch.cuda.set_device(int(which_gpu))
        device = torch.device(which_gpu)
        torch.randn(1, 3, 12, 12).to(device)
            
    else:
        device = torch.device("cpu")
        
    return device


@torch.no_grad()
def validate(model, dataloader, device):
    """
    Validation function.

    Args:
        model: pytorch neural network
        dataloader: pytorch dataloader
        device: gpu or cpu
    """
    # set model to eval mode
    model.eval()
    
    # loop thru data in loader
    for i, data in enumerate(dataloader):
        
        # load input features
        img_noisy = data["input"].squeeze(dim=2)
        img_noisy = img_noisy.to(device, dtype=torch.float)
        
        # load labels
        img_label = data["label"].squeeze(dim=2)
        img_label = img_label.to(device, dtype=torch.float)
        
        # load masks
        img_lmask = data["lmask"].squeeze(dim=2)
        img_lmask = img_lmask.to(device, dtype=torch.float)
        
        outputs = model(img_noisy) # predict the model output
        
    return img_noisy, outputs, img_label, img_lmask


def load_model_and_predict(CONFIG_PATH, MODEL_PATH):
    
    with open(CONFIG_PATH) as cf:
        conf = yaml.load(cf, Loader=yaml.FullLoader)

    # Trainer params
    valid_batch_size = conf["trainer"]["valid_batch_size"]
    epochs = conf["trainer"]["epochs"]
    lr_patience = conf["trainer"]["lr_patience"]
    stopping_patience = conf["trainer"]["stopping_patience"]
    callback_metric = conf["trainer"]["metric"]
    callback_direction = conf["trainer"]["direction"]

    nc = conf["model"]["in_channels"]
    feattopo = conf["data"]["feat_topo"]
    featcoord = conf["data"]["feat_coord"]

    save_loc = conf["save_loc"]
    homedir = conf["data"]["homedir"]

    # Data
    var = conf["data"]["var"]
    wks = conf["data"]["wks"]
    dxdy = conf["data"]["dxdy"]
    lat0 = conf["data"]["lat0"]
    lon0 = conf["data"]["lon0"]
    norm = conf["data"]["norm"]
    norm_pixel = conf["data"]["norm_pixel"]
    dual_norm = conf["data"]["dual_norm"]
    region = conf["data"]["region"]
    
    # assign device
    device = device_assignment_using_trials()
    
    # Load model
    model = load_model(conf["model"]).to(device)
    model.load_state_dict(
        torch.load(
            MODEL_PATH, map_location=torch.device('cpu')
        )['model_state_dict']
    )
    model.eval()

    train = torch_s2s_dataset.S2SDataset(
        week=wks,
        variable=var,
        norm=norm,
        norm_pixel=norm_pixel,
        dual_norm=dual_norm,
        region=region,
        minv=None,
        maxv=None,
        mini=None,
        maxi=None,
        mnv=None,
        stdv=None,
        mni=None,
        stdi=None,
        lon0=lon0,
        lat0=lat0,
        dxdy=dxdy,
        feat_topo=feattopo,
        feat_lats=featcoord,
        feat_lons=featcoord,
        startdt="1999-02-01",
        enddt="2014-12-31",
        homedir=homedir,
    )

    if not norm or norm == "None":

        # min-max
        tmin = None # era5
        tmax = None # era5
        tmin_inp = None # cesm
        tmax_inp = None # cesm

        # z-score
        tmu = None
        tsig = None
        tmu_inp = None
        tsig_inp = None

    elif norm in ["minmax", "negone"]:

        # min-max
        tmin = train.min_val
        tmax = train.max_val
        tmin_inp = train.min_inp
        tmax_inp = train.max_inp

        # z-score
        tmu = None
        tsig = None
        tmu_inp = None
        tsig_inp = None

    elif norm == "zscore":

        # min-max
        tmin = None
        tmax = None
        tmin_inp = None
        tmax_inp = None

        # z-score
        tmu = train.mean_val
        tsig = train.std_val
        tmu_inp = train.mean_inp
        tsig_inp = train.std_inp

    tests = torch_s2s_dataset.S2SDataset(
        week=wks,
        variable=var,
        norm=norm,
        norm_pixel=norm_pixel,
        dual_norm=dual_norm,
        region=region,
        minv=tmin,
        maxv=tmax,
        mini=tmin_inp,
        maxi=tmax_inp,
        mnv=tmu,
        stdv=tsig,
        mni=tmu_inp,
        stdi=tsig_inp,
        lon0=lon0,
        lat0=lat0,
        dxdy=dxdy,
        feat_topo=feattopo,
        feat_lats=featcoord,
        feat_lons=featcoord,
        startdt="2018-01-01",
        enddt="2020-12-31",
        homedir=homedir,
    )

    tests_loader = DataLoader(
        tests, batch_size=len(tests), shuffle=False, drop_last=False
    )
    
    tmp_inp, tmp_out, tmp_lbl, tmp_msk = validate(model, tests_loader, device)
    
    ds_coords = xr.open_dataset('/glade/derecho/scratch/molina/ml_coordsv2.nc')

    ds_output = xr.Dataset(
        data_vars=dict(
            cesm_input=(["samples", "channel", "x", "y"], tmp_inp[:,:,12:-12,6:-5].numpy()),
            era5_label=(["samples", "x", "y"], tmp_lbl[:,:,12:-12,6:-5].numpy().squeeze()),
            ml_predict=(["samples", "x", "y"], tmp_out[:,:,12:-12,6:-5].numpy().squeeze()),
            land_masks=(["samples", "x", "y"], tmp_msk[:,:,12:-12,6:-5].numpy().squeeze()),
        ),
        coords=dict(
            x=(["x"], ds_coords.coords["x"].values),
            y=(["y"], ds_coords.coords["y"].values),
        ),
    )
    
    return ds_output

In [3]:
confg_path = "/glade/work/molina/studies/test_best/config/best_1.yml"
model_path = "/glade/work/molina/studies/test_best/echo_stuff/trial/model_.pt"

In [4]:
ds_final = load_model_and_predict(confg_path, model_path)

In [5]:
ds_final

In [6]:
ds_final.to_netcdf("/glade/work/molina/studies/test_best/echo_stuff/trial/inference_data.nc")