In [1]:
# How changing the dataset sample size effects the calibration methods
# Fix training dataset size and change the calib set samples - best method is one that gets max calib with least data

In [2]:
# imports
import sys
import pandas as pd
import numpy as np
sys.path.append('../../') # to access the files in higher directories
sys.path.append('../') # to access the files in higher directories
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

import Data.data_provider as dp
import core as cal
from estimators.IR_RF_estimator import IR_RF
from sklearn.model_selection import RandomizedSearchCV


In [3]:
# params
calib_methods = cal.calib_methods.copy() 
metrics = cal.metrics.copy()
metrics.remove("tce")

data_list = ["spambase", "climate", "QSAR", "bank", "climate", "parkinsons", "vertebral", "ionosphere", "diabetes", "breast", "blod"]
# data_list = ["spambase", "climate"]

params = {
    "runs": 5,
    "n_estimators": 10,
    "oob": False,
    "test_split": 0.3,
    "calib_split": 0.05
}

In [4]:
calib_results_dict = {}

for data_name in data_list:

    # Data
    X, y = dp.load_data(data_name, "../../")
    
    data_dict = {} # results for each data set will be saved in here.
    for seed in range(params["runs"]): # running the same dataset multiple times
        # split the data
        data = cal.split_train_calib_test(data_name, X, y, params["test_split"], params["calib_split"], seed)
        # print("train", len(data["x_train"]))
        # print("calib", len(data["x_calib"]))
        # print("test", len(data["x_test"]))
        # print("---------------------------------")

        # train model
        search_space = {
            "n_estimators": [20],
            "max_depth": [5, 10, 15, 20, 25],
            "criterion": ["gini", "entropy"],
            "min_samples_split": [2,3,4,5],
            "min_samples_leaf": [1,2,3],
        }
        rf = IR_RF(random_state=seed)

        RS = RandomizedSearchCV(rf, search_space, scoring=["accuracy"], refit="accuracy", cv=5, n_iter=10, random_state=0)
        RS.fit(data["x_train"], data["y_train"])
        rf_best = RS.best_estimator_
        # irrf.fit(data["x_train"], data["y_train"])

        # calibration
        res = cal.calibration(rf_best, data, calib_methods, metrics) # res is a dict with all the metrics results as well as RF probs and every calibration method decision for every test data point
        data_dict = cal.update_runs(data_dict, res) # calib results for every run for the same dataset is aggregated in data_dict (ex. acc of every run as an array)
    calib_results_dict.update(data_dict) # merge results of all datasets together

In [5]:
tables = cal.mean_and_ranking_table(calib_results_dict, metrics, calib_methods, data_list, mean_and_rank=True)

