In [1]:
# How changing the dataset sample size effects the calibration methods
# Fix training dataset size and change the calib set samples - best method is one that gets max calib with least data

# imports
import sys
import pandas as pd
sys.path.append('../../') # to access the files in higher directories
sys.path.append('../') # to access the files in higher directories
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

import core_exp as cx
import core_calib as cal

In [2]:
params = {
    # exp
    "seed": 0,
    "exp_name": "real",
    "cv_folds": 10,
    "plot": True,
    "calib_methods": ["RF", "RF_CT", "RF_fulldata", 
                      "Platt", "ISO", "Rank", "CRF", "VA", "Beta", "Elkan", "tlr", "Line", 
                      "RF_boot", 
                      "RF_ens_r", "RF_ens_line", "RF_ens_CRF", "RF_ens_Platt", "RF_ens_ISO",  "RF_ens_Beta", 
                      "RF_large", "RF_large_line", "RF_large_Platt", "RF_large_ISO",  "RF_large_Beta"],

    "metrics": ["acc", "logloss", "brier", "ece", "auc"],

    # calib param
    "ece_bins": 20,
    "boot_size": 1000,
    "boot_count": 10,

    # RF hyper opt
    "hyper_opt": True,
    "opt_cv":5, 
    "opt_n_iter":10,
    "search_space": {
                    "n_estimators": [10],
                    "max_depth": [2,3,4,5,6,7,8,10,20,50,100],
                    "criterion": ["gini", "entropy"],
                    # "min_samples_split": [2,3,4,5],
                    # "min_samples_leaf": [1,2,3],
                    },
    
    "depth": 4,
    "n_estimators": 10,
    "oob": False,

}

exp_key = "data_name"
exp_values = [
              "spambase", 
              "QSAR", 
              "bank", 
              "parkinsons", 
              "vertebral", 
              "ionosphere", 
              "diabetes", 
              "breast", 
            #   "hillvalley",
              "madelon",
              "scene",
              "Sonar_Mine_Rock_Data",
              "Customer_Churn",
            #   "jm1",
              "pc4",
              "eeg",
              "heart",
              "HRCompetencyScores",
              "phoneme",
              "SPF",
              "wdbc",
              "nomao",
              "wilt"
              ]
# exp_values = ["wilt"]

In [3]:
calib_results_dict, data_list = cx.run_exp(exp_key, exp_values, params)
tables = cal.mean_and_ranking_table(calib_results_dict, 
                                    params["metrics"], 
                                    params["calib_methods"], 
                                    data_list, 
                                    mean_and_rank=True, 
                                    std=True)

>>>>>>> data parkinsons NOT LEARNING - learnign diff is 0.0
>>>>>>> data diabetes NOT LEARNING - learnign diff is -2.6315789473684292
>>>>>>> data pc4 NOT LEARNING - learnign diff is 0.6849315068493178
>>>>>>> data pc4 NOT LEARNING - learnign diff is 0.6849315068493178
>>>>>>> data pc4 NOT LEARNING - learnign diff is 0.0


In [4]:
print(params["calib_methods"][tables["brier"].loc["Rank"].argmin()])
tables["brier"]

RF_large_Platt


