# Calculate pixel-level SST trends for specific year combinations

A calculation, file by file, of pixel-wise linear trends of multiple specific start/end year pairs, used in Figures 4/5 and their derivatives. Set up to produce pixel-wise linear trends of SSTs and 500 hPa $\omega$. 

In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import os
import glob
import re
from scipy import stats as sstats
from matplotlib import pyplot as plt
from cartopy import crs as ccrs
from tqdm.notebook import tqdm
import cmocean
import warnings

from funcs_support import get_filepaths,get_params
dir_list = get_params()

In [2]:
def seasmean(ds,seas_range):
    ''' Calculate seasonal means of one season

    Note: the output year is always the year of the _first_
    month of the season - so 2013/12 - 2014/2 is returned 
    as the "2013" DJF season. 

    Parameters
    -------------
    ds : :py:class:`xr.Dataset` or :py:class:`xr.DataArray`

    seas_range : `py:class:`list` of :py:class:`int`
        e.g., [3,5] for MAM. A list of the integer indices
        of the months (inclusive) in season. THIS IS 1-,
        NOT 0-INDEXED.

    Returns
    -------------
    ds_out : :py:class:`xr.Dataset` or :py:class:`xr.DataArray`
        Seasonal means of the input data, with the time 
        dimension renamed "year" and showing integer years
        instead of time
    
    '''
    if seas_range[0] < seas_range[1]:
        # If the start season < end season, no wrap around, 
        # so we need months that are least the start
        # month and at most the end month
        tsub = ((ds.time.dt.month>=seas_range[0]) & 
                (ds.time.dt.month<=seas_range[1]))
    else:
        # If the start season > end season, wrap around the
        # new year, so we need months that are least the start
        # month OR at most the end month
        tsub = ((ds.time.dt.month >= seas_range[0]) | 
                (ds.time.dt.month <= seas_range[1]))
        
    # Subset to just this season
    ds_out = ds.isel(time=tsub)
    
    # Resample annually to get seasonal means (using the anchored offset
    # to ensure wrap-around seasons are correctly averaged
    ds_out = (ds_out.resample(time='1YS-'+
                              pd.Timestamp('1900-'+str(seas_range[0]).zfill(2)+'-01').month_name()[0:3].upper()).
          mean(skipna=False))

    # Change time dimension to year int
    ds_out['time'] = ds_out['time'].dt.year
    ds_out = ds_out.rename({'time':'year'})

    # Take out incomplete seasons (from wrap around, both the first
    # and last year's seasons are incomplete means) 
    if seas_range[0] > seas_range[1]:
        ds_out = ds_out.isel(year=slice(1,-1))
    
    return ds_out

def seasmeans(ds,seasons):
    ''' Calculate seasonal means of multiple seasons

    A wrapper for :py:meth:`seasmean`. 

    Parameters
    -------------
    ds : :py:class:`xr.Dataset` or :py:class:`xr.DataArray`

    seasons : :py:class:`dict` 
        Dictionary, with keys as season names and items as 
        lists with length 2 of the start, end months (inclusive)
        of each season. Note that seasonal limits are 1-indexed, 
        i.e., Jan == 1. 

    Returns
    -------------
    ds_out : :py:class:`xr.Dataset` or :py:class:`xr.DataArray`
        Seasonal means of the input data, along a new dimension 
        "season", with the time dimension renamed "year" and 
        showing integer years instead of time

    '''

    if type(ds) == xr.core.dataset.Dataset:
        # Get vars with time dimension
        vars_wtime = [v for v in ds if 'time' in ds[v].sizes]
    
    # Aggregate to season
    ds_out = xr.concat([seasmean(ds[vars_wtime],seas_range) for seas,seas_range in seasons.items()],
                      dim = pd.Index([seas for seas in seasons],name='season'))

    if type(ds) == xr.core.dataset.Dataset:
        # Add back non-time vars
        for var in [v for v in ds if v not in vars_wtime]:
            ds_out[var] = ds[var].copy()

    # Input seasonal information
    ds_out['season_bnds'] = xr.DataArray(np.array([seas_range for seas,seas_range in seasons.items()]),
                                             dims = ('season','bnds'),
                                             coords = {'season':(('season'),[seas for seas in seasons]),
                                                       })
    ds_out['season_bnds'].attrs['DESCRIPTION'] = 'Start, end month (inclusive) of season (1-indexed, i.e., January = 1)'

    return ds_out
    

In [28]:
#-------------- Inputs --------------
remove_all_existing = False
overwrite = False