In [6]:
tables["brier"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.043107,0.04056,0.042899,0.045224,0.041825,0.042273,0.040901,0.163989,0.044171,0.042207
climate,0.066759,0.071511,0.083844,0.122188,0.067934,0.068474,0.069082,0.099914,0.080509,0.078701
QSAR,0.1046,0.108391,0.113164,0.123997,0.1074,0.109798,0.108911,0.296115,0.105942,0.106031
bank,0.009152,0.008493,0.007504,0.022992,0.007121,0.012405,0.00909,0.053668,0.009,0.008694
climate,0.066759,0.071511,0.083844,0.122188,0.067934,0.068474,0.069082,0.099914,0.080509,0.078701
parkinsons,0.089967,0.114326,0.130314,0.168517,0.089747,0.118272,0.129675,0.19139,0.099949,0.110745
vertebral,0.106286,0.121233,0.128788,0.144767,0.104979,0.125246,0.142987,0.274742,0.114306,0.113651
ionosphere,0.062172,0.076369,0.079017,0.107454,0.057677,0.08309,0.083587,0.197664,0.061906,0.064561
diabetes,0.166498,0.177363,0.189928,0.193911,0.172206,0.180971,0.179136,0.404045,0.173418,0.175166
breast,0.044349,0.052751,0.046595,0.063973,0.04603,0.05574,0.04611,0.118272,0.043746,0.044971


In [7]:
tables["logloss"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.043107,0.04056,0.042899,0.045224,0.041825,0.042273,0.040901,0.163989,0.044171,0.042207
climate,0.066759,0.071511,0.083844,0.122188,0.067934,0.068474,0.069082,0.099914,0.080509,0.078701
QSAR,0.1046,0.108391,0.113164,0.123997,0.1074,0.109798,0.108911,0.296115,0.105942,0.106031
bank,0.009152,0.008493,0.007504,0.022992,0.007121,0.012405,0.00909,0.053668,0.009,0.008694
climate,0.066759,0.071511,0.083844,0.122188,0.067934,0.068474,0.069082,0.099914,0.080509,0.078701
parkinsons,0.089967,0.114326,0.130314,0.168517,0.089747,0.118272,0.129675,0.19139,0.099949,0.110745
vertebral,0.106286,0.121233,0.128788,0.144767,0.104979,0.125246,0.142987,0.274742,0.114306,0.113651
ionosphere,0.062172,0.076369,0.079017,0.107454,0.057677,0.08309,0.083587,0.197664,0.061906,0.064561
diabetes,0.166498,0.177363,0.189928,0.193911,0.172206,0.180971,0.179136,0.404045,0.173418,0.175166
breast,0.044349,0.052751,0.046595,0.063973,0.04603,0.05574,0.04611,0.118272,0.043746,0.044971


In [8]:
tables["acc"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.947864,0.947574,0.945112,0.939464,0.947574,0.945257,0.947429,0.712238,0.941781,0.947429
climate,0.908642,0.907407,0.897531,0.798765,0.906173,0.912346,0.907407,0.895062,0.893827,0.904938
QSAR,0.849842,0.849842,0.844164,0.821451,0.852366,0.842902,0.849211,0.540694,0.845426,0.857413
bank,0.991748,0.991748,0.990291,0.973301,0.991262,0.984951,0.98932,0.893689,0.991748,0.991262
climate,0.908642,0.907407,0.897531,0.798765,0.906173,0.912346,0.907407,0.895062,0.893827,0.904938
parkinsons,0.871186,0.840678,0.837288,0.759322,0.864407,0.833898,0.833898,0.759322,0.861017,0.833898
vertebral,0.843011,0.84086,0.834409,0.819355,0.843011,0.802151,0.832258,0.597849,0.827957,0.858065
ionosphere,0.932075,0.916981,0.90566,0.873585,0.933962,0.881132,0.909434,0.701887,0.926415,0.920755
diabetes,0.750649,0.736797,0.716883,0.729004,0.753247,0.705628,0.733333,0.432035,0.742857,0.745455
breast,0.938012,0.94152,0.946199,0.918129,0.936842,0.921637,0.94386,0.819883,0.938012,0.936842


In [9]:
tables["ece"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.039757,0.019792,0.02636,0.021221,0.029028,0.02396,0.01849,0.173822,0.029044,0.033143
climate,0.047044,0.068441,0.067066,0.120673,0.051355,0.07082,0.050597,0.10021,0.073519,0.074105
QSAR,0.040074,0.049982,0.060158,0.085824,0.053586,0.062649,0.048629,0.323771,0.042429,0.050158
bank,0.023288,0.039263,0.005913,0.016767,0.008654,0.041151,0.008596,0.060077,0.020437,0.01894
climate,0.047044,0.068441,0.067066,0.120673,0.051355,0.07082,0.050597,0.10021,0.073519,0.074105
parkinsons,0.103522,0.130546,0.112748,0.116476,0.099935,0.132252,0.126211,0.20926,0.074237,0.10723
vertebral,0.093538,0.116436,0.103373,0.114432,0.084787,0.128397,0.121744,0.270962,0.077312,0.09228
ionosphere,0.072125,0.104701,0.061736,0.112104,0.050507,0.118945,0.074625,0.219011,0.064906,0.054886
diabetes,0.067655,0.091573,0.109791,0.130103,0.079365,0.086101,0.090855,0.425752,0.092338,0.09178
breast,0.031304,0.096776,0.039347,0.053799,0.036633,0.079282,0.043838,0.114603,0.023918,0.04924
