# functions

In [1]:
import numpy as np
import pandas as pd
import xarray as xr

# Updated for PBS scheduler
# this could go into utils.
# By default gets 1 core w/ 25 GB memory
def get_ClusterClient(ncores=1, nmem='25GB'):
    """Code from Daniel Kennedy"""
    from dask_jobqueue import PBSCluster
    from dask.distributed import Client
    ncores = ncores
    nmem = nmem

    cluster = PBSCluster(
        cores=ncores,           # The number of cores you want
        memory=nmem,            # Amount of memory
        processes=ncores,       # How many processes
        queue='casper',         # The type of queue to utilize (/glade/u/apps/dav/opt/usr/bin/execcasper)
        resource_spec='select=1:ncpus='+str(ncores)+':mem='+nmem,  # Specify resources
        project='UWAS0155',     # Input your project ID here
        walltime='1:00:00',     # Amount of wall time
        # interface='ib0',        # Interface to use
        scheduler_options={"interface": "ib0"}
    )

    client = Client(cluster)
    return cluster, client


def format_ds_coords(dataset):
    """
    Formats the coordinates of an xarray Dataset.

    This function performs the following operations on the input Dataset:
        1. Adds bounds to the 'X' coordinate.
        2. Adds bounds to the 'Y' coordinate.
        3. Swaps the longitude axis to the range (-180, 180).

    Parameters:
    dataset (xarray.Dataset):
        The input Dataset to be formatted.

    Returns:
    ds_formatted (xarray.Dataset):
        The formatted Dataset with updated coordinates.
    """
    ds = dataset.copy()
    lon = ds.lon
    new_lon = ((lon + 180) % 360) - 180
    new_lon.encoding = lon.encoding
    ds.coords['lon'] = new_lon
    ds = ds.sortby(list(lon.dims), ascending=True)
    return ds


def select_sites_from_gridded_data(xr_grid_data, df_site_data):
    """
    Selects grid data for specific sites from a gridded dataset.

    Parameters:
    xr_grid_data (xarray.Dataset or xarray.DataArray):
        The gridded dataset from which to select data.
    df_site_data (pandas.DataFrame):
        A DataFrame containing site information with columns
        'lat' and 'lon' for latitude and longitude.

    Returns:
    xr_site_data (xarray.DataArray):
        An xarray DataArray containing the selected site data,
        with a new dimension 'site' corresponding to each site in df_site_data.
    """

    # Get the number of sites and create an empty array to store the selected data
    nsite = df_site_data.iloc[:,0].size
    np_site_data = np.empty((nsite), dtype=xr.DataArray)

    # Select the grid box nearest to the coordinates for each site
    for i, row in df_site_data.iterrows():
        np_site_data[i] = xr_grid_data.sel(lat=row['lat'], lon=row['lon'], method='nearest')

    # Concatenate the selected data along a new 'site' dimension
    xr_site_data = xr.concat(np_site_data, dim='site')
    xr_site_data = xr_site_data.assign_coords({'site': np.arange(nsite)})
    xr_site_data['site'].attrs = {'long_name': 'tree ring site, arbitrary numbering from csv file'}

    return xr_site_data

# metadata and setup

In [2]:
global_metadata = {
    'latlon_site_file': '/glade/u/home/bbuchovecky/projects/wue_trend/lat_lon_pft_all.csv',
    'description': 'Gridded to (time,lat,lon,pft) then only grid boxes corresponding to tree ring sites were selected. All global attributes below are copied from the original history file.',
    'title': 'CLM History file information',
    'comment': 'NOTE: None of the variables are weighted by land fraction!',
    'Conventions': 'CF-1.0',
    'history': 'created on 04/27/21 13:27:32',
    'source': 'Community Terrestrial Systems Model',
    'hostname': 'cheyenne',
    'username': 'oleson',
    'version': 'cesm2_3_alpha02c',
    'revision_id': '$Id: histFileMod.F90 42903 2012-12-21 15:32:10Z muszala $',
    'case_title': 'UNSET',
    'case_id': 'clm50_cesm23a02cPPEn08ctsm51d030_1deg_GSWP3V1_hist',
    'Surface_dataset': 'surfdata_0.9x1.25_hist_78pfts_CMIP6_simyr1850_c190214.nc',
    'Initial_conditions_dataset': 'finidat_interp_dest.nc',
    'PFT_physiological_constants_dataset': 'clm50_params.c210217.nc',
    'time_period_freq': 'month_1',
    'Time_constant_3Dvars_filename': './clm50_cesm23a02cPPEn08ctsm51d030_1deg_GSWP3V1_hist.clm2.h3.1850-01-01-00000.nc',
    'Time_constant_3Dvars': 'ZSOI:DZSOI:WATSAT:SUCSAT:BSW:HKSAT:ZLAKE:DZLAKE:PCT_SAND:PCT_CLAY',
}

