In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import cftime
import numpy as np
import xarray as xr

import matplotlib.pyplot as plt

import config
import util

In [None]:
catalog = util.curate_flux_products().open_catalog()
catalog

In [None]:
cat_keys = list(catalog._entries.keys())
catalog._entries

In [None]:
USER = os.environ['USER']
dirout = f'/glade/scratch/{USER}/sno-flux-products'
os.makedirs(dirout, exist_ok=True)

In [None]:
dso_grid = util.generate_latlon_grid(**config.config_dict["flux-dst-grid-kwargs"])[["area"]]
dso_grid

In [None]:
cluster, client = util.get_ClusterClient()
cluster.scale(12)

client

In [None]:
def get_varname_from_key(key, dsi):
    varname = key.split('.')[0]    
    if varname not in dsi.data_vars:
        varname = f'{varname}_ocn'
    assert varname in dsi.data_vars        
    return varname

In [None]:
year_range = 1999, 2018

time_units = 'days since 1990-01-01 00:00:00'
daily_time, daily_time_bnds = util.gen_daily_cftime_coord(year_range)        


dsets = {}
for key in cat_keys:
    dsi = catalog[key].to_dask()    
    varname = get_varname_from_key(key, dsi)
    is_climatology = 'climatology' in dsi.time.attrs
    
    if is_climatology:
        file_out = f'{dirout}/{key}.1x1.repeat_monclim.{year_range[0]}0101-{year_range[1]}1231.nc' 
    else:
        file_out = f'{dirout}/{key}.1x1.{year_range[0]}0101-{year_range[1]}1231.nc' 
        
    
    # if it's a monthly climatology, repeat it for the period of interest ± 1 year
    if is_climatology:
        data = np.concatenate([dsi[varname].data for i in range(year_range[0] - 1, year_range[1] + 2)], axis=0)                
        var_attrs = dsi[varname].attrs
        monthly_time, monthly_time_bounds = util.gen_midmonth_cftime_coord([year_range[0] - 1, year_range[1] + 1])
        monthly_time_num = xr.DataArray(
            cftime.date2num(monthly_time, units=dsi.time.units),
            dims=('time'),
            attrs={'units': dsi.time.units, 'bounds': 'time_bnds'},
        )
        monthly_time_bounds_num = xr.DataArray(
            cftime.date2num(monthly_time_bounds, units=dsi.time.units),
            dims=('time', 'd2'),
        )        
        
        dims = dsi[varname].dims        
        dsi = dsi[[v for v in dsi.variables if 'time' not in dsi[v].dims]]
        
        dsi[varname] = xr.DataArray(data, dims=dims, 
                                    coords={'time': monthly_time_num},
                                    attrs=var_attrs,
                                    name=varname,                                    
                                   )
        dsi['time_bnds'] = monthly_time_bounds_num
        dsi['time'] = monthly_time_num
        
    
    daily_time, daily_time_bnds = util.gen_daily_cftime_coord(year_range)        
    daily_time_num_data = cftime.date2num(daily_time, units=dsi.time.units)
        
    # interpolate
    drop_list = [v for v in dsi.data_vars if 'time' in dsi[v].dims and v != varname]
    dsi_daily = dsi.drop(drop_list).interp(time=daily_time_num_data)

    dsi_daily['time'] = daily_time
    dsi_daily.time.encoding['units'] = time_units    
    dsi_daily[daily_time.bounds] = daily_time_bnds    
    
    dsi_daily['time_components'] = util.gen_time_components_variable(daily_time)

    for v in ['area', 'lat', 'lon']:
        dsi_daily[v] = dso_grid[v]
    
    assert dsi_daily[varname].attrs['units'] == 'mol/m^2/s'
        
    util.to_netcdf_clean(dsi_daily, file_out)
    dsets[key] = dsi_daily
    
dsi_daily

In [None]:
dsi = dsets[cat_keys[0]]

region = xr.DataArray(['NET', 'SET'], dims=('region'), name='region')
masked_area = xr.concat([
    dsi.area.where((dsi.lat >= 20)),
    dsi.area.where((dsi.lat <= -20)),    
],
    dim=region,
)
masked_area.isel(region=0).plot()

In [None]:
sel_year = '2009'

for key in cat_keys:
    print(key)
    
    dsi = dsets[key].sel(time=sel_year)
    v = get_varname_from_key(key, dsi)    
    assert dsi[v].units in ['mol/m^2/s', 'mol m-2 s-1']
    
    masked_area = xr.concat([
        dsi.area.where((dsi.lat >= 20)),
        dsi.area.where((dsi.lat <= -20)),    
    ],
        dim=region,
    )    
    
    plt.figure()
    
    global_sum = ((dsi[v] * dsi.area).sum(['lat', 'lon'])).sum('time').values * 1e-12 * 86400.
    print(f'GLB {key}: {global_sum:04f} Tmol/yr')
    
    with xr.set_options(keep_attrs=True):
        dsi_region = (dsi[[v]] * masked_area).sum(['lat', 'lon']) * 1e-12 * 86400.
        dsi_region[v].attrs['units'] = 'Tmol day$^{-1}$'
        dsi_mon = dsi_region * 365. / 12.
        dsi_mon[v].attrs['units'] = 'Tmol month$^{-1}$'
    
    for region_name in dsi_region.region.values:   
        da = dsi_region[v].sel(region=region_name, drop=True)        
        regional_sum = da.sum('time').values
        da_mon = dsi_mon[v].sel(region=region_name, drop=True)     
        
        #da.plot(label=region_name)
        da_mon.plot(label=region_name)
        print(f'{region_name} {key}: {regional_sum} Tmol/yr')        
    
    plt.title(f'{key} ({sel_year})')
    plt.legend();
    print()

In [None]:
client.close()
cluster.close()