In [1]:
# How changing the dataset sample size effects the calibration methods
# Fix training dataset size and change the calib set samples - best method is one that gets max calib with least data

# imports
import sys
import pandas as pd
sys.path.append('../../') # to access the files in higher directories
sys.path.append('../') # to access the files in higher directories
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

import core_exp as cx
import core_calib as cal

In [2]:
params = {
    # exp
    "seed": 0,
    "exp_name": "real",
    "cv_folds": 10,
    "plot": True,
    "calib_methods": ["RF", "RF_CT", "RF_fulldata", 
                      "Platt", "ISO", "Rank", "CRF", "VA", "Beta", "Elkan", "tlr", "Line", 
                      "RF_boot", 
                      "RF_ens_r", "RF_ens_line", "RF_ens_CRF", "RF_ens_Platt", "RF_ens_ISO",  "RF_ens_Beta", 
                      "RF_large", "RF_large_line", "RF_large_Platt", "RF_large_ISO",  "RF_large_Beta"],

    "metrics": ["acc", "logloss", "brier", "ece", "auc"],

    # calib param
    "bin_strategy": "quantile",
    "ece_bins": 20,
    "boot_size": 1000,
    "boot_count": 10,

    # RF hyper opt
    "hyper_opt": True,
    "opt_cv":5, 
    "opt_n_iter":10,
    "search_space": {
                    "n_estimators": [100],
                    "max_depth": [2,3,4,5,6,7,8,10,20,50,100],
                    "criterion": ["gini", "entropy"],
                    # "min_samples_split": [2,3,4,5],
                    # "min_samples_leaf": [1,2,3],
                    },
}

exp_key = "data_name"
exp_values = [
              "spambase", 
              "QSAR", 
              "bank", 
              "parkinsons", 
              "vertebral", 
              "ionosphere", 
              "diabetes", 
              "breast", 
              "hillvalley",
              "madelon",
              "scene",
              "Sonar_Mine_Rock_Data",
              "Customer_Churn",
              "jm1",
              "pc4",
              "eeg",
              "heart",
              "HRCompetencyScores",
              "phoneme",
              "SPF",
              "wdbc",
              "nomao",
              "wilt"
              ]
# exp_values = ["wilt"]

In [3]:
calib_results_dict, data_list = cx.run_exp(exp_key, exp_values, params)
tables = cal.mean_and_ranking_table(calib_results_dict, 
                                    params["metrics"], 
                                    params["calib_methods"], 
                                    data_list, 
                                    mean_and_rank=True, 
                                    std=True)



In [4]:
print(params["calib_methods"][tables["brier"].loc["Rank"].argmin()])
tables["brier"]

RF_ens_Platt


