# MERRA2 Super Resolution

A simple notebook to test assorted SR models operating on MERRA2 data

## Preparation

In [1]:
%matplotlib ipympl
import logging, torch, math
from fmod.base.util.logging import lgm, exception_handled, log_timing
from fmod.models.sres.util import *
import torch.nn as nn
import xarray as xa
import hydra, os, time
from typing import Any, Dict, List, Tuple, Type, Optional, Union, Sequence, Mapping, Callable
from fmod.plot.sres import mplplot, create_plot_data
from fmod.base.util.config import fmconfig, cfg
from fmod.pipeline.dual_trainer import ModelTrainer
from fmod.base.io.loader import BaseDataset
from fmod.models.sres.manager import SRModels
from fmod.pipeline.data import Datasets

In [2]:
hydra.initialize(version_base=None, config_path="../config")

task="sres"
model="mscnn"
dataset="merra2"
scenario="s1"
fmconfig( task, model, dataset, scenario )
# lgm().set_level( logging.DEBUG )

load_state = "best"
save_state = True 
cfg().task['nbatches'] = 1
cfg().task['batch_iter'] = 10
cfg().task['nepochs'] = 2
cfg().task['lr'] = 1e-4

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.cuda.set_device(device.index)
    
print( cfg().model.name )

 *** Configuration sres-mscnn-merra2-s1 initialized *** 

  --------- Opening log file:  '/explore/nobackup/projects/ilab/data/FMF/cache/logs/sres-mscnn-merra2-s1.log' ---------  

mscnn


### Training data
Prepare a Dataloader which computes results on the fly:

In [3]:
dataset_mgr: Datasets = Datasets( dataset, cfg().task )
input_dataset:  BaseDataset  = dataset_mgr.get_dataset( vres="low",  load_inputs=True,  load_base=False, load_targets=False )
target_dataset: BaseDataset  = dataset_mgr.get_dataset( vres="high", load_inputs=False, load_base=False, load_targets=True )


Task start date: 01/01/1995: [1995, 1, 1]

Task start date: 01/01/1995: [1995, 1, 1]


In [4]:
model_manager: SRModels = SRModels( input_dataset, target_dataset, device )
trainer:   ModelTrainer = ModelTrainer( model_manager ) 

 * load_dataset[low](1995-01-01) [lon[64]:[-120.00,37.50:2.50],lat[64]:[-60.00,66.00:2.00]] nts=4 /explore/nobackup/projects/ilab/data/FMF/processed/merra2.sr.1/1995-1-1.us4.nc
 * load_dataset[low](1995-01-02) [lon[64]:[-120.00,37.50:2.50],lat[64]:[-60.00,66.00:2.00]] nts=4 /explore/nobackup/projects/ilab/data/FMF/processed/merra2.sr.1/1995-1-2.us4.nc
 * load_dataset[low](1995-01-03) [lon[64]:[-120.00,37.50:2.50],lat[64]:[-60.00,66.00:2.00]] nts=4 /explore/nobackup/projects/ilab/data/FMF/processed/merra2.sr.1/1995-1-3.us4.nc
 * load_dataset[low](1995-01-04) [lon[64]:[-120.00,37.50:2.50],lat[64]:[-60.00,66.00:2.00]] nts=4 /explore/nobackup/projects/ilab/data/FMF/processed/merra2.sr.1/1995-1-4.us4.nc
 * load_dataset[low](1995-01-05) [lon[64]:[-120.00,37.50:2.50],lat[64]:[-60.00,66.00:2.00]] nts=4 /explore/nobackup/projects/ilab/data/FMF/processed/merra2.sr.1/1995-1-5.us4.nc
 * load_dataset[low](1995-01-06) [lon[64]:[-120.00,37.50:2.50],lat[64]:[-60.00,66.00:2.00]] nts=4 /explore/nobackup

## Training the model

In [5]:
trainer.train( load_state=load_state, save_state=save_state )

Loaded model from /explore/nobackup/projects/ilab/data/FMF/results/checkpoints/merra2-mscnn.1.best.pt, loss = 0.28
Epoch 1503/1504: 
  ----------- Epoch 1503/1504   ----------- 
 ** Loss[1995-01-01:0]:  0.02840  [0.0094,0.0190]
 ** Loss[1995-01-01:1]:  0.02801  [0.0093,0.0188]
 ** Loss[1995-01-01:2]:  0.02767  [0.0091,0.0185]
 ** Loss[1995-01-01:3]:  0.02743  [0.0090,0.0184]
 ** Loss[1995-01-01:4]:  0.02732  [0.0090,0.0183]
 ** Loss[1995-01-01:5]:  0.02734  [0.0090,0.0184]
 ** Loss[1995-01-01:6]:  0.02744  [0.0090,0.0185]
 ** Loss[1995-01-01:7]:  0.02760  [0.0090,0.0186]
 ** Loss[1995-01-01:8]:  0.02778  [0.0091,0.0187]
 ** Loss[1995-01-01:9]:  0.02785  [0.0090,0.0188]
Saving current model to /explore/nobackup/projects/ilab/data/FMF/results/checkpoints/merra2-mscnn.1.current.pt
   ---- Saving best model (loss=0.2768) to /explore/nobackup/projects/ilab/data/FMF/results/checkpoints/merra2-mscnn.1.best.pt
Epoch 1502, time: 4.8, loss: 0.27685 [0.0090,0.0188]
Epoch 1504/1504: 
  -----------

0.2737642154097557

In [6]:
inputs, targets, predictions = trainer.apply( date_index = 0 )

 * in: [64, 1, 64, 64], target: [64, 1, 256, 256], out: [64, 1, 256, 256]


In [7]:
pdata:  Dict[str,xa.DataArray] = create_plot_data( inputs, targets, predictions, model_manager.sample_input, trainer.sample_target )
mplplot( pdata, fsize=8.0 )

Plotting 4 images, sample('batch', 'channels', 'lat', 'lon'): (64, 1, 64, 64)


VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Ba…

In [8]:
# s = 0; ch = 0
# print( f'Input shape: {inp.shape}, Output shape: {out.shape}, type = {type(inp)} ')
# input_image = inp[s, ch]
# vmin, vmax = gridops.color_range(input_image,2.0)
# 
# fig = plt.figure()
# im = gridops.plot_griddata(input_image, fig, projection='3d', title='input', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(out[s, ch], fig, projection='3d', title='prediction', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(tar[s, ch], fig, projection='3d', title='target', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata((tar-out)[s, ch], fig, projection='3d', title='error', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()