In [3]:
import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import norm
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
import pandas as pd
# !pip install openml
import openml
# !pip install cardinal
# !pip install scikit-learn==0.20.4
# !pip install sklearn.cluster._k_means_fast
from cardinal.uncertainty import MarginSampler
from cardinal.random import RandomSampler
from cardinal.zhdanov2019 import TwoStepKMeansSampler
from cardinal.plotting import plot_confidence_interval
import tqdm


import os
import tensorflow as tf
from sklearn.model_selection import KFold, StratifiedKFold
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.models import clone_model

In [4]:
np.random.seed(123)

# Selection of the best model

In [13]:
dataset_ids = [1471, 1502, 40922, 41162, 43551, 1461, 1590, 41138, 42395, 42803, 43439]  
#DONE : 1471, 1502, 40922  (arrays ?)     \   41162, 43551, 1461, 1590, 41138, 42395, 42803, 43439 

# Maybe : 42088
#TODO : 42750
#No : 42493, 42256

### Preprocessing

In [67]:
def getData(dataset_id):
    """
    Returns X, y corresponding to a specific OpenML dataset id with some additional preprocessing
    """

    dataset = openml.datasets.get_dataset(dataset_id)

    if dataset_id in [1471, 1502, 40922, 41162]:
        X, y, cat_indicator, names = dataset.get_data(dataset_format='array', target=dataset.default_target_attribute)  #dataframe / array
    else:
        X, y, cat_indicator, names = dataset.get_data(dataset_format='dataframe', target=dataset.default_target_attribute)  #dataframe / array
    
    cat_indicator = np.asarray(cat_indicator)


    #Special preprocessing for debugging
    if dataset_id == 42395:
        X=X.drop(['ID_code'], axis = 1)       #id = 42395
        cat_indicator = cat_indicator[1:] #TODO
    if dataset_id == 42088:
        X=X.drop(['brewery_name', 'review_profilename'], axis = 1)        #id = 42088
        cat_indicator = cat_indicator[2:] #TODO
    if dataset_id == 42256:
        X=X.drop(['full_name'], axis = 1)       #id = 42256
        cat_indicator = cat_indicator[1:] #TODO
    if dataset_id == 42803:
        X=X.drop(['Accident_Index', 'Date','Time', 'Local_Authority_(Highway)', 'LSOA_of_Accident_Location'], axis = 1)       #id = 42803
        cat_indicator = cat_indicator[5:] #TODO
    if dataset_id == 43439:
        X=X.drop(['Gender', 'ScheduledDay', 'AppointmentDay','Neighbourhood'], axis = 1)     
        cat_indicator = cat_indicator[4:] #TODO
    if dataset_id == 42088: 
        X=X.drop(['brewery_name', 'review_profilename', 'beer_name'], axis = 1)      #id = 42088
        cat_indicator = cat_indicator[3:] #TODO


    ct_cat = ColumnTransformer([
        ('normalizer', StandardScaler(), np.where(~cat_indicator)[0])
    ], remainder='passthrough')

    X_cat = pd.DataFrame(ct_cat.fit_transform(X)).convert_dtypes()

    ct = ColumnTransformer([
        ('encoder', OneHotEncoder(), np.where(cat_indicator)[0]),
        ('normalizer', StandardScaler(), np.where(~cat_indicator)[0])
    ], remainder='passthrough')

    X = ct.fit_transform(X)

    #TODO
    if dataset_id in [41162]:
        X = np.asarray(np.nan_to_num(X.todense()))    #X.todense()
    else:
        X = np.asarray(np.nan_to_num(X))

    #Shuffle
    # if dataset_id in [1471, 1502, 40922]:
    #     idx = np.arange(len(X)) 
    # else:
    print(X.shape)
    idx = np.arange(X.shape[0]) 
    np.random.shuffle(idx)
    X = X[idx]
    y = y[idx]
    
    #Reduce execution time
    X = X[:int(0.1 * X.shape[0])]
    y = y[:int(0.1 * y.shape[0])]
    
    return X, y

