In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import xarray as xr
import torch
from pathlib import Path
from hython.utils import write_to_zarr, build_mask_dataarray
import matplotlib.pyplot as plt
from pathlib import Path
from hython.datasets.datasets import get_dataset
from numcodecs import Blosc

from torch.utils.data import Dataset, DataLoader

In [None]:
from hython.trainer import XBatcherTrainer
from hython.trainer import train_val
from hython.sampler import SamplerBuilder
from hython.metrics import MSEMetric
from hython.losses import RMSELoss
from hython.utils import read_from_zarr, set_seed
from hython.models.cudnnLSTM import CuDNNLSTM
from hython.trainer import RNNTrainer, RNNTrainParams
from hython.normalizer import Normalizer

import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset

In [None]:
#from torchvision.datasets import MovingMNIST

In [None]:
dir_surr_input = Path("/mnt/CEPH_PROJECTS/InterTwin/Wflow/models/adg1km_eobs")


static = xr.open_dataset(dir_surr_input / "staticmaps.nc")#.chunk("auto")
dynamic = xr.open_dataset(dir_surr_input/ "forcings.nc").chunk("auto") # C T W H => N T C H W
target = xr.open_dataset(dir_surr_input / "run_default/output.nc").sel(layer=1).isel(lat=slice(None, None, -1))#.chunk("auto") # C T W H => N T C H W


surr_model_output = "convlstm.pt"
experiment = "exp1" # experiment name

dir_surr_output = "/mnt/CEPH_PROJECTS/InterTwin/hydrologic_data/surrogate_model"

SEED = 1696

dynamic_names = ["precip", "pet", "temp"] 
static_names = [ 'thetaS', 'thetaR', 'RootingDepth', 'Swood','KsatVer', "Sl"] 
target_names = [ "vwc","actevap"] 

mask_from_static = ["wflow_lakeareas"]
rename_mask = ["mask_lake"]


dataset = "XBatchDataset"


# DL model hyper parameters
HIDDEN_SIZE = 24
DYNAMIC_INPUT_SIZE = len(dynamic_names)
STATIC_INPUT_SIZE = len(static_names)
KERNEL_SIZE = (3, 3)
NUM_LSTM_LAYER = 2
OUTPUT_SIZE = len(target_names)


TARGET_WEIGHTS = {t:0.5 for t in target_names}



# train/test parameters
train_temporal_range = slice("2016-01-01","2018-12-31")
test_temporal_range = slice("2019-01-01", "2020-12-31")

EPOCHS = 90
BATCH = 64
TEMPORAL_SUBSAMPLING = True
TEMPORAL_SUBSET = [150, 150] 
SEQ_LENGTH = 360


assert sum(v for v in TARGET_WEIGHTS.values()) == 1, "check target weights"
TARGET_INITIALS = "".join([i[0].capitalize() for i in target_names])


In [None]:
wd = Path("/mnt/CEPH_PROJECTS/InterTwin/Wflow/models/adg1km_eobs")

In [None]:
try:
    dynamic = dynamic.rename({"latitude":"lat", "longitude":"lon"})
    static = static.rename({"latitude":"lat", "longitude":"lon"})
except:
    pass

In [None]:
# masking 

mask_missing = np.isnan(static[static_names[0]]).rename("mask_missing")

masks = []
masks.append(mask_missing)

for i, mask in enumerate(mask_from_static):
    masks.append((static[mask] > 0).astype(np.bool_).rename(rename_mask[i]))

masks = build_mask_dataarray(masks, names = ["mask_missing"]+ rename_mask).any(dim="mask_layer")

In [None]:
# filter 
dynamic = dynamic[dynamic_names]
target = target[target_names]
static = static.drop_dims("time").sel(layer=1)[ static_names ] 

In [None]:
# expand static to dynamic 
time_da = xr.DataArray(dynamic.time.values, [('time', dynamic.time.values)])
static = static.expand_dims({"time":time_da})

