In [1]:
%matplotlib inline

In [2]:
import os
import sys 
import pathlib

In [3]:
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta

In [4]:
HOME = pathlib.Path.home()

In [5]:
sys.path.append('../../../utils/')

In [6]:
from set_root_dir import set_root_dir

In [7]:
from matplotlib import pyplot as plt

In [8]:
import proplot as plot

In [9]:
import numpy as np
import pandas as pd

In [10]:
import xarray as xr

In [11]:
xr.__version__

'0.15.1'

In [12]:
rpath = set_root_dir(root='gdata')

In [13]:
# %%writefile ../../../utils/get_GCM_outputs.py
def get_GCM_outputs(provider='CDS', GCM='ECMWF', var_name='T2M', period='hindcasts', rpath=None, domain=[90, 300, -65, 50], step=None, verbose=False, flatten=True):
    
    """
    Get the GCM outputs 
    
    Parameters
    ----------
    
    - provider: in ['CDS','IRI','JMA'] 
    - GCM: name of the GCM 
    - var_name: in ['T2M', 'PRECIP']
    - period: in ['hindcasts', 'forecasts']
    - rpath (root path, pathlib.Path object, see `set_root_dir` in the utils module)
    - domain [lon_min, lon_max, lat_min, lat_max]
    - step ( in [3, 4, 5] )
    - verbose: Boolean, whether to print names of files successfully opened
    - flatten: Boolean, whether of not to stack the dataset over the spatial (+ member if present) dimension
    to get 2D fields
    
    Return
    ------ 
    
    - dset: xarray.Dataset concatenated along the time dimension 

    """
    
    
    import pathlib
    import xarray as xr
    
    ipath = rpath /  'GCMs'/ 'processed' / period / provider / GCM / var_name
    
    lfiles_gcm = list(ipath.glob(f"{GCM}_{var_name}_seasonal_anomalies_interp_????_??.nc"))
    
    if (period == 'hindcasts') and (len(lfiles_gcm) ) < 200: 
        print(f"Something wrong with the number of files in the list for the {period} period, the length is {len(lfiles_gcm)}")
    if (period == 'forecasts') and (len(lfiles_gcm) ) < 20:
        print(f"Something wrong with the number of files in the list for the {period} period, the length is {len(lfiles_gcm)}")
    
    lfiles_gcm.sort()
    
    print(f"first file is {str(lfiles_gcm[0])}")
    print(f"last file is {str(lfiles_gcm[-1])}")

    dset_l = []
    
    for fname in lfiles_gcm: 
        
        dset = xr.open_dataset(fname)[[var_name.lower()]]

        
        # select the domain 
        
        if domain is not None: 
            dset = dset.sel(lon=slice(domain[0], domain[1]), lat=slice(domain[2], domain[3]))
        if step is not None: 
            dset = dset.sel(step=step)
         
                
        if verbose: 
            print(f"successfully opened and extracted {fname}")
    
        dset_l.append(dset)

    dset = xr.concat(dset_l, dim='time')

    if flatten: 
        
        if 'member' in dset.dims: 
            
            dset = dset.stack(z=('member','lat','lon'))
        
        else: 
            
            dset = dset.stack(z=('lat','lon'))
    
    return dset 

In [14]:
dset_t2m_ecmwf_hindcasts = get_GCM_outputs(provider='CDS', GCM='ECMWF', var_name='T2M', period='hindcasts', rpath=rpath, domain=[90, 300, -65, 50], step=3, flatten=True)

first file is /media/nicolasf/GDATA/END19101/Working/data/GCMs/processed/hindcasts/CDS/ECMWF/T2M/ECMWF_T2M_seasonal_anomalies_interp_1993_01.nc
last file is /media/nicolasf/GDATA/END19101/Working/data/GCMs/processed/hindcasts/CDS/ECMWF/T2M/ECMWF_T2M_seasonal_anomalies_interp_2016_12.nc


In [15]:
dset_t2m_ecmwf_hindcasts

In [16]:
dset_precip_ecmwf_hindcasts = get_GCM_outputs(provider='CDS', GCM='ECMWF', var_name='PRECIP', period='hindcasts', rpath=rpath, domain=[90, 300, -65, 50], step=3, flatten=True)

first file is /media/nicolasf/GDATA/END19101/Working/data/GCMs/processed/hindcasts/CDS/ECMWF/PRECIP/ECMWF_PRECIP_seasonal_anomalies_interp_1993_01.nc
last file is /media/nicolasf/GDATA/END19101/Working/data/GCMs/processed/hindcasts/CDS/ECMWF/PRECIP/ECMWF_PRECIP_seasonal_anomalies_interp_2016_12.nc


In [17]:
dset_precip_ecmwf_hindcasts.drop('valid_time')

In [18]:
dset_t2m_ecmwf_forecasts = get_GCM_outputs(provider='CDS', GCM='ECMWF', var_name='T2M', period='forecasts', rpath=rpath, domain=[90, 300, -65, 50], step=3, flatten=True)

first file is /media/nicolasf/GDATA/END19101/Working/data/GCMs/processed/forecasts/CDS/ECMWF/T2M/ECMWF_T2M_seasonal_anomalies_interp_2017_01.nc
last file is /media/nicolasf/GDATA/END19101/Working/data/GCMs/processed/forecasts/CDS/ECMWF/T2M/ECMWF_T2M_seasonal_anomalies_interp_2019_12.nc


