In [1]:
import os
os.chdir("../..")

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import pickle
import pandas as pd
import numpy as np
from datetime import datetime

from scripts.data_preparation import merge_station_data
from scripts.models.frank_model import fit_and_predict
from scripts.metrics import report_performance

In [4]:
# Get data
df = merge_station_data()

# Get optimal parameters
opt_params = pickle.load( open( "./cross_validation/frank/opt_par_cv.p", "rb" ) )
opt_params

['Processed data for station: KLO3']
['Processed data for station: ANV3']
['Processed data for station: GAN2']
['Processed data for station: DAV3']


{'classifier': ('xgb', 'logistic', 'tree'),
 'PCA': False,
 'drop_cat': True,
 'params': [{'scale_pos_weight': 1.0,
   'learning_rate': 0.1,
   'max_depth': 9,
   'gamma': 0.2,
   'colsample_bytree': 1.0,
   'verbosity': 0},
  {'max_iter': 1000, 'tol': 0.0001, 'class_weight': None, 'C': 2.0},
  {'class_weight': None,
   'min_samples_split': 8,
   'min_samples_leaf': 1,
   'max_depth': 6}]}

In [5]:
# Suppress XGB warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Gather prediction data
nruns = 100
results = []
for i in range(nruns):
    y_pred, y_true, clf = fit_and_predict(df, verbose=False, **opt_params)
    results.append(report_performance(y_pred, y_true, i).T)
    if ((i % (nruns//100)) == 0):
        print(f'{i+1:3d} of {nruns} ({(i+1)/nruns:3.0%}) runs complete at {datetime.now().strftime("%H:%M:%S")}')
    
data = pd.concat(results)

  1 of 100 ( 1%) runs complete at 09:32:01
  2 of 100 ( 2%) runs complete at 09:32:14
  3 of 100 ( 3%) runs complete at 09:32:24
  4 of 100 ( 4%) runs complete at 09:32:35
  5 of 100 ( 5%) runs complete at 09:32:44
  6 of 100 ( 6%) runs complete at 09:32:52
  7 of 100 ( 7%) runs complete at 09:33:02
  8 of 100 ( 8%) runs complete at 09:33:12
  9 of 100 ( 9%) runs complete at 09:33:21
 10 of 100 (10%) runs complete at 09:33:31
 11 of 100 (11%) runs complete at 09:33:40
 12 of 100 (12%) runs complete at 09:33:54
 13 of 100 (13%) runs complete at 09:34:05
 14 of 100 (14%) runs complete at 09:34:15
 15 of 100 (15%) runs complete at 09:34:24
 16 of 100 (16%) runs complete at 09:34:37
 17 of 100 (17%) runs complete at 09:34:46
 18 of 100 (18%) runs complete at 09:34:56
 19 of 100 (19%) runs complete at 09:35:05
 20 of 100 (20%) runs complete at 09:35:15
 21 of 100 (21%) runs complete at 09:35:25
 22 of 100 (22%) runs complete at 09:35:35
 23 of 100 (23%) runs complete at 09:35:44
 24 of 100 

In [6]:
data.describe()

Unnamed: 0,Prec_1.0,Prec_2.0,Prec_3.0,Prec_4.0,Rec_1.0,Rec_2.0,Rec_3.0,Rec_4.0,MMSE,MMAD
count,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0
mean,0.93707,0.771,0.77907,0.29197,0.883,0.823,0.78387,0.17517,0.40447,0.35737
std,0.000256,6.694897e-16,0.001174,0.0245,4.463264e-16,1.227398e-15,0.002639,0.021208,0.005145,0.005029
min,0.937,0.771,0.776,0.231,0.883,0.823,0.778,0.102,0.397,0.35
25%,0.937,0.771,0.778,0.276,0.883,0.823,0.782,0.163,0.402,0.355
50%,0.937,0.771,0.779,0.29,0.883,0.823,0.783,0.184,0.403,0.356
75%,0.937,0.771,0.78,0.31,0.883,0.823,0.785,0.184,0.408,0.361
max,0.938,0.771,0.781,0.333,0.883,0.823,0.791,0.204,0.421,0.374


In [85]:
# Precision/Recall Summary
import scipy.stats as st
N = 100
alpha = 0.05
c_crit = st.t.ppf(1-alpha/2, N-1)

precisions = data[['Prec_1.0','Prec_2.0','Prec_3.0','Prec_4.0']]
recalls = data[['Rec_1.0','Rec_2.0','Rec_3.0','Rec_4.0']]


results = pd.concat([precisions.mean(), recalls.mean()], axis=1)
results.columns = ["Precision","Recall"]
results["CI_prec"] = c_crit * precisions.std()/np.sqrt(N)
results["CI_rec"] = c_crit * recalls.std()/np.sqrt(N)

avg_mamse = data["MMSE"].mean()
ci_mamse = c_crit * data["MMSE"].std() / np.sqrt(N)

In [90]:
display(results)
print(f'AVG. MAMSE: {avg_mamse}\nCI MAMSE: {ci_mamse}')

Unnamed: 0,Precision,Recall,CI_prec,CI_rec
Prec_1.0,0.93707,,5.088175e-05,
Prec_2.0,0.771,,1.328413e-16,
Prec_3.0,0.77907,,0.0002329989,
Prec_4.0,0.29197,,0.00486142,
Rec_1.0,,0.883,,8.856085000000001e-17
Rec_2.0,,0.823,,2.435423e-16
Rec_3.0,,0.78387,,0.0005235749
Rec_4.0,,0.17517,,0.00420822


AVG. MAMSE: 0.40446999999999994
CI MAMSE: 0.0010209338704879039


In [93]:
# Save results
res = {
    'raw_data': data,
    'results': {
        'precision_recall':results,
        'MMSE': avg_mamse,
        'MMSE_CI': ci_mamse
    }
}

fname = f"./cross_validation/frank/pred_stats.p"
pickle.dump(res, open(fname, "wb"))