# LapSRN

A simple notebook to test the LapSRN network operating on MERRA2 data


## Preparation

In [1]:
%matplotlib ipympl
import torch, math
from models.sres.legacy.lapsrn.network import LapSrnMS
import xarray as xa
import hydra
from typing import Dict, List
from fmod.plot.sres import mplplot, create_plot_data
from fmod.base.util.config import configure, cfg, coerce_to_data_grid
from fmod.pipeline.dual_trainer import ModelTrainer
from fmod.pipeline.ncbatch import ncBatchDataset

In [2]:
hydra.initialize(version_base=None, config_path="../config")
configure('merra2-lapsrn-s1')
# lgm().set_level( logging.DEBUG )

load_state = ""
save_state = True 
cfg().task['nepochs'] = 100

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.cuda.set_device(device.index)

### Training data
Prepare a Dataloader which computes results on the fly:

In [None]:
input_dataset   = ncBatchDataset( cfg().task, vres="low",  load_inputs=True,  load_base=False, load_targets=False )
sample_batch: xa.DataArray  = input_dataset.get_batch( input_dataset.train_dates[0] )['input']
coerce_to_data_grid( sample_batch )
target_dataset  = ncBatchDataset( cfg().task,  vres="high", load_inputs=False, load_base=False, load_targets=True )

In [None]:
trainer = ModelTrainer( input_dataset, target_dataset, device ) 
results = trainer.get_batch( input_dataset.train_dates[0], as_tensor=False  )
sample_input: xa.DataArray = results['input']
sample_target: xa.DataArray = results['target']
print( f"sample_input: shape={sample_input.shape}")
print( f"sample_target: shape={sample_target.shape}")

In [None]:

scale_factors: List[int]    = cfg().model.upscale_factors
scale: int = math.prod( scale_factors )
nchannels: int = sample_input.shape[1]
nfeatures: int = cfg().model.nfeatures
rdepth: int = cfg().model.rdepth
rlayers: int = cfg().model.rlayers

In [6]:
model = LapSrnMS( nchannels, nfeatures, rdepth, rlayers, scale ).to(device)

## Training the model

In [None]:
trainer.train( model, load_state=load_state, save_state=save_state )

In [10]:
inputs, targets, predictions = trainer.apply( date_index = 0 )

In [None]:
pdata:  Dict[str,xa.DataArray] = create_plot_data( inputs, targets, predictions, sample_input, sample_target )
mplplot( pdata, fsize=8.0 )

In [13]:
# s = 0; ch = 0
# print( f'Input shape: {inp.shape}, Output shape: {out.shape}, type = {type(inp)} ')
# input_image = inp[s, ch]
# vmin, vmax = gridops.color_range(input_image,2.0)
# 
# fig = plt.figure()
# im = gridops.plot_griddata(input_image, fig, projection='3d', title='input', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(out[s, ch], fig, projection='3d', title='prediction', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(tar[s, ch], fig, projection='3d', title='target', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata((tar-out)[s, ch], fig, projection='3d', title='error', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()