In [None]:
dynamic = dynamic.to_array() # C T H W
static = static.to_array() # C T H W
target = target.to_array() # C T H W

In [None]:
dynamic_train = dynamic.sel(time=train_temporal_range)
static_train = static.sel(time=train_temporal_range)
target_train = target.sel(time=train_temporal_range)

dynamic_test = dynamic.sel(time=test_temporal_range)
static_test = static.sel(time=test_temporal_range)
target_test = target.sel(time=test_temporal_range)

In [None]:
target_train.shape, static_train.shape, dynamic_train.shape, masks.shape

In [None]:
# normalize

normalizer_dynamic = Normalizer(method="standardize", type="spacetime", shape="2D")

normalizer_static = Normalizer(method="standardize", type="space", shape="2D")

normalizer_target = Normalizer(method="standardize", type="spacetime", shape="2D")

normalizer_dynamic.compute_stats(dynamic_train)
normalizer_static.compute_stats(static_train)
normalizer_target.compute_stats(target_train)

# TODO: save stats, implement caching of stats to save computation

dynamic = normalizer_dynamic.normalize(dynamic_train)
static = normalizer_static.normalize(static_train)
target = normalizer_target.normalize(target_train)

dynamic_test = normalizer_dynamic.normalize(dynamic_test)
static_test = normalizer_static.normalize(static_test)
target_test = normalizer_target.normalize(target_test)

In [None]:
# masking 
dynamic_train = dynamic_train.where(~masks, 0)

static_train = static_train.where(~masks, 0)

target_train = target_train.where(~masks, 0)


dynamic_test = dynamic_test.where(~masks, 0)

static_test = static_test.where(~masks, 0)

target_test = target_test.where(~masks, 0)

In [None]:
# compressor = Blosc(cname='zl4', clevel=9, shuffle=Blosc.BITSHUFFLE)

# ss = params.drop_dims("time")[[ 'thetaS', 'thetaR', 'RootingDepth', 'Swood','KsatVer', "Sl"]].expand_dims({"time": ds.time}).chunk({"time":500, "latitude":50, "longitude":50})

# ss.to_zarr(wd / "test.zarr",storage_options={"compressor":compressor})

# ss = xr.open_dataset( wd / "test.zarr", engine = "zarr")

# time, lat, lon = 365, 16, 16 

In [None]:
import xbatcher

In [None]:
xbatcher.BatchGenerator?

## Test xbatcher

Check whether I can get and index of the chunks so that I can do:

$$ \mathbf{f}: \mathbb{R}^1 \rightarrow \mathbb{R}^2$$

f(0) => (0,0) </br>
f(1) => (0,1) </br>
f(2) => (1,0) </br>
f(3) => (1,1) </br>

In [None]:
xgen = xbatcher.BatchGenerator(
    dynamic_train,
    input_dims={"lat":80, "lon":80, "time":360}, # dimension size of the sample cube
    preload_batch=True,
    #batch_dims={"time":60, "lat":80, "lon":80},
    #concat_input_dims= True,
    #input_overlap={"time":10, "lat":10, "lon":10} # overlaps between dimensions of each cube
)

In [None]:
len(xgen)

In [None]:
sample = xgen[11]

In [None]:
ds = xr.merge(xgen)

In [None]:
ds.isel(time=199).precip.plot()

In [None]:
# this should work for the convLSTM 
sample_convlstm = sample.to_stacked_array(new_dim="feat", sample_dims=("lat","lon", "time"))

In [None]:
sample_convlstm.shape

In [None]:
# how to handle null cubes
sample_convlstm.isnull().all().item(0)

In [None]:
#sample_convlstm.sel(feat="precip").isel(time=10).plot()

In [None]:
# this should work for the 1D lstm
sample_lstm = sample_convlstm.stack(gridcell=["lat","lon"])

In [None]:
#ds.sel(variable="precip").to_dataset(name="ds").to_stacked_array(new_dim="batch", sample_dims=("lon","lat"))