Unnamed: 0_level_0,RF,RF_CT,RF_fulldata,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line,RF_boot,RF_ens_r,RF_ens_line,RF_ens_CRF,RF_ens_Platt,RF_ens_ISO,RF_ens_Beta,RF_large,RF_large_line,RF_large_Platt,RF_large_ISO,RF_large_Beta
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
spambase,0.039139,0.039248,0.038079,0.035682,0.036817,0.040588,0.036347,0.036367,0.03587143,0.10274,0.039572,0.037894,0.039102,0.038744,0.037253,0.036839,0.035391,0.037108,0.03525854,0.038431,0.036875,0.034982,0.036684,0.03489114
QSAR,0.09768,0.097809,0.096792,0.099424,0.103171,0.104232,0.10171,0.100339,0.09850557,0.200983,0.098552,0.097591,0.097658,0.096323,0.09607,0.098572,0.097997,0.103528,0.09670026,0.096493,0.097107,0.098146,0.102502,0.09709563
bank,0.005378,0.005379,0.005438,0.004679,0.005874,0.020482,0.004874,0.00804,0.005141954,0.019699,0.005416,0.005153,0.005404,0.005843,0.005641,0.005283,0.005432,0.006781,0.006473707,0.00574,0.005535,0.005248,0.006443,0.006289009
parkinsons,0.102644,0.105929,0.094346,0.109211,0.109446,0.142878,0.100989,0.115277,0.1069685,0.168713,0.10921,0.109763,0.102593,0.102937,0.107991,0.102213,0.109107,0.111394,0.1068284,0.103045,0.106077,0.109002,0.109827,0.1065579
vertebral,0.105931,0.105282,0.104856,0.109878,0.119774,0.152756,0.104921,0.113498,0.1144655,0.205266,0.104169,0.106897,0.105857,0.106838,0.110265,0.107419,0.111252,0.121197,0.114793,0.107321,0.109595,0.111905,0.122451,0.116009
ionosphere,0.055743,0.055927,0.056198,0.0581,0.06964,0.086679,0.055363,0.065282,0.07085496,0.138243,0.055407,0.055317,0.055873,0.054065,0.052611,0.052016,0.056324,0.071682,0.06655596,0.053641,0.051546,0.056379,0.067812,0.06595518
diabetes,0.160145,0.162018,0.161495,0.161137,0.168071,0.173805,0.158798,0.16564,0.1619002,0.30023,0.16537,0.160743,0.160184,0.15831,0.159972,0.157124,0.158974,0.16817,0.1589449,0.158638,0.15961,0.159473,0.169132,0.1597741
breast,0.032365,0.032383,0.031898,0.031818,0.032984,0.04895,0.030891,0.033831,0.03378235,0.0754,0.032367,0.031745,0.032319,0.030697,0.030223,0.029845,0.029736,0.033059,0.03150654,0.031118,0.030522,0.030285,0.032812,0.03323005
hillvalley,0.253449,0.258175,0.247504,0.251131,0.255434,0.256945,0.25345,0.252078,0.2519566,0.389836,0.264715,0.252807,0.253611,0.250354,0.250004,0.250658,0.250385,0.25762,0.2508091,0.250917,0.252021,0.251195,0.257505,0.2517418
madelon,0.202876,0.202887,0.199147,0.184868,0.189453,0.188236,0.191954,0.188387,0.1850176,0.367074,0.200794,0.186108,0.20292,0.201023,0.17555,0.185474,0.175372,0.180732,0.1758382,0.202405,0.17643,0.175784,0.179104,0.1761145


In [5]:
print(params["calib_methods"][tables["logloss"].loc["Rank"].argmin()])
tables["logloss"]

RF_ens_CRF


Unnamed: 0_level_0,RF,RF_CT,RF_fulldata,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line,RF_boot,RF_ens_r,RF_ens_line,RF_ens_CRF,RF_ens_Platt,RF_ens_ISO,RF_ens_Beta,RF_large,RF_large_line,RF_large_Platt,RF_large_ISO,RF_large_Beta
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
spambase,0.170129,0.190225,0.167312,0.134049,0.258861,0.255278,0.156573,0.136395,0.160919,0.33692,0.208343,0.207971,0.170031,0.156892,0.23925,0.14588,0.133274,0.266409,0.148055,0.155903,0.237876,0.131639,0.251455,0.146729
QSAR,0.349624,0.350371,0.378255,0.329191,0.733003,0.704934,0.357315,0.330361,0.331455,0.627861,0.349847,0.473379,0.349738,0.315577,0.433308,0.318174,0.324441,0.705389,0.322465,0.316675,0.380023,0.325016,0.797414,0.325421
bank,0.027201,0.027189,0.026169,0.028582,0.11384,0.372886,0.022981,0.041602,0.061422,0.068496,0.02644,0.023627,0.027227,0.028376,0.024283,0.024163,0.030298,0.140607,0.063285,0.028155,0.024067,0.029946,0.139696,0.054934
parkinsons,0.32104,0.33914,0.30736,0.360575,1.158246,0.62241,0.314038,0.369605,0.333303,0.537408,0.355099,0.34193,0.320896,0.319863,0.334362,0.31669,0.359219,1.316905,0.326493,0.320692,0.505494,0.35925,0.987854,0.328276
vertebral,0.32981,0.32758,0.329059,0.350616,1.631424,3.220739,0.32536,0.35591,0.418006,0.618448,0.321858,0.433159,0.329957,0.331711,0.339038,0.331563,0.353304,1.414949,0.401307,0.33299,0.436142,0.355157,1.52638,0.416553
ionosphere,0.202702,0.293165,0.207349,0.218728,1.63683,1.691817,0.195208,0.2329,1.388162,0.42293,0.287923,0.287989,0.203201,0.198729,0.183077,0.186773,0.212877,1.641093,0.750473,0.197606,0.179596,0.213174,1.625045,1.037829
diabetes,0.482899,0.491769,0.490916,0.490422,0.928407,0.982529,0.478442,0.496897,0.487769,0.854747,0.587392,0.487029,0.482974,0.479055,0.653647,0.475352,0.484917,0.849999,0.479694,0.480671,0.524097,0.486308,0.889872,0.483418
breast,0.172665,0.172099,0.175137,0.127462,0.509676,0.622042,0.167041,0.133705,0.386885,0.287635,0.170821,0.175471,0.172571,0.117297,0.166422,0.114048,0.122358,0.507027,0.383612,0.115357,0.165832,0.123617,0.506265,0.371895
hillvalley,0.704247,0.716579,0.691291,0.695741,0.985413,0.876118,0.704248,0.700657,0.698019,1.156053,0.737351,0.700113,0.704545,0.697322,0.693973,0.698001,0.69443,1.1547,0.695873,0.698294,0.698053,0.69613,1.129176,0.697682
madelon,0.5947,0.594707,0.586721,0.546655,0.844095,0.842265,0.567382,0.557555,0.547711,1.036536,0.589966,0.661228,0.594794,0.591513,0.583966,0.553019,0.522314,0.808901,0.523674,0.59454,0.583401,0.523491,0.70664,0.524865


