<span style="color:hotpink; font-size:40px; font-weight:bold;">Packages and imports</span>

In [4]:
!git config --global user.name "mauriekeppens"
!git config --global user.email "keppens_maurie@hotmail.com"

In [2]:
from xmip.preprocessing import combined_preprocessing
from xmip.postprocessing import concat_experiments, merge_variables
from xmip.utils import cmip6_dataset_id

from dask.diagnostics import ProgressBar
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import xesmf as xe
import pandas as pd
import intake

import gcsfs
fs = gcsfs.GCSFileSystem()

<span style="color:hotpink; font-size:40px; font-weight:bold;">Setting path for saving cleaned CMIP6 testbed files</span>

In [None]:
### set paths ###

# this is the path to the 1982-2023 testbed, already saved
ensemble_dir = 'gs://leap-persistent/abbysh/pco2_all_members_1982-2023/post00_regridded_members'

# if you'd like to save your own version of the testbed, 
# for instance, change the year range (a few cells below) or include other variables,
# define your own path here:

# your_username = # leap pangeo bucket name, also should be your github username
# ensemble_dir = f'gs://leap-persistent/{your_username}/pco2_residual/post00_regridded_members'

<span style="color:hotpink; font-size:40px; font-weight:bold;">Load CMIP6 datasets from cloud that satisfy our requirements</span>

<span style="color:lightblue; font-size:30px; font-weight:bold;">CMIP6 Earth System Models (ESMs) we want, called "source_id"</span>

In [None]:
source_id=['ACCESS-ESM1-5','UKESM1-0-LL','CMCC-ESM2','CESM2-WACCM','CESM2','CanESM5-CanOE','CanESM5','MPI-ESM1-2-LR','GFDL-ESM4']

<span style="color:lightblue; font-size:30px; font-weight:bold;">Searching CMIP6 catalog</span>

<span style="color:thistle; font-size:25px; font-weight:bold;">Definitions and units</span>

Note: To convert pascals to microatm: 1 atmosphere (atm) = 101325 pascal (Pa) and then 10^6uatm = 1atm.

| Our variable name | CMIP6 output name  | Description                      | Units                                  |
|-------------------|--------------------|----------------------------------|----------------------------------------|
| pCO2              | spco2              | sea surface co2 partial pressure |pascals, CONVERTS TO MICROATM LATER     |
| SST               | tos                | sea surface temperature          |degrees Celsius                         |
| SSS               | sos                | sea surface salinity             |.001 (parts per thousand)               |
| Chl               | chl                | sea surface chlorophyll          |kilograms per cubic meter               |
| MLD               | mlotst             | mixed layer depth                |meters (defined by sigma T criterion)   |

In [None]:
# filter the full catalog for data we could use

# This is the store for CMIP6 datasets that pass ingestion tests. For more information: https://github.com/leap-stc/cmip6-leap-feedstock
url = "https://storage.googleapis.com/cmip6/cmip6-pgf-ingestion-test/catalog/catalog.json"
col = intake.open_esm_datastore(url)

##search for data##
cat = col.search(
    source_id = source_id, # ESM list
    variable_id=['tos', 'sos', 'chl', 'mlotst', 'spco2'], #variables we want, descriptions written above
    table_id=['Omon'], # monthly ocean output only
    experiment_id=['ssp245','historical'], # ssp scenario of choice, plus historical for pre-2014 model output
    require_all_on=['source_id', 'member_id', 'grid_label'] # this ensures that results will have all variables and experiments available
)

<span style="color:thistle; font-size:25px; font-weight:bold;">To view available datasets, number of members per ESM</span>

In [None]:
cat.df.groupby(['source_id', 'grid_label','table_id'])[['member_id']].nunique()

<span style="color:hotpink; font-size:40px; font-weight:bold;">Regridding</span>

<span style="color:lightblue; font-size:30px; font-weight:bold;">Turn catalog of ESM datasets into dictionary</span>

In [None]:
##turn data into dataset dictionary##
ddict = cat.to_dataset_dict(
    preprocess=combined_preprocessing,
    xarray_open_kwargs=dict(use_cftime=True),
    aggregate=False
)

## some arrays have an "area" variable. This is to drop that variable:

for item in ddict:
    if 'area' in ddict[item]:
        ddict[item] = ddict[item].drop_vars('area')

## ignore the warnings!

<span style="color:lightblue; font-size:30px; font-weight:bold;">Filltering out members with buggy times</span>

What the following hack does is take the "time" data from a historical member, and a scenario member, that have no bugs.