## Example xbatcher from https://github.com/earth-mover/dataloader-demo/blob/main/main.py

In [None]:
def print_json(obj):
    print(json.dumps(obj))
from torch.utils.data import Dataset as TorchDataset
import multiprocessing
class XBatcherPyTorchDataset(TorchDataset):
    def __init__(self, batch_generator: xbatcher.BatchGenerator):
        self.bgen = batch_generator

    def __len__(self):
        return len(self.bgen)

    def __getitem__(self, idx):
        t0 = time.time()
        print_json(
            {
                "event": "get-batch start",
                "time": t0,
                "idx": idx,
                "pid": multiprocessing.current_process().pid,
            }
        )
        # load before stacking
        batch = self.bgen[idx].load()

        print(batch)

        # Use to_stacked_array to stack without broadcasting,
        stacked = batch.to_stacked_array(
            new_dim="batch", sample_dims=("time", "longitude", "latitude")
        ).transpose("time", "batch", ...)
        print(stacked)
        x = torch.tensor(stacked.data)
        t1 = time.time()
        print_json(
            {
                "event": "get-batch end",
                "time": t1,
                "idx": idx,
                "pid": multiprocessing.current_process().pid,
                "duration": t1 - t0,
            }
        )
        return x

In [None]:
import time, json
def setup(source="gcs", patch_size: int = 48, input_steps: int = 3):
    if source == "gcs":
        ds = xr.open_dataset(
            "gs://weatherbench2/datasets/era5/1959-2022-6h-128x64_equiangular_with_poles_conservative.zarr",
            engine="zarr",
            chunks={},
        )
    elif source == "arraylake":
        config.set({"s3.endpoint_url": "https://storage.googleapis.com", "s3.anon": True})
        ds = (
            Client()
            .get_repo("earthmover-public/weatherbench2")
            .to_xarray(
                group="datasets/era5/1959-2022-6h-128x64_equiangular_with_poles_conservative",
                chunks={},
            )
        )
    else:
        raise ValueError(f"Unknown source {source}")

    DEFAULT_VARS = [
        "10m_wind_speed",
        "2m_temperature",
        "specific_humidity",
    ]

    ds = ds[DEFAULT_VARS]
    patch = dict(
        latitude=patch_size,
        longitude=patch_size,
        time=input_steps,
    )
    overlap = dict(latitude=32, longitude=32, time=input_steps // 3 * 2)

    bgen = xbatcher.BatchGenerator(
        ds,
        input_dims=patch,
        input_overlap=overlap,
        preload_batch=False,
    )

    dataset = XBatcherPyTorchDataset(bgen)

    return dataset

In [None]:
xgen = setup()

In [None]:
len(xgen)

In [None]:
ds = xgen[1]

In [None]:
res = res[1]
res.shape, res.dims, res.coords, len(xgen)

In [None]:
# for i, b in enumerate(xgen):
#     print(i)
#     print(b.shape)
#     plt.figure()
#     try:
#         b.isel(variable_input=1, sample=0).plot()
#         b.isel(variable_input=1, lat_input=100, lon_input=100).plot(x="time")
#     except Exception as e:
#         print(e)
#         b.isel(variable=1, time=1).plot()
#         b.isel(variable=1, lat=10, lon=10).plot()
#     if i > 20:
#         break

## Test custom "xbatcher"

The current implementation of xbatcher looks cool but it lacks:
- way to index tiles and sequence, it only indexes cubes, therefore how I can subsample only in one of the two dimensions?
- as I don't know the ordering of the cube samples, how can I subsample?
- how to handle NULL cubes
- It drops the "edges" of the dimension (i.e. it does not provide a collate function and cut short returning sample of the same dimension size)

In [None]:
from hython.sampler import compute_grid_indices
import itertools

In [None]:
# Indexing cubelets

In [None]:
dynamic

In [None]:
data = dynamic#.transpose("lat", "lon", "time") # x, y, t

data.sizes

In [None]:
data

In [None]:
xsize, ysize, tsize = 20, 20, 360 #4748 # cubelet dimension size
xover, yover, tover= 0, 0, 0 # cubelets overlaps


space_idx = compute_grid_indices(grid=data)

print(space_idx.shape)

In [None]:
# # create cubelets, keep or not edge cubelets
# keep_edge_cubelets = True

# space_indices = []
# space_indices_all_missing = []
# space_slices = []
# idx = 0
# for ix,iy in zip(range(0, data.shape[0], xsize - xover), range(0, data.shape[1], ysize - yover)):
#     xslice = slice(ix, ix + xsize)
#     yslice = slice(iy, iy + ysize)
#     cubelet = space_idx[xslice, yslice]
#     mask_cubelet = masks[xslice, yslice]
    
#     #plt.figure(figsize=(2,2))
#     #plt.imshow(cubelet)
#     #plt.annotate(idx, list(map(lambda x: int(x/4), cubelet.shape)) ,color="red", size=20)
#     #plt.colorbar()
#     space_slices.append([xslice, yslice])
#     space_indices.append(idx)
#     if mask_cubelet.all().item(0):
#         space_indices_all_missing.append(idx)
#     idx += 1

In [None]:
# create cubelets, keep or not edge cubelets

# THIS DEPENDS ON WHICH AXES are the spatial coordinates!

data_current_coordinates = {"time":0, "lat":1, "lon":2}

data_time_size = len(data.time)
data_lat_size = len(data.lat)
data_lon_size = len(data.lon)

keep_edge_cubelets = False

space_indices = []
space_indices_all_missing = []
space_slices = []
idx = 0
for ix in range(0, data_lon_size, ysize - yover):
    for iy in range(0, data_lat_size, xsize - xover):
        xslice = slice(ix, ix + xsize)
        yslice = slice(iy, iy + ysize)
        # don't need the original data, but a derived 2D array of indices, very light! 
        cubelet = space_idx[xslice, yslice]

        # decide whether keep or not degenerate cubelets, otherwise these can be restored in the dataset using the collate function, which will fill with zeros
        if not keep_edge_cubelets:
            if cubelet.shape != (ysize, xsize):
                continue

        space_slices.append([xslice, yslice])
        space_indices.append(idx)
        
        # keep or not cubelets that are all nans
        mask_cubelet = masks[xslice, yslice]
        if mask_cubelet.all().item(0):
            space_indices_all_missing.append(idx)
            
        idx += 1

In [None]:
data_time_size, data_lat_size, data_lon_size

In [None]:
time_indices = []
time_slices = []
idx = 0

latlon = data_current_coordinates["lat"], data_current_coordinates["lon"]

for it in range(0, data_time_size, tsize - tover):
    tslice = slice(it, it + tsize)
    
    # this requires the actual dataset? probably an array of a variable
    # probably don't need raw data
    
    if data_current_coordinates["time"] == 0:
        cubelet = data.precip[tslice,...]
    elif data_current_coordinates["time"] == len(data_current_coordinates.keys()):
        cubelet = data.precip[...,tslice]
    else:
        cubelet = data.precip[...,tslice,...]
        
    if not keep_edge_cubelets:
        if cubelet.shape[data_current_coordinates["time"]] != tsize:
            continue
            
    time_indices.append(idx)
    time_slices.append(tslice)
    idx += 1

In [None]:
time_indices

In [None]:
# cubelets idx
cube_idx = list(itertools.product(*(space_indices, time_indices)))

In [None]:
cube_idx[:10]

In [None]:
# slices
slice_idx = list(itertools.product(*(space_slices, time_slices)))

In [None]:
len(slice_idx)

In [None]:
slice_idx[:10]

In [None]:
# create mapping
mapping_cubelets_slices = {}  # coordinates sequence should be as the model expects 
print(data_current_coordinates)
# Actually the slicing occurs at the getitem of the dataset, so after the data is kind of transposed

data.to_stacked_array( new_dim="feat", sample_dims = ["time", "lat", "lon"])

In [None]:
for ic, islice in zip(cube_idx, slice_idx):
    m = {"time":"", "lat":"", "lon":""}
    sp_slice, t_slice = islice
    tot_slice = (sp_slice[0], sp_slice[1], t_slice) # T C H W
    m.update({"time":t_slice})
    m.update({"lat":sp_slice[1]})
    m.update({"lon":sp_slice[0]})
    mapping_cubelets_slices[ic] = m # (sp_slice[0], sp_slice[1], t_slice)    

In [None]:
mapping_cubelets_slices[(0,0)].values()

In [None]:
# function that maps the cubelet indices to the grid indices for chunking

def return_cubelet_slices(cubelet_idx):
    return mapping_cubelets_slices[cubelet_idx]

def return_cubelet_data(data,cubelet_idx):
    return data[*mapping_cubelets_slices[cubelet_idx].values()]

In [None]:
return_cubelet_data(data.precip, (1,0))

In [None]:
# missing values
def cubelet_idx_with_all_missing_values(mapping_cubelets, cubelets_idx_missing, time_indices):

    new_map = mapping_cubelets.copy()
    for t in time_indices:
        for idx in cubelets_idx_missing:
            try:
                new_map.pop((idx,t)) 
            except:
                pass
    return new_map

# can create 

In [None]:
new_mapping = cubelet_idx_with_all_missing_values(mapping_cubelets_slices, space_indices_all_missing, time_indices)

In [None]:
#subsample space and time, this becomes a class like RandomCubeletsSampler

# keys (space, time)
new_mapping

def subsample(mapping, time_indices, space_indices):

    new_mapping = {}
    for filter_key in itertools.product(space_indices, time_indices):
        #print(filter_key)
        value = mapping.get(filter_key, None)
        if value is not None:
            new_mapping[filter_key] = value

    return new_mapping        



In [None]:
time_indices, space_indices[:10]

In [None]:
new_sub_mapping = subsample(new_mapping, [0,1,2], [i for i in space_indices if i % 2 == 0]) # only even indices 

In [None]:
list(new_sub_mapping.items())[:3]

In [None]:
# for k in new_mapping:
#     data[new_mapping[k]].isel(time=1).plot(figsize=(1,1), add_colorbar=False)
#     plt.axis('off')
#     plt.title("")

In [None]:
# torch dataset
# collate function to make cubelets of the same shape! and do padding with zeros!

In [None]:
class CubeletsDataset(Dataset):

    def __init__(self, xd: xr.Dataset, xs: xr.Dataset, y:xr.Dataset, cubelet_indices, persist=False, lstm_1d = False, static_to_dynamic=False):
        
        self.xd = xd
        self.y = y
        self.xs = xs
        
        self.xd = self.xd.to_stacked_array( new_dim="feat", sample_dims = ["time", "lat", "lon"]) # time, lat, lon , feat
        self.xd = self.xd.transpose("time", "feat", "lat" , "lon") # T C H W

        self.y = self.y.to_stacked_array( new_dim="feat", sample_dims = ["time", "lat", "lon"])
        self.y = self.y.transpose("time", "feat", "lat" , "lon") # T C H W

        self.xs = xs.to_stacked_array( new_dim="feat", sample_dims = ["lat", "lon"]) # H W C
        self.xs = self.xs.transpose("feat", "lat", "lon")
        
        if persist:
            self.xd = self.xd.persist()
            self.y = self.y.persist()
            self.xs = self.xs.persist()

        self.lstm_1d = lstm_1d
        self.static_to_dynamic = static_to_dynamic
        
        self.cubelet_indices = cubelet_indices

        # expand static to dynamic 
        #time_da = xr.DataArray(dynamic.time.values, [('time', dynamic.time.values)])
        #static = static.expand_dims({"time":time_da})

    def __len__(self):
        return len(self.cubelet_indices)

    def __getitem__(self, index):

        cubelet_idx = list(self.cubelet_indices.keys())[index]
        
        print(cubelet_idx, self.cubelet_indices[cubelet_idx])

        time_slice = self.cubelet_indices[cubelet_idx]["time"]
        lat_slice =  self.cubelet_indices[cubelet_idx]["lat"]
        lon_slice = self.cubelet_indices[cubelet_idx]["lon"]

        # L C H W
        xd = self.xd[time_slice,:, lat_slice, lon_slice].values
        y = self.y[time_slice,:, lat_slice, lon_slice].values
        xs = self.xs[:, lat_slice,lon_slice].values
        
        xd = torch.tensor(xd)
        y = torch.tensor(y)
        xs = torch.tensor(xs)
            
        if self.lstm_1d:
            xd = xd.flatten(2,3) # L C H W => L C N
            xd = torch.permute(xd, (2, 0, 1))
            
            y = y.flatten(2,3) # L C H W => L C N
            y = torch.permute(y, (2, 0, 1))
    
            xs = xs.flatten(xs, 1,2)
                
        if self.xs is not None:

            if self.static_to_dynamic:
                xs = xs.unsqueeze(0).repeat(xs.size(0), 1, 1, 1)
            return xd, xs, y
        else:
            return xd, y

In [None]:
list(new_sub_mapping.keys())[1]

In [None]:
y = target.chunk("auto")
xs = static.chunk("auto")

In [None]:
dataset = CubeletsDataset(data, xs, y, new_sub_mapping, persist=True, lstm_1d=False, static_to_dynamic=True)

In [None]:
#dataset[1]

In [None]:
# sampler 
from hython.sampler import SubsetRandomSampler, SubsetSequentialSampler

In [None]:
sampler = SubsetRandomSampler(range(len(dataset)))

In [None]:
#sampler = SubsetSequentialSampler(range(len(new_mapping.keys())))

In [None]:
dataloader = DataLoader(dataset = dataset , batch_size = 1, sampler=sampler) # for lstm_1d  the batches are decided by lat*lon so here put batch = 1

In [None]:
len(dataloader)

In [None]:
#next(iter(dataloader)).shape

In [None]:
for tx,ts,ty in dataloader:
    print(tx.shape,ts.shape, ty.shape)
    fig, axs = plt.subplots(1,3, figsize=(5,5))
    axs[0].imshow(tx[0 , 0, 0 , ...])
    axs[1].imshow(ts[0 , 0, 0 , ...])
    axs[2].imshow(ty[0 , 0, 0 , ...])

In [None]:
# change shape for old lstm

res = dataset[1]

In [None]:
time, lat ,lon = 360, 32,32

In [None]:
train_dataset = get_dataset("XBatchDataset")(
                      dynamic_train, 
                      target_train, 
                      static_train, 
                      lstm=False, 
                      xbatcher_kwargs={ "input_dims": {"time": time, "lat":lat, "lon":lon},
                                       "batch_dims": {"lat":lat, "lon":lon}, 
                                       #"input_overlap":{"time":1},
                                       "concat_input_dims":False,
                                       "preload_batch":True})
test_dataset = get_dataset("XBatchDataset")(
                      dynamic_test, 
                      target_test, 
                      static_test, 
                      lstm=False, 
                      xbatcher_kwargs={ "input_dims": {"time": time, "lat":lat, "lon":lon},
                                       "batch_dims": {"lat":lat, "lon":lon}, 
                                       #"input_overlap":{"time":1},
                                       "concat_input_dims":False,
                                       "preload_batch":True})

In [None]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=16)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=16)

