In [1]:
from datetime import date, datetime, timezone

import cftime
import git
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

import config
import util

In [2]:
# script identifier to go in generated file

repo = 'github.com/marbl-ecosys/marbl-forcing'
sha = git.Repo(search_parent_directories=True).head.object.hexsha
script_fname = 'Fe_aeolian_dep/gen_cesm2_SMYLE_ndep.ipynb'
id_string = f'{repo}/tree/{sha}/{script_fname}'

In [3]:
# input files

grids = ['gx3v7', 'gx1v6']

dirin = f'{config.inputdata}/ocn/pop'
fnames_hist = [f'{dirin}/{grid}/forcing/ndep_ocn_1850-2000_w_nhx_emis_{grid}_c180926.nc' for grid in grids]
fnames_ssp = [f'{dirin}/{grid}/forcing/ndep_ocn_ssp370_w_nhx_emis_{grid}_c190412.nc' for grid in grids]

print(fnames_hist)
print(fnames_ssp)

ds_hist = [xr.open_dataset(fname_hist) for fname_hist in fnames_hist]
ds_ssp = [xr.open_dataset(fname_ssp) for fname_ssp in fnames_ssp]

# assume all files have same yr_range, so get yr_range from first file
yr_range_hist = (ds_hist[0].time.values[0].year, ds_hist[0].time.values[-1].year)
print(yr_range_hist)

# assume all files have same yr_range, so get yr_range from first file
yr_range_ssp = (ds_ssp[0].time.values[0].year, ds_ssp[0].time.values[-1].year)
print(yr_range_ssp)

yr_hist_ssp_boundary = 2015

['/glade/p/cesmdata/cseg/inputdata/ocn/pop/gx3v7/forcing/ndep_ocn_1850-2000_w_nhx_emis_gx3v7_c180926.nc', '/glade/p/cesmdata/cseg/inputdata/ocn/pop/gx1v6/forcing/ndep_ocn_1850-2000_w_nhx_emis_gx1v6_c180926.nc']
['/glade/p/cesmdata/cseg/inputdata/ocn/pop/gx3v7/forcing/ndep_ocn_ssp370_w_nhx_emis_gx3v7_c190412.nc', '/glade/p/cesmdata/cseg/inputdata/ocn/pop/gx1v6/forcing/ndep_ocn_ssp370_w_nhx_emis_gx1v6_c190412.nc']
(1849, 2015)
(2014, 2101)


In [4]:
# details on file being generated

datestamp = date.today().strftime("%y%m%d")

def yr_start_cycle0_fosi(yr_range_fosi, cycle_cnt_fosi):
    yr_cnt_fosi = yr_range_fosi[1] - yr_range_fosi[0] + 1
    return yr_range_fosi[0] - (cycle_cnt_fosi - 1) * yr_cnt_fosi

def yr_range_fmt(yr_range):
    return f'{yr_range[0]:04d}-{yr_range[1]:04d}'

yr_range_JRA = (1958, 2018)
yr_start_cycle0_JRA = yr_start_cycle0_fosi(yr_range_JRA, cycle_cnt_fosi=6)

yr_lo_SMYLE = yr_start_cycle0_JRA - 1
yr_hi_SMYLE = 2025

yr_range_SMYLE = (yr_lo_SMYLE, yr_hi_SMYLE)
fnames_SMYLE = [f'ndep_ocn_SMYLE_w_nhx_emis_{grid}_{yr_range_fmt(yr_range_SMYLE)}_c{datestamp}.nc' for grid in grids]

In [5]:
# construct time values for new datasets

def time_vars(yr_range, time_units):
    calendar = 'noleap'

    days_1yr = np.array([31.0, 28.0, 31.0, 30.0, 31.0, 30.0, 31.0, 31.0, 30.0, 31.0, 30.0, 31.0])
    nyrs = yr_range[1] - yr_range[0] + 1
    time_edges = np.insert(np.cumsum(np.tile(days_1yr, nyrs)), 0, 0)
    time_edges += cftime.date2num(cftime.DatetimeNoLeap(yr_range[0], 1, 1), time_units, calendar='noleap')
    time_bnds_vals = np.stack((time_edges[:-1], time_edges[1:]), axis=1)
    time_vals = np.mean(time_bnds_vals, axis=1)

    time_var = xr.DataArray(time_vals, dims='time', coords={'time':time_vals},
                            attrs={'long_name':'time', 'units':time_units, 'calendar':calendar, 'bounds':'time_bnds'})
    time_bnds_var = xr.DataArray(time_bnds_vals, dims=('time', 'd2'), coords={'time':time_var})
    
    return time_var, time_bnds_var

time_var_SMYLE, time_bnds_var_SMYLE = time_vars(yr_range_SMYLE, ds_hist[0].time.encoding['units'])

In [6]:
# construct new datasets, preserving grid and domain variables from ds_hist

for grid_ind, grid in enumerate(grids):
    print(f'grid = {grid}')

    ds_out = xr.Dataset({'time': time_var_SMYLE, 'time_bnds': time_bnds_var_SMYLE})

    for varname, var_in in ds_hist[grid_ind].data_vars.items():
        if 'Y' in var_in.dims and 'X' in var_in.dims:
            if 'time' not in var_in.dims:
                print(f'copying {varname}')
                ds_out[varname] = var_in
            else:
                print(f'processing {varname}')

                # initial years use first year of ds_hist[grid_ind]
                yr_hist = yr_range_hist[0]
                var_in_slice = var_in.sel(time=slice(f'{yr_hist:04d}-01-01', f'{(yr_hist+1):04d}-01-01'))
                var_out_prehist_vals = np.tile(var_in_slice, (yr_hist - yr_range_SMYLE[0], 1, 1))

                yr_hist = yr_range_hist[0]
                var_out_hist_slice = var_in.sel(time=slice(f'{yr_hist:04d}-01-01', f'{yr_hist_ssp_boundary:04d}-01-01')) 

                var_out_posthist_slice = ds_ssp[grid_ind][varname].sel(time=slice(f'{yr_hist_ssp_boundary:04d}-01-01', f'{(yr_hi_SMYLE+1):04d}-01-01'))
                
                var_out_vals = np.concatenate((var_out_prehist_vals, var_out_hist_slice.values, var_out_posthist_slice.values))

                var_out = xr.DataArray(var_out_vals, dims=var_in.dims,
                                       coords={'time':time_var_SMYLE, 'Y':ds_hist[grid_ind].Y, 'X':ds_hist[grid_ind].X})
                var_out.attrs = var_in.attrs
                var_out.attrs['missing_value'] = var_in.encoding['missing_value']
                var_out.encoding['_FillValue'] = var_in.encoding['_FillValue']
                ds_out[varname] = var_out

    datestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
    ds_out.attrs['history'] = f'created by {id_string} on {datestamp}'
    ds_out.attrs['input_file_list'] = ' '.join([fnames_hist[grid_ind], fnames_ssp[grid_ind]])
    util.ds_clean(ds_out).to_netcdf(fnames_SMYLE[grid_ind], unlimited_dims='time')


grid = gx3v7
copying ULAT
copying ULONG
copying TAREA
copying REGION_MASK
copying KMT
processing NOy_deposition
processing NHx_deposition
grid = gx1v6
copying ULAT
copying ULONG
copying TAREA
copying REGION_MASK
copying KMT
processing NOy_deposition
processing NHx_deposition
