In [None]:
import numpy as np
import xarray as xr
import torch
from src.train_nn_pytorch import Dataset

if torch.cuda.is_available():
    print('using CUDA !')
    device = torch.device("cuda")
    torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
    print("CUDA not available")
    device = torch.device("cpu")
    torch.set_default_tensor_type("torch.FloatTensor")

datadir = '/gpfs/work/nonnenma/data/forecast_predictability/weatherbench/5_625deg/'
res_dir = '/gpfs/work/nonnenma/results/forecast_predictability/weatherbench/5_625deg/'


# multi-level fields

In [None]:
fields = [{'geopotential': ('z', [1, 10, 100, 200, 300, 400, 500, 600, 700, 850, 1000])},
          {'temperature': ('t', [1, 10, 100, 200, 300, 400, 500, 600, 700, 850, 1000])},
          {'u_component_of_wind': ('u', [1, 10, 100, 200, 300, 400, 500, 600, 700, 850, 1000])},
          {'v_component_of_wind': ('v', [1, 10, 100, 200, 300, 400, 500, 600, 700, 850, 1000])}
         ]

for field in fields:

    var_dict = field
    print(var_dict.keys())    

    x = xr.merge(
    [xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords')
     for var in var_dict.keys()],
    fill_value=0  # For the 'tisr' NaNs
    )
    x = x.chunk({'time' : np.sum(x.chunks['time']), 'lat' : x.chunks['lat'], 'lon': x.chunks['lon']})

    dg_train = Dataset(x.sel(time=slice('1979', '2015')), var_dict, lead_time=72, 
                       normalize=True, norm_subsample=1, 
                       target_vars=[list(field.keys())[0]], target_levels=[field[list(field.keys())[0]][1][0]])

    print(dg_train.mean.values, dg_train.std.values, dg_train.level_names)
    dg_train.mean.to_netcdf(res_dir + list(field.keys())[0] + '/mean_1979_2015.nc')
    dg_train.std.to_netcdf(res_dir + list(field.keys())[0] + '/std_1979_2015.nc')


# constant fields

In [None]:
var_dict = {'constants': ['lat2d','lon2d', 'lsm','orography', 'slt']} # note alphabetical ordering
x = xr.merge(
[xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords')
 for var in var_dict.keys()],
fill_value=0  # For the 'tisr' NaNs
)
x = x.chunk({'lat' : x.chunks['lat'], 'lon': x.chunks['lon']})

generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
time = xr.open_mfdataset(f'{datadir}/geopotential_500/*.nc', combine='by_coords').time.values
data = []
for _, params in var_dict.items():
    for var in params:
        data.append(x[var].expand_dims(
            {'level': generic_level, 'time': time}, (1, 0)
        ).astype(np.float32))
data = xr.concat(data, 'level')
data.level.values = field['constants']

const_mean = data.isel(time=slice(0, None, 1)).mean(
                ('time', 'lat', 'lon')).compute()
const_std = data.isel(time=slice(0, None, 1)).std(
                ('time', 'lat', 'lon')).compute()

const_mean.to_netcdf(res_dir + list(field.keys())[0] + '/mean_1979_2015.nc')
const_std.to_netcdf(res_dir + list(field.keys())[0] + '/std_1979_2015.nc')