variable_metadata = {
    'GSSUNLN': {
        'long_name': 'sunlit leaf stomatal conductance at local noon',
        'units': 'umol H20/m2/s',
        'cell_methods': 'time: mean',
    },
    'GPP': {
        'long_name': 'gross primary production',
        'units': 'gC/m^2/s',
        'cell_methods': 'time: mean',
    },
    'FCTR': {
        'long_name': 'canopy transpiration',
        'units': 'W/m^2',
        'cell_methods': 'time: mean',
    },
    'TLAI': {
        'long_name': 'total projected leaf area index',
        'units': 'm^2/m^2',
        'cell_methods': 'time: mean',
    }
}

# Load tree ring site data
site_data = pd.read_csv('./lat_lon_pft_all.csv', usecols=['lat', 'lon', 'PFT'])

# Load CLM5 gridded data
directory = '/glade/work/bbuchovecky/WUE_analysis'
casename = 'clm50_cesm23a02cPPEn08ctsm51d030_1deg_GSWP3V1_hist'
tperiod = slice('1901-01', '2014-12')
sitepft = sorted(site_data['PFT'].unique())

# regrid main variables

In [3]:
# Start Dask cluster
# cluster, client = get_ClusterClient(ncores=4, nmem='25GB')
# cluster.scale(20)

In [None]:
variables = ['GSSUNLN', 'GPP', 'FCTR', 'TLAI']

for var in variables:
    # Load gridded variable and format coordinates
    gridded = xr.open_dataset(
        f'{directory}/gridded/{casename}.clm2.h1.{var}.185001-201412_gridded.nc',
        chunks={'time': 36, 'lat': 192, 'lon': 288}
    )
    gridded = format_ds_coords(gridded)

    # Format PFT coordinates
    gridded = gridded.rename({'vegtype': 'pft', 'vegtype_name': 'pft_name'})
    gridded['pft'].attrs = {'long_name': 'plant functional type'}
    gridded['pft_name'].attrs = {'long_name': 'plant functional type name'}

    # Select PFTs and time period
    gridded = gridded[var].sel(pft=sitepft, time=tperiod)

    # Select sites from gridded variable
    indvsites = select_sites_from_gridded_data(gridded, site_data)

    # Iterate through PFTs
    for pft in sitepft:
        # Select site indices with PFT
        pft_site = site_data[site_data['PFT'] == pft]

        # Select corresponding sites
        pft_data = indvsites.sel(site=pft_site.index, pft=pft)
        pft_data = pft_data.drop('pft')

        # Convert to dataset
        pft_data = pft_data.to_dataset(name=var)

        # Add metadata
        pft_data.attrs = global_metadata
        pft_data[var].attrs = variable_metadata[var]

        # Save to NetCDF
        pft_data.to_netcdf(f'{directory}/for/marja/mon/{var.lower()}/clm50_1deg_GSWP3V1.{var}.190101-201412.sites.pft{pft}.nc')
        print(f'done saving pft{pft}')

    # Convert to dataset
    indvsites = indvsites.to_dataset(name=var)

    # Add metadata
    indvsites.attrs = global_metadata
    indvsites[var].attrs = variable_metadata[var]

    # Save to NetCDF
    indvsites.to_netcdf(f'{directory}/for/marja/mon/{var.lower()}/clm50_1deg_GSWP3V1.{var}.190101-201412.sites.pftall.nc')
    print(f'done saving pftall')