# X.shape, y.shape, batch_size

In [68]:
# for id in dataset_ids:
#     getData(id)

### Run

In [69]:
models = [
    ('GBC', GradientBoostingClassifier()),
    # ('Margin', MarginSampler(model, batch_size)),
    # ('Random', RandomSampler(batch_size)),
]

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
num_folds = 10

# Define the K-fold Cross Validator
kfold = KFold(n_splits=num_folds, shuffle=True)
skf = StratifiedKFold(n_splits=num_folds)


for dataset_id in tqdm.tqdm(dataset_ids, desc = 'DATASETS'):

    # print(dataset_id)

    X, y = getData(dataset_id)

    for model_name, base_model in tqdm.tqdm(models, desc = f'models dataset id={dataset_id}'):

        #Check if model already studied in this dataset
        filePath = f'./results/{dataset_id}-{model_name}.csv'
        if os.path.isfile(filePath) == False :

            all_accuracies = []

            #Train/test split : le "test set" sera utilisé plus tard dans le benchmark pour entrainer initialement le modèle (donc pas utilisé ici)
            X_train, X_test, y_train, y_test = \
                    train_test_split(X, y, test_size=int(.2 * X.shape[0]))
        
            for train, validation in kfold.split(X_train, y_train): # or kfold.split(X_train, y_train):

                print(X_train[train].shape, X_train[validation].shape)

                #Training with EarlyStopping based on the crossvalidation validation set
                model = base_model    #clone_model(base_model)  #TODO : verifier qu'il s'agit bien d'une nouvelle instance
                if model_name in []:
                    model.fit(X_train[train], y_train[train], callbacks=[callback])
                else:
                    # No need for Early Stopping callback for GBC : https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_early_stopping.html
                    model.fit(X_train[train], y_train[train])

                # Record metrics
                all_accuracies.append(model.score(X_train[validation], y_train[validation]))

            #Save model results
            results = np.array([[dataset_id, model_name, acc] for acc in all_accuracies])
            df = pd.DataFrame(results, columns = ['datasetId', 'modelName', 'accuracy'])
            save_path = f'./results/{dataset_id}-{model_name}.csv'
            df.to_csv(save_path, index=False)


DATASETS:   0%|          | 0/11 [00:00<?, ?it/s]

(14980, 14)


models dataset id=1471: 100%|██████████| 1/1 [00:00<00:00, 8128.50it/s]
DATASETS:   9%|▉         | 1/11 [00:00<00:01,  6.43it/s]

(245057, 3)


models dataset id=1502: 100%|██████████| 1/1 [00:00<00:00, 4629.47it/s]
DATASETS:  18%|█▊        | 2/11 [00:00<00:01,  7.62it/s]

(88588, 6)


models dataset id=40922: 100%|██████████| 1/1 [00:00<00:00, 4447.83it/s]


(72983, 2437)




(5255, 2437) (584, 2437)
(5255, 2437) (584, 2437)
(5255, 2437) (584, 2437)
(5255, 2437) (584, 2437)
(5255, 2437) (584, 2437)
(5255, 2437) (584, 2437)
(5255, 2437) (584, 2437)
(5255, 2437) (584, 2437)
(5255, 2437) (584, 2437)
(5256, 2437) (583, 2437)


models dataset id=41162: 100%|██████████| 1/1 [05:08<00:00, 308.61s/it]
DATASETS:  36%|███▋      | 4/11 [05:13<11:27, 98.22s/it]

(34452, 9)


models dataset id=43551:   0%|          | 0/1 [00:00<?, ?it/s]
DATASETS:  36%|███▋      | 4/11 [05:13<09:09, 78.46s/it]

(2480, 9) (276, 9)