Unnamed: 0_level_0,RF,RF_CT,RF_fulldata,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line,RF_boot,RF_ens_r,RF_ens_line,RF_ens_CRF,RF_ens_Platt,RF_ens_ISO,RF_ens_Beta,RF_large,RF_large_line,RF_large_Platt,RF_large_ISO,RF_large_Beta
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
spambase,0.045081,0.045673,0.042847,0.042688,0.043594,0.044806,0.043392,0.043516,0.04917,0.102259,0.046084,0.044414,0.045111,0.038754,0.037438,0.03655,0.035282,0.037106,0.03525149,0.039044,0.037869,0.035854,0.036903,0.03617908
QSAR,0.106007,0.106709,0.108994,0.107005,0.111741,0.114924,0.106617,0.109932,0.117681,0.209683,0.108129,0.109203,0.106053,0.099836,0.10011,0.101776,0.101336,0.103979,0.1003616,0.098303,0.097513,0.099725,0.105534,0.09840387
bank,0.00804,0.00804,0.007687,0.006304,0.010201,0.021932,0.006729,0.009091,0.010836,0.023803,0.007843,0.007852,0.008059,0.006097,0.005879,0.005524,0.005401,0.005938,0.006866009,0.005688,0.005465,0.004851,0.005594,0.0053659
parkinsons,0.107609,0.112524,0.104109,0.113749,0.115681,0.140431,0.104585,0.118848,0.114049,0.162417,0.112818,0.11222,0.107749,0.097105,0.096818,0.097191,0.101575,0.102988,0.1018558,0.099044,0.10043,0.103504,0.09997,0.1031431
vertebral,0.115088,0.116581,0.114767,0.119546,0.12439,0.163383,0.120586,0.123621,0.123189,0.215061,0.110323,0.116736,0.11519,0.10908,0.112171,0.111443,0.113403,0.119526,0.116834,0.106384,0.109531,0.1102,0.125619,0.1157487
ionosphere,0.064019,0.06387,0.062722,0.066919,0.072815,0.082691,0.063245,0.069332,0.076193,0.146412,0.063802,0.065167,0.06397,0.055316,0.053988,0.05722,0.057614,0.073464,0.06956616,0.056759,0.055626,0.059175,0.07092,0.07016615
diabetes,0.171603,0.17723,0.163634,0.173117,0.180388,0.18138,0.171773,0.176502,0.175997,0.306485,0.182295,0.173311,0.171689,0.160736,0.161565,0.158813,0.161795,0.173305,0.1625066,0.158896,0.157807,0.158998,0.16692,0.1592913
breast,0.038454,0.038917,0.035933,0.039402,0.042216,0.054907,0.038309,0.039976,0.045171,0.078688,0.039727,0.038282,0.038577,0.033027,0.032294,0.032023,0.032106,0.034322,0.03617752,0.032467,0.031721,0.031708,0.035124,0.03400987
madelon,0.219836,0.22495,0.218746,0.216808,0.219617,0.232258,0.219263,0.218368,0.216864,0.37445,0.223996,0.217944,0.219745,0.216269,0.198534,0.211377,0.199278,0.205475,0.199452,0.218083,0.201004,0.202138,0.206919,0.2022897
scene,0.073341,0.073579,0.070752,0.065863,0.067791,0.066367,0.06875,0.067714,0.065925,0.186055,0.07414,0.070647,0.073291,0.063428,0.052787,0.058221,0.046362,0.048671,0.04619939,0.064262,0.053041,0.046227,0.049551,0.04602941


In [5]:
print(params["calib_methods"][tables["logloss"].loc["Rank"].argmin()])
tables["logloss"]

RF_large_Platt


Unnamed: 0_level_0,RF,RF_CT,RF_fulldata,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line,RF_boot,RF_ens_r,RF_ens_line,RF_ens_CRF,RF_ens_Platt,RF_ens_ISO,RF_ens_Beta,RF_large,RF_large_line,RF_large_Platt,RF_large_ISO,RF_large_Beta
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
spambase,0.298374,0.3685,0.258215,0.157037,0.244292,0.269433,0.290702,0.158778,0.244794,0.437894,0.409226,0.281893,0.298445,0.176648,0.226653,0.16502,0.134179,0.285817,0.165248,0.169602,0.214509,0.134412,0.273209,0.159012
QSAR,0.407648,0.739129,0.836931,0.354829,0.66501,0.731248,0.411432,0.35752,0.384231,0.668542,0.861074,0.453609,0.407955,0.329558,0.477904,0.333468,0.334669,0.794507,0.335659,0.324984,0.496144,0.329092,0.737824,0.327419
bank,0.056843,0.056843,0.02979,0.036152,0.271007,0.355433,0.051039,0.039849,0.18916,0.093167,0.055634,0.034323,0.056882,0.028442,0.025165,0.024153,0.030051,0.066275,0.043237,0.027497,0.024122,0.029003,0.089584,0.053393
parkinsons,0.329222,0.85042,0.654397,0.375511,1.162025,1.130674,0.319711,0.374068,0.373159,0.539782,0.848583,0.354671,0.329425,0.311183,0.30691,0.306325,0.342205,0.973407,0.332821,0.324244,0.318115,0.348352,0.965309,0.347044
vertebral,0.455339,0.452384,0.667508,0.373832,1.638546,3.781747,0.469411,0.381422,0.506751,0.745458,0.537389,0.454807,0.455672,0.338781,0.444761,0.341885,0.358361,1.622948,0.493531,0.331777,0.54248,0.348148,1.961716,0.41742
ionosphere,0.314424,0.403138,0.236398,0.2452,1.074137,1.493439,0.307257,0.245792,0.723309,0.52699,0.400578,0.4149,0.314375,0.198809,0.186589,0.194505,0.214706,1.928719,1.148621,0.202778,0.286064,0.219859,1.824586,1.236953
diabetes,0.510406,1.078678,0.53271,0.520132,0.923307,0.873697,0.511318,0.525979,0.528203,0.881525,1.564805,0.727968,0.510475,0.485627,0.52971,0.481016,0.491702,0.820446,0.490955,0.482278,0.479991,0.485256,0.797961,0.484589
breast,0.241833,0.240149,0.244749,0.149481,0.651537,0.990338,0.238796,0.148361,0.490569,0.351894,0.240228,0.245764,0.242174,0.122066,0.175085,0.117991,0.129211,0.568079,0.406586,0.123262,0.172648,0.127794,0.512928,0.389572
madelon,0.630615,0.784251,0.628043,0.624067,0.793509,1.036034,0.628652,0.627331,0.625938,1.075661,0.831706,0.676589,0.63042,0.623736,0.575441,0.610719,0.578218,0.816679,0.578388,0.627563,0.657549,0.58704,0.809881,0.587253
scene,0.381304,0.422418,0.386452,0.228989,0.284591,0.439117,0.364558,0.234823,0.228819,0.64288,0.46971,0.384988,0.381161,0.227635,0.227705,0.203846,0.169016,0.384798,0.166374,0.230017,0.243458,0.168649,0.334815,0.167175


