# Preprocess CMIP6 data
Script for downloading and saving CMIP6 files, with ability to subset by time and space. CMIP6 data is lazily loaded directly from the cloud, using the Pangeo - Google Cloud Public Dataset Program collaboration (more info [here](https://medium.com/pangeo/cmip6-in-the-cloud-five-ways-96b177abe396)).

# Setup

In [1]:
import xarray as xr
import xagg as xa
import pandas as pd
import numpy as np
import cftime
from tqdm.notebook import tqdm
import re
from operator import itemgetter # For list subsetting but this is idiotic
import intake
import gcsfs
import os
import warnings 
import xagg as xa

In [2]:
from funcs_support import (get_varlist,get_params)

In [3]:
dir_list = get_params()

In [4]:
# Set whether to regrid 360-day calendars to 365-day calendars
# (probably don't do this while saving files. Only do this in 
# processing code that doesn't affect the original file)
regrid_360 = False

In [9]:
# Set parameters for spatiotemporal subsetting
subset_params_all = [{'lat':[16.5,21.5],'lon':[70,75],
                      'time':{'historical':['1958-01-01','2014-12-31'],
                              'amip':['1958-01-01','2014-12-31'],
                              'ssp370':['2015-01-01','2099-12-31'],
                              'ssp585':['2015-01-01','2099-12-31']},
                      'fn_suffix':'_Mumbai',
                      'lon_range':180,'lon_origin':-180}]

#data_params_all = [{'experiment_id':'historical','table_id':'day','variable_id':'pr'},
#                   {'experiment_id':'ssp370','table_id':'day','variable_id':'pr'},
#                   {'experiment_id':'ssp585','table_id':'day','variable_id':'pr'}]

data_params_all = [{'experiment_id':'historical','table_id':'day','variable_id':'tas','source_id':'ACCESS-CM2'},
                   {'experiment_id':'ssp370','table_id':'day','variable_id':'tas','source_id':'ACCESS-CM2'},
                   {'experiment_id':'ssp585','table_id':'day','variable_id':'tas','source_id':'ACCESS-CM2'}]


# Process

In [10]:
## Prepare the full query for all the datasets that will end up getting use in this 
# process - this is to create the master dataset, so to build up the 'model' and 
# 'experiment' dimension in the dataset with all the values that will end up used
source_calls = np.zeros(len(data_params_all[0].keys()))

for key in data_params_all[0].keys():
    if len(np.unique([x[key] for x in data_params_all]))==1:
        source_calls[list(data_params_all[0].keys()).index(key)] = 1
        
# To skip the 'other' one in the join below, hopefully it works.


# First get all the ones with the same value for each key 
subset_query = ' and '.join([k+" == '"+data_params_all[0][k]+"'" for k in itemgetter(*source_calls.nonzero()[0])(list(data_params_all[0].keys())) if k != 'other'])

# Now add all that are different between subset params - i.e. those that need an OR statement
# These have to be in two statements, because if there's only one OR'ed statement, then the 
# for k in statement goes through the letters instead of the keys. 
if len((source_calls-1).nonzero()[0])==1:
    subset_query=subset_query+' and ('+') and ('.join([' or '.join([k+" == '"+data_params[k]+"'" for data_params in data_params_all]) 
              for k in [itemgetter(*(source_calls-1).nonzero()[0])(list(data_params_all[0].keys()))] if k != 'other'])+')'
elif len((source_calls-1).nonzero()[0])>1:
    subset_query=subset_query+' and ('+') and ('.join([' or '.join([k+" == '"+data_params[k]+"'" for data_params in data_params_all]) 
              for k in itemgetter(*(source_calls-1).nonzero()[0])(list(data_params_all[0].keys())) if k != 'other'])+')'


In [11]:
# Access google cloud storage links
fs = gcsfs.GCSFileSystem(token='anon', access='read_only')
# Get info about CMIP6 datasets
cmip6_datasets = pd.read_csv('https://storage.googleapis.com/cmip6/cmip6-zarr-consolidated-stores.csv')
# Get subset based on the data params above (for all search parameters)
cmip6_sub = cmip6_datasets.query(subset_query)

if len(cmip6_sub) == 0:
    warnings.warn('Query unsuccessful, no files found! Check to make sure your table_id matches the domain - SSTs are listed as "Oday" instead of "day" for example')
    

In [12]:
cmip6_sub

Unnamed: 0,activity_id,institution_id,source_id,experiment_id,member_id,table_id,variable_id,grid_label,zstore,dcpp_init_year,version
377958,ScenarioMIP,CSIRO-ARCCSS,ACCESS-CM2,ssp370,r1i1p1f1,day,tas,gn,gs://cmip6/CMIP6/ScenarioMIP/CSIRO-ARCCSS/ACCE...,,20191108
379386,CMIP,CSIRO-ARCCSS,ACCESS-CM2,historical,r1i1p1f1,day,tas,gn,gs://cmip6/CMIP6/CMIP/CSIRO-ARCCSS/ACCESS-CM2/...,,20191108
391447,CMIP,CSIRO-ARCCSS,ACCESS-CM2,historical,r2i1p1f1,day,tas,gn,gs://cmip6/CMIP6/CMIP/CSIRO-ARCCSS/ACCESS-CM2/...,,20191125
423702,ScenarioMIP,CSIRO-ARCCSS,ACCESS-CM2,ssp370,r2i1p1f1,day,tas,gn,gs://cmip6/CMIP6/ScenarioMIP/CSIRO-ARCCSS/ACCE...,,20200303
424456,ScenarioMIP,CSIRO-ARCCSS,ACCESS-CM2,ssp585,r2i1p1f1,day,tas,gn,gs://cmip6/CMIP6/ScenarioMIP/CSIRO-ARCCSS/ACCE...,,20200303
425192,CMIP,CSIRO-ARCCSS,ACCESS-CM2,historical,r3i1p1f1,day,tas,gn,gs://cmip6/CMIP6/CMIP/CSIRO-ARCCSS/ACCESS-CM2/...,,20200306
439379,ScenarioMIP,CSIRO-ARCCSS,ACCESS-CM2,ssp585,r3i1p1f1,day,tas,gn,gs://cmip6/CMIP6/ScenarioMIP/CSIRO-ARCCSS/ACCE...,,20200428
439638,ScenarioMIP,CSIRO-ARCCSS,ACCESS-CM2,ssp370,r3i1p1f1,day,tas,gn,gs://cmip6/CMIP6/ScenarioMIP/CSIRO-ARCCSS/ACCE...,,20200428
515792,ScenarioMIP,CSIRO-ARCCSS,ACCESS-CM2,ssp585,r1i1p1f1,day,tas,gn,gs://cmip6/CMIP6/ScenarioMIP/CSIRO-ARCCSS/ACCE...,,20210317


In [13]:
#------ Process by variable and dataset in the subset ------
overwrite=False
for data_params in data_params_all:
    # Get subset based on the data params above, now just for this one variable
    cmip6_sub = cmip6_datasets.query(' and '.join([k+" == '"+data_params[k]+"'" for k in data_params.keys() if k != 'other']))
         
        
    for url in tqdm(cmip6_sub.zstore.values):
        print('processing '+url.split('/')[6]+'!')
        try:
            # Set output filenames
            output_fns = [None]*len(subset_params_all)
            path_exists = [None]*len(subset_params_all)
            for subset_params in subset_params_all:
                if 'time' in subset_params:
                    time_str = '-'.join([re.sub('-','',t) for t in subset_params['time'][data_params['experiment_id']]])
                else:
                    time_str = ''
                    
                if 'member_id' in data_params:
                    member_id = data_params['member_id']
                else:
                    member_id = url.split('/')[8]
                output_fns[subset_params_all.index(subset_params)] = (dir_list['raw']+url.split('/')[6]+'/'+
                                                                     data_params['variable_id']+'_'+
                                                                     data_params['table_id']+'_'+url.split('/')[6]+'_'+
                                                                     data_params['experiment_id']+'_'+member_id+'_'+
                                                                     time_str+
                                                                     subset_params['fn_suffix']+'.nc')

                if 'other' in data_params.keys(): 
                    if 'plev_subset' in data_params['other'].keys():
                        output_fns[subset_params_all.index(subset_params)] = re.sub(data_params['variable_id'],
                                                                                data_params['other']['plev_subset']['outputfn'],
                                                                               output_fns[subset_params_all.index(subset_params)])



                path_exists[subset_params_all.index(subset_params)] = os.path.exists(output_fns[subset_params_all.index(subset_params)])

            if (not overwrite) & all(path_exists):
                warnings.warn('All files already created for '+data_params['variable_id']+' '+
                                                                     data_params['table_id']+' '+url.split('/')[6]+' '+
                                                                     data_params['experiment_id']+' '+member_id+', skipped.')
                continue
            elif any(path_exists):
                if overwrite:
                    for subset_params in subset_params_all:
                        if path_exists[subset_params_all.index(subset_params)]:
                            os.remove(output_fns[subset_params_all.index(subset_params)])
                            warnings.warn('All files already exist for '+data_params['variable_id']+' '+
                                                                                 data_params['table_id']+' '+url.split('/')[6]+' '+
                                                                                 data_params['experiment_id']+' '+member_id+
                                          ', because OVERWRITE=TRUE these files have been deleted.')


            # Open dataset
            ds = xr.open_zarr(fs.get_mapper(url),consolidated=True)

            # Rename to lat / lon (let's hope there's no 
            # Latitude / latitude_1 / etc. in this dataset)
            try:
                ds = ds.rename({'longitude':'lon','latitude':'lat'})
            except: 
                pass

            # same with 'nav_lat' and 'nav_lon' ???
            try:
                ds = ds.rename({'nav_lon':'lon','nav_lat':'lat'})
            except: 
                pass

            # If precip, kg/m^2/s, switch to mm/day
            #if data_params['variable_id']=='pr':
            #    ds[data_params['variable_id']] = ds[data_params['variable_id']]*60*60*24

            # Fix coordinate doubling (this was an issue in NorCPM1, 
            # where thankfully the values of the variables were nans,
            # though I still don't know how this happened - some lat
            # values were doubled within floating point errors)
            if 'lat' in ds[data_params['variable_id']].dims:
                if len(np.unique(np.round(ds.lat.values,10))) != ds.dims['lat']:
                    ds = ds.isel(lat=(~np.isnan(ds.isel(lon=1,time=1)[data_params['variable_id']].values)).nonzero()[0],drop=True)
                    warnings.warn('Model '+ds.source_id+' has duplicate lat values; attempting to compensate by dropping lat values that are nan in the main variable in the first timestep')
                if len(np.unique(np.round(ds.lon.values,10))) != ds.dims['lon']:
                    ds = ds.isel(lon=(~np.isnan(ds.isel(lat=1,time=1)[data_params['variable_id']].values)).nonzero()[0],drop=True)
                    warnings.warn('Model '+ds.source_id+' has duplicate lon values; attempting to compensate by dropping lon values that are nan in the main variable in the first timestep')

            # Sort by time, if not sorted (this happened with
            # a model; keeping a warning, cuz this seems weird)
            if 'time' in subset_params:
                if (ds.time.values != np.sort(ds.time)).any():
                    warnings.warn('Model '+ds.source_id+' has an unsorted time dimension.')
                    ds = ds.sortby('time')

            # If 360-day calendar, regrid to 365-day calendar
            if regrid_360:
                if ds.dims['dayofyear'] == 360:
                    # Have to put in the compute() because these 
                    # are by default dask arrays, chunked along
                    # the time dimension, and can't interpolate
                    # across dask chunks... 
                    ds = ds.compute().interp(dayofyear=(np.arange(1,366)/365)*360)
                    # And reset it to 1:365 indexing on day of year
                    ds['dayofyear'] = np.arange(1,366)
                    # Throw in a warning, too, why not
                    warnings.warn('Model '+ds.source_id+' has a 360-day calendar; daily values were interpolated to a 365-day calendar')

            # Now, save by the subsets desired in subset_params_all above
            for subset_params in subset_params_all:
                # Make sure this file hasn't already been processed
                if (not overwrite) & path_exists[subset_params_all.index(subset_params)]:
                    warnings.warn(output_fns[subset_params_all.index(subset_params)]+' already exists; skipped.')
                    continue

                # Make sure the target directory exists
                if not os.path.exists(dir_list['raw']+url.split('/')[6]+'/'):
                    os.mkdir(dir_list['raw']+url.split('/')[6]+'/')
                    warnings.warn('Directory '+dir_list['raw']+url.split('/')[6]+'/'+' created!')

                # Fix longitude (by setting it to either [-180:180] 
                # or [0:360] as determined by subset_params, and 
                # to roll them so the correct range is consecutive 
                # in lon (so if you're looking at the Equatorial 
                # Pacific, make it 0:360, with the first lon value
                # at 45E). 
                if 'lat' in ds[data_params['variable_id']].dims:
                    ds_tmp = xa.fix_ds(ds,subset_params)
                    # Now, cutoff the values below the 'lon_origin', 
                    # because slice doesn't work if the indices aren't
                    # montonically increasing (and the above changes it
                    # to [lon_origin:360 0:lon_origin]
                    if np.abs(ds_tmp.lon[0]-subset_params['lon_origin'])>5:
                        ds_tmp = ds_tmp.isel(lon=np.arange(0,(ds_tmp.lon // (subset_params['lon_origin']) == 0).values.nonzero()[0][0]))
                else:
                    ds_tmp = ds.copy()
                    warnings.warn('fix_ds did not work because of the multi-dimensional index')

                # Subset by time as set in subset_params
                if 'time' in subset_params:
                    if (ds.time.max().dt.day==30) | (type(ds.time.values[0]) == cftime._cftime.Datetime360Day): 
                        # (If it's a 360-day calendar, then subsetting to "12-31"
                        # will throw an error; this switches that call to "12-30")
                        # Also checking explicitly for 360day calendar; some monthly 
                        # data is still shown as 360-day even when it's monthly, and will
                        # fail on date ranges with date 31 in a month
                        ds_tmp = (ds_tmp.sel(time=slice(subset_params['time'][data_params['experiment_id']][0],
                                                re.sub('-31','-30',subset_params['time'][data_params['experiment_id']][1]))))
                    else:
                        ds_tmp = (ds_tmp.sel(time=slice(*subset_params['time'][data_params['experiment_id']])))

                # Subset by space as set in subset_params
                if 'lat' in subset_params.keys():
                    if not 'lat' in ds[data_params['variable_id']].dims:
                        ds_tmp = ds_tmp.where((ds_tmp.lat >= subset_params['lat'][0]) & (ds_tmp.lat <= subset_params['lat'][1]) &
                         (ds_tmp.lon >= subset_params['lon'][0]) & (ds_tmp.lon <= subset_params['lon'][1]),drop=True)
                    else:
                        ds_tmp = (ds_tmp.sel(lat=slice(*subset_params['lat']),
                                             lon=slice(*subset_params['lon'])))

                # If subsetting by pressure level...
                if 'other' in data_params.keys():
                    if 'plev_subset' in data_params['other'].keys():
                        # Have to use np.allclose for floating point errors
                        try:
                            ds_tmp = ds_tmp.isel(plev=np.where([np.allclose(p,data_params['other']['plev_subset']['plev']) for p in ds_tmp.plev])[0][0])
                            ds_tmp = ds_tmp.rename({data_params['variable_id']:data_params['other']['plev_subset']['outputfn']})
                        except KeyError:
                            print('The pressure levels: ')
                            print(ds_tmp.plev.values)
                            print(' do not contain '+str(data_params['other']['plev_subset']['plev'])+'; skipping.')
                            del ds_tmp
                            continue
                # Trick to make the loading process go faster (otherwise it
                # gets stuck forever in .to_netcdf below; and .load() is 
                # just as slow for some reason)
                #if 'time' in subset_params:
                    #tmp = ds_tmp.mean('time')
                    #del tmp

                # Save as NetCDF file
                if ds_tmp.dims['time']>0:
                    try:
                        ds_tmp.to_netcdf(output_fns[subset_params_all.index(subset_params)])
                    except ValueError:
                        print('issue with export; skipping')
                        #del ds_tmp
                        continue
                else:
                    print('time dimension is 0, skipping')
                    continue
                    

                # Status update
                print(output_fns[subset_params_all.index(subset_params)]+' processed!')

            del ds, ds_tmp, subset_params
        except AssertionError:
            print('checksum error with model '+url.split('/')[6]+', skipping for now.')
            continue
        

  0%|          | 0/3 [00:00<?, ?it/s]

processing ACCESS-CM2!
/dx01/kschwarz/project_data/mumbai_projs/ACCESS-CM2/tas_day_ACCESS-CM2_historical_r1i1p1f1_19580101-20141231_Mumbai.nc processed!
processing ACCESS-CM2!
/dx01/kschwarz/project_data/mumbai_projs/ACCESS-CM2/tas_day_ACCESS-CM2_historical_r2i1p1f1_19580101-20141231_Mumbai.nc processed!
processing ACCESS-CM2!
/dx01/kschwarz/project_data/mumbai_projs/ACCESS-CM2/tas_day_ACCESS-CM2_historical_r3i1p1f1_19580101-20141231_Mumbai.nc processed!


  0%|          | 0/3 [00:00<?, ?it/s]

processing ACCESS-CM2!
/dx01/kschwarz/project_data/mumbai_projs/ACCESS-CM2/tas_day_ACCESS-CM2_ssp370_r1i1p1f1_20150101-20991231_Mumbai.nc processed!
processing ACCESS-CM2!
/dx01/kschwarz/project_data/mumbai_projs/ACCESS-CM2/tas_day_ACCESS-CM2_ssp370_r2i1p1f1_20150101-20991231_Mumbai.nc processed!
processing ACCESS-CM2!
/dx01/kschwarz/project_data/mumbai_projs/ACCESS-CM2/tas_day_ACCESS-CM2_ssp370_r3i1p1f1_20150101-20991231_Mumbai.nc processed!


  0%|          | 0/3 [00:00<?, ?it/s]

processing ACCESS-CM2!
/dx01/kschwarz/project_data/mumbai_projs/ACCESS-CM2/tas_day_ACCESS-CM2_ssp585_r2i1p1f1_20150101-20991231_Mumbai.nc processed!
processing ACCESS-CM2!
/dx01/kschwarz/project_data/mumbai_projs/ACCESS-CM2/tas_day_ACCESS-CM2_ssp585_r3i1p1f1_20150101-20991231_Mumbai.nc processed!
processing ACCESS-CM2!


  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)


time dimension is 0, skipping