In [6]:
print(params["calib_methods"][tables["acc"].loc["Rank"].argmin()])
tables["acc"]

RF_ens_line


Unnamed: 0_level_0,RF,RF_CT,RF_fulldata,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line,RF_boot,RF_ens_r,RF_ens_line,RF_ens_CRF,RF_ens_Platt,RF_ens_ISO,RF_ens_Beta,RF_large,RF_large_line,RF_large_Platt,RF_large_ISO,RF_large_Beta
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
spambase,0.954793,0.954358,0.955226,0.954575,0.953489,0.944793,0.954793,0.95414,0.953491,0.849387,0.952185,0.95501,0.955227,0.953705,0.954792,0.953705,0.953706,0.950446,0.953489,0.954793,0.95501,0.953489,0.951316,0.953707
QSAR,0.854025,0.854025,0.862606,0.854025,0.86257,0.859721,0.854025,0.863513,0.864429,0.697736,0.857808,0.854034,0.853082,0.855948,0.857853,0.855948,0.856891,0.871123,0.86354,0.860683,0.864456,0.862561,0.870171,0.865409
bank,0.994166,0.994166,0.991987,0.993441,0.993447,0.968666,0.994166,0.988353,0.993447,0.976663,0.994166,0.994166,0.994166,0.992711,0.992711,0.992711,0.992711,0.991987,0.992717,0.993436,0.993436,0.992711,0.991987,0.992717
parkinsons,0.820526,0.850789,0.871579,0.855526,0.850789,0.805789,0.820526,0.835263,0.851053,0.768947,0.866053,0.825789,0.830789,0.820526,0.835789,0.820526,0.850526,0.840526,0.830526,0.820526,0.835,0.855789,0.850789,0.830526
vertebral,0.851613,0.854839,0.854839,0.83871,0.832258,0.809677,0.851613,0.832258,0.835484,0.729032,0.83871,0.848387,0.851613,0.841935,0.83871,0.841935,0.825806,0.819355,0.832258,0.83871,0.825806,0.832258,0.819355,0.832258
ionosphere,0.931587,0.931587,0.92873,0.92873,0.925873,0.894444,0.931587,0.925873,0.920159,0.780556,0.934444,0.925873,0.931587,0.940159,0.92873,0.940159,0.92873,0.923016,0.92873,0.940159,0.925873,0.931587,0.92873,0.92873
diabetes,0.773411,0.772112,0.769532,0.765584,0.75393,0.725342,0.773411,0.757809,0.759108,0.537833,0.773411,0.766883,0.772095,0.774744,0.769532,0.774744,0.769532,0.760407,0.768216,0.770813,0.77083,0.77083,0.752563,0.769515
breast,0.961341,0.961341,0.959586,0.957832,0.952569,0.934994,0.961341,0.952569,0.954292,0.878665,0.959586,0.959586,0.961341,0.966604,0.963095,0.966604,0.963127,0.952569,0.961341,0.96485,0.963095,0.961341,0.94906,0.954323
hillvalley,0.545407,0.542948,0.575139,0.530545,0.553658,0.51573,0.545407,0.554484,0.529698,0.500826,0.555318,0.514029,0.542941,0.559423,0.538775,0.559423,0.538768,0.560229,0.539568,0.554464,0.512336,0.523899,0.561022,0.532157
madelon,0.717692,0.716923,0.720769,0.718846,0.71,0.716154,0.717692,0.710385,0.715769,0.5,0.717692,0.720385,0.718846,0.737308,0.736154,0.737308,0.736538,0.730385,0.737308,0.742692,0.741154,0.744615,0.738462,0.742308