We then apply the non-buggy historical "time" data to all the historical members, and apply the non-buggy scenario "time" data to all the scenario members. 

If a member does not satisfy the necessary requirements (historical "time" data starts in 1850 and ends in 2014, for example), it is removed from our list of members.

In [None]:
## temporary time hack due to time bugs ###

hist_time = ddict['CMIP.CCCma.CanESM5.historical.r6i1p1f1.Omon.sos.gn.none.r6i1p1f1.v20190429.gs://cmip6/CMIP6/CMIP/CCCma/CanESM5/historical/r6i1p1f1/Omon/sos/gn/v20190429/'].time
fut_time = ddict['ScenarioMIP.CCCma.CanESM5.ssp245.r3i1p1f1.Omon.sos.gn.none.r3i1p1f1.v20190429.gs://cmip6/CMIP6/ScenarioMIP/CCCma/CanESM5/ssp245/r3i1p1f1/Omon/sos/gn/v20190429/'].time
dict_list = list(ddict.values())
for item in dict_list:
    if item.time.data[0].year == 1850 and item.time.data[-1].year == 2014:
        item['time'] = hist_time
    elif item.time.data[0].year == 2015 and item.time.data[-1].year == 2100:
        item['time'] = fut_time

In [None]:
## functions like xarray.concat: chains together historical + scenario for each member, so time span is 1850-2100

ds = concat_experiments(ddict)

In [None]:
## functions like xarray.merge: combines separate variables into one dataset per member

ds = merge_variables(ds) 

# lots of warnings right now, but should be fine! the warnings mean the bad members are getting removed
# TODO: I think we can fix this bug by doing a time slice for when we want on each dataset BEFORE merging

<span style="color:lightblue; font-size:30px; font-weight:bold;">Setting up target grid</span>

In [None]:
## create desired resolution and time frame 
# This is set up for -180 thru 180 degrees for longitude, -90 thru 90 degrees for latitude 
# ^ As opposed to 0 thru 360 for longitude, for example

ylat = xr.DataArray(data=[x+.5 for x in range(-90, 90, 1)], dims=['ylat'], coords=dict( ylat=(['ylat'],[x+.5 for x in range(-90, 90, 1)]) ),)
xlon = xr.DataArray(data=[x+.5 for x in range(-180,180,1)], dims=['xlon'], coords=dict( xlon=(['xlon'],[x+.5 for x in range(-180,180,1)]) ),)
# alternatively: xlon = xr.DataArray(data=[x+.5 for x in range(0,360,1)], dims=['xlon'], coords=dict( xlon=(['xlon'],[x+.5 for x in range(0,360,1)]) ),)

## desired start and end 'year-month' for testbed 
processed_start_yearmonth = '1982-02'
processed_end_yearmonth = '2023-12'

## desired start and end year for testbed
# init_year = 1982
# fin_year = 2023

## time should be monthly on the middle of the month ('freq = "MS") refers to "month start" frequency, and we add 14 days to get to mid-month
# note that the time doesnt affect regridding but we do use this time to overwrite the monthly dates so its consistent
ttime = pd.date_range(start=str(processed_start_yearmonth), end=str(processed_end_yearmonth),freq='MS') + np.timedelta64(14, 'D') 

## set up our desired grid. It must be named this way for old XESFM versions
target_grid = xr.Dataset({'time':(['time'],ttime.values), 'latitude':(['latitude'],ylat.values),'longitude':(['longitude'],xlon.values)}) 

<span style="color:lightblue; font-size:30px; font-weight:bold;">Functions for regridding data</span>

In [None]:
def replace_calendar(ds:xr.Dataset) -> xr.Dataset:
    """
    Sets new time data for xarray dataset, according to target grid defined above.

    Args:
        ds (xr.Dataset): Initial dataset with times to be fixed.
    Returns: 
        ds (xr.Dataset): Dataset with fixed times.
    """
    year = ds.time.data[0].year
    month = ds.time.data[0].month
    start_date = f'{year}-{month:0>2}-01'
    new_monthly_time = xr.cftime_range(start_date, periods=len(ds.time), freq='1MS')
    ds = ds.assign_coords(time=new_monthly_time)
    return ds

#TODO:  create a regridder dict per source_id (faster)