KeyError: '[0, 3, 4, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 45, 46, 47, 48, 51, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 69, 70, 73, 74, 75, 76, 77, 79, 80, 81, 82, 83, 84, 86, 87, 88, 91, 94, 95, 96, 97, 98, 99, 100, 101, 102, 104, 105, 106, 107, 108, 109, 111, 112, 113, 114, 115, 116, 117, 120, 121, 122, 123, 125, 127, 128, 131, 132, 133, 134, 135, 136, 137, 139, 140, 141, 142, 144, 145, 146, 147, 148, 149, 150, 152, 153, 155, 156, 157, 160, 161, 163, 164, 165, 166, 167, 169, 170, 171, 172, 174, 175, 177, 179, 180, 181, 182, 183, 184, 185, 186, 187, 189, 190, 191, 193, 194, 195, 197, 198, 199, 200, 201, 203, 204, 205, 206, 207, 208, 209, 211, 212, 213, 214, 215, 216, 219, 220, 222, 223, 224, 225, 226, 228, 229, 230, 231, 232, 233, 235, 236, 237, 239, 240, 241, 242, 244, 245, 246, 248, 251, 252, 253, 257, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 271, 272, 273, 274, 275, 276, 277, 278, 279, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 304, 305, 306, 307, 309, 310, 311, 312, 313, 314, 317, 318, 319, 320, 321, 322, 323, 325, 326, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 343, 344, 345, 346, 347, 348, 349, 350, 351, 353, 354, 355, 356, 357, 358, 359, 360, 361, 365, 366, 367, 368, 369, 370, 371, 375, 376, 377, 378, 379, 380, 382, 383, 384, 385, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 404, 405, 406, 408, 409, 410, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 425, 427, 430, 431, 432, 433, 434, 435, 436, 437, 440, 441, 443, 444, 445, 446, 447, 448, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 462, 463, 464, 465, 466, 467, 468, 470, 472, 473, 475, 476, 477, 478, 479, 480, 481, 482, 483, 485, 486, 487, 488, 490, 491, 492, 493, 495, 496, 497, 498, 499, 501, 502, 503, 504, 506, 507, 508, 509, 510, 511, 512, 514, 516, 517, 519, 520, 521, 522, 523, 525, 526, 527, 528, 529, 530, 532, 534, 535, 536, 537, 538, 539, 540, 543, 544, 546, 547, 549, 550, 551, 553, 555, 558, 560, 562, 564, 565, 566, 567, 568, 569, 570, 572, 573, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 602, 603, 604, 605, 606, 607, 609, 610, 611, 612, 613, 614, 616, 617, 618, 620, 621, 622, 623, 625, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 641, 642, 644, 645, 646, 647, 648, 649, 651, 654, 655, 656, 658, 660, 661, 664, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 677, 678, 679, 680, 681, 683, 684, 686, 687, 689, 691, 692, 694, 695, 696, 697, 700, 701, 702, 704, 705, 706, 707, 708, 710, 711, 712, 713, 714, 715, 716, 718, 719, 721, 722, 723, 724, 725, 726, 727, 728, 730, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 756, 758, 759, 761, 762, 763, 764, 765, 766, 769, 770, 771, 772, 774, 778, 779, 780, 782, 785, 786, 787, 788, 789, 791, 793, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 832, 833, 834, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 859, 860, 861, 862, 864, 865, 866, 867, 869, 870, 871, 872, 874, 875, 876, 877, 879, 880, 881, 882, 883, 884, 885, 886, 888, 889, 890, 891, 892, 893, 894, 895, 897, 898, 900, 901, 902, 903, 904, 905, 906, 907, 908, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 924, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 945, 946, 947, 948, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 967, 968, 970, 971, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 985, 986, 988, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1003, 1004, 1006, 1007, 1008, 1011, 1012, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1033, 1034, 1036, 1037, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1055, 1056, 1057, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1067, 1068, 1069, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1078, 1079, 1080, 1081, 1082, 1083, 1085, 1086, 1087, 1091, 1092, 1093, 1094, 1095, 1096, 1098, 1099, 1101, 1102, 1103, 1105, 1106, 1107, 1108, 1109, 1110, 1112, 1113, 1114, 1115, 1116, 1117, 1119, 1121, 1123, 1124, 1125, 1126, 1127, 1128, 1129, 1130, 1131, 1132, 1133, 1134, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143, 1144, 1145, 1147, 1148, 1149, 1150, 1151, 1152, 1153, 1154, 1155, 1156, 1157, 1158, 1159, 1160, 1161, 1162, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1179, 1180, 1181, 1182, 1184, 1185, 1186, 1187, 1188, 1190, 1191, 1193, 1194, 1195, 1197, 1198, 1200, 1201, 1202, 1203, 1204, 1205, 1206, 1208, 1209, 1210, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1221, 1222, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1231, 1232, 1233, 1234, 1235, 1237, 1238, 1239, 1240, 1242, 1243, 1244, 1245, 1246, 1247, 1248, 1249, 1250, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1260, 1261, 1262, 1263, 1264, 1266, 1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1280, 1281, 1282, 1283, 1284, 1285, 1286, 1287, 1289, 1290, 1291, 1292, 1296, 1297, 1298, 1300, 1301, 1302, 1303, 1304, 1305, 1306, 1307, 1308, 1309, 1310, 1311, 1312, 1313, 1314, 1315, 1316, 1317, 1318, 1321, 1322, 1323, 1324, 1325, 1327, 1328, 1330, 1331, 1332, 1333, 1334, 1335, 1336, 1337, 1339, 1341, 1342, 1344, 1345, 1346, 1347, 1348, 1349, 1350, 1351, 1352, 1353, 1354, 1355, 1356, 1358, 1359, 1360, 1361, 1362, 1363, 1364, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1373, 1374, 1375, 1376, 1378, 1379, 1380, 1381, 1382, 1383, 1384, 1385, 1386, 1387, 1388, 1389, 1390, 1391, 1392, 1393, 1394, 1396, 1397, 1398, 1399, 1401, 1402, 1403, 1404, 1405, 1406, 1407, 1408, 1409, 1410, 1411, 1413, 1414, 1415, 1417, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1426, 1427, 1429, 1430, 1432, 1433, 1434, 1436, 1437, 1438, 1439, 1440, 1442, 1443, 1444, 1445, 1447, 1448, 1449, 1450, 1451, 1452, 1453, 1454, 1455, 1456, 1457, 1458, 1460, 1461, 1462, 1463, 1464, 1465, 1466, 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1475, 1476, 1477, 1478, 1479, 1480, 1481, 1483, 1484, 1485, 1486, 1487, 1488, 1489, 1490, 1491, 1493, 1495, 1496, 1497, 1498, 1499, 1500, 1503, 1504, 1505, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1516, 1517, 1518, 1519, 1520, 1521, 1524, 1525, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1552, 1553, 1554, 1558, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1571, 1572, 1573, 1574, 1575, 1577, 1579, 1580, 1582, 1583, 1584, 1586, 1587, 1589, 1590, 1593, 1594, 1595, 1596, 1597, 1599, 1600, 1602, 1603, 1604, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1630, 1631, 1632, 1633, 1636, 1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1651, 1652, 1653, 1654, 1655, 1656, 1658, 1660, 1661, 1662, 1663, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685, 1686, 1688, 1689, 1690, 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707, 1709, 1711, 1714, 1716, 1718, 1719, 1720, 1721, 1722, 1723, 1726, 1727, 1728, 1729, 1730, 1731, 1732, 1734, 1735, 1736, 1737, 1739, 1741, 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1752, 1753, 1754, 1755, 1756, 1757, 1758, 1759, 1760, 1762, 1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1772, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1783, 1784, 1785, 1786, 1787, 1788, 1790, 1792, 1793, 1794, 1796, 1798, 1799, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1809, 1810, 1811, 1812, 1813, 1815, 1816, 1817, 1818, 1820, 1821, 1823, 1824, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1837, 1838, 1840, 1841, 1842, 1843, 1844, 1845, 1846, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1858, 1859, 1860, 1862, 1863, 1864, 1865, 1866, 1868, 1869, 1870, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1882, 1883, 1884, 1885, 1886, 1888, 1889, 1890, 1891, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1900, 1901, 1902, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1912, 1913, 1914, 1915, 1917, 1918, 1919, 1923, 1924, 1925, 1926, 1927, 1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1938, 1939, 1940, 1941, 1942, 1943, 1944, 1945, 1946, 1947, 1951, 1952, 1954, 1956, 1958, 1959, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1968, 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1981, 1984, 1985, 1986, 1987, 1989, 1990, 1991, 1992, 1993, 1995, 1996, 1998, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2013, 2014, 2015, 2016, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2032, 2033, 2034, 2036, 2038, 2039, 2040, 2041, 2042, 2043, 2044, 2045, 2046, 2047, 2048, 2049, 2050, 2051, 2052, 2054, 2055, 2056, 2057, 2058, 2060, 2061, 2062, 2063, 2064, 2065, 2066, 2067, 2068, 2069, 2070, 2072, 2073, 2077, 2080, 2081, 2082, 2083, 2085, 2086, 2087, 2088, 2090, 2091, 2092, 2093, 2094, 2095, 2096, 2097, 2098, 2100, 2101, 2102, 2103, 2104, 2105, 2106, 2107, 2108, 2110, 2112, 2113, 2114, 2115, 2116, 2117, 2118, 2119, 2120, 2121, 2122, 2123, 2124, 2125, 2126, 2127, 2128, 2130, 2131, 2133, 2134, 2135, 2136, 2137, 2138, 2139, 2140, 2141, 2142, 2143, 2144, 2145, 2146, 2147, 2148, 2149, 2150, 2151, 2152, 2153, 2154, 2155, 2156, 2157, 2159, 2160, 2161, 2162, 2163, 2164, 2165, 2166, 2167, 2168, 2169, 2170, 2171, 2172, 2173, 2174, 2175, 2176, 2177, 2178, 2179, 2180, 2181, 2182, 2183, 2184, 2185, 2186, 2187, 2188, 2189, 2190, 2191, 2193, 2194, 2195, 2197, 2198, 2199, 2201, 2203, 2204, 2205, 2207, 2208, 2212, 2213, 2214, 2215, 2216, 2217, 2219, 2220, 2221, 2222, 2224, 2225, 2226, 2227, 2228, 2229, 2230, 2231, 2233, 2234, 2235, 2236, 2237, 2238, 2240, 2241, 2242, 2243, 2245, 2247, 2248, 2249, 2250, 2251, 2252, 2253, 2255, 2256, 2257, 2258, 2261, 2262, 2263, 2265, 2266, 2267, 2268, 2269, 2270, 2271, 2272, 2275, 2276, 2277, 2279, 2280, 2281, 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2290, 2291, 2292, 2295, 2296, 2297, 2298, 2300, 2301, 2302, 2303, 2304, 2306, 2307, 2309, 2311, 2312, 2313, 2314, 2315, 2316, 2317, 2318, 2319, 2321, 2322, 2323, 2324, 2327, 2328, 2329, 2330, 2331, 2332, 2334, 2336, 2337, 2338, 2339, 2340, 2341, 2342, 2343, 2346, 2347, 2348, 2349, 2350, 2351, 2352, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2362, 2363, 2364, 2365, 2366, 2367, 2369, 2370, 2371, 2372, 2375, 2376, 2377, 2378, 2380, 2381, 2382, 2383, 2384, 2385, 2386, 2387, 2388, 2389, 2390, 2391, 2392, 2393, 2396, 2397, 2398, 2399, 2400, 2401, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2409, 2410, 2412, 2413, 2414, 2415, 2416, 2418, 2419, 2420, 2422, 2423, 2424, 2425, 2427, 2428, 2429, 2431, 2433, 2435, 2436, 2437, 2438, 2440, 2441, 2442, 2443, 2444, 2445, 2446, 2448, 2449, 2450, 2451, 2452, 2454, 2455, 2456, 2458, 2459, 2460, 2461, 2464, 2465, 2467, 2468, 2469, 2470, 2471, 2474, 2475, 2476, 2478, 2479, 2480, 2481, 2483, 2484, 2485, 2486, 2487, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2497, 2498, 2500, 2501, 2503, 2504, 2505, 2506, 2507, 2509, 2510, 2511, 2512, 2514, 2515, 2517, 2518, 2519, 2521, 2522, 2523, 2524, 2525, 2526, 2527, 2528, 2529, 2532, 2533, 2536, 2537, 2538, 2540, 2541, 2542, 2543, 2545, 2546, 2547, 2549, 2550, 2552, 2553, 2554, 2555, 2556, 2557, 2558, 2559, 2562, 2563, 2565, 2566, 2567, 2568, 2569, 2570, 2571, 2572, 2573, 2574, 2575, 2576, 2578, 2579, 2580, 2581, 2582, 2583, 2584, 2585, 2586, 2587, 2588, 2589, 2590, 2591, 2592, 2593, 2594, 2595, 2596, 2597, 2598, 2599, 2600, 2601, 2602, 2603, 2604, 2605, 2606, 2607, 2609, 2610, 2611, 2612, 2613, 2614, 2616, 2617, 2618, 2619, 2620, 2621, 2624, 2625, 2626, 2627, 2630, 2631, 2632, 2633, 2634, 2635, 2637, 2638, 2639, 2640, 2641, 2642, 2643, 2644, 2645, 2646, 2647, 2648, 2649, 2650, 2652, 2653, 2654, 2655, 2656, 2657, 2658, 2660, 2661, 2662, 2663, 2664, 2666, 2668, 2669, 2671, 2672, 2673, 2674, 2676, 2677, 2678, 2679, 2680, 2681, 2682, 2683, 2684, 2685, 2687, 2688, 2689, 2690, 2691, 2692, 2693, 2695, 2696, 2697, 2698, 2699, 2700, 2701, 2702, 2703, 2704, 2705, 2706, 2707, 2709, 2710, 2711, 2712, 2713, 2714, 2715, 2717, 2718, 2719, 2720, 2721, 2722, 2723, 2724, 2726, 2727, 2728, 2729, 2730, 2731, 2732, 2733, 2734, 2735, 2736, 2740, 2741, 2743, 2744, 2745, 2746, 2747, 2748, 2749, 2751, 2752, 2753, 2754, 2755] not in index'

### Summarizing results

In [None]:
all_results = pd.DataFrame(columns = ['datasetId', 'modelName', 'accuracy'])
avg_results = pd.DataFrame(columns = ['datasetId', 'modelName', 'accuracy'])    #Averaged accuracy of K-fold model results
pd.DataFrame(results, columns = ['datasetId', 'modelName', 'accuracy'])
for dataset_id in dataset_ids:
    for model_name, base_model in models:

        #Load model results
        load_path = f'./results/{dataset_id}-{model_name}.csv'
        df = pd.read_csv(load_path)

        #Join results go global dataframes
        all_results.append(df)

        mean_acc = df['accuracy'].mean()
        avg_results.append(\
            pd.DataFrame(\
                np.array([dataset_id, model_name, mean_acc]),
                columns = ['datasetId', 'modelName', 'accuracy']
                )
            )

#Saving grouped results 
save_path = f'./results/ALL-RESULTS.csv'
all_results.to_csv(save_path, index=False)

save_path = f'./results/RESULTS.csv'
avg_results.to_csv(save_path, index=False)        