In [None]:
len(train_dataloader), len(test_dataloader)

In [None]:
next(iter(train_dataloader))[0].shape, next(iter(train_dataloader))[1].shape, next(iter(train_dataloader))[2].shape

In [None]:
from hython.models.convLSTM import ConvLSTM

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = ConvLSTM(
    input_dim =  DYNAMIC_INPUT_SIZE + STATIC_INPUT_SIZE,
    output_dim= OUTPUT_SIZE,
    hidden_dim = (HIDDEN_SIZE),
    kernel_size = KERNEL_SIZE,
    num_layers = NUM_LSTM_LAYER,
    batch_first = True,
    bias = True,
    return_all_layers = False
).to(device)

In [None]:
opt = optim.Adam(model.parameters(), lr=1e-3)
lr_scheduler = ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=10)

loss_fn = RMSELoss(target_weight={"vwc":0.5, "actevap":0.5})
metric_fn = MSEMetric(target_names=["vwc", "actevap"])

In [None]:
trainer = XBatcherTrainer(
    RNNTrainParams(
               experiment=experiment, 
               temporal_subsampling=False, 
               temporal_subset=1, 
               seq_length=SEQ_LENGTH, 
               target_names=target_names,
               metric_func=metric_fn,
               loss_func=loss_fn)
)

