In [None]:
%matplotlib ipympl
from fmod.base.plot.image import mplplot
from fmod.base.source.merra2.model import load_const_dataset, load_merra2_norm_data, load_dataset
import xarray as xa
import hydra, dataclasses
import logging, torch
from typing import List, Union, Tuple, Optional, Dict, Type
from fmod.base.util.logging import lgm, exception_handled, log_timing
from fmod.base.util.ops import print_norms, vars3d
from fmod.base.util.dates import date_list
from fmod.base.util.config import configure, cfg, cfg_date, cfg2args, pp
from fmod.pipeline.merra2 import MERRA2Dataset

hydra.initialize(version_base=None, config_path="../config")
configure('merra2-sr')
cfg().task.device = "cpu"
lgm().set_level( logging.DEBUG )

def nnan(varray: torch.Tensor) -> int: return torch.isnan(varray).sum().item()
def pctnan(varray: torch.Tensor) -> str: return f"{nnan(varray)*100.0/torch.numel(varray):.2f}%"

# set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.cuda.set_device(device.index)

In [None]:
dataset = MERRA2Dataset( train_dates=date_list( cfg_date('task'), cfg().task.max_steps ), vres="low" )
data_iter = iter(dataset)
norm_data: Dict[str, xa.Dataset] = load_merra2_norm_data()

for inp, tar in data_iter:
	print(f" ** inp shape={inp.shape}, pct-nan= {pctnan(inp)}")
	print(f" ** tar shape={tar.shape}, pct-nan= {pctnan(tar)}")
	pvars: List[str] =  vars3d(inp)
	mplplot( inp, pvars[:5], norms=norm_data )
	break