In [11]:
# How changing the dataset sample size effects the calibration methods
# Fix training dataset size and change the calib set samples - best method is one that gets max calib with least data

In [12]:
# imports
import sys
import pandas as pd
import numpy as np
sys.path.append('../../') # to access the files in higher directories
sys.path.append('../') # to access the files in higher directories
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

import Data.data_provider as dp
import core as cal
from estimators.IR_RF_estimator import IR_RF

In [13]:
# params
calib_methods = ["RF", "Platt" , "ISO", "Rank", "CRF", "VA", "Beta", "Elkan", "tlr", "Line"]
metrics = ["acc", "auc", "brier", "ece", "logloss"]
data_list = ["spambase", "climate", "QSAR", "bank", "climate", "parkinsons", "vertebral", "ionosphere", "diabetes", "breast", "blod"]
# data_list = ["spambase", "climate"]

params = {
    "runs": 3,
    "n_estimators": 10,
    "oob": False,
    "test_split": 0.3,
    "calib_split": 0.1
}

In [14]:
calib_results_dict = {}

for data_name in data_list:

    # Data
    X, y = dp.load_data(data_name, "../../")
    
    data_dict = {} # results for each data set will be saved in here.
    for seed in range(params["runs"]): # running the same dataset multiple times
        # split the data
        data = cal.split_train_calib_test(data_name, X, y, params["test_split"], params["calib_split"], seed)
        print("train", len(data["x_train"]))
        print("calib", len(data["x_calib"]))
        print("test", len(data["x_test"]))
        print("---------------------------------")

        # train model
        irrf = IR_RF(n_estimators=params["n_estimators"], oob_score=params["oob"], random_state=seed)
        irrf.fit(data["x_train"], data["y_train"])

        # calibration
        res = cal.calibration(irrf, data, calib_methods, metrics) # res is a dict with all the metrics results as well as RF probs and every calibration method decision for every test data point
        data_dict = cal.update_runs(data_dict, res) # calib results for every run for the same dataset is aggregated in data_dict (ex. acc of every run as an array)
    calib_results_dict.update(data_dict) # merge results of all datasets together
    
tables = cal.mean_and_rankinsg_table(calib_results_dict, metrics, calib_methods, data_list, mean_and_rank=False)

train 2898
calib 322
test 1381
---------------------------------
train 2898
calib 322
test 1381
---------------------------------
train 2898
calib 322
test 1381
---------------------------------
train 340
calib 38
test 162
---------------------------------
train 340
calib 38
test 162
---------------------------------
train 340
calib 38
test 162
---------------------------------
train 664
calib 74
test 317
---------------------------------
train 664
calib 74
test 317
---------------------------------
train 664
calib 74
test 317
---------------------------------
train 864
calib 96
test 412
---------------------------------
train 864
calib 96
test 412
---------------------------------
train 864
calib 96
test 412
---------------------------------
train 340
calib 38
test 162
---------------------------------
train 340
calib 38
test 162
---------------------------------
train 340
calib 38
test 162
---------------------------------
train 122
calib 14
test 59
---------------------------------


In [15]:
tables = cal.mean_and_ranking_table(calib_results_dict, metrics, calib_methods, data_list, mean_and_rank=True)

