# Cape Basin Super Resolution

A simple notebook to test assorted SR models operating on the Cape Basin dataset

## Preparation

In [1]:
%matplotlib ipympl
import torch
import xarray as xa, numpy as np
import hydra, os
from datetime import datetime
from typing import Dict, List
from fmod.view.sres import mplplot, create_plot_data
from fmod.base.util.config import fmconfig, cfg
from fmod.controller.dual_trainer import ModelTrainer
from fmod.controller.dual_trainer import LearningContext
from fmod.model.sres.manager import SRModels
from fmod.data.batch import BatchDataset
os.environ["PYTORCH_CUDA_ALLOC_CONF"]="expandable_segments:True"

In [2]:
hydra.initialize(version_base=None, config_path="../config")

task="sres"
model="edsr"
dataset="LLC4320"
scenario="s1"
fmconfig( task, model, dataset, scenario )
# lgm().set_level( logging.DEBUG )

load_state  = "current"
save_state  = True 
cfg().task['nepochs'] = 3
eval_tileset = LearningContext.Validation

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.cuda.set_device(device.index)
    
print( cfg().model.name )

### Training data
Prepare a Dataloader which computes results on the fly:

In [None]:
input_dataset:  BatchDataset  = BatchDataset( cfg().task, vres="low", )
target_dataset: BatchDataset  = BatchDataset( cfg().task, vres="high" )

In [6]:
model_manager: SRModels = SRModels( input_dataset, target_dataset, device )
trainer:   ModelTrainer = ModelTrainer( model_manager ) 
sample_input:  xa.DataArray = model_manager.get_sample_input() 
sample_target: xa.DataArray = model_manager.get_sample_target()

## Training the model

In [None]:
train_losses: Dict[str,float] = trainer.train( load_state=load_state, save_state=save_state )

In [None]:
pdata:  Dict[str,xa.DataArray] = {}
if len(train_losses) > 0:
    inp, targ, prod, ups = trainer.get_current_input(), trainer.get_current_target(), trainer.get_current_product(), trainer.get_current_upsampled()
    pdata = create_plot_data( inp, targ, prod, ups, sample_input, sample_target  )
    pdata['domain'] = target_dataset.load_global_timeslice()
mplplot( pdata, LearningContext.Training, fsize=6.0, losses=train_losses )

## Validating the Model

In [None]:
eval_losses = trainer.evaluate( eval_tileset )

In [None]:
inp, targ, prod, ups = trainer.get_current_input(), trainer.get_current_target(), trainer.get_current_product(), trainer.get_current_upsampled()
pdata:  Dict[str,xa.DataArray] = create_plot_data( inp, targ, prod, ups, sample_input, sample_target  )
mplplot( pdata, LearningContext.Validation, fsize=6.0, losses=eval_losses )