#### Generate river input files on ocean grid

Currently the CESM ocean component (either MOM6 or POP) reads in riverine fluxes and adds them to surface fluxes returned from MARBL.
Matt Long has river inputs on the `rJRA025` and `rx1` runoff grids.
For each pair of runoff and ocean grids, we use the same runoff -> ocean map as CESM.

In [None]:
from datetime import datetime
import os

import xarray as xr
import numpy as np
from scipy.sparse import csc_matrix 

date_str = datetime.now().strftime("%Y%m%d")
print(f'xarray: {xr.__version__}')
print(f'numpy: {np.__version__}')
print(f'date string: {date_str}')

In [None]:
src_grid = 'rJRA025'
# dst_grid = 'tx0.66v1'
dst_grid = 'tx2_3v2'

In [None]:
runoff_data = dict()
mapping_dict = dict()
mapping_suffix = dict()
out_file = dict()
rof_grids = ['rJRA025']#, 'rx1']:

for rof_grid in rof_grids:
    mapping_dict[rof_grid] = dict()
    mapping_suffix[rof_grid] = dict()
    out_file[rof_grid] = dict()

# Store location of runoff data
# Nutrient file on runoff grid: /glade/work/mclong/cesm_inputdata/work/river_nutrients.GNEWS_GNM.JRA55.20190602.nc
runoff_data['rJRA025'] = os.path.join(os.sep, 'glade', 'work', 'mclong', 'cesm_inputdata', 'work', 'river_nutrients.GNEWS_GNM.JRA55.20190602.nc')
# Store location of mapping files
root_dir = os.path.join(os.sep, 'glade', 'campaign', 'cesm', 'cesmdata', 'inputdata', 'cpl', 'gridmaps')
suffix = 'nnsm_e333r100_190910'
mapping_suffix['rJRA025']['tx0.66v1'] = suffix
mapping_dict['rJRA025']['tx0.66v1'] = os.path.join(root_dir, 'rJRA025', f'map_JRA025m_to_tx0.66v1_{suffix}.nc')
suffix = 'nnsm_e333r100_230415'
mapping_suffix['rJRA025']['tx2_3v2'] = suffix
mapping_dict['rJRA025']['tx2_3v2'] = os.path.join(os.sep, 'glade', 'work', 'gmarques', 'cesm', 'tx2_3', 'runoff_mapping', f'map_jra_to_tx2_3_{suffix}.nc')

# Names for netcdf output
for rof_grid in rof_grids:
    for ocn_grid in mapping_dict[rof_grid]:
        out_file[rof_grid][ocn_grid] = f'riv_nut.gnews_gnm.{rof_grid}_to_{ocn_grid}_{mapping_suffix[rof_grid][ocn_grid]}.{date_str}.nc'

out_file[src_grid][dst_grid]

In [None]:
out_vars = [
            'din_riv_flux',
            'dip_riv_flux',
            'don_riv_flux',
            'dop_riv_flux',
            'dsi_riv_flux',
            'dfe_riv_flux',
            'dic_riv_flux',
            'alk_riv_flux',
            'doc_riv_flux'
           ]

ds_map = xr.open_dataset(mapping_dict[src_grid][dst_grid])
ds_roff = xr.open_dataset(runoff_data[src_grid], decode_times=False)[out_vars]
ds_roff

In [None]:
# Subroutines for unit conversion to check global integrals after mapping

r = 6371220 # radius of earth in m
native_area = r*r*ds_map['area_a'].data # r^2 since area is rad^2 not m^2
mapped_area = r*r*ds_map['area_b'].data # r^2 since area is rad^2 not m^2
def compute_global_integral(native_values, mapped_values, conv_factor, var, print_stat=False):
    if print_stat:
        global_sum_native = np.sum(native_values*native_area*(0.01*conv_factor)) # 0.01 nmol/cm^2 -> mmol/m^2
        global_sum_mapped = np.sum(mapped_values*mapped_area*conv_factor)
        rel_err = np.abs((global_sum_native-global_sum_mapped)/global_sum_native)
        print(f'{var} stats: sums are {global_sum_native:.3e} (native) and {global_sum_mapped:.3e} (mapped); rel_err is {rel_err:.3e}')

In [None]:
src_grid_size = np.prod(ds_map['src_grid_dims'].data)
src_grid_shape = (ds_map['src_grid_dims'].data[1], ds_map['src_grid_dims'].data[0])
dst_grid_size = np.prod(ds_map['dst_grid_dims'].data)
dst_grid_shape = (ds_map['dst_grid_dims'].data[1], ds_map['dst_grid_dims'].data[0])
mapping_weights_sparse = csc_matrix((ds_map['S'].data,
                                     (ds_map['row'].data-1, ds_map['col'].data-1)),
                                    shape=(dst_grid_size, src_grid_size))

def map_var(ds_src, var_name, scale_factor=0.01, weights=mapping_weights_sparse):
    da_list = []
    for time in range(len(ds_roff['time'])):
        native_values = ds_roff[var_name].isel(time=time).data.reshape(src_grid_size)
        mapped_values = mapping_weights_sparse*(scale_factor*native_values)
        da_list.append(xr.DataArray(mapped_values.reshape(dst_grid_shape), dims=('y', 'x'), name=var_name))
        if var_name in ['din_riv_flux', 'don_riv_flux']:
            conv_factor = 28*1e-15*86400*365 # 28 mg / mmol, 1e-15 Tg / mg, 86400 s/d, 365 d/yr
        elif var_name in ['dic_riv_flux', 'doc_riv_flux']:
            conv_factor = 12*1e-18*86400*365 # 28 mg / mmol, 1e-18 Pg / mg, 86400 s/d, 365 d/yr
        elif var_name in ['dip_riv_flux', 'dop_riv_flux', 'dsi_riv_flux']:
            conv_factor = 1e-15*86400*365 # 1e-15 Tmol / mmol, 86400 s/d, 365 d/yr
        elif var_name in ['dfe_riv_flux']:
            conv_factor = 1e-12*86400*365 # 1e-12 Gmol / mmol, 86400 s/d, 365 d/yr
        elif var_name in ['alk_riv_flux']:
            continue
        else:
            conv_factor = 0
        compute_global_integral(native_values, mapped_values, conv_factor, var_name, print_stat=time in [0, len(ds_roff['time'])-1])
    da = xr.concat(da_list, dim='time').assign_coords({'time': ds_src['time'].data})
    da.attrs = ds_src[var_name].attrs
    da.attrs['units'] = 'mmol/m^2/s'
    da.encoding['_FillValue'] = None
    return da

In [None]:
da_out = map_var(ds_roff, out_vars[0])
ds = da_out.to_dataset()
for var in out_vars:
    if var in ds:
        continue
    ds[var] = map_var(ds_roff, var)

In [None]:
# Update time attributes
ds['time'].attrs = ds_roff['time'].attrs
ds['time'].attrs['axis'] = 'T'
ds['time'].attrs['cartesian_axis'] = 'T'
ds['time'].encoding['_FillValue'] = None
ds

In [None]:
ds.isel(time=0)['din_riv_flux'].plot(levels=[0, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4])

In [None]:
ds.to_netcdf(out_file[src_grid][dst_grid], unlimited_dims='time')