In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import cftime
import numpy as np
import xarray as xr

import matplotlib.pyplot as plt

import config
import util

In [3]:
catalog = util.curate_flux_products().open_catalog()
catalog

flux_products-catalog-local:
  args:
    path: catalogs/flux_products-catalog-local.yml
  description: Flux products for transport modeling
  driver: intake.catalog.local.YAMLFileCatalog
  metadata: {}


In [4]:
cat_keys = list(catalog._entries.keys())
catalog._entries

{'SFAPO.carboscope.apo99X_v2021': name: SFAPO.carboscope.apo99X_v2021
 container: xarray
 plugin: ['netcdf']
 driver: ['netcdf']
 description: APO fluxes from CarboScope inversion apo99X_v2021
 direct_access: forbid
 user_parameters: []
 metadata: 
 args: 
   urlpath: /glade/work/mclong/sno-analysis/flux-products/SFAPO.CarboScope.apo99X_v2021.nc
   xarray_kwargs: 
     decode_times: False,
 'SFCO2_FF.GCP-GridFED.v2021.3': name: SFCO2_FF.GCP-GridFED.v2021.3
 container: xarray
 plugin: ['netcdf']
 driver: ['netcdf']
 description: Gridded fossil CO2 emissions consistent with national inventories 1959-2020
 direct_access: forbid
 user_parameters: []
 metadata: 
 args: 
   urlpath: /glade/work/mclong/sno-analysis/flux-products/SFCO2_FF.GCP-GridFED.v2021.3.19840101-20201231.nc
   xarray_kwargs: 
     decode_times: False,
 'SFCO2_FF.OCO2-MIP.v2020.1': name: SFCO2_FF.OCO2-MIP.v2020.1
 container: xarray
 plugin: ['netcdf']
 driver: ['netcdf']
 description: daily integral of Fossil Fuel CO₂ Emis

In [5]:
USER = os.environ['USER']
dirout = f'/glade/scratch/{USER}/sno-flux-products'
os.makedirs(dirout, exist_ok=True)

In [6]:
dso_grid = util.generate_latlon_grid(**config.config_dict["flux-dst-grid-kwargs"])[["area"]]
dso_grid

In [7]:
cluster, client = util.get_ClusterClient()
cluster.scale(12)

client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 37493 instead
  f"Port {expected} is already in use.\n"


0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mclong/calcs/proxy/37493/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mclong/calcs/proxy/37493/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.12.206.60:44746,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mclong/calcs/proxy/37493/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [8]:
def get_varname_from_key(key, dsi):
    varname = key.split('.')[0]    
    if varname not in dsi.data_vars:
        varname = f'{varname}_OCN'
    assert varname in dsi.data_vars        
    return varname

In [None]:
year_range = 1986, 2020

time_units = 'days since 1980-01-01 00:00:00'
daily_time, daily_time_bnds = util.gen_daily_cftime_coord(year_range)        


dsets = {}
for key in cat_keys:
    dsi = catalog[key].to_dask()    
    varname = get_varname_from_key(key, dsi)
    is_climatology = 'climatology' in dsi.time.attrs
    
    if is_climatology:
        file_out = f'{dirout}/{key}.1x1.repeat_monclim.{year_range[0]}0101-{year_range[1]}1231.nc' 
    else:
        file_out = f'{dirout}/{key}.1x1.{year_range[0]}0101-{year_range[1]}1231.nc' 
        
    
    # if it's a monthly climatology, repeat it for the period of interest ± 1 year
    if is_climatology:
        data = np.concatenate([dsi[varname].data for i in range(year_range[0] - 1, year_range[1] + 2)], axis=0)                
        var_attrs = dsi[varname].attrs
        monthly_time, monthly_time_bounds = util.gen_midmonth_cftime_coord([year_range[0] - 1, year_range[1] + 1])
        monthly_time_num = xr.DataArray(
            cftime.date2num(monthly_time, units=dsi.time.units),
            dims=('time'),
            attrs={'units': dsi.time.units, 'bounds': 'time_bnds'},
        )
        monthly_time_bounds_num = xr.DataArray(
            cftime.date2num(monthly_time_bounds, units=dsi.time.units),
            dims=('time', 'd2'),
        )        
        
        dims = dsi[varname].dims        
        dsi = dsi[[v for v in dsi.variables if 'time' not in dsi[v].dims]]
        
        dsi[varname] = xr.DataArray(data, dims=dims, 
                                    coords={'time': monthly_time_num},
                                    attrs=var_attrs,
                                    name=varname,                                    
                                   )
        dsi['time_bnds'] = monthly_time_bounds_num
        dsi['time'] = monthly_time_num
        
    
    daily_time, daily_time_bnds = util.gen_daily_cftime_coord(year_range)        
    daily_time_num_data = cftime.date2num(daily_time, units=dsi.time.units)
        
    # interpolate
    drop_list = [v for v in dsi.data_vars if 'time' in dsi[v].dims and v != varname]
    dsi_daily = dsi.drop(drop_list).interp(time=daily_time_num_data)

    dsi_daily['time'] = daily_time
    dsi_daily.time.encoding['units'] = time_units    
    dsi_daily[daily_time.bounds] = daily_time_bnds    
    
    dsi_daily['time_components'] = util.gen_time_components_variable(daily_time)

    for v in ['area', 'lat', 'lon']:
        dsi_daily[v] = dso_grid[v]
    
    assert dsi_daily[varname].attrs['units'] == 'mol/m^2/s'
        
    util.to_netcdf_clean(dsi_daily, file_out, format='NETCDF4')
    dsets[key] = dsi_daily
    
dsi_daily

------------------------------
Writing /glade/scratch/mclong/sno-flux-products/SFAPO.carboscope.apo99X_v2021.1x1.19860101-20201231.nc


In [None]:
dsi = dsets[cat_keys[0]]

region = xr.DataArray(['NET', 'SET'], dims=('region'), name='region')
masked_area = xr.concat([
    dsi.area.where((dsi.lat >= 20)),
    dsi.area.where((dsi.lat <= -20)),    
],
    dim=region,
)
masked_area.isel(region=0).plot()

In [None]:
for key in cat_keys:
    print(key)
    
    dsi = dsets[key]
    v = get_varname_from_key(key, dsi)    
    assert dsi[v].units in ['mol/m^2/s']

    plt.figure()
    
    with xr.set_options(keep_attrs=True):    
        global_sum = ((dsi[v] * dsi.area).sum(['lat', 'lon'])) * 1e-12 * 86400.
        global_sum.attrs['units'] = 'Tmol day$^{-1}$'
        global_sum.attrs['units'] = 'Tmol month$^{-1}$'

    global_sum.plot()
    plt.title(key)


In [None]:
sel_year = '2009'

region = xr.DataArray(['NET', 'SET'], dims=('region'), name='region')

for key in cat_keys:
    print(key)
    
    dsi = dsets[key].sel(time=sel_year)
    v = get_varname_from_key(key, dsi)    
    assert dsi[v].units in ['mol/m^2/s']
    
    masked_area = xr.concat([
        dsi.area.where((dsi.lat >= 20)),
        dsi.area.where((dsi.lat <= -20)),    
    ],
        dim=region,
    )    
    
    plt.figure()
    
    global_sum = ((dsi[v] * dsi.area).sum(['lat', 'lon'])).sum('time').values * 1e-12 * 86400.
    print(f'GLB {key}: {global_sum:03f} Tmol/yr')
    
    with xr.set_options(keep_attrs=True):
        dsi_region = (dsi[[v]] * masked_area).sum(['lat', 'lon']) * 1e-12 * 86400.
        dsi_region[v].attrs['units'] = 'Tmol day$^{-1}$'
        dsi_mon = dsi_region * 365. / 12.
        dsi_mon[v].attrs['units'] = 'Tmol month$^{-1}$'
    
    for region_name in dsi_region.region.values:   
        da = dsi_region[v].sel(region=region_name, drop=True)        
        regional_sum = da.sum('time').values
        da_mon = dsi_mon[v].sel(region=region_name, drop=True)     
        
        #da.plot(label=region_name)
        da_mon.plot(label=region_name)
        print(f'{region_name} {key}: {regional_sum:0.3f} Tmol/yr')        
    
    plt.title(f'{key} ({sel_year})')
    plt.legend();
    print()

In [None]:
sel_year = '2009'

fig, axs = plt.subplots(4, 1, figsize=(8, 12))

totals = {}
for key in cat_keys:
    if 'fgn2' not in key and 'fgo2' not in key:
        continue
    print(key)
    totals[key] = []
    
    dsi = dsets[key].sel(time=sel_year)
    v = get_varname_from_key(key, dsi)    
    assert dsi[v].units in ['mol/m^2/s']
    
    masked_area = xr.concat([
        dsi.area,        
        dsi.area.where((dsi.lat < 20) & (dsi.lat > -20)),        
        dsi.area.where((dsi.lat >= 20)),
        dsi.area.where((dsi.lat <= -20)),    
    ],
        dim=xr.DataArray(['GLB', 'TRP', 'NET', 'SET'], dims=('region'), name='region'),
    )    
        
    
    with xr.set_options(keep_attrs=True):
        dsi_region = (dsi[[v]] * masked_area).sum(['lat', 'lon']) * 1e-12 * 86400.
        dsi_region[v].attrs['units'] = 'Tmol day$^{-1}$'
        dsi_mon = dsi_region * 365. / 12.
        dsi_mon[v].attrs['units'] = 'Tmol month$^{-1}$'
    
    for i, region_name in enumerate(dsi_region.region.values):
        
        da = dsi_region[v].sel(region=region_name, drop=True)        
        regional_sum = da.sum('time').values
        da_mon = dsi_mon[v].sel(region=region_name, drop=True)     
        
        print(f'{region_name} {key}: {regional_sum:0.4f} Tmol/yr')        
        
        totals[key].append(float(regional_sum))
        
        
        #da.plot(label=key, ax=axs[i])
        da_mon.plot(label=f"{key}", ax=axs[i])
        axs[i].set_title(region_name)
        axs[i].set_xlabel('')
        axs[i].axhline(0, color='k', linewidth=0.5)
        
    plt.legend(loc=(1.01, 0));
    print()

In [None]:
dx = 0
for key, values in totals.items():
    print(values[0])
    print(sum(values[1:]))
    print()
    plt.bar(np.arange(0, 4)+dx, values, width=0.4, label=key)
    dx = 0.4

plt.axhline(0, color='k', linewidth=0.5)
plt.xticks([0, 1, 2, 3], labels=['GLB', 'TRP', 'NET', 'SET']);
plt.legend();
plt.ylabel('Annual flux [Tmol/yr]')

In [None]:
totals

In [None]:
client.close()
cluster.close()