In [1]:
# Parameters 

application = 'Innovation_Vineyards'
varname = 'RAIN_BC'
stat = 'mean'
num_quantiles = 3
target_type = f'cat{num_quantiles}'
step = 4
lag_sst = True
max_lag = 6
detrend_sst = False

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2

In [4]:
%matplotlib inline

In [5]:
import sys 
import pathlib

In [6]:
import matplotlib.pyplot as plt 

In [7]:
import numpy as np 
import pandas as pd 
import xarray as xr
import cartopy.crs as ccrs
from scipy.signal import detrend

In [8]:
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

In [9]:
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import RepeatedStratifiedKFold

In [10]:
import autogluon as ag
from autogluon import TabularPrediction as task

In [11]:
np.random.seed(42)

In [12]:
HOME = pathlib.Path.home()
CWD = pathlib.Path.cwd()

### read the target variable 

In [13]:
ipath_target = pathlib.Path(HOME.joinpath(f"research/Smart_Ideas/outputs/targets/application_cases/{application}/SEASONAL/{varname}"))

In [14]:
target = pd.read_csv(ipath_target.joinpath(f"Seasonal_{varname}_{stat}_anomalies_and_Q{num_quantiles}_categories.csv"), index_col=0, parse_dates=True)

In [15]:
target = target.loc[:,[target_type]]

### keep only data post 1981 to enable direct comparison with GCM derived fields 

In [16]:
target = target.loc['1981':,:]

### reads the SSTs, seasonal anomalies, pre-computed 

In [17]:
ipath_sst = pathlib.Path(HOME.joinpath("/media/nicolasf/END19101/data/ERSST/processed"))

In [18]:
lfiles_sst = list(ipath_sst.glob("*.nc")) 

In [19]:
lfiles_sst.sort() 

In [20]:
lfiles_sst[0]

PosixPath('/media/nicolasf/END19101/data/ERSST/processed/ERSST_seasonal_anomalies_1979-03.nc')

In [21]:
lfiles_sst[-1]

PosixPath('/media/nicolasf/END19101/data/ERSST/processed/ERSST_seasonal_anomalies_2020-03.nc')

In [22]:
dset_sst = xr.open_mfdataset(lfiles_sst, concat_dim='time')

In [23]:
dset_sst

Unnamed: 0,Array,Chunk
Bytes,31.59 MB,64.08 kB
Shape,"(493, 89, 180)","(1, 89, 180)"
Count,1972 Tasks,493 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 31.59 MB 64.08 kB Shape (493, 89, 180) (1, 89, 180) Count 1972 Tasks 493 Chunks Type float32 numpy.ndarray",180  89  493,

Unnamed: 0,Array,Chunk
Bytes,31.59 MB,64.08 kB
Shape,"(493, 89, 180)","(1, 89, 180)"
Count,1972 Tasks,493 Chunks
Type,float32,numpy.ndarray


In [24]:
dset_sst = dset_sst.sel(time=slice('1981',None))

### domain selection 

In [25]:
domain_def = {}
domain_def['HB_seasonal'] = [120, 290, -60, 40]
domain_def['local'] = [150, 200, -50, -10]
domain_def['regional'] = [90, 300, -65, 50]
domain_def['ext_regional'] = [70, 300, -70, 60]
domain_def['global'] = [0, 360, -70, 70]
domain_def['tropics'] = [0, 360, -40, 40]

In [26]:
domain = 'HB_seasonal'

##### initial 

In [27]:
# dset_sst = dset_sst.sel(lat=slice(-60, 40), lon=slice(120, 290)) 

##### extended 

In [28]:
# dset_sst = dset_sst.sel(lat=slice(-70, 70)) 

##### local

In [29]:
dset_sst = dset_sst.sel(lat=slice(*domain_def[domain][2:]), lon=slice(*domain_def[domain][:2])) 

In [30]:
dset_sst

Unnamed: 0,Array,Chunk
Bytes,8.26 MB,17.54 kB
Shape,"(471, 51, 86)","(1, 51, 86)"
Count,2914 Tasks,471 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 8.26 MB 17.54 kB Shape (471, 51, 86) (1, 51, 86) Count 2914 Tasks 471 Chunks Type float32 numpy.ndarray",86  51  471,