# Get seasonal means
seasons = {'MAM':[3,5],
           'AMJ':[4,6],
           'SON':[9,11],
           'OND':[10,12],
           'DJF':[12,2]}
prev_seasons = ['OND']

# Get timeframes
trends = {'rowell':[1986,2004],
          'long0':[1982,2014],
          'long1':[1982,2022],
          'wet0':[2000,2014],
          'wet1':[2000,2020],
          'dry0':[1985,2000]}

# Future experiments to extend historic runs on
base_exp = 'historical'
#base_exp = 'hindcastsf'
future_exps = ['ssp245']
# Range over which to calculate standard deviation
sd_range = slice(1981,2014)

process_var = 'tos'; freq = 'Omon'

acceptable_trend_meths = ['linear']
trend_meths = ['linear']


# Process by variable
for process_var,freq in zip(['tos','wap500'],['Omon','Amon']):
    
    # Filepaths for source files
    df = get_filepaths()
    df = df.query('varname == "'+process_var+'" and freq == "'+freq+'"')
    
    # Filepaths for (possible previously) output files
    dfs_proc = get_filepaths(source_dir='proc')
    
    mods = df.model.unique()
    for mod in tqdm(mods):
        print('\nProcessing model '+mod)
        #-------------- Setup --------------
        
        # Get files for that model
        df_tmp = df.query('model == "'+mod+'"')
    
        if 'historical' not in df_tmp.exp.values:
            warnings.warn('No "historical" files for '+mod+', skipped!')
            continue
        
        # Get runs for each experiment 
        run_dict = df_tmp.groupby('exp')['run'].apply(lambda x: np.sort(list(x))).to_dict()
        
        # Get for each future exp to process which runs
        # match up with the historical runs
        overlap_runs = dict()
        for exp in future_exps:
            if exp in run_dict:
                overlap_runs = {exp:[run for run in run_dict[base_exp] if run in run_dict[exp]]}
        
        
        # Set which run-exp combinations to process
        process_list = dict()
        for exp in future_exps:
            if exp in overlap_runs:
                process_list['hist-'+exp] = [df_tmp.loc[((df_tmp.exp == 'historical') | 
                                                               (df_tmp.exp == exp)) & (df_tmp.run == run)]
                                                   for run in overlap_runs[exp]]
        
        # Now get any runs that are historical only to save
        # in hist-none trends
        hist_only_runs = [run not in np.unique(np.array([run_dict[exp] for exp in future_exps if exp in run_dict]).flatten()) 
                            for run in run_dict['historical']]
        if np.any(hist_only_runs):
            process_list['hist-none'] = [df_tmp.loc[(df_tmp.exp == 'historical') & (df_tmp.run == run)]
                                                   for run in run_dict['historical'][hist_only_runs]]
        
        #-------------- Process by run, exp combination --------------
        for exp in process_list:
            print('Processing exp '+exp+', '+str(len(process_list[exp]))+' runs...')
            for fparams in tqdm(process_list[exp]): 
                # Get save filename
                output_fn = (dir_list['proc']+mod+'/'+
                             process_var+'trends_seasavg_'+mod+'_'+exp+'_'+fparams.run.values[0]+'_'+
                            str(np.min([ts for tf,ts in trends.items()]))+'0101'+'-'+
                            str(np.max([ts for tf,ts in trends.items()]))+'1231'+'_'+
                            'spectrends.nc')
        
                # Get any existing files
                dfs_proc_tmp = (dfs_proc.query('varname == "'+process_var+'trends" and model == "'+mod+
                                               '" and freq == "seasavg" and suffix == "spectrends"'))
        
                if not overwrite:
                    if (len(dfs_proc_tmp)>0) and (dfs_proc_tmp.path.str.contains(output_fn).any()):
                        process = False
                    else:
                        process = True
                else:
                    process = True
        
                if remove_all_existing:
                    for fn in dfs_proc_tmp.path.values:
                        os.remove(fn) 
                        print(fn+' removed.')
        
                if process:
                    try:
                        #--------- Load and cleanup
                        # Load
                        ds = xr.open_mfdataset(fparams.path)
        
                        # Subset to just years needed (with 1 year padding,
                        # if available, to catch previous seasons) 
                        ds = ds.sel(time=slice(str(np.min([ts for tr,ts in trends.items()])-1)+'-01-01',
                                               str(np.max([ts for tr,ts in trends.items()])+1)+'-12-'+str(np.max(ds.time.dt.daysinmonth.values))))
        
                        # Load into memory
                        ds = ds.load()
                        
                        # Clean up
                        # open_mfdataset sometimes concatenates along time even 
                        # if there's no time dependence
                        for gridv in ['lon','lat']:
                            if gridv in ds.cf.bounds:
                                if 'time' in ds[ds.cf.bounds[gridv]].sizes:
                                    ds[ds.cf.bounds[gridv][0]] = ds[ds.cf.bounds[gridv][0]].isel(time=0)
        
                
                        #--------- Get seasonal means
                        ds = seasmeans(ds,seasons)
                        
                        # Add an additional seasonal mean from the previous season
                        for seas in prev_seasons:
                            ds_tmp = ds.sel(season=seas)
                            ds_tmp['year'] = ds_tmp['year'].values+1
                            ds_tmp = ds_tmp.isel(year=slice(0,-1))
                            ds_tmp[process_var] = ds_tmp[process_var].expand_dims({'season':['prev'+seas]})
    
                            # data_vars 'minimal' ensures that only the desired variable
                            # is concatenated 
                            ds = xr.concat([ds,ds_tmp],dim='season',data_vars='minimal')
    
                        #--------- Get std
                        with warnings.catch_warnings():
                            # Catches warning about dof == 0
                            warnings.filterwarnings('ignore')
                            ds[process_var+'_std'] = ds[process_var].sel(year=sd_range).std('year')
                
                        #--------- Get trends
                        # Get which trends to calculate
                        trends_tmp = {trend_name:trend_years for trend_name,trend_years in trends.items()
                                      if np.all([y in ds.year.values for y in trend_years])}
        
                        # Calculate by different desired trend methods
                        for trend_meth in trend_meths:
                            if trend_meth=='linear':
                                # Calculate trends 
                                ds[process_var+'_lslope'] = xr.concat([(ds[process_var].
                                                                     sel(year = slice(*trends_tmp[period])).
                                                                     polyfit(dim='year',deg=1).sel(degree=1))
                                                                    for period in trends_tmp],
                                                                   dim = pd.Index([period for period in trends_tmp],name='time_period'))['polyfit_coefficients']
                                if 'long_name' in ds[process_var].attrs:
                                    ds[process_var+'_lslope'].attrs['long_name'] = ds[process_var].attrs['long_name']+' OLS slope'
                                else:
                                    ds[process_var+'_lslope'].attrs['long_name'] = process_var+' OLS slope'
                                if 'units' in ds[process_var].attrs:
                                    ds[process_var+'_lslope'].attrs['units'] = ds[process_var].attrs['units']+'/year'
                                else:
                                    ds[process_var+'_lslope'].attrs['units'] = '/year'
                                    
                            elif trend_meth == 'theilsen':
                                raise NotImplementedError('theilsen slope estimator not yet implemented')
                            else:
                                raise KeyError('trend_meth '+trend_meth+' must be one of '+', '.join(acceptable_trend_meths))
                            
                        # Add start / end years for each trend period
                        ds = xr.merge([ds,
                                   xr.Dataset({'start_year':(('time_period'),[ys[0] for period,ys in trends_tmp.items()]),
                                               'end_year':(('time_period'),[ys[1] for period,ys in trends_tmp.items()])},
                                              coords = {'time_period':ds.time_period})])
                
                        #--------- Export
                        # Clean up
                        ds = (ds.set_coords(['start_year','end_year','season_bnds']).
                                 drop_vars('degree').
                                 drop_vars(process_var))
            
                        # Attributes
                        ds.attrs['SOURCE'] = 'calculate_sst_trends_specifictimes.ipynb'
                        ds.attrs['DESCRIPTION'] = ', '.join(trend_meths)+' trends in seasonal means calculated at each grid cell for specific year pairs.'
            
                        # Remove existing if needed
                        if os.path.exists(output_fn):
                            os.remove(output_fn)
                            print(output_fn+' removed to allow overwrite!')
                    
                        # Save
                        ds.to_netcdf(output_fn)
                        print(output_fn+' saved!')
        
                        del ds
                    except:
                        warnings.warn('Issue with '+mod+' run '+fparams.run.values[0]+', skipped')
                        continue
                else:
                    print(output_fn+' already exists, skipped!')


        


Processing model ERSST
Processing exp hist-none, 1 runs...


  0%|          | 0/1 [00:00<?, ?it/s]

/dx02/kschwarz/climate_proc/ERSST/tostrends_seasavg_ERSST_hist-none_obs_19820101-20221231_spectrends.nc saved!
