# CESM loading from cloud

## Imports

In [None]:
## packages
import intake
import time
import xarray as xr
from distributed import LocalCluster, Client

## Functions

In [None]:
def load_cesm_from_cloud_parallel(n_workers=4, **kwargs):
    """wrapper function to parallelize data loading. Args:
        - n_workers: how many workers to parallelize with
        - **kwargs: arguments passed to 'load_cesm_from_cloud' function
    """

    ## initialize cluster
    cluster = LocalCluster(n_workers=n_workers)
    client = Client(cluster)

    ## load data
    data = load_cesm_from_cloud(**kwargs)

    ## shutdown cluster
    client.shutdown()

    return data

def load_cesm_from_cloud(lon_range, lat_range, varname="TREFHT", load_ssp370=False, n_members=8):
    """Load CESM data from cloud. Args:
        - lon_range, lat_range: each is a two-element array
        - varname: variable to load ("TREFHT" is 2m-temperature)
        - load_ssp370: bool; if True, load historical AND ssp370 simulations
        - n_members: number of ensemble members to load
    """

    ## get catalog of available data
    catalog = intake.open_esm_datastore(
        "https://raw.githubusercontent.com/NCAR/cesm2-le-aws/main/intake-catalogs/aws-cesm2-le.json"
    )
    
    ## subset for temperature data
    ## to look at available data, use: catalog.df
    catalog_subset = catalog.search(variable=varname, frequency="monthly")
    
    ## kwargs for opening data
    kwargs = dict(
        aggregate=True,
        xarray_open_kwargs=dict(engine="zarr",decode_timedelta=True),
        zarr_kwargs={"consolidated": True},
        storage_options={"anon": True},
    )
    
    ## open data (but don't load to memory)
    dsets = catalog_subset.to_dataset_dict(**kwargs)
    data = dsets["atm.historical.monthly.cmip6"]

    ## optionally load ssp data as well
    if load_ssp370:
        data = xr.concat([data, dsets["atm.ssp370.monthly.cmip6"]], dim="time")

    ## trim data (select ensemble members and lon/lat space)
    lonlat_idx = dict(lon=slice(*lon_range), lat=slice(*lat_range))
    data = data.sel(lonlat_idx).isel(member_id=slice(None,n_members))

    ## Load data to memory
    return data[varname].compute()

## Test

In [None]:
## specify kwargs
kwargs = dict(lon_range=[280, 300], lat_range=[35, 45], n_members=8, load_ssp370=True)
kwargs = dict(lon_range=[285, 295], lat_range=[35, 45], n_members=8, load_ssp370=True)

t0 = time.time()
data0 = load_cesm_from_cloud(**kwargs)
print(f"{time.time()-t0:.1f} seconds")

t0 = time.time()
data0 = load_cesm_from_cloud_parallel(n_workers=8, **kwargs)
print(f"{time.time()-t0:.1f} seconds")

In [None]:
d = xr.open_dataset("../single_model_ensemble/data/SST_hist.nc")

In [None]:
data0.isel(time=0, member_id=0)

In [None]:
d.TLONG.values.min()

## Old version

In [None]:
## specify lat/lon range
lonlat_vals = dict(lon=slice(285, 295), lat=slice(35, 45))

## trim in lon/lat space
data_ = data["TREFHT"].sel(lonlat_vals)

In [None]:
t0 = time.time()
data_.load();
t1 = time.time()

In [None]:
member_idx = dict(member_id=0)

## load data into memory
t0 = time.time()
data_loaded = data["TREFHT"].isel(member_idx).sel(lonlat_vals).compute()
t1 = time.time()
print(f"Elapsed time: {t1-t0:.2f} seconds.")