In [None]:
file_surr_output = f"{dir_surr_output}/{experiment}_{surr_model_output}"

In [None]:
# train
model, loss_history, metric_history = train_val(
    trainer,
    model,
    train_dataloader,
    test_dataloader,
    10,
    opt,
    lr_scheduler,
    file_surr_output,
    device,
)

In [None]:
model.load_state_dict(torch.load("/mnt/CEPH_PROJECTS/InterTwin/hydrologic_data/surrogate_model/exp1_convlstm.pt"))

In [None]:
model

In [None]:
lepochs = list(range(1, EPOCHS + 1))

fig, axs = plt.subplots(3, 1, figsize= (12,6), sharex=True)

axs[0].plot(lepochs, metric_history['train_vwc'], marker='.', linestyle='-', color='b', label='Training')
axs[0].plot(lepochs, metric_history['val_vwc'], marker='.', linestyle='-', color='r', label='Validation')
axs[0].set_title('SM')
axs[0].set_ylabel(metric_fn.__class__.__name__)
axs[0].grid(True)
axs[0].legend(bbox_to_anchor=(1,1))

axs[1].plot(lepochs, metric_history['train_actevap'], marker='.', linestyle='-', color='b', label='Training')
axs[1].plot(lepochs, metric_history['val_actevap'], marker='.', linestyle='-', color='r', label='Validation')
axs[1].set_title('ET')
axs[1].set_ylabel(metric_fn.__class__.__name__)
axs[1].grid(True)