Unnamed: 0,Array,Chunk
Bytes,8.26 MB,17.54 kB
Shape,"(471, 51, 86)","(1, 51, 86)"
Count,2914 Tasks,471 Chunks
Type,float32,numpy.ndarray


### shift the time index in order to align the observed SST at time m - 4 with the observed rainfall at time 0, e.g. OND SST --> FMA precip 

In [31]:
dset_sst_shift = dset_sst.copy()

In [32]:
dset_sst_shift = dset_sst.shift(time=step)

### remove the first `step` seasons which are missing 

In [33]:
dset_sst_shift = dset_sst_shift.isel(time=slice(step, None))

In [34]:
dset_sst_shift = dset_sst_shift.stack(s=('lat','lon'))

In [35]:
dset_sst_shift

Unnamed: 0,Array,Chunk
Bytes,8.19 MB,17.54 kB
Shape,"(467, 4386)","(1, 4386)"
Count,5261 Tasks,467 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 8.19 MB 17.54 kB Shape (467, 4386) (1, 4386) Count 5261 Tasks 467 Chunks Type float32 numpy.ndarray",4386  467,

Unnamed: 0,Array,Chunk
Bytes,8.19 MB,17.54 kB
Shape,"(467, 4386)","(1, 4386)"
Count,5261 Tasks,467 Chunks
Type,float32,numpy.ndarray


### drop the land points 

In [36]:
dset_sst_shift = dset_sst_shift.dropna('s')

In [37]:
dset_sst_shift.load()

In [38]:
sst_data = dset_sst_shift['sst'].data

In [39]:
sst_data.shape

(467, 3981)

In [40]:
df_sst = pd.DataFrame(sst_data, index=dset_sst_shift.time.to_index())

In [41]:
df_sst

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,3971,3972,3973,3974,3975,3976,3977,3978,3979,3980
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1981-05-31,0.425469,0.452932,0.445086,0.450316,0.475322,0.539737,0.542407,0.526550,0.528740,0.548263,...,0.509807,0.581705,0.607628,0.593648,0.533564,0.506667,-1.091065,-1.030591,-0.916759,-0.844929
1981-06-30,0.370928,0.395698,0.382532,0.366968,0.372283,0.411084,0.400267,0.383998,0.356813,0.354419,...,0.707175,0.779552,0.809687,0.793489,0.721863,0.687952,-0.684669,-0.676358,-0.618497,-0.563119
1981-07-31,0.192077,0.212412,0.199903,0.176881,0.180434,0.213022,0.209131,0.212409,0.171129,0.157672,...,0.511929,0.646268,0.762092,0.830067,0.803822,0.775673,-0.405702,-0.395685,-0.336811,-0.267130
1981-08-31,0.052004,0.073057,0.070495,0.067105,0.094715,0.151829,0.171223,0.195796,0.170889,0.170496,...,0.244747,0.355242,0.465580,0.521439,0.453079,0.374420,-0.331133,-0.338840,-0.296816,-0.207553
1981-09-30,-0.050947,-0.037916,-0.040765,-0.038309,-0.003810,0.060986,0.090720,0.125933,0.107915,0.111751,...,0.169777,0.224669,0.279029,0.283365,0.177876,0.083774,-0.697181,-0.630045,-0.490177,-0.306935
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2019-11-30,0.139806,0.165507,0.214928,0.269337,0.318541,0.333053,0.294879,0.246468,0.224448,0.224015,...,0.838111,0.582053,0.382717,0.262774,0.212553,0.241214,1.320290,1.193980,1.021117,0.785675
2019-12-31,0.083776,0.086358,0.105316,0.122533,0.134827,0.120190,0.063494,0.003334,-0.030574,-0.047638,...,1.710595,1.440660,1.152585,0.871248,0.617027,0.482767,1.465768,1.336002,1.164136,0.956868
2020-01-31,0.002080,-0.012536,-0.013261,-0.016449,-0.022417,-0.047975,-0.106200,-0.163115,-0.196911,-0.220356,...,2.115240,1.997016,1.809291,1.557510,1.260005,1.049830,1.361691,1.376684,1.384573,1.300473
2020-02-29,-0.013831,-0.035971,-0.043834,-0.052942,-0.064318,-0.092778,-0.150801,-0.206202,-0.240669,-0.267832,...,1.804538,1.746679,1.610256,1.391628,1.113494,0.911428,1.119526,1.197577,1.302717,1.292904


