# MERRA2 UNET

A simple notebook to test the UNET architecture operating on MERRA2 data


## Preparation

In [1]:
%matplotlib ipympl
import logging, torch
from fmod.base.util.logging import lgm, exception_handled, log_timing
import xarray as xa
import hydra, os, time
from typing import Any, Dict, List, Tuple, Type, Optional, Union, Sequence, Mapping
from fmod.plot.sres import mplplot, create_plot_data
from fmod.base.util.config import configure, cfg, start_date,  cfg2args, get_roi
from fmod.plot.training_results import ResultsPlotter
from fmod.model.sres.emul_unet.trainer import ModelTrainer
from fmod.model.sres.emul_unet.unet import UNet
from fmod.controller.ncbatch import ncBatchDataset

In [2]:
hydra.initialize(version_base=None, config_path="../config")
configure('merra2-unet-s1')
# lgm().set_level( logging.DEBUG )

load_state = "" # "best"
save_state = True 
# set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.cuda.set_device(device.index)

### Training data
Prepare a Dataloader which computes results on the fly:

In [None]:
input_dataset   = ncBatchDataset( cfg().task, vres="low",  load_inputs=True,  load_base=False, load_targets=False )
sample_batch: [str,xa.DataArray]  = input_dataset.get_batch( input_dataset.train_dates[0] )
print( f" @@@ sample_batch: {sample_batch['input']}" )
cfg().task['roi'] = get_roi( sample_batch['input'].coords )
target_dataset  = ncBatchDataset( cfg().task,  vres="high", load_inputs=False, load_base=False, load_targets=True )
trainer = ModelTrainer( input_dataset, target_dataset, device ) 
results = trainer.get_batch( input_dataset.train_dates[0] )
sample_input: xa.DataArray = results['input']
sample_target: xa.DataArray = results['target']
print( f"sample_input: {sample_input.shape}")
print( f"sample_target: {sample_target.shape}")

In [6]:
inchannels: int = sample_input.shape[1]
model = UNet( inchannels, cfg().model ).to(device)

## Training the model

In [None]:
trainer.train( model, load_state=load_state, save_state=save_state )

In [10]:
inputs, targets, predictions = trainer.apply( date_index = 0 )

In [None]:
pdata:  Dict[str,xa.DataArray] = create_plot_data( inputs, targets, predictions, sample_input, sample_target )
mplplot( pdata )

In [13]:
# s = 0; ch = 0
# print( f'Input shape: {inp.shape}, Output shape: {out.shape}, type = {type(inp)} ')
# input_image = inp[s, ch]
# vmin, vmax = gridops.color_range(input_image,2.0)
# 
# fig = plt.figure()
# im = gridops.plot_griddata(input_image, fig, projection='3d', title='input', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(out[s, ch], fig, projection='3d', title='prediction', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(tar[s, ch], fig, projection='3d', title='target', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata((tar-out)[s, ch], fig, projection='3d', title='error', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()