In [28]:
# How changing the dataset sample size effects the calibration methods
# Fix training dataset size and change the calib set samples - best method is one that gets max calib with least data

In [29]:
# imports
import sys
import pandas as pd
import numpy as np
sys.path.append('../../') # to access the files in higher directories
sys.path.append('../') # to access the files in higher directories
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

import Data.data_provider as dp
import core as cal
from estimators.IR_RF_estimator import IR_RF

In [30]:
# params
calib_methods = cal.calib_methods.copy() #["RF", "Platt" , "ISO", "Rank", "CRF", "VA", "Beta", "Elkan", "tlr", "Line"]
metrics = cal.metrics.copy() #["acc", "auc", "brier", "logloss", "ece", "tce"]
metrics.remove("tce")

data_list = ["spambase", "climate", "QSAR", "bank", "climate", "parkinsons", "vertebral", "ionosphere", "diabetes", "breast", "blod"]
# data_list = ["spambase", "climate"]

params = {
    "runs": 5,
    "n_estimators": 10,
    "oob": False,
    "test_split": 0.3,
    "calib_split": 0.05
}

In [31]:
calib_results_dict = {}

for data_name in data_list:

    # Data
    X, y = dp.load_data(data_name, "../../")
    
    data_dict = {} # results for each data set will be saved in here.
    for seed in range(params["runs"]): # running the same dataset multiple times
        # split the data
        data = cal.split_train_calib_test(data_name, X, y, params["test_split"], params["calib_split"], seed)
        # print("train", len(data["x_train"]))
        # print("calib", len(data["x_calib"]))
        # print("test", len(data["x_test"]))
        # print("---------------------------------")

        # train model
        irrf = IR_RF(n_estimators=params["n_estimators"], oob_score=params["oob"], random_state=seed)
        irrf.fit(data["x_train"], data["y_train"])

        # calibration
        res = cal.calibration(irrf, data, calib_methods, metrics) # res is a dict with all the metrics results as well as RF probs and every calibration method decision for every test data point
        data_dict = cal.update_runs(data_dict, res) # calib results for every run for the same dataset is aggregated in data_dict (ex. acc of every run as an array)
    calib_results_dict.update(data_dict) # merge results of all datasets together
    
tables = cal.mean_and_ranking_table(calib_results_dict, metrics, calib_methods, data_list, mean_and_rank=False)

In [32]:
tables = cal.mean_and_ranking_table(calib_results_dict, metrics, calib_methods, data_list, mean_and_rank=True)