In [19]:
dset_precip_ecmwf_forecasts = get_GCM_outputs(provider='CDS', GCM='ECMWF', var_name='PRECIP', period='forecasts', rpath=rpath, domain=[90, 300, -65, 50], step=3, flatten=True)

first file is /media/nicolasf/GDATA/END19101/Working/data/GCMs/processed/forecasts/CDS/ECMWF/PRECIP/ECMWF_PRECIP_seasonal_anomalies_interp_2017_01.nc
last file is /media/nicolasf/GDATA/END19101/Working/data/GCMs/processed/forecasts/CDS/ECMWF/PRECIP/ECMWF_PRECIP_seasonal_anomalies_interp_2019_12.nc


In [20]:
dset_precip_ecmwf_hindcasts

In [21]:
dset_precip_ecmwf_hindcasts.nbytes / 1e6

230.917928

### now shifts the time index so that the time corresponds to the time of the forecast, not the initialisation time 

### also shifts to the end of the month, to correspond to the convention used in the target time-series 

In [22]:
step = 3

In [23]:
# %%writefile ../../../utils/shift_dset_time.py
def shift_dset_time(dset, name='time', step=3, end_month = True): 
    """
    Shift the time index of a xarray.Dataset by the specified number of steps (in month)
    
    Parameters
    ----------
    - name: str, the name of the time variable (usually 'time')
    - step: the number of steps (in months) by which to shift the time index 
    - end_month: Boolean, if True, the day of the month is set to be the last 
    day of the month. Note that if the xarray Dataset, the day or the month is not 
    equal to 1, it will return an error and fail 
    
    Returm
    ------ 
    
    - dset: the xarray.Dataset with the shifted time variable 
    
    """
    if end_month: 
        if not (np.alltrue(np.ones(len(dset[name].to_index())) == dset[name].to_index().day.values)): 
            print("""warning, the end_month argument is set to True,
            but the time variable does NOT start at the beinning of the month
            """)
        else: 
            dset[name] = dset.time.to_index().shift(periods = step + 1, freq='M')
    else: 
        dset[name] = dset.time.to_index().shift(periods = step, freq='MS')
    
    return dset 

In [None]:
index = dset_t2m_ecmwf_hindcasts.time.to_index()

In [None]:
index[0]

In [None]:
index.shift(periods=3+1, freq='M')

### scikit learn imports 

### scaler 

In [None]:
from sklearn.preprocessing import StandardScaler

### PCA 

In [None]:
from sklearn.decomposition import kernel_pca, pca
from eofs.xarray import Eof

### creates the weights now 

In [None]:
coslat = np.cos(np.deg2rad(dset_t2m_ecmwf_forecasts.coords['lat'].data))
wgts = np.sqrt(coslat)[..., np.newaxis]

In [None]:
wgts.shape

### pipeline 

In [None]:
from sklearn.pipeline import make_pipeline

### TARGETS 

In [None]:
dpath_target = HOME / 'research' / 'Smart_Ideas' / 'outputs' / 'targets' / 'NZ_regions' / 'NZ_6_regions'

#### RAINFALL 

In [None]:
target_var = 'RAIN'

In [None]:
region_name = 'NNI'

In [None]:
targets_rain = []
for region_name in ['NNI','WNI','ENI','NSI','WSI','ESI']: 
    target = pd.read_csv(dpath_target / target_var / region_name / f'TS_NZ_region_{region_name}_{target_var}_3_quantiles_anoms.csv', index_col=0, parse_dates=True)
    target.columns = pd.MultiIndex.from_product([[region_name],target.columns])
    targets_rain.append(target)

In [None]:
targets_rain = pd.concat(targets_rain, axis=1)

In [None]:
targets_rain.head()

In [None]:
targets_rain_anomalies = targets_rain.loc[:, (slice(None), ["anomalies"])]

In [None]:
target_rain_terciles = targets_rain.loc[:, (slice(None), ["cat_3"])]

In [None]:
targets_rain_anomalies.columns = targets_rain_anomalies.columns.droplevel(1)

In [None]:
targets_rain_anomalies.corr()

In [None]:
target_rain_terciles.columns = target_rain_terciles.columns.droplevel(1)

In [None]:
target_rain_terciles.corr()

In [None]:
f, ax = plt.subplots()
targets_rain_anomalies.NNI.plot(ax=ax, lw=2)
ax.grid(ls=':', color='w')

### reduce the dimensionality of the hindcasts / forecasts using PCA 

In [None]:
X_t2m = dset_t2m_ecmwf_hindcasts['t2m'].data

In [None]:
scaler_t2m = StandardScaler()

In [None]:
X_t2m = scaler_t2m.fit_transform(X_t2m)

In [None]:
X_t2m.mean(0)

In [None]:
X_t2m.std(0)

In [None]:
skpca_t2m = pca.PCA(n_components=0.8)

In [None]:
skpca_t2m_PCs = skpca_t2m.fit_transform(X_t2m)

In [None]:
skpca_t2m_PCs.shape

### try Bayesian Gaussian Mixtures to estimate the number of clusters in the GCM data 

In [None]:
from sklearn.mixture import BayesianGaussianMixture

In [None]:
bgm = BayesianGaussianMixture(n_components=100, n_init=100, )

#### fit the bayesian mixture model on the PCs coming from the EOF on the hindcast data

In [None]:
bgm.fit(skpca_t2m_PCs)

### get the weights: it gives weights close to or equal to zero for the unnecessary clusters 

In [None]:
np.round(bgm.weights_, 2)