In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
# Make sure we're in the right directory
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")
os.getcwd()

In [None]:
import sys
import h5py
import json
import time
import wandb
import hydra
import logging
import xarray as xr
import numpy as np
from typing import *
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from aibedo.utilities.plotting import data_snapshots_plotting, data_mean_plotting
from aibedo.utilities.wandb_api import reload_checkpoint_from_wandb

In [None]:
run_id = "3l3tun8f"   # wandb run id
num_workers = 2
test_set = "merra2"
DATA_DIR = "../Data/aibedo"
overrides = [f'datamodule.num_workers={num_workers}',  f'datamodule.partition={[0.85, 0.15, test_set]}',
             'datamodule.eval_batch_size=5', 
             'verbose=False', 
             f'datamodule.data_dir={DATA_DIR}'
            ]
trainer = pl.Trainer(gpus=-1, max_epochs=1)

In [None]:
def get_model_and_dm_from_run_id(run_id):
    values = reload_checkpoint_from_wandb(run_id=run_id, project='AIBEDO', override_key_value=overrides)
    return  values['model'], values['datamodule'], values['config']

## The following will evaluate the model on the validation set

In [None]:
model, dm, cfg = get_model_and_dm_from_run_id(run_id)
dm.setup(stage="val")   # stage can be 'val', 'test', or 'predict' and will only load the respective data

In [None]:
dm._data_predict = dm._data_val   # if you don't do this, the prediction data_loader will be the test one

In [None]:
ds = dm.get_predictions_xarray(model, dataloader=dm.val_dataloader())   # get predictions on valid set

In [None]:
ds.attrs['variable_names'] = ""
PREDS_DIR = "./out_dir/preds"    # where to save nc file
ds.to_netcdf(PREDS_DIR + '/MLP_CESM2_preds_targets_errors.nc')

In [None]:
FIG_DIR="./out_dir/figs"

fig, axs = data_mean_plotting(ds)
plt.savefig(f"{FIG_DIR}/MLP_mean_plots_CESM2_val_set.png")