In [16]:
tables["brier"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.049218,0.044494,0.047523,0.05361,0.046533,0.044581,0.045123,0.140233,0.087024,0.045735
climate,0.049935,0.044406,0.048387,0.094029,0.050525,0.054752,0.046835,0.067065,0.074074,0.049749
QSAR,0.105838,0.105216,0.111093,0.116919,0.106476,0.119521,0.108181,0.24243,0.175584,0.105038
bank,0.012571,0.009975,0.009844,0.021876,0.01072,0.024718,0.009551,0.050436,0.024328,0.011841
climate,0.049935,0.044406,0.048387,0.094029,0.050525,0.054752,0.046835,0.067065,0.074074,0.049749
parkinsons,0.094118,0.092168,0.08914,0.116366,0.097574,0.134741,0.086594,0.168702,0.172147,0.086088
vertebral,0.129287,0.141214,0.159471,0.171256,0.131365,0.214941,0.166844,0.272962,0.177563,0.14168
ionosphere,0.069959,0.076616,0.081521,0.093947,0.065042,0.115019,0.07761,0.210474,0.217296,0.067079
diabetes,0.160104,0.164381,0.171203,0.164303,0.160076,0.179009,0.164685,0.318915,0.249481,0.162274
breast,0.041487,0.0414,0.046196,0.053383,0.040908,0.057439,0.041523,0.098416,0.065185,0.04084


In [17]:
tables["logloss"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.049218,0.044494,0.047523,0.05361,0.046533,0.044581,0.045123,0.140233,0.087024,0.045735
climate,0.049935,0.044406,0.048387,0.094029,0.050525,0.054752,0.046835,0.067065,0.074074,0.049749
QSAR,0.105838,0.105216,0.111093,0.116919,0.106476,0.119521,0.108181,0.24243,0.175584,0.105038
bank,0.012571,0.009975,0.009844,0.021876,0.01072,0.024718,0.009551,0.050436,0.024328,0.011841
climate,0.049935,0.044406,0.048387,0.094029,0.050525,0.054752,0.046835,0.067065,0.074074,0.049749
parkinsons,0.094118,0.092168,0.08914,0.116366,0.097574,0.134741,0.086594,0.168702,0.172147,0.086088
vertebral,0.129287,0.141214,0.159471,0.171256,0.131365,0.214941,0.166844,0.272962,0.177563,0.14168
ionosphere,0.069959,0.076616,0.081521,0.093947,0.065042,0.115019,0.07761,0.210474,0.217296,0.067079
diabetes,0.160104,0.164381,0.171203,0.164303,0.160076,0.179009,0.164685,0.318915,0.249481,0.162274
breast,0.041487,0.0414,0.046196,0.053383,0.040908,0.057439,0.041523,0.098416,0.065185,0.04084


In [18]:
tables["acc"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.941588,0.942795,0.942795,0.933623,0.941588,0.401883,0.942071,0.771904,0.869177,0.943036
climate,0.936214,0.95679,0.938272,0.853909,0.936214,0.925926,0.950617,0.925926,0.925926,0.95679
QSAR,0.858044,0.867508,0.858044,0.831756,0.858044,0.321767,0.859096,0.628812,0.761304,0.869611
bank,0.989482,0.989482,0.9911,0.97411,0.989482,0.446602,0.988673,0.932039,0.976537,0.989482
climate,0.936214,0.95679,0.938272,0.853909,0.936214,0.925926,0.950617,0.925926,0.925926,0.95679
parkinsons,0.853107,0.870056,0.858757,0.80226,0.853107,0.751412,0.870056,0.774011,0.79661,0.864407
vertebral,0.817204,0.802867,0.759857,0.795699,0.817204,0.293907,0.777778,0.587814,0.759857,0.827957
ionosphere,0.924528,0.91195,0.91195,0.899371,0.924528,0.603774,0.908805,0.672956,0.68239,0.915094
diabetes,0.761905,0.761905,0.737374,0.759019,0.761905,0.339105,0.751804,0.515152,0.681097,0.767677
breast,0.94347,0.94347,0.951267,0.931774,0.94347,0.376218,0.94347,0.847953,0.896686,0.945419


In [19]:
tables["ece"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.061628,0.017882,0.026423,0.033695,0.032816,0.025974,0.02064,0.065908,0.044412,0.037919
climate,0.049,0.068053,0.067306,0.119869,0.0481,0.040667,0.05718,0.064627,0.074074,0.076485
QSAR,0.054851,0.07413,0.071408,0.07861,0.05487,0.111013,0.060845,0.171501,0.13817,0.065924
bank,0.034851,0.0244,0.009069,0.012879,0.026115,0.061656,0.008778,0.033263,0.038269,0.02227
climate,0.049,0.068053,0.067306,0.119869,0.0481,0.040667,0.05718,0.064627,0.074074,0.076485
parkinsons,0.099972,0.086613,0.054841,0.098614,0.096173,0.112147,0.078134,0.178063,0.138418,0.048677
vertebral,0.095517,0.122281,0.141846,0.144508,0.098174,0.242587,0.134988,0.220414,0.182796,0.117781
ionosphere,0.093003,0.100983,0.081856,0.089041,0.069066,0.164924,0.08593,0.223154,0.212579,0.080951
diabetes,0.068545,0.083685,0.09395,0.086141,0.060209,0.12096,0.076665,0.283654,0.219769,0.078576
breast,0.03302,0.046861,0.040097,0.042351,0.027251,0.088828,0.02737,0.079491,0.039376,0.030487