done saving pft1
done saving pft2
done saving pft3
done saving pft4
done saving pft5
done saving pft6
done saving pft7
done saving pft8


In [7]:
# client.shutdown()

# format area variables

In [None]:
ltype_name = [
    'vegetated_or_bare_soil',
    'crop',
    'UNUSED',
    'landice_multiple_elevation_classes',
    'deep_lake',
    'wetland',
    'urban_tbd',
    'urban_hd',
    'urban_md',
]

natpft_name = [
    'needleleaf_evergreen_temperate_tree',
    'needleleaf_evergreen_boreal_tree',
    'needleleaf_deciduous_boreal_tree',
    'broadleaf_evergreen_tropical_tree',
    'broadleaf_evergreen_temperate_tree',
    'broadleaf_deciduous_tropical_tree',
    'broadleaf_deciduous_temperate_tree',
    'broadleaf_deciduous_boreal_tree',
]

pct_natpft = xr.open_dataset(f'{directory}/gridded/{casename}.clm2.h3.PCT_NAT_PFT.18500101-20141231.nc')
pct_natpft = format_ds_coords(pct_natpft)
pct_natpft = pct_natpft.sel(time=tperiod)
pct_natpft['PCT_NAT_PFT'] = pct_natpft['PCT_NAT_PFT'].rename({'natpft':'pft'}).assign_coords({'pft':np.arange(15)}) / 100
pct_natpft['PCT_NAT_PFT'].attrs = {
    'long_name': 'fraction of each PFT on the natural vegetation (i.e., soil) landunit',
    'units': 'fraction',
    'natpft_name': natpft_name,
}
# pct_natpft['PCT_NAT_PFT'].to_netcdf(f'{directory}/for/marja/clm50_1deg_GSWP3V1.PCT_NAT_PFT.190101-201412.gridded.nc')

pct_landunit = xr.open_dataset(f'{directory}/gridded/{casename}.clm2.h3.PCT_LANDUNIT.18500101-20141231.nc')
pct_landunit = format_ds_coords(pct_landunit)
pct_landunit = pct_landunit.sel(time=tperiod)
pct_landunit['PCT_LANDUNIT'] = pct_landunit['PCT_LANDUNIT'].assign_coords({'ltype':np.arange(9)}) / 100
pct_landunit['PCT_LANDUNIT'].attrs = {
    'long_name': 'fraction of each landunit on grid cell',
    'units': 'fraction',
    'ltype_name': ltype_name,
}
# pct_landunit['PCT_LANDUNIT'].to_netcdf(f'{directory}/for/marja/clm50_1deg_GSWP3V1.PCT_LANDUNIT.190101-201412.gridded.nc')

pct_natpft_gridbox = pct_natpft['PCT_NAT_PFT'] * pct_landunit['PCT_LANDUNIT'].isel(ltype=0)
pct_natpft_gridbox.name = 'PCT_NAT_PFT_GRIDBOX'
pct_natpft_gridbox.attrs = {
    'long_name': 'fraction of each PFT within each grid cell',
    'units': 'fraction',
    'ltype_name': ltype_name[0],
    'natpft_name': natpft_name,
}
# pct_natpft_gridbox.to_netcdf(f'{directory}/for/marja/clm50_1deg_GSWP3V1.PCT_NAT_PFT_GRIDBOX.190101-201412.gridded.nc')

area = xr.open_dataset(f'{directory}/gridded/{casename}.clm2.h3.PCT_NAT_PFT.18500101-20141231.nc')
area = format_ds_coords(area)
area['area'].attrs = {
    'long_name': 'grid cell areas',
    'units': 'km^2',
}
# area['area'].to_netcdf(f'{directory}/for/marja/clm50_1deg_GSWP3V1.AREA.190101-201412.gridded.nc')

# regrid area variables

In [None]:
# Start Dask cluster
# cluster, client = get_ClusterClient(ncores=4, nmem='25GB')
# cluster.scale(20)

In [11]:
variables = [
    'PCT_NAT_PFT',
    'PCT_LANDUNIT',
    'PCT_NAT_PFT_GRIDBOX',
]

