# Preprocess CMIP6 data
Script for downloading and saving CMIP6 files, with ability to subset by time and space. CMIP6 data is lazily loaded directly from the cloud, using the Pangeo - Google Cloud Public Dataset Program collaboration (more info [here](https://medium.com/pangeo/cmip6-in-the-cloud-five-ways-96b177abe396)).

For each model, files are placed in a subdirectory of the `raw_data_dir` set in the `dwnld_config.py` file `[raw_data_dir]/[model_name]/`. If this subdirectory doesn't yet exist, it is created.

# Setup

## Packages

In [4]:
import xarray as xr
import pandas as pd
import numpy as np
import cftime
from tqdm.notebook import tqdm
import re
from operator import itemgetter # For list subsetting but this is idiotic
import intake
import gcsfs
import os
import warnings 

In [None]:
# Get config file 
import dwnld_config as cfg

In [5]:
# Set whether to regrid 360-day calendars to 365-day calendars
# (probably don't do this while saving files. Only do this in 
# processing code that doesn't affect the original file)
regrid_360 = False

## Variables

In [6]:
# Set CMIP6 data parameters: each of these dict keys 
# is a column in the dataframe that gives the link 
# of each of the relevant files. You'll hopefully not 
# need to go more specific than this. Files defined by
# each dict are processed separately. 
data_params_all =[{'experiment_id':'historical','table_id':'day','variable_id':'pr','member_id':'r1i1p1f1'},
               {'experiment_id':'ssp370','table_id':'day','variable_id':'pr','member_id':'r1i1p1f1'},
               {'experiment_id':'ssp585','table_id':'day','variable_id':'pr','member_id':'r1i1p1f1'},
               {'experiment_id':'historical','table_id':'day','variable_id':'tas','member_id':'r1i1p1f1'},
               {'experiment_id':'ssp370','table_id':'day','variable_id':'pr','member_id':'r1i1p1f1'},
               {'experiment_id':'ssp585','table_id':'day','variable_id':'tas','member_id':'r1i1p1f1'}] 



## Subsetting

In [None]:
# Set parameters for spatiotemporal subsetting
subset_params_all = [{'lat':[-38,40],'lon':[-20,54],
                  'time':{'historical':['1979-01-01','2014-12-31'],'ssp585':['2015-01-01','2100-12-31'],
                          'ssp370':['2015-01-01','2100-12-31']}, # make sure to specify the experiment id for each time range
                 'fn_suffix':'_Africa', # added to end of filename when saving
                 'lon_range':180, # 180 or 360 - do you want your output file to count lon -180:180 or 0:360?
                 'lon_origin':-180}] # set origin (first lon value) of pre-processed grid. 



## `fix_lons` aux function

In [7]:
def fix_lons(ds,subset_params):
    """
    This function fixes a few issues that show up when dealing with 
    longitude values. 
    
    Input: an xarray dataset, with a longitude dimension called "lon"
    
    Changes: 
    - The dataset is re-indexed to -180:180 or 0:360 longitude format, 
      depending on the subset_params['lon_range'] parameter
    - the origin (the first longitude value) is changed to the closest 
      lon value to subset_params['lon_origin'], if using a 0:360 range. 
      In other words, the range becomes [lon_origin:360 0:lon_origin]. 
      This is to make sure the subsetting occurs in the 'right' direction, 
      with the longitude indices increasing consecutively (this is to ensure
      that subsetting to, say, [45, 275] doesn't subset to [275, 45] or vice-
      versa). Set lon_origin to a longitude value lower than your first subset 
      value.
    """
    
    if subset_params['lon_range']==180:
        # Switch to -180:180 longitude if necessary
        if any (ds.lon>180):
            ds = ds.assign_coords(lon=(((ds.lon + 180) % 360) - 180))
        # Change origin to half the world over, to allow for the 
        # longitude indexing to cross the prime meridian, but only
        # if the first lon isn't around -180 (using 5deg as an approx
        # biggest grid spacing). This is intended to move [0:180 -180:0]
        # to [-180:0:180].
        if ds.lon[0] > -175:
            ds = ds.roll(lon=(ds.sizes['lon'] // 2),roll_coords=True)
    elif subset_params['lon_range']==360:
        # Switch to 0:360 longitude if necessary
        ds = ds.assign_coords(lon = ds.lon % 360)
        # Change origin to the lon_origin
        ds = ds.roll(lon=-((ds.lon // subset_params['lon_origin'])==1).values.nonzero()[0][0],roll_coords=True)
    return ds

# Process

In [8]:
## Prepare the full query for all the datasets that will end up getting use in this 
# process - this is to create the master dataset, so to build up the 'model' and 
# 'experiment' dimension in the dataset with all the values that will end up used
source_calls = np.zeros(len(data_params_all[0].keys()))

for key in data_params_all[0].keys():
    if len(np.unique([x[key] for x in data_params_all]))==1:
        source_calls[list(data_params_all[0].keys()).index(key)] = 1

# First get all the ones with the same value for each key 
subset_query = ' and '.join([k+" == '"+data_params_all[0][k]+"'" for k in itemgetter(*source_calls.nonzero()[0])(list(data_params_all[0].keys()))])

# Now add all that are different between subset params - i.e. those that need an OR statement
# These have to be in two statements, because if there's only one OR'ed statement, then the 
# for k in statement goes through the letters instead of the keys. 
if len((source_calls-1).nonzero()[0])==1:
    subset_query=subset_query+' and ('+') and ('.join([' or '.join([k+" == '"+data_params[k]+"'" for data_params in data_params_all]) 
              for k in [itemgetter(*(source_calls-1).nonzero()[0])(list(data_params_all[0].keys()))]])+')'
elif len((source_calls-1).nonzero()[0])>1:
    subset_query=subset_query+' and ('+') and ('.join([' or '.join([k+" == '"+data_params[k]+"'" for data_params in data_params_all]) 
              for k in itemgetter(*(source_calls-1).nonzero()[0])(list(data_params_all[0].keys()))])+')'


In [None]:
# Access google cloud storage links
fs = gcsfs.GCSFileSystem(token='anon', access='read_only')
# Get info about CMIP6 datasets
cmip6_datasets = pd.read_csv('https://storage.googleapis.com/cmip6/cmip6-zarr-consolidated-stores.csv')
# Get subset based on the data params above (for all search parameters)
cmip6_sub = cmip6_datasets.query(subset_query)

if len(cmip6_sub) == 0:
    warnings.warn('Query unsuccessful, no files found! Check to make sure your table_id matches the domain - for example, SSTs are listed as "Oday" instead of "day"')
    

In [None]:
#------ Process by variable and dataset in the subset ------
overwrite=True
for data_params in data_params_all:
    # Get subset based on the data params above, now just for this one variable
    cmip6_sub = cmip6_datasets.query(' and '.join([k+" == '"+data_params[k]+"'" for k in data_params.keys() if k is not 'other']))
     
    for url in tqdm(cmip6_sub.zstore.values):
        # Set output filenames
        output_fns = [None]*len(subset_params_all)
        path_exists = [None]*len(subset_params_all)
        for subset_params in subset_params_all:
            output_fns[subset_params_all.index(subset_params)] = (cfg.lpaths['raw_data_dir']+url.split('/')[5]+'/'+
                                                                 data_params['variable_id']+'_'+
                                                                 data_params['table_id']+'_'+url.split('/')[5]+'_'+
                                                                 data_params['experiment_id']+'_'+url.split('/')[8]+'_'+
                                                                 '-'.join([re.sub('-','',t) for t in subset_params['time'][data_params['experiment_id']]])+
                                                                 subset_params['fn_suffix']+'.nc')
            
            if 'other' in data_params.keys(): 
                if 'plev_subset' in data_params['other'].keys():
                    output_fns[subset_params_all.index(subset_params)] = re.sub(data_params['variable_id'],
                                                                            data_params['other']['plev_subset']['outputfn'],
                                                                           output_fns[subset_params_all.index(subset_params)])
            
            
            
            path_exists[subset_params_all.index(subset_params)] = os.path.exists(output_fns[subset_params_all.index(subset_params)])
        
        if (not overwrite) & all(path_exists):
            warnings.warn('All files already created for '+data_params['variable_id']+' '+
                                                                 data_params['table_id']+' '+url.split('/')[5]+' '+
                                                                 data_params['experiment_id']+' '+url.split('/')[8]+', skipped.')
            continue
        elif any(path_exists):
            if overwrite:
                for subset_params in subset_params_all:
                    if path_exists[subset_params_all.index(subset_params)]:
                        os.remove(output_fns[subset_params_all.index(subset_params)])
                        warnings.warn('All files already exist for '+data_params['variable_id']+' '+
                                                                             data_params['table_id']+' '+url.split('/')[5]+' '+
                                                                             data_params['experiment_id']+' '+url.split('/')[8]+
                                      ', because OVERWRITE=TRUE theses files have been deleted.')
        
        
        # Open dataset
        ds = xr.open_zarr(fs.get_mapper(url),consolidated=True)

        # Rename to lat / lon (let's hope there's no 
        # Latitude / latitude_1 / etc. in this dataset)
        try:
            ds = ds.rename({'longitude':'lon','latitude':'lat'})
        except: 
            pass
        
        # same with 'nav_lat' and 'nav_lon' ???
        try:
            ds = ds.rename({'nav_lon':'lon','nav_lat':'lat'})
        except: 
            pass

        # If precip, kg/m^2/s, switch to mm/day
        #if data_params['variable_id']=='pr':
        #    ds[data_params['variable_id']] = ds[data_params['variable_id']]*60*60*24

        # Fix coordinate doubling (this was an issue in NorCPM1, 
        # where thankfully the values of the variables were nans,
        # though I still don't know how this happened - some lat
        # values were doubled within floating point errors)
        if 'lat' in ds[data_params['variable_id']].dims:
            if len(np.unique(np.round(ds.lat.values,10))) != ds.dims['lat']:
                ds = ds.isel(lat=(~np.isnan(ds.isel(lon=1,time=1)[data_params['variable_id']].values)).nonzero()[0],drop=True)
                warnings.warn('Model '+ds.source_id+' has duplicate lat values; attempting to compensate by dropping lat values that are nan in the main variable in the first timestep')
            if len(np.unique(np.round(ds.lon.values,10))) != ds.dims['lon']:
                ds = ds.isel(lon=(~np.isnan(ds.isel(lat=1,time=1)[data_params['variable_id']].values)).nonzero()[0],drop=True)
                warnings.warn('Model '+ds.source_id+' has duplicate lon values; attempting to compensate by dropping lon values that are nan in the main variable in the first timestep')

        # Sort by time, if not sorted (this happened with
        # a model; keeping a warning, cuz this seems weird)
        if (ds.time.values != np.sort(ds.time)).any():
            warnings.warn('Model '+ds.source_id+' has an unsorted time dimension.')
            ds = ds.sortby('time')
            
        # If 360-day calendar, regrid to 365-day calendar
        if regrid_360:
            if ds.dims['dayofyear'] == 360:
                # Have to put in the compute() because these 
                # are by default dask arrays, chunked along
                # the time dimension, and can't interpolate
                # across dask chunks... 
                ds = ds.compute().interp(dayofyear=(np.arange(1,366)/365)*360)
                # And reset it to 1:365 indexing on day of year
                ds['dayofyear'] = np.arange(1,366)
                # Throw in a warning, too, why not
                warnings.warn('Model '+ds.source_id+' has a 360-day calendar; daily values were interpolated to a 365-day calendar')

        # Now, save by the subsets desired in subset_params_all above
        for subset_params in subset_params_all:
            # Make sure this file hasn't already been processed
            if (not overwrite) & path_exists[subset_params_all.index(subset_params)]:
                warnings.warn(output_fns[subset_params_all.index(subset_params)]+' already exists; skipped.')
                continue
            
            # Make sure the target directory exists
            if not os.path.exists(cfg.lpaths['raw_data_dir']+url.split('/')[5]+'/'):
                os.mkdir(cfg.lpaths['raw_data_dir']+url.split('/')[5]+'/')
                warnings.warn('Directory '+cfg.lpaths['raw_data_dir']+url.split('/')[5]+'/'+' created!')
         
            # Fix longitude (by setting it to either [-180:180] 
            # or [0:360] as determined by subset_params, and 
            # to roll them so the correct range is consecutive 
            # in lon (so if you're looking at the Equatorial 
            # Pacific, make it 0:360, with the first lon value
            # at 45E). 
            if 'lat' in ds[data_params['variable_id']].dims:
                ds_tmp = fix_lons(ds,subset_params)
                # Now, cutoff the values below the 'lon_origin', 
                # because slice doesn't work if the indices aren't
                # montonically increasing (and the above changes it
                # to [lon_origin:360 0:lon_origin]
                if np.abs(ds_tmp.lon[0]-subset_params['lon_origin'])>5:
                    ds_tmp = ds_tmp.isel(lon=np.arange(0,(ds_tmp.lon // (subset_params['lon_origin']) == 0).values.nonzero()[0][0]))
            else:
                ds_tmp = ds
                warnings.warn('fix_lons did not work because of the multi-dimensional index')

            # Subset by time as set in subset_params
            if (ds.time.max().dt.day==30) | (type(ds.time.values[0]) == cftime._cftime.Datetime360Day): 
                # (If it's a 360-day calendar, then subsetting to "12-31"
                # will throw an error; this switches that call to "12-30")
                # Also checking explicitly for 360day calendar; some monthly 
                # data is still shown as 360-day even when it's monthly, and will
                # fail on date ranges with date 31 in a month
                ds_tmp = (ds_tmp.sel(time=slice(subset_params['time'][data_params['experiment_id']][0],
                                        re.sub('-31','-30',subset_params['time'][data_params['experiment_id']][1]))))
            else:
                ds_tmp = (ds_tmp.sel(time=slice(*subset_params['time'][data_params['experiment_id']])))
            
           # Subset by space as set in subset_params
            if not 'lat' in ds[data_params['variable_id']].dims:
                ds_tmp = ds_tmp.where((ds_tmp.lat >= subset_params['lat'][0]) & (ds_tmp.lat <= subset_params['lat'][1]) &
                 (ds_tmp.lon >= subset_params['lon'][0]) & (ds_tmp.lon <= subset_params['lon'][1]),drop=True)
            else:
                ds_tmp = (ds_tmp.sel(lat=slice(*subset_params['lat']),
                                     lon=slice(*subset_params['lon'])))
                
            # If subsetting by pressure level...
            if 'other' in data_params.keys():
                if 'plev_subset' in data_params['other'].keys:
                    # Have to use np.allclose for floating point errors
                    try:
                        ds_tmp = ds_tmp.isel(plev=np.where([np.allclose(p,data_params['other']['plev_subset']['plev']) for p in ds_tmp.plev])[0][0])
                        ds_tmp = ds_tmp.rename({data_params['variable_id']:data_params['other']['plev_subset']['outputfn']})
                    except KeyError:
                        print('The pressure levels: ')
                        print(ds_tmp.plev.values)
                        print(' do not contain '+str(data_params['other']['plev_subset']['plev'])+'; skipping.')
                        del ds_tmp
                        continue

            # Save as NetCDF file
            ds_tmp.to_netcdf(output_fns[subset_params_all.index(subset_params)])

            # Status update
            print(output_fns[subset_params_all.index(subset_params)]+' processed!')
        
        del ds, ds_tmp, subset_params
        