In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import os, shutil, time, pdb, random
import scipy.stats as stats 
import scipy

from math import pi
from datetime import datetime
from collections import OrderedDict
import pickle

import torch
from torch.utils.data import TensorDataset, DataLoader

from plot_utils import table_of_predictions_for_metric

from importlib import reload
from models import *
from utils import *
from runmanager import *
from experiment import *
from plot_utils import *
from preprocessing_utils import *
from seasonal_analysis import * 

from sklearn.metrics import mean_squared_error as mse

import matplotlib
matplotlib.rc_file_defaults()
%matplotlib inline

from matplotlib import pyplot as plt
import seaborn as sns
from tabulate import tabulate

pd.options.display.max_columns = None

np.random.seed(4)

%load_ext autoreload
%autoreload 2

# Detect device.
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

### Import data

In [2]:
# Parameters
start="1998-01-01"
end="2015-12-31"

# TRAIN_PATH = "../data/pickle/df_stations_all_nonzero_extended.pkl"
# TEST_PATH = "../data/pickle/df_stations_val_all_nonzero_extended.pkl"

TRAIN_PATH = '../../data/norris/langtang/observations_with_WRF_norris.pkl'

In [3]:
# st = create_station_dataframe(TRAIN_PATH, start, end, add_yesterday=True, basin_filter = None, filter_incomplete_years = True)

data = DataPreprocessing(train_path=TRAIN_PATH, start=start, end=end, 
                add_yesterday = False, 
                basin_filter = None, 
                split_bias_corrected_only = False, 
                filter_incomplete_years = False, 
                include_non_bc_stations = True, 
                split_by = 'station')

In [4]:
data.split_stations()

In [5]:
data.st.rename(columns={'Longitude':'X','Latitude':'Y','Elevation':'Z'}, inplace=True)

In [12]:
data.st['doy'] = data.st['Date'].dt.dayofyear

In [14]:
data.st['doy_sin'] = data.st['doy'].transform(lambda x: np.sin(x))
data.st['doy_cos'] = data.st['doy'].transform(lambda x: np.cos(x))

### Split data into held out sets for K-fold cross validation

### Prepare data

In [15]:
predictors = [
                'wrf_prcp',
                'wrf_prcp_-2','wrf_prcp_-1','wrf_prcp_1','wrf_prcp_2',
                'Z',
                #'doy',
                'doy_sin',
                'doy_cos',
                'X',
                'Y',
                'aspect',
                'slope',
                'year',
                'era5_u','era5_u_-1','era5_u_-2','era5_u_1','era5_u_2',
                'era5_v','era5_v_-1','era5_v_-2','era5_v_1','era5_v_2',
             ]

predictors = [ 
                'doy_sin',
                'doy_cos',
                'Z',
                'X',
                'Y',
                #'aspect',
                #'slope',
                'year',
                'CWV_norris', 
                'RH2_norris', 'RH500_norris', 
                'T2_norris', 'T2max_norris', 'T2min_norris', 'Td2_norris', 
                'precip_norris', 'rain_norris', 
                'u10_norris', 'u500_norris', 'v10_norris', 'v500_norris',
              ]

# predictors.append('obs_yesterday')

predictand = ['Prec']

data.input_data(predictors, predictand, sort_by_quantile=False)

## Multi-Run: Train model with different hyperparameters

### Model run

In [16]:
params = OrderedDict(
    lr = [0.005]
    ,batch_size = [128]
    ,likelihood_fn = ['bgmm','bernoulli_loggaussian'] #, 'bernoulli_loggaussian', 'b2gmm'] #['bernoulli_loggaussian']
    ,hidden_channels = [[50]] #[[10],[30],[50],[100],[10,10],[30,30],[50,50],[100,100]]
    ,dropout_rate = [0]
    ,linear_model = [False, True] #['True','False']
    #,k = [0]
    ,k = list(range(10))
)

epochs = 5

In [17]:
st_test, predictions = multirun(data, predictors, params, epochs, split_by='station',
                                sequential_samples=False)

