In [None]:
import numpy as np
import xarray as xr

datadir = '/gpfs/work/nonnenma/data/forecast_predictability/weatherbench/5_625deg/'
res_dir = '/gpfs/work/nonnenma/results/forecast_predictability/weatherbench/5_625deg/'

var_dict = {'geopotential': ('z', [1,10,100,200,300,400,500,600,700,850,1000]),
           'temperature': ('t', [1,10,100,200,300,400,500,600,700,850,1000]),
           'u_component_of_wind': ('u', [1,10,100,200,300,400,500,600,700,850,1000]), 
           'v_component_of_wind': ('v', [1,10,100,200,300,400,500,600,700,850,1000]),
           'specific_humidity': ('q', [1,10,100,200,300,400,500,600,700,850,1000]),
           'relative_humidity': ('r', [1,10,100,200,300,400,500,600,700,850,1000]),
           'vorticity': ('vo', [1,10,100,200,300,400,500,600,700,850,1000]),
           'potential_vorticity': ('pv', [1,10,100,200,300,400,500,600,700,850,1000]),
           'total_cloud_cover': ('tcc', [None]),
           'total_precipitation': ('tp', [None]),
           'toa_incident_solar_radiation': ('tisr', [None]),
           'constants': ['lsm','orography','lat2d']
           }

In [None]:
x = xr.merge(
[xr.open_mfdataset(f'{datadir}/{var}/*.nc', combine='by_coords')
 for var in var_dict.keys()],
fill_value=0  # For the 'tisr' NaNs
)
x

In [None]:
ds, dtype = x, np.float32
data = []
level_names = []
generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level'])
for long_var, params in var_dict.items():
    if long_var == 'constants':
        for var in params:
            data.append(ds[var].expand_dims(
                {'level': generic_level, 'time': ds.time}, (1, 0)
            ).astype(dtype))
            level_names.append(var)
    else:
        var, levels = params
        try:
            data.append(ds[var].sel(level=levels))
            level_names += [f'{var}_{level}' for level in levels]
        except ValueError:
            data.append(ds[var].expand_dims({'level': generic_level}, 1))
            level_names.append(var)
data = xr.concat(data, 'level')  # .transpose('time', 'lat', 'lon', 'level')
data['level_names'] = xr.DataArray(level_names, dims=['level'], coords={'level': data.level})        

In [None]:
data

In [None]:
allvalues = data.values
allvalues.shape

In [None]:
np.save(datadir + '5_625deg_all', allvalues, allow_pickle=False)

In [None]:
np.save(datadir + '5_625deg_all_level_names', data['level_names'].values)

In [None]:
#from src.pytorch.Dataset import load_mean_std
train_years = ('1979', '2015')

mean, std, level, level_names = load_mean_std(res_dir, var_dict, train_years)
for i in np.where([ln[-5:]=='_None' for ln in level_names])[0]: # some cleanup
    level_names[i] = level_names[i][:-5] # discrepancy in code between code for single-level levels: '1' vs 'None'
assert np.all( np.array(level_names) == data['level_names'])
mean.shape, std.shape, len(level_names)

In [None]:
allvalues -= mean.reshape(1,-1,1,1) # in-place feels dangerous, 
allvalues /= std.reshape(1,-1,1,1)  # but allvalues hardly fits into memory twice

In [None]:
np.save(datadir + '5_625deg_all_zscored', allvalues, allow_pickle=False)