In [33]:
tables["brier"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.048545,0.042912,0.044717,0.047053,0.043994,0.044588,0.042552,0.214108,0.114739,0.046016
climate,0.07141,0.074259,0.07803,0.10888,0.071294,0.08313,0.081853,0.100543,0.104938,0.078358
QSAR,0.110396,0.115096,0.121554,0.125505,0.111768,0.116409,0.115323,0.335076,0.182442,0.112101
bank,0.013206,0.012564,0.014225,0.027895,0.012127,0.018927,0.015307,0.071282,0.030922,0.01238
climate,0.07141,0.074259,0.07803,0.10888,0.071294,0.08313,0.081853,0.100543,0.104938,0.078358
parkinsons,0.106156,0.121408,0.156588,0.16775,0.105923,0.131932,0.155862,0.206936,0.243695,0.126014
vertebral,0.110981,0.122243,0.139937,0.15563,0.116574,0.1237,0.151547,0.321941,0.255398,0.115298
ionosphere,0.066684,0.077507,0.074181,0.153039,0.06115,0.088641,0.075473,0.232749,0.187566,0.064708
diabetes,0.169517,0.179382,0.193537,0.188348,0.174371,0.186487,0.183654,0.409337,0.302216,0.178283
breast,0.049211,0.059382,0.055721,0.075415,0.049251,0.059595,0.054331,0.12238,0.077018,0.051075


In [34]:
tables["logloss"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.048545,0.042912,0.044717,0.047053,0.043994,0.044588,0.042552,0.214108,0.114739,0.046016
climate,0.07141,0.074259,0.07803,0.10888,0.071294,0.08313,0.081853,0.100543,0.104938,0.078358
QSAR,0.110396,0.115096,0.121554,0.125505,0.111768,0.116409,0.115323,0.335076,0.182442,0.112101
bank,0.013206,0.012564,0.014225,0.027895,0.012127,0.018927,0.015307,0.071282,0.030922,0.01238
climate,0.07141,0.074259,0.07803,0.10888,0.071294,0.08313,0.081853,0.100543,0.104938,0.078358
parkinsons,0.106156,0.121408,0.156588,0.16775,0.105923,0.131932,0.155862,0.206936,0.243695,0.126014
vertebral,0.110981,0.122243,0.139937,0.15563,0.116574,0.1237,0.151547,0.321941,0.255398,0.115298
ionosphere,0.066684,0.077507,0.074181,0.153039,0.06115,0.088641,0.075473,0.232749,0.187566,0.064708
diabetes,0.169517,0.179382,0.193537,0.188348,0.174371,0.186487,0.183654,0.409337,0.302216,0.178283
breast,0.049211,0.059382,0.055721,0.075415,0.049251,0.059595,0.054331,0.12238,0.077018,0.051075


In [35]:
tables["acc"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.946271,0.945547,0.944243,0.942505,0.946271,0.944967,0.94685,0.690369,0.83244,0.945836
climate,0.902469,0.907407,0.896296,0.793827,0.902469,0.865432,0.9,0.895062,0.895062,0.907407
QSAR,0.846057,0.847319,0.842902,0.829653,0.846057,0.839117,0.849211,0.53123,0.752681,0.849211
bank,0.984466,0.983981,0.981553,0.966019,0.984466,0.974757,0.98301,0.906796,0.958738,0.984466
climate,0.902469,0.907407,0.896296,0.793827,0.902469,0.865432,0.9,0.895062,0.895062,0.907407
parkinsons,0.861017,0.847458,0.820339,0.732203,0.861017,0.79661,0.827119,0.752542,0.752542,0.833898
vertebral,0.836559,0.843011,0.823656,0.821505,0.836559,0.815054,0.821505,0.569892,0.686022,0.847312
ionosphere,0.937736,0.920755,0.90566,0.839623,0.937736,0.888679,0.916981,0.681132,0.732075,0.928302
diabetes,0.750649,0.725541,0.719481,0.727273,0.750649,0.724675,0.732468,0.447619,0.625108,0.729004
breast,0.938012,0.936842,0.932164,0.905263,0.938012,0.921637,0.933333,0.842105,0.891228,0.936842


In [36]:
tables["ece"]

Unnamed: 0_level_0,RF,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
spambase,0.060888,0.019655,0.023056,0.024379,0.028778,0.025522,0.018965,0.15504,0.07286,0.042575
climate,0.054069,0.074308,0.041095,0.114018,0.051985,0.074763,0.065632,0.100578,0.103704,0.067477
QSAR,0.043439,0.053102,0.065097,0.072873,0.044645,0.057444,0.064482,0.298593,0.143533,0.0512
bank,0.02857,0.034585,0.01026,0.022826,0.014685,0.054382,0.013892,0.040195,0.033107,0.015579
climate,0.054069,0.074308,0.041095,0.114018,0.051985,0.074763,0.065632,0.100578,0.103704,0.067477
parkinsons,0.088648,0.128171,0.134459,0.117031,0.088057,0.134705,0.133608,0.221703,0.197627,0.104939
vertebral,0.068456,0.107357,0.121433,0.139976,0.085098,0.12271,0.139327,0.317888,0.248602,0.089753
ionosphere,0.090423,0.122954,0.07233,0.147187,0.078431,0.134661,0.070806,0.246508,0.179057,0.077066
diabetes,0.06978,0.077833,0.118874,0.119687,0.086071,0.098571,0.095632,0.405092,0.260779,0.084711
breast,0.03923,0.098656,0.051836,0.0585,0.040748,0.09123,0.053228,0.093767,0.056257,0.048588