for var in variables:
    print(var)

    # Load gridded variable and format coordinates
    gridded = xr.open_dataset(
        f'{directory}/for/marja/clm50_1deg_GSWP3V1.{var}.190101-201412.gridded.nc',
        chunks={'time': 36, 'lat': 192, 'lon': 288}
    )
    global_metadata = gridded.attrs
    variable_metadata = gridded[var].attrs
    gridded = format_ds_coords(gridded)
    gridded = gridded[var].sel(time=tperiod)

    # Select sites from gridded variable
    indvsites = select_sites_from_gridded_data(gridded, site_data)

    if var != 'PCT_LANDUNIT':
        # Select PFTs and time period
        gridded = gridded.sel(pft=sitepft)

        # Iterate through PFTs
        for pft in sitepft:
            # Select site indices with PFT
            pft_site = site_data[site_data['PFT'] == pft]
    
            # Select corresponding sites
            pft_data = indvsites.sel(site=pft_site.index, pft=pft)
            pft_data = pft_data.drop('pft')
    
            # Convert to dataset
            pft_data = pft_data.to_dataset(name=var)
    
            # Add metadata
            pft_data.attrs = global_metadata
            pft_data[var].attrs = variable_metadata
    
            # Save to NetCDF
            pft_data.to_netcdf(f'{directory}/for/marja/mon/{var.lower()}/clm50_1deg_GSWP3V1.{var}.190101-201412.sites.pft{pft}.nc')
            print(f'done saving pft{pft}')

    # Convert to dataset
    indvsites = indvsites.to_dataset(name=var)

    # Add metadata
    indvsites.attrs = global_metadata
    indvsites[var].attrs = variable_metadata

    # Save to NetCDF
    indvsites.to_netcdf(f'{directory}/for/marja/mon/{var.lower()}/clm50_1deg_GSWP3V1.{var}.190101-201412.sites.pftall.nc')
    print(f'done saving pftall')

done saving pftall
done saving pft1
done saving pft2
done saving pft3
done saving pft4
done saving pft5
done saving pft6
done saving pft7
done saving pft8
done saving pftall


# compute annual mean

In [3]:
def calculate_annual_timeseries(da):
    """
    Calculates the annual timeseries, weighted by the number of days in each month
    """
    nyears = len(da.groupby('time.year'))
    month_length = da.time.dt.days_in_month

    weights = month_length.groupby('time.year') / month_length.astype(float).groupby('time.year').sum()        
    np.testing.assert_allclose(weights.groupby('time.year').sum().values, np.ones(nyears)) 

    return (da * weights).groupby('time.year').sum(dim='time')

In [None]:
variables = [
    'GSSUNLN',
    'GPP',
    'FCTR',
    'TLAI',
    'PCT_NAT_PFT',
    'PCT_LANDUNIT',
    'PCT_NAT_PFT_GRIDBOX',
]

for var in variables:
    if var != 'PCT_LANDUNIT':
        for pft in sitepft:
            ds = xr.open_dataset(f'{directory}/for/marja/mon/{var.lower()}/clm50_1deg_GSWP3V1.{var}.190101-201412.sites.pft{pft}.nc')
            variable_metadata = ds[var].attrs
            global_metadata = ds.attrs
            ds = calculate_annual_timeseries(ds)
            ds[var].attrs = variable_metadata
            ds.attrs = global_metadata
            ds.to_netcdf(f'{directory}/for/marja/year/{var.lower()}/clm50_1deg_GSWP3V1.{var}.190101-201412.year.sites.pft{pft}.nc')

    ds = xr.open_dataset(f'{directory}/for/marja/mon/{var.lower()}/clm50_1deg_GSWP3V1.{var}.190101-201412.sites.pftall.nc')
    variable_metadata = ds[var].attrs
    global_metadata = ds.attrs
    ds = calculate_annual_timeseries(ds)
    ds[var].attrs = variable_metadata
    ds.attrs = global_metadata
    ds.to_netcdf(f'{directory}/for/marja/year/{var.lower()}/clm50_1deg_GSWP3V1.{var}.190101-201412.year.sites.pftall.nc')