Unnamed: 0,run,epoch,loss,valid_loss,epoch duration,run duration,lr,batch_size,likelihood_fn,hidden_channels,dropout_rate,linear_model,k
0,1,1,1.789287,1.648339,0.137134,0.143348,0.005,128,bgmm,[50],0,False,0
1,1,2,1.492428,1.711310,0.437151,0.634552,0.005,128,bgmm,[50],0,False,0
2,1,3,1.452473,1.690260,0.098070,0.776647,0.005,128,bgmm,[50],0,False,0
3,1,4,1.433532,1.682061,0.095922,0.918693,0.005,128,bgmm,[50],0,False,0
4,1,5,1.417292,1.648132,0.096253,1.057243,0.005,128,bgmm,[50],0,False,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
195,40,1,1.463484,1.475279,0.089541,0.107246,0.005,128,bernoulli_loggaussian,[50],0,True,9
196,40,2,1.129613,1.337203,0.076969,0.261373,0.005,128,bernoulli_loggaussian,[50],0,True,9
197,40,3,1.099796,1.295955,0.132821,0.463060,0.005,128,bernoulli_loggaussian,[50],0,True,9
198,40,4,1.086291,1.276650,0.079702,0.615554,0.005,128,bernoulli_loggaussian,[50],0,True,9


In [18]:
for run in predictions.keys():
    for i in range(len(params['k'])):
        predictions[run][f'k{i}']['k_fold'] = i
        if i == 0:
            predictions[run]['k_all'] = predictions[run][f'k{i}']
        else:
            predictions[run]['k_all'] = predictions[run]['k_all'].append(predictions[run][f'k{i}'])

In [19]:
# Table of predictions
seasons = ['Winter (JFM)', 'Premonsoon (AM)', 'Monsoon (JJAS)','Postmonsoon (OND)']
# table_of_predictions(predictions, seasons, sample_cols=['sample'])

In [20]:
# for key in predictions.keys():
#     predictions[key]['k_all'].rename(columns={"sample": "sample_0"}, inplace=True)

In [21]:
# predictions['bgmm_[50]_False']['k_all'].rename(columns={"wrf_prcp": "bannister_wrf_prcp", "precip_norris": "wrf_prcp"}, inplace=True)

In [22]:
n_samples = 10
sample_cols = [f'sample_{i}' for i in range(n_samples)]
add_cols = []

columns = ['Prec','wrf_prcp','wrf_bc_prcp','precip_norris'] #+ sample_cols + add_cols

In [23]:
for p in predictions.keys(): 
    for k,v in predictions[p].items():
        v['wrf_prcp'] = v['precip_norris'] 
        v['wrf_bc_prcp'] = v['precip_norris'] 

In [24]:
table_of_predictions_ks_test(predictions, seasons, columns, sample_cols, add_cols)

bgmm_[50]_NL_B=128_D=0
bgmm_[50]_L_B=128_D=0
bernoulli_loggaussian_[50]_NL_B=128_D=0
bernoulli_loggaussian_[50]_L_B=128_D=0
Model                                    Winter (JFM) mean    Winter (JFM) median    Premonsoon (AM) mean    Premonsoon (AM) median    Monsoon (JJAS) mean    Monsoon (JJAS) median    Postmonsoon (OND) mean    Postmonsoon (OND) median
---------------------------------------  -------------------  ---------------------  ----------------------  ------------------------  ---------------------  -----------------------  ------------------------  --------------------------
Bann                                     0.2200               0.1545                 0.2291                  0.1441                    0.3291                 0.1354                   0.1545                    0.1630
BannCorr                                 0.2200               0.1545                 0.2291                  0.1441                    0.3291                 0.1354                   0.1545 

In [25]:
table_of_predictions_for_metric(predictions, seasons, columns, n_samples, sample_cols, add_cols, metric = 'smape', prefix='smape')

Model                                    Winter (JFM) mean    Winter (JFM) median    Premonsoon (AM) mean    Premonsoon (AM) median    Monsoon (JJAS) mean    Monsoon (JJAS) median    Postmonsoon (OND) mean    Postmonsoon (OND) median
---------------------------------------  -------------------  ---------------------  ----------------------  ------------------------  ---------------------  -----------------------  ------------------------  --------------------------
Bann                                     0.66                 0.70                   0.60                    0.62                      0.64                   0.66                     0.68                      0.77
BannCorr                                 0.66                 0.70                   0.60                    0.62                      0.64                   0.66                     0.68                      0.77
Norr                                     0.66                 0.70                   0.60             

In [26]:
table_of_predictions_for_metric(predictions, seasons, columns, n_samples, sample_cols, add_cols, metric = 'edd', prefix='edd')

Model                                    Winter (JFM) mean    Winter (JFM) median    Premonsoon (AM) mean    Premonsoon (AM) median    Monsoon (JJAS) mean    Monsoon (JJAS) median    Postmonsoon (OND) mean    Postmonsoon (OND) median
---------------------------------------  -------------------  ---------------------  ----------------------  ------------------------  ---------------------  -----------------------  ------------------------  --------------------------
Bann                                     16.00                12.00                  26.81                   21.00                     27.76                  19.00                    16.29                     16.00
BannCorr                                 16.00                12.00                  26.81                   21.00                     27.76                  19.00                    16.29                     16.00
Norr                                     16.00                12.00                  26.81          