In [42]:
if detrend_sst: 
    df_sst = df_sst.apply(detrend)

### lag ? 

In [43]:
if lag_sst: 
    dset_lagged = pd.concat(
    [df_sst.shift(i).add_suffix(f"_{i}") for i in range(max_lag)], axis=1
    )
    dset_lagged = dset_lagged.dropna()
    df_sst = dset_lagged

### make sure the target is at the right frequency 

In [44]:
target.index.freq = 'M'

In [45]:
target.index

DatetimeIndex(['1981-01-31', '1981-02-28', '1981-03-31', '1981-04-30',
               '1981-05-31', '1981-06-30', '1981-07-31', '1981-08-31',
               '1981-09-30', '1981-10-31',
               ...
               '2019-03-31', '2019-04-30', '2019-05-31', '2019-06-30',
               '2019-07-31', '2019-08-31', '2019-09-30', '2019-10-31',
               '2019-11-30', '2019-12-31'],
              dtype='datetime64[ns]', name='time', length=468, freq='M')

In [46]:
df = pd.concat([df_sst, target], axis=1)

In [47]:
df

Unnamed: 0_level_0,0_0,1_0,2_0,3_0,4_0,5_0,6_0,7_0,8_0,9_0,...,3972_5,3973_5,3974_5,3975_5,3976_5,3977_5,3978_5,3979_5,3980_5,cat3
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1981-01-31,,,,,,,,,,,...,,,,,,,,,,1.0
1981-02-28,,,,,,,,,,,...,,,,,,,,,,1.0
1981-03-31,,,,,,,,,,,...,,,,,,,,,,1.0
1981-04-30,,,,,,,,,,,...,,,,,,,,,,1.0
1981-05-31,,,,,,,,,,,...,,,,,,,,,,3.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2019-11-30,0.139806,0.165507,0.214928,0.269337,0.318541,0.333053,0.294879,0.246468,0.224448,0.224015,...,0.408832,0.502925,0.583108,0.620982,0.651341,-0.441316,-0.357568,-0.183989,-0.050078,2.0
2019-12-31,0.083776,0.086358,0.105316,0.122533,0.134827,0.120190,0.063494,0.003334,-0.030574,-0.047638,...,0.060864,0.189895,0.300710,0.336517,0.342920,-0.286510,-0.150973,0.076421,0.231688,3.0
2020-01-31,0.002080,-0.012536,-0.013261,-0.016449,-0.022417,-0.047975,-0.106200,-0.163115,-0.196911,-0.220356,...,-0.047256,0.069712,0.156612,0.161158,0.145299,0.191236,0.270978,0.402814,0.438559,
2020-02-29,-0.013831,-0.035971,-0.043834,-0.052942,-0.064318,-0.092778,-0.150801,-0.206202,-0.240669,-0.267832,...,-0.037295,0.023826,0.097780,0.140150,0.174880,0.688077,0.732987,0.788707,0.717870,


In [48]:
df = df.dropna(axis=0)

In [49]:
df.head()