In [7]:
print(params["calib_methods"][tables["ece"].loc["Rank"].argmin()])
tables["ece"]

RF_ens_Beta


Unnamed: 0_level_0,RF,RF_CT,RF_fulldata,Platt,ISO,Rank,CRF,VA,Beta,Elkan,tlr,Line,RF_boot,RF_ens_r,RF_ens_line,RF_ens_CRF,RF_ens_Platt,RF_ens_ISO,RF_ens_Beta,RF_large,RF_large_line,RF_large_Platt,RF_large_ISO,RF_large_Beta
Data,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
spambase,0.003762,0.004085,0.003943,0.00022,0.000198,0.000248,0.000877,0.000155,0.0003872364,0.070424,0.003863,0.003679,0.003819,0.003731,0.003408,0.000806,0.000243,0.0005267326,0.0002614442,0.003734,0.003103,0.000321,0.0003316598,0.000192945
QSAR,0.002301,0.002601,0.002287,0.001817,0.002086,0.007006,0.003701,0.001619,0.001530862,0.105677,0.003702,0.002922,0.002492,0.002078,0.001835,0.002687,0.001782,0.007415883,0.001236659,0.001945,0.003679,0.002514,0.005529805,0.00133577
bank,0.002313,0.002319,0.002436,0.000254,8e-06,0.008719,0.000936,0.002282,0.0002016994,0.030855,0.00217,0.001742,0.00201,0.000678,0.000814,0.000583,0.000206,1.029419e-05,0.0001432254,0.000926,0.001223,0.000213,9.028214e-06,0.000181078
parkinsons,0.014283,0.020284,0.01388,0.022499,0.053355,0.057447,0.017362,0.008443,0.01763497,0.079957,0.031274,0.015945,0.013307,0.009112,0.016026,0.008785,0.013397,0.005997357,0.008776471,0.010021,0.012955,0.016784,0.01623842,0.009778325
vertebral,0.005527,0.005191,0.005285,0.005735,0.010076,0.051905,0.004695,0.005218,0.008822493,0.111765,0.009567,0.002976,0.005986,0.005816,0.020052,0.003332,0.008474,0.01065692,0.01259785,0.006823,0.007752,0.007026,0.01251121,0.008627474
ionosphere,0.007213,0.007307,0.010459,0.006441,0.004858,0.023797,0.004112,0.004145,0.008901834,0.099986,0.00696,0.005434,0.007174,0.006483,0.005269,0.002613,0.006507,0.004357746,0.003895623,0.005809,0.007886,0.005412,0.009013796,0.005070408
diabetes,0.003913,0.004322,0.004424,0.004582,0.00708,0.007563,0.003956,0.006048,0.002760645,0.145528,0.009045,0.003812,0.004128,0.003458,0.003214,0.002445,0.002798,0.01026426,0.00612805,0.00375,0.002568,0.004393,0.00828238,0.004835496
breast,0.002365,0.002331,0.003746,0.002235,0.001443,0.02369,0.001491,0.002543,0.002525642,0.06799,0.001897,0.002137,0.002634,0.002637,0.003251,0.001366,0.001702,0.001011044,0.001053143,0.001848,0.002988,0.002134,0.004383593,0.001519808
hillvalley,0.009329,0.014736,0.00684,0.004382,0.009997,0.00994,0.009623,0.006881,0.00588796,0.144898,0.024314,0.019658,0.010165,0.009749,0.007566,0.010995,0.00368,0.01275097,0.004164671,0.007872,0.007325,0.006657,0.0139734,0.008719061
madelon,0.019077,0.019419,0.020538,0.001331,0.002768,0.004854,0.004372,0.002392,0.000756541,0.182548,0.018331,0.001753,0.01944,0.027362,0.001137,0.005551,0.00139,0.004018039,0.001181614,0.029385,0.002145,0.001785,0.003136255,0.00157527
