In [11]:
import xarray as xr
from dask.distributed import Client
import dask
import numpy as np
import pandas as pd
from pathlib import Path
import intake
import cftime
import datetime as dt
import hvplot.xarray
from dask_jobqueue import PBSCluster

In [12]:
cluster = PBSCluster(
    job_name = 'siparcs-vis-aggregate-cesm',
    cores = 1,
    memory = '6GiB',
    processes = 1,
    local_directory = '/glade/work/pdas47/scratch/pbs.$PBS_JOBID/dask/spill',
    resource_spec = 'select=1:ncpus=1:mem=8GB',
    queue = 'casper',
    walltime = '01:00:00',
    interface = 'ib0'
)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 41495 instead


In [13]:
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/pdas47/viz/proxy/41495/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/pdas47/viz/proxy/41495/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.12.206.63:36201,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/pdas47/viz/proxy/41495/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [14]:
print(cluster.job_script())

#!/usr/bin/env bash

#PBS -N siparcs-vis-aggregate-cesm
#PBS -q casper
#PBS -A SCSG0002
#PBS -l select=1:ncpus=1:mem=8GB
#PBS -l walltime=01:00:00

/glade/u/home/pdas47/SIParCS-Vis/.env/bin/python -m distributed.cli.dask_worker tcp://10.12.206.63:36201 --nthreads 1 --memory-limit 6.00GiB --name dummy-name --nanny --death-timeout 60 --local-directory /glade/work/pdas47/scratch/pbs.$PBS_JOBID/dask/spill --interface ib0



In [15]:
cluster.scale(6)
client.wait_for_workers(6)
cluster.workers

{'PBSCluster-1': <dask_jobqueue.pbs.PBSJob: status=running>,
 'PBSCluster-0': <dask_jobqueue.pbs.PBSJob: status=running>,
 'PBSCluster-2': <dask_jobqueue.pbs.PBSJob: status=running>,
 'PBSCluster-4': <dask_jobqueue.pbs.PBSJob: status=running>,
 'PBSCluster-5': <dask_jobqueue.pbs.PBSJob: status=running>,
 'PBSCluster-3': <dask_jobqueue.pbs.PBSJob: status=running>}

# Load CESM2 data

In [16]:
catalog_url = 'https://raw.githubusercontent.com/NCAR/cesm2-le-aws/main/intake-catalogs/aws-cesm2-le.json'
col = intake.open_esm_datastore(catalog_url)
col

Unnamed: 0,unique
variable,53
long_name,51
component,4
experiment,2
forcing_variant,2
frequency,3
vertical_levels,3
spatial_domain,3
units,20
start_time,4


In [19]:
def weighted_temporal_mean(ds, var):
    """Calcualte annual mean, weight by days in each month. 
    from https://ncar.github.io/esds/posts/2021/yearly-averages-xarray/
    
    Parameters
    ----------
    ds : `xr.Dataset`
        Dataset containing `var` variable and `time` dimension with monthly frequency.
        Weighted average will be created for the `var` variable.
    var : `str`
        Name of variable in the `ds` dataset.
        
    Returns
    -------
    wgt_avg : `xr.DataArray`
        Annual mean of the variable, weighted by days in a month.
    """
    # Determine the month length
    month_length = ds.time.dt.days_in_month

    # Calculate the weights
    month_length_grouped_year = month_length.groupby("time.year")
    wgts = month_length_grouped_year / month_length_grouped_year.sum()

    # # Make sure the weights in each year add up to 1
    # np.testing.assert_allclose(wgts.groupby("time.year").sum(xr.ALL_DIMS), 1.0)

    # Subset our dataset for our variable
    obs = ds[var]

    # Setup our masking for nan values
    cond = obs.isnull()
    ones = xr.where(cond, 0.0, 1.0)

    # Calculate the numerator
    obs_sum = (obs * wgts).resample(time="AS").sum(dim="time")

    # Calculate the denominator
    ones_out = (ones * wgts).resample(time="AS").sum(dim="time")
    
    # calcualte weighted average
    wgt_avg = obs_sum / ones_out
    
    return wgt_avg

def create_annual_dataset(var_name, member_mean=True):
    """Create a combined (historical & future, cmip6 & smbb forcings) annual 
    average dataset for a output variable (filtered to monthly datasets) in 
    the CESM-LENS2 catalog defind by `var_name`.
    
    Parameters
    ----------
    var_name : `str`
        Name of variable to search for in the CEMS-LENS2 catalog.
    member_mean : `bool`
        Whether to mean across members.
        
    Returns
    -------
    res_ds : `xr.Dataset`
        Dataset containing annual average and a combination of historical & future
        and cmip6 & smbb forcing types.
    """
    print(f"Creating annual dataset of {var_name}")

    freq = 'monthly'

    col_subset = col.search(
        variable=var_name, 
        frequency=freq
    )
    component = col_subset.df['component'].iloc[0]
    dset_dict = col_subset.to_dataset_dict(storage_options={'anon':True}, )

    das = {}

    for k in dset_dict.keys():
        ds = dset_dict[k]
        
        if member_mean:
            ds[var_name] = ds[var_name].mean('member_id', keep_attrs=True)

        da = weighted_temporal_mean(ds, var_name)
        
        forcing_type = k.split('.')[-1]
        da = da.expand_dims(dim={"forcing_type": [forcing_type]}, axis=0)
        
        da.name = var_name
        da.attrs = ds[var_name].attrs
        
        das[k] = da
    
    cmip6_da = xr.concat([das[f'{component}.historical.monthly.cmip6'], das[f'{component}.ssp370.monthly.cmip6']], dim='time', combine_attrs='no_conflicts')
    smbb_da = xr.concat([das[f'{component}.historical.monthly.smbb'], das[f'{component}.ssp370.monthly.smbb']], dim='time', combine_attrs='no_conflicts')
        
    res_ds = xr.concat([cmip6_da, smbb_da], dim='forcing_type', combine_attrs='no_conflicts').to_dataset()
    
    # add attrs for forcing_type
    res_ds.coords['forcing_type'].attrs = {
        'comments': '`cmip6` refers to the original CMIP6 BMB protocol, `smbb` refers to smoothed CMIP6 BMB protocol which are evenly distributed amonst different initialization dates. Visit https://www.cesm.ucar.edu/community-projects/lens2 for definitions.'
    }
    
    return res_ds

In [22]:
vars_of_interest = ['PRECL', 'PS', 'TS', 'FSNS', 'PSL', 'FSNO', 'RAIN', 'SNOW', 'TREFMXAV', 'TREFHTMN', 'TREFHTMX']
save_dir = Path('/glade/work/pdas47/cesm-annual')

for var_name in vars_of_interest:
    ds = create_annual_dataset(var_name)
    ds.to_netcdf(save_dir / f"{var_name}.nc")

Creating annual dataset of PRECL

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Creating annual dataset of PS

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Creating annual dataset of TS

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Creating annual dataset of FSNS

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Creating annual dataset of PSL

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Creating annual dataset of FSNO

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Creating annual dataset of RAIN

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Creating annual dataset of SNOW

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Creating annual dataset of TREFMXAV

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Creating annual dataset of TREFHTMN

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Creating annual dataset of TREFHTMX

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


In [23]:
client.close()
cluster.close()