# target_grid = xe.util.grid_global(1,1) # this is old, from julius
def regrid(target_grid, ds:xr.Dataset) -> xr.Dataset:
    """
    Regrids dataset to match times/lats/lons we want.

    Args:
        target_grid (xr.Dataset): Ideal dataset example in terms of format of times/lat/lon.
        ds (xr.Dataset): Dataset to regrid to match format of target_grid.
    Returns:
        ds_regridded (xr.Dataset): Dataset changed to match target_grid format.
    """
    #FIXME: This should not be done for every dataset
    regridder = xe.Regridder(ds, target_grid, 'bilinear', ignore_degenerate=True, periodic=True) 
    #TODO: Check if this should be conservative?
    ds_regridded = regridder(ds, keep_attrs=True)
    
    return ds_regridded

def full_testbed_processing(target_grid, init_year_month, fin_year_month, timespan, ds:xr.Dataset) -> xr.Dataset:
    """

    Args:
        target_grid (xr.Dataset): Ideal dataset example in terms of format of times/lat/lon.
        init_year_month (str): Initial year/month included in dataset.
        fin_year_month (str): Final year/month included in dataset.
        timespan (int): Number of months in dataset (time length of dataset).
        ds (xr.Dataset): Dataset to regrid to match format of target_grid.

    Returns:
        ds_new_cal (xr.Dataset): Dataset gridded to match target_grid format, with changed calendar and removed extraneous variables.
    """
    ds = ds.squeeze(drop=True)
    if 'lev' in ds.dims:
        ds = ds.isel(lev=0).drop_vars(('lev'))
    
    ds = ds.sel(time=slice(init_year_month,fin_year_month))
    # testing
    assert len(ds.time) == timespan
    assert ds.time.data[0].year == int(init_year_month[0:4])
    
    # Processing
    ds_regridded = regrid(target_grid, ds)
    ds_new_cal = replace_calendar(ds_regridded)

    return ds_new_cal

<span style="color:lightblue; font-size:30px; font-weight:bold;">Regridding all members and saving them</span>

In [None]:
member_counter = 0

## loop through all members

for k,item in ds.items():
        print(f"Processing member no.{member_counter},{k}")

        ## regridding step here
        item_out = full_testbed_processing(target_grid, processed_start_yearmonth, processed_end_yearmonth, len(ttime), item)

        ## get CMIP6 ID for member
        item_id = cmip6_dataset_id(item_out, id_attrs=[
            'source_id',
            'variant_label',
            'table_id'
        ])

        ## converting pco2 to microatmospheres, from pascals!
        if 'spco2' in list(item_out.keys()):
            if item_out['spco2'].attrs['units'] == 'Pa':
                print('fixing spco2 units')
                new_spco2 = item_out.spco2 * 10**6 / 101325  #to get to microatm
                item_out['spco2'].values = new_spco2
                item_out['spco2'].attrs["units"] = 'microatmospheres'
        
        ### removing unneccessary variables ###
        if 'lev_bounds' in item_out:
            item_out = item_out.drop_vars(('lev_bounds'))
        if 'time_bounds' in item_out:
            item_out = item_out.drop_vars(('time_bounds'))
        if 'lev_partial' in item_out:
            item_out = item_out.drop_vars(('lev_partial'))
        if 'nbnd' in item_out:
            item_out = item_out.drop_vars(('nbnd'))
        if 'lev_partial' in item_out.chl.dims:
            item_out['chl'] = item_out['chl'].sel({'lev_partial':1},drop=True)
    
        ## fixing variable/coord names ###
        fixed_names = ['ylat','xlon','mld','sss','sst']
        old_names = ['latitude','longitude','mlotst','sos','tos']
        
        for new,old in zip(fixed_names,old_names):
            item_out = item_out.rename({old:new})
        
        ## calculate pco2-T (temperature component of pco2) and pco2-residual (or "non-T pco2")
        pco2_T_calc = item_out['spco2'].mean('time') * np.exp(0.0423 * (item_out['sst'] - item_out['sst'].mean("time")))
        pco2_resid_calc = item_out['spco2'] - pco2_T_calc
        
        item_out = item_out.assign(pco2_T = pco2_T_calc)
        item_out = item_out.assign(pco2_residual = pco2_resid_calc)

        ## to make sure dimensions are lined up for all variables
        item_out = item_out.transpose('time','ylat','xlon')

        ## save regridded member data
        save_path = f"{ensemble_dir}/{item.attrs['source_id']}/member_{item.attrs['intake_esm_attrs:member_id']}/{item_id}.zarr"
        print(f"Writing to {save_path = }")
        with ProgressBar():
            item_out.chunk({'time':200}).to_zarr(save_path, mode='w')
        member_counter +=1