Unnamed: 0_level_0,0_0,1_0,2_0,3_0,4_0,5_0,6_0,7_0,8_0,9_0,...,3972_5,3973_5,3974_5,3975_5,3976_5,3977_5,3978_5,3979_5,3980_5,cat3
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1981-10-31,-0.037903,-0.027635,-0.031161,-0.024755,0.01568,0.087785,0.127075,0.171974,0.161534,0.170722,...,0.581705,0.607628,0.593648,0.533564,0.506667,-1.091065,-1.030591,-0.916759,-0.844929,2.0
1981-11-30,-0.088452,-0.074617,-0.073852,-0.063112,-0.019391,0.057738,0.105502,0.159218,0.155047,0.168013,...,0.779552,0.809687,0.793489,0.721863,0.687952,-0.684669,-0.676358,-0.618497,-0.563119,2.0
1981-12-31,-0.083178,-0.057363,-0.040846,-0.009592,0.05498,0.15193,0.218309,0.288859,0.300203,0.328176,...,0.646268,0.762092,0.830067,0.803822,0.775673,-0.405702,-0.395685,-0.336811,-0.26713,3.0
1982-01-31,0.140003,0.196787,0.249345,0.320706,0.423519,0.548384,0.627574,0.705634,0.733831,0.789905,...,0.355242,0.46558,0.521439,0.453079,0.37442,-0.331133,-0.33884,-0.296816,-0.207553,2.0
1982-02-28,0.2626,0.328359,0.38881,0.466164,0.572077,0.695237,0.766572,0.8361,0.861905,0.921259,...,0.224669,0.279029,0.283365,0.177876,0.083774,-0.697181,-0.630045,-0.490177,-0.306935,1.0


In [50]:
df.tail()

Unnamed: 0_level_0,0_0,1_0,2_0,3_0,4_0,5_0,6_0,7_0,8_0,9_0,...,3972_5,3973_5,3974_5,3975_5,3976_5,3977_5,3978_5,3979_5,3980_5,cat3
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2019-08-31,0.030762,0.075377,0.14982,0.236188,0.318111,0.362172,0.347685,0.314055,0.298182,0.303365,...,0.454943,0.460003,0.517762,0.62257,0.72861,1.166771,1.109354,1.10722,1.186688,2.0
2019-09-30,0.067156,0.120798,0.21021,0.316811,0.41967,0.48079,0.478895,0.456438,0.453703,0.4743,...,0.690517,0.697859,0.734446,0.809492,0.911131,0.242399,0.187085,0.2109,0.311823,3.0
2019-10-31,0.110709,0.15474,0.231048,0.32029,0.404579,0.448335,0.431706,0.39814,0.387827,0.401469,...,0.788737,0.809831,0.83152,0.856492,0.909876,-0.416757,-0.363843,-0.213745,-0.062147,2.0
2019-11-30,0.139806,0.165507,0.214928,0.269337,0.318541,0.333053,0.294879,0.246468,0.224448,0.224015,...,0.408832,0.502925,0.583108,0.620982,0.651341,-0.441316,-0.357568,-0.183989,-0.050078,2.0
2019-12-31,0.083776,0.086358,0.105316,0.122533,0.134827,0.12019,0.063494,0.003334,-0.030574,-0.047638,...,0.060864,0.189895,0.30071,0.336517,0.34292,-0.28651,-0.150973,0.076421,0.231688,3.0


### loop over each season, train, evaluate and keep the model 

In [76]:
saved_models = pathlib.Path('./saved_models/AUTOGLUON_v3/')

In [77]:
seasonal_acc = {}
seasonal_best_model = {}
seasonal_predictor_info = {}

In [None]:
for season in range(1, 13): 
    print(f"\ntraining and evaluating for season {season}")
    dfs = df.loc[df.index.month == season]
    opath = saved_models.joinpath(f'./autogluon_exp_SKPCA_SSTobs_1981_2010_pred_{application}_reg_{varname}_targetvar_{target_type}_target_type_season_{season}')
    if not opath.exists(): 
        opath.mkdir(parents=True)
    dfs = dfs.sample(frac=1., random_state=42)
    predictor = task.fit(train_data=dfs, label=target_type, auto_stack=True, presets='best_quality', output_directory=opath, verbosity=0)
    seasonal_acc[season] = predictor.model_performance[predictor.get_model_best()]
    seasonal_best_model = predictor.get_model_best()
    seasonal_predictor_info[season] = predictor.info()
    print(f"best model is {predictor.get_model_best()}, validation accuracy reaching {predictor.model_performance[predictor.get_model_best()]}")


training and evaluating for season 1
best model is weighted_ensemble_k0_l1, validation accuracy reaching 0.6578947368421053

training and evaluating for season 2
best model is CatboostClassifier_STACKER_l0, validation accuracy reaching 0.47368421052631576

training and evaluating for season 3