In [6]:
print(params["calib_methods"][tables["acc"].loc["Rank"].argmin()])
tables["acc"]

RF_ens_line


Unnamed: 0_level_0,RF,RF_CT,RF_fulldata,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line,RF_boot,RF_ens_r,RF_ens_line,RF_ens_CRF,RF_ens_Platt,RF_ens_ISO,RF_ens_Beta,RF_large,RF_large_line,RF_large_Platt,RF_large_ISO,RF_large_Beta
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
spambase,0.945664,0.942838,0.949796,0.944577,0.94284,0.940448,0.945664,0.942623,0.941318,0.865465,0.939579,0.944795,0.945014,0.954792,0.954792,0.954792,0.954792,0.951968,0.955011,0.954358,0.954575,0.954575,0.954576,0.952839
QSAR,0.859704,0.861599,0.852183,0.860656,0.847421,0.83894,0.859704,0.847421,0.835229,0.693037,0.857817,0.85973,0.860656,0.862516,0.865373,0.862516,0.861599,0.861671,0.861608,0.862552,0.860665,0.856873,0.860647,0.870135
bank,0.992711,0.992711,0.989067,0.993441,0.986163,0.973776,0.992711,0.985438,0.989083,0.975209,0.992711,0.993441,0.991987,0.993441,0.993441,0.993441,0.992711,0.991257,0.991262,0.994896,0.994896,0.994171,0.992717,0.992717
parkinsons,0.855789,0.855789,0.856316,0.865789,0.835789,0.805526,0.855789,0.841053,0.845526,0.799211,0.866053,0.846053,0.855789,0.851316,0.871579,0.851316,0.871316,0.861053,0.866316,0.830526,0.851053,0.860789,0.866579,0.866316
vertebral,0.819355,0.825806,0.835484,0.825806,0.825806,0.793548,0.819355,0.822581,0.822581,0.712903,0.829032,0.829032,0.816129,0.841935,0.835484,0.841935,0.829032,0.832258,0.835484,0.848387,0.841935,0.841935,0.819355,0.845161
ionosphere,0.905873,0.90873,0.92873,0.920079,0.905873,0.908889,0.905873,0.917222,0.903016,0.806111,0.920159,0.905873,0.911587,0.92873,0.931587,0.92873,0.925873,0.925873,0.923095,0.925873,0.923016,0.920159,0.923016,0.917302
diabetes,0.734296,0.729101,0.756511,0.736945,0.733083,0.734381,0.734296,0.735697,0.730468,0.54284,0.727837,0.734381,0.734296,0.770848,0.776042,0.770848,0.772146,0.75393,0.760424,0.773411,0.768199,0.769498,0.746104,0.762987
breast,0.945551,0.943797,0.954261,0.943797,0.942043,0.938471,0.945551,0.945551,0.940257,0.898058,0.940288,0.945551,0.945551,0.961341,0.959618,0.961341,0.959618,0.94906,0.952569,0.961341,0.959586,0.959586,0.945551,0.954292
madelon,0.647308,0.645385,0.660385,0.653846,0.652308,0.609231,0.647308,0.653077,0.652692,0.500385,0.634615,0.649615,0.651538,0.664615,0.668077,0.665,0.668077,0.671538,0.666923,0.666154,0.669231,0.671538,0.665769,0.668462
scene,0.907355,0.906938,0.910256,0.911081,0.911497,0.915247,0.907355,0.912329,0.911501,0.747372,0.896141,0.911081,0.912332,0.927701,0.943069,0.927286,0.943895,0.939739,0.943895,0.923964,0.943489,0.943071,0.935593,0.942657