axs[2].plot(lepochs, [i.detach().cpu().numpy() for i in loss_history['train']], marker='.', linestyle='-', color='b', label='Training')
axs[2].plot(lepochs, [i.detach().cpu().numpy() for i in loss_history['val']], marker='.', linestyle='-', color='r', label='Validation')
axs[2].set_title('Loss')
axs[2].set_xlabel('Epochs')
axs[2].set_ylabel(loss_fn.__name__)
axs[2].grid(True)

In [None]:
def predict(Xd, Xs, model, batch_size, device):
    model = model.to(device)
    X = torch.concat([Xd, Xs], 2).to(device)
    arr = []
    for i in range(0, Xd.shape[0], batch_size):
        out = model(X)[0]
        #import pdb;pdb.set_trace()
        arr.append(out[i : (i + batch_size)].detach().cpu().numpy())
    return np.vstack(arr)

In [None]:
output = predict(res[0], res[1], model, batch_size=8,device=device)

In [None]:
import matplotlib.pyplot as plt

In [None]:
output.shape

In [None]:
plt.imshow(output[10,-1,:,:,0])
plt.colorbar()

In [None]:
plt.imshow(res[2][0,-1,0,:,:])
plt.colorbar()

In [None]:
plt.imshow((output[0,-1,:,:,0] - np.array(res[2][0,-1,0,:,:])), cmap="RdBu")
plt.colorbar()

In [None]:
torch

In [None]:
torch.split?