# MERRA2 Super Resolution

A simple notebook to test assorted SR models operating on MERRA2 data

## Preparation

In [1]:
%matplotlib ipympl
import torch
import xarray as xa
import hydra
from typing import Dict
from fmod.view.sres import mplplot, create_plot_data
from fmod.base.util.config import fmconfig, cfg
from fmod.controller.dual_trainer import ModelTrainer
from fmod.base.io.loader import BaseDataset
from fmod.model.sres.manager import SRModels
from fmod.data.dataset import Datasets

In [2]:
hydra.initialize(version_base=None, config_path="../config")

task="sres"
model="mscnn"
dataset="merra2"
scenario="s1"
fmconfig( task, model, dataset, scenario )
# lgm().set_level( logging.DEBUG )

load_state = "best"
save_state = True 
cfg().task['nepochs'] = 100
cfg().task['lr'] = 1e-4

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.cuda.set_device(device.index)
    
print( cfg().model.name )

### Training data
Prepare a Dataloader which computes results on the fly:

In [None]:
dataset_mgr: Datasets = Datasets( dataset, cfg().task )
input_dataset:  BaseDataset  = dataset_mgr.get_dataset( vres="low",  load_inputs=True,  load_base=False, load_targets=False )
target_dataset: BaseDataset  = dataset_mgr.get_dataset( vres="high", load_inputs=False, load_base=False, load_targets=True )

In [6]:
model_manager: SRModels = SRModels( input_dataset, target_dataset, device )
trainer:   ModelTrainer = ModelTrainer( model_manager ) 

## Training the model

In [None]:
trainer.train( load_state=load_state, save_state=save_state )

In [10]:
inputs, targets, predictions = trainer.apply( date_index = 0 )

In [None]:
pdata:  Dict[str,xa.DataArray] = create_plot_data( inputs, targets, predictions, model_manager.sample_input, model_manager.sample_target )
mplplot( pdata, fsize=8.0 )

In [13]:
# s = 0; ch = 0
# print( f'Input shape: {inp.shape}, Output shape: {out.shape}, type = {type(inp)} ')
# input_image = inp[s, ch]
# vmin, vmax = gridops.color_range(input_image,2.0)
# 
# fig = plt.figure()
# im = gridops.plot_griddata(input_image, fig, projection='3d', title='input', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(out[s, ch], fig, projection='3d', title='prediction', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(tar[s, ch], fig, projection='3d', title='target', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata((tar-out)[s, ch], fig, projection='3d', title='error', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()