In [7]:
print(params["calib_methods"][tables["ece"].loc["Rank"].argmin()])
tables["ece"]

RF_large_Beta


Unnamed: 0_level_0,RF,RF_CT,RF_fulldata,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line,RF_boot,RF_ens_r,RF_ens_line,RF_ens_CRF,RF_ens_Platt,RF_ens_ISO,RF_ens_Beta,RF_large,RF_large_line,RF_large_Platt,RF_large_ISO,RF_large_Beta
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
spambase,0.038261,0.036601,0.038325,0.011798,0.013181,0.016738,0.017851,0.01123,0.026911,0.106575,0.031276,0.03281,0.037578,0.043949,0.032864,0.019346,0.01093,0.012454,0.010054,0.043043,0.033388,0.010667,0.012017,0.009644
QSAR,0.042965,0.041611,0.037079,0.031108,0.040554,0.056611,0.032999,0.038102,0.050947,0.198192,0.036209,0.051773,0.046178,0.036943,0.034853,0.032116,0.028039,0.044548,0.033023,0.039235,0.039402,0.036995,0.049458,0.034091
bank,0.016545,0.016545,0.00963,0.017206,0.008973,0.013973,0.009229,0.011451,0.006955,0.016542,0.015452,0.016781,0.015812,0.014869,0.01165,0.009336,0.014067,0.001928,0.004517,0.01578,0.012497,0.014077,0.002667,0.004127
parkinsons,0.066205,0.068205,0.080295,0.113724,0.078276,0.136006,0.052014,0.113817,0.089499,0.150504,0.087692,0.074708,0.076474,0.086025,0.079291,0.059778,0.083664,0.082358,0.051978,0.096697,0.075581,0.083742,0.066234,0.075223
vertebral,0.065114,0.045806,0.060099,0.050747,0.071996,0.138308,0.067553,0.059085,0.0533,0.209866,0.052903,0.040351,0.052075,0.057292,0.052515,0.042494,0.071511,0.083095,0.067292,0.060277,0.047621,0.063032,0.083681,0.060203
ionosphere,0.077807,0.05812,0.057407,0.071691,0.04936,0.091196,0.053833,0.064099,0.057344,0.18138,0.049003,0.071674,0.060413,0.06691,0.044346,0.045519,0.043623,0.077081,0.052606,0.044622,0.047455,0.056854,0.065189,0.057017
diabetes,0.057258,0.072396,0.038836,0.058196,0.057852,0.07693,0.059895,0.06024,0.054499,0.28825,0.099089,0.067421,0.058842,0.040175,0.054848,0.056846,0.044286,0.065723,0.048233,0.054318,0.04735,0.05147,0.067367,0.041682
breast,0.029553,0.023374,0.032545,0.041055,0.02597,0.045867,0.015067,0.031483,0.02647,0.064821,0.021617,0.030684,0.025948,0.033653,0.029718,0.020189,0.032878,0.018967,0.021463,0.032717,0.026359,0.041799,0.022324,0.021036
madelon,0.053232,0.050038,0.072929,0.034923,0.062976,0.086911,0.041172,0.051948,0.036301,0.376895,0.065346,0.048434,0.055332,0.116125,0.030051,0.062422,0.051815,0.064073,0.053792,0.118485,0.029561,0.032221,0.063875,0.042632
scene,0.056545,0.05941,0.055713,0.017494,0.017272,0.02266,0.026318,0.015238,0.014869,0.242643,0.036809,0.051769,0.06087,0.069822,0.054892,0.026992,0.016374,0.014582,0.012574,0.068199,0.058836,0.016236,0.020399,0.01028