In [30]:
table_of_predictions_for_metric(predictions, seasons, columns, n_samples, sample_cols, add_cols, metric = 'ae', prefix='ae')

Model                                    Winter (JFM) mean    Winter (JFM) median    Premonsoon (AM) mean    Premonsoon (AM) median    Monsoon (JJAS) mean    Monsoon (JJAS) median    Postmonsoon (OND) mean    Postmonsoon (OND) median
---------------------------------------  -------------------  ---------------------  ----------------------  ------------------------  ---------------------  -----------------------  ------------------------  --------------------------
Bann                                     187.84               160.49                 351.35                  323.67                    2287.97                2020.47                  318.22                    346.11
BannCorr                                 187.84               160.49                 351.35                  323.67                    2287.97                2020.47                  318.22                    346.11
Norr                                     187.84               160.49                 351.35       

In [31]:
table_of_predictions_for_metric(predictions, seasons, columns, n_samples, sample_cols, add_cols, metric = 'se', prefix='se')

Model                                    Winter (JFM) mean    Winter (JFM) median    Premonsoon (AM) mean    Premonsoon (AM) median    Monsoon (JJAS) mean    Monsoon (JJAS) median    Postmonsoon (OND) mean    Postmonsoon (OND) median
---------------------------------------  -------------------  ---------------------  ----------------------  ------------------------  ---------------------  -----------------------  ------------------------  --------------------------
Bann                                     42690.41             25755.70               139591.84               104761.02                 5995093.26             4082278.88               119194.94                 119789.33
BannCorr                                 42690.41             25755.70               139591.84               104761.02                 5995093.26             4082278.88               119194.94                 119789.33
Norr                                     42690.41             25755.70               139591.

In [27]:
a = pd.read_csv('results.csv')

b = a.groupby(['k','run']).agg({'valid_loss': 'min', 
                                 'hidden_channels': 'first', 
                                 'likelihood_fn': 'first',
                                 'lr':'first',
                                 'batch_size':'first',
                                 'dropout_rate':'first',
                                 'linear_model':'first'})

c = b.groupby(['run']).agg({'valid_loss': 'mean', 
                        'hidden_channels': 'first', 
                        'likelihood_fn': 'first',
                        'lr':'first',
                        'batch_size':'first',
                        'dropout_rate':'first',
                         'linear_model':'first'}
                   ).sort_values('valid_loss').reset_index()

c.groupby(['hidden_channels',
           'likelihood_fn',
           'lr',
           'batch_size',
           'dropout_rate',
           'linear_model']).agg({'valid_loss': 'mean'}
                   ).sort_values('valid_loss').reset_index()

Unnamed: 0,hidden_channels,likelihood_fn,lr,batch_size,dropout_rate,linear_model,valid_loss
0,[50],bernoulli_loggaussian,0.005,128,0,False,1.216784
1,[50],bernoulli_loggaussian,0.005,128,0,True,1.281871
2,[50],bgmm,0.005,128,0,False,1.671055
3,[50],bgmm,0.005,128,0,True,1.888055


In [62]:
a

Unnamed: 0.1,Unnamed: 0,run,epoch,loss,valid_loss,epoch duration,run duration,lr,batch_size,likelihood_fn,hidden_channels,dropout_rate,linear_model,k
0,0,1,1,1.834609,1.693409,0.123530,0.129121,0.005,128,bgmm,[50],0,False,0
1,1,1,2,1.503547,1.688267,0.083380,0.275722,0.005,128,bgmm,[50],0,False,0
2,2,1,3,1.457541,1.656358,0.088912,0.416786,0.005,128,bgmm,[50],0,False,0
3,3,1,4,1.438161,1.665433,0.098454,0.564409,0.005,128,bgmm,[50],0,False,0
4,4,1,5,1.421409,1.660284,0.102358,0.717939,0.005,128,bgmm,[50],0,False,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,295,60,1,1.540624,1.452504,0.116836,0.119695,0.005,128,bernoulli_loggaussian,[50],0,True,9
296,296,60,2,1.188366,1.348620,0.103595,0.261039,0.005,128,bernoulli_loggaussian,[50],0,True,9
297,297,60,3,1.139284,1.309553,0.098859,0.394304,0.005,128,bernoulli_loggaussian,[50],0,True,9
298,298,60,4,1.115526,1.294975,0.098492,0.528290,0.005,128,bernoulli_loggaussian,[50],0,True,9
