In [1]:
import numpy as np
import pandas as pd

import tqdm

from scipy.stats import wilcoxon

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.metrics import roc_auc_score

from flipping_random_forest import FlippingRandomForestClassifier

from datasets import binclas_datasets

In [2]:
data_loaders = binclas_datasets['data_loader_function'].values.tolist()

In [3]:
results = []

validator = RepeatedStratifiedKFold(n_splits=5, n_repeats=400, random_state=5)

for data_loader in data_loaders:
    dataset = data_loader()
    X = dataset['data']
    y = dataset['target']
    
    aucs_orig = []
    aucs_flipped = []
    aucs_baseline = []
    aucs_baseline_flipped = []
    aucs_flipping_full = []
    aucs_flipping_coord = []
    
    for train, test in tqdm.tqdm(validator.split(X, y, y)):
        X_train = X[train]
        X_test = X[test]
        y_train = y[train]
        y_test = y[test]
        
        min_samples_leaf = 1#np.random.randint(1, 21)
        max_depth = None#np.random.choice([None, 1, 2, 3, 4, 5, 6, 7, 8, 9])
        
        params = {'n_jobs': 1,
                    'min_samples_leaf': min_samples_leaf,
                    'max_depth': max_depth}
        
        try:
            pred = RandomForestClassifier(**params).fit(X_train, y_train).predict_proba(X_test)[:, 1]
            aucs_orig.append(roc_auc_score(y_test, pred))
            
            #pred = RandomForestClassifier(n_jobs=4).fit(-X_train, y_train).predict_proba(-X_test)[:, 1]
            #aucs_flipped.append(roc_auc_score(y_test, pred))
            aucs_flipped.append(0)
            
            pred = FlippingRandomForestClassifier(**params).fit(X_train, y_train).predict_proba(X_test)[:, 1]
            aucs_baseline.append(roc_auc_score(y_test, pred))
            
            #pred = FlippingRandomForestClassifier(n_jobs=4).fit(-X_train, y_train).predict_proba(-X_test)[:, 1]
            #aucs_baseline_flipped.append(roc_auc_score(y_test, pred))
            aucs_baseline_flipped.append(0)
            
            pred = FlippingRandomForestClassifier(flipping='full', **params).fit(X_train, y_train).predict_proba(X_test)[:, 1]
            aucs_flipping_full.append(roc_auc_score(y_test, pred))
            
            pred = FlippingRandomForestClassifier(flipping='coordinate', **params).fit(X_train, y_train).predict_proba(X_test)[:, 1]
            aucs_flipping_coord.append(roc_auc_score(y_test, pred))
        except:
            pass
    
    tmp = [dataset['name'], np.mean(aucs_orig), np.mean(aucs_flipped), np.mean(aucs_baseline), 
                            np.mean(aucs_baseline_flipped), np.mean(aucs_flipping_full), np.mean(aucs_flipping_coord), 
                            aucs_orig, aucs_flipped, aucs_baseline, aucs_baseline_flipped, aucs_flipping_full, aucs_flipping_coord]
    
    tmp = tmp + [wilcoxon(aucs_baseline, aucs_flipping_full, alternative='less', zero_method='zsplit').pvalue,
                wilcoxon(aucs_baseline, aucs_flipping_coord, alternative='less', zero_method='zsplit').pvalue]
    
    results.append(tmp)
    
    results_pdf = pd.DataFrame(results, columns=['name', 'auc_orig', 'auc_flipped', 'auc_baseline', 'auc_baseline_flipped', 
                                'auc_flipping_full', 'auc_flipping_coord', 'aucs_orig', 'aucs_flipped', 'aucs_baseline', 
                                'aucs_baseline_flipped', 'aucs_flipping_full', 'aucs_flipping_coord', 'p_full', 'p_coord'])
    results_pdf['auc_baseline_min'] = results_pdf[['auc_baseline', 'auc_baseline_flipped']].apply(lambda x: min(x), axis=1)
    results_pdf['auc_baseline_max'] = results_pdf[['auc_baseline', 'auc_baseline_flipped']].apply(lambda x: max(x), axis=1)
    print(results_pdf[['name', 
                        'auc_orig', 
                        #'auc_flipped', 
                        'auc_baseline', 
                        #'auc_baseline_flipped', 
                        #'auc_baseline_min', 
                        #'auc_baseline_max', 
                        'auc_flipping_full', 
                        'auc_flipping_coord', 'p_full', 'p_coord']])
        


2000it [12:50,  2.60it/s]


       name  auc_orig  auc_baseline  auc_flipping_full  auc_flipping_coord  \
0  haberman  0.668162      0.668696           0.670559            0.670639   

     p_full  p_coord  
0  0.000145  0.00004  


2000it [11:48,  2.82it/s]


           name  auc_orig  auc_baseline  auc_flipping_full  \
0      haberman  0.668162      0.668696           0.670559   
1  new_thyroid1  0.999173      0.999248           0.999202   

   auc_flipping_coord    p_full   p_coord  
0            0.670639  0.000145  0.000040  
1            0.999148  0.811868  0.903625  


2000it [10:01,  3.33it/s]


               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.668162      0.668696           0.670559   
1      new_thyroid1  0.999173      0.999248           0.999202   
2  shuttle-6_vs_2-3  1.000000      1.000000           1.000000   

   auc_flipping_coord    p_full   p_coord  
0            0.670639  0.000145  0.000040  
1            0.999148  0.811868  0.903625  
2            1.000000  0.517619  0.464497  


2000it [12:13,  2.73it/s]


               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.668162      0.668696           0.670559   
1      new_thyroid1  0.999173      0.999248           0.999202   
2  shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3              bupa  0.764281      0.764262           0.762579   

   auc_flipping_coord    p_full   p_coord  
0            0.670639  0.000145  0.000040  
1            0.999148  0.811868  0.903625  
2            1.000000  0.517619  0.464497  
3            0.763146  0.999517  0.978406  


2000it [09:57,  3.35it/s]


               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.668162      0.668696           0.670559   
1      new_thyroid1  0.999173      0.999248           0.999202   
2  shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3              bupa  0.764281      0.764262           0.762579   
4  cleveland-0_vs_4  0.971887      0.971914           0.973872   

   auc_flipping_coord    p_full   p_coord  
0            0.670639  0.000145  0.000040  
1            0.999148  0.811868  0.903625  
2            1.000000  0.517619  0.464497  
3            0.763146  0.999517  0.978406  
4            0.973434  0.000008  0.000206  


2000it [11:03,  3.01it/s]


               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.668162      0.668696           0.670559   
1      new_thyroid1  0.999173      0.999248           0.999202   
2  shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3              bupa  0.764281      0.764262           0.762579   
4  cleveland-0_vs_4  0.971887      0.971914           0.973872   
5            ecoli1  0.954372      0.954315           0.954413   

   auc_flipping_coord    p_full   p_coord  
0            0.670639  0.000145  0.000040  
1            0.999148  0.811868  0.903625  
2            1.000000  0.517619  0.464497  
3            0.763146  0.999517  0.978406  
4            0.973434  0.000008  0.000206  
5            0.954508  0.144733  0.086045  


2000it [09:59,  3.34it/s]


               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.668162      0.668696           0.670559   
1      new_thyroid1  0.999173      0.999248           0.999202   
2  shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3              bupa  0.764281      0.764262           0.762579   
4  cleveland-0_vs_4  0.971887      0.971914           0.973872   
5            ecoli1  0.954372      0.954315           0.954413   
6      poker-9_vs_7  0.985171      0.985666           0.986835   

   auc_flipping_coord    p_full   p_coord  
0            0.670639  0.000145  0.000040  
1            0.999148  0.811868  0.903625  
2            1.000000  0.517619  0.464497  
3            0.763146  0.999517  0.978406  
4            0.973434  0.000008  0.000206  
5            0.954508  0.144733  0.086045  
6            0.986424  0.000008  0.005947  


2000it [10:36,  3.14it/s]


               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.668162      0.668696           0.670559   
1      new_thyroid1  0.999173      0.999248           0.999202   
2  shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3              bupa  0.764281      0.764262           0.762579   
4  cleveland-0_vs_4  0.971887      0.971914           0.973872   
5            ecoli1  0.954372      0.954315           0.954413   
6      poker-9_vs_7  0.985171      0.985666           0.986835   
7            monk-2  1.000000      1.000000           0.999999   

   auc_flipping_coord    p_full   p_coord  
0            0.670639  0.000145  0.000040  
1            0.999148  0.811868  0.903625  
2            1.000000  0.517619  0.464497  
3            0.763146  0.999517  0.978406  
4            0.973434  0.000008  0.000206  
5            0.954508  0.144733  0.086045  
6            0.986424  0.000008  0.005947  
7            0.999999  0.153986  0.143223  


2000it [10:17,  3.24it/s]


               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.668162      0.668696           0.670559   
1      new_thyroid1  0.999173      0.999248           0.999202   
2  shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3              bupa  0.764281      0.764262           0.762579   
4  cleveland-0_vs_4  0.971887      0.971914           0.973872   
5            ecoli1  0.954372      0.954315           0.954413   
6      poker-9_vs_7  0.985171      0.985666           0.986835   
7            monk-2  1.000000      1.000000           0.999999   
8         hepatitis  0.876796      0.875910           0.876334   

   auc_flipping_coord    p_full   p_coord  
0            0.670639  0.000145  0.000040  
1            0.999148  0.811868  0.903625  
2            1.000000  0.517619  0.464497  
3            0.763146  0.999517  0.978406  
4            0.973434  0.000008  0.000206  
5            0.954508  0.144733  0.086045  
6            0.986424  0.000008

2000it [12:12,  2.73it/s]


                   name  auc_orig  auc_baseline  auc_flipping_full  \
0              haberman  0.668162      0.668696           0.670559   
1          new_thyroid1  0.999173      0.999248           0.999202   
2      shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                  bupa  0.764281      0.764262           0.762579   
4      cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                ecoli1  0.954372      0.954315           0.954413   
6          poker-9_vs_7  0.985171      0.985666           0.986835   
7                monk-2  1.000000      1.000000           0.999999   
8             hepatitis  0.876796      0.875910           0.876334   
9  yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   

   auc_flipping_coord    p_full   p_coord  
0            0.670639  0.000145  0.000040  
1            0.999148  0.811868  0.903625  
2            1.000000  0.517619  0.464497  
3            0.763146  0.999517  0.978406  
4        

2000it [13:52,  2.40it/s]


                    name  auc_orig  auc_baseline  auc_flipping_full  \
0               haberman  0.668162      0.668696           0.670559   
1           new_thyroid1  0.999173      0.999248           0.999202   
2       shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                   bupa  0.764281      0.764262           0.762579   
4       cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                 ecoli1  0.954372      0.954315           0.954413   
6           poker-9_vs_7  0.985171      0.985666           0.986835   
7                 monk-2  1.000000      1.000000           0.999999   
8              hepatitis  0.876796      0.875910           0.876334   
9   yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   
10          mammographic  0.867540      0.867568           0.868364   

    auc_flipping_coord        p_full       p_coord  
0             0.670639  1.451637e-04  4.038552e-05  
1             0.999148  8.118676e-01  9.0

2000it [15:40,  2.13it/s]


                    name  auc_orig  auc_baseline  auc_flipping_full  \
0               haberman  0.668162      0.668696           0.670559   
1           new_thyroid1  0.999173      0.999248           0.999202   
2       shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                   bupa  0.764281      0.764262           0.762579   
4       cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                 ecoli1  0.954372      0.954315           0.954413   
6           poker-9_vs_7  0.985171      0.985666           0.986835   
7                 monk-2  1.000000      1.000000           0.999999   
8              hepatitis  0.876796      0.875910           0.876334   
9   yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   
10          mammographic  0.867540      0.867568           0.868364   
11               saheart  0.722119      0.722297           0.721944   

    auc_flipping_coord        p_full       p_coord  
0             0.670639 

2000it [11:06,  3.00it/s]


                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.668162      0.668696           0.670559   
1                   new_thyroid1  0.999173      0.999248           0.999202   
2               shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                           bupa  0.764281      0.764262           0.762579   
4               cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                         ecoli1  0.954372      0.954315           0.954413   
6                   poker-9_vs_7  0.985171      0.985666           0.986835   
7                         monk-2  1.000000      1.000000           0.999999   
8                      hepatitis  0.876796      0.875910           0.876334   
9           yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   
10                  mammographic  0.867540      0.867568           0.868364   
11                       saheart  0.722119      0.72

2000it [17:04,  1.95it/s]


                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.668162      0.668696           0.670559   
1                   new_thyroid1  0.999173      0.999248           0.999202   
2               shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                           bupa  0.764281      0.764262           0.762579   
4               cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                         ecoli1  0.954372      0.954315           0.954413   
6                   poker-9_vs_7  0.985171      0.985666           0.986835   
7                         monk-2  1.000000      1.000000           0.999999   
8                      hepatitis  0.876796      0.875910           0.876334   
9           yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   
10                  mammographic  0.867540      0.867568           0.868364   
11                       saheart  0.722119      0.72

2000it [11:56,  2.79it/s]


                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.668162      0.668696           0.670559   
1                   new_thyroid1  0.999173      0.999248           0.999202   
2               shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                           bupa  0.764281      0.764262           0.762579   
4               cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                         ecoli1  0.954372      0.954315           0.954413   
6                   poker-9_vs_7  0.985171      0.985666           0.986835   
7                         monk-2  1.000000      1.000000           0.999999   
8                      hepatitis  0.876796      0.875910           0.876334   
9           yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   
10                  mammographic  0.867540      0.867568           0.868364   
11                       saheart  0.722119      0.72

2000it [15:55,  2.09it/s]


                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.668162      0.668696           0.670559   
1                   new_thyroid1  0.999173      0.999248           0.999202   
2               shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                           bupa  0.764281      0.764262           0.762579   
4               cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                         ecoli1  0.954372      0.954315           0.954413   
6                   poker-9_vs_7  0.985171      0.985666           0.986835   
7                         monk-2  1.000000      1.000000           0.999999   
8                      hepatitis  0.876796      0.875910           0.876334   
9           yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   
10                  mammographic  0.867540      0.867568           0.868364   
11                       saheart  0.722119      0.72

2000it [13:28,  2.47it/s]


                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.668162      0.668696           0.670559   
1                   new_thyroid1  0.999173      0.999248           0.999202   
2               shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                           bupa  0.764281      0.764262           0.762579   
4               cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                         ecoli1  0.954372      0.954315           0.954413   
6                   poker-9_vs_7  0.985171      0.985666           0.986835   
7                         monk-2  1.000000      1.000000           0.999999   
8                      hepatitis  0.876796      0.875910           0.876334   
9           yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   
10                  mammographic  0.867540      0.867568           0.868364   
11                       saheart  0.722119      0.72

2000it [14:51,  2.24it/s]


                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.668162      0.668696           0.670559   
1                   new_thyroid1  0.999173      0.999248           0.999202   
2               shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                           bupa  0.764281      0.764262           0.762579   
4               cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                         ecoli1  0.954372      0.954315           0.954413   
6                   poker-9_vs_7  0.985171      0.985666           0.986835   
7                         monk-2  1.000000      1.000000           0.999999   
8                      hepatitis  0.876796      0.875910           0.876334   
9           yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   
10                  mammographic  0.867540      0.867568           0.868364   
11                       saheart  0.722119      0.72

2000it [17:04,  1.95it/s]


                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.668162      0.668696           0.670559   
1                   new_thyroid1  0.999173      0.999248           0.999202   
2               shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                           bupa  0.764281      0.764262           0.762579   
4               cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                         ecoli1  0.954372      0.954315           0.954413   
6                   poker-9_vs_7  0.985171      0.985666           0.986835   
7                         monk-2  1.000000      1.000000           0.999999   
8                      hepatitis  0.876796      0.875910           0.876334   
9           yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   
10                  mammographic  0.867540      0.867568           0.868364   
11                       saheart  0.722119      0.72

2000it [13:59,  2.38it/s]

                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.668162      0.668696           0.670559   
1                   new_thyroid1  0.999173      0.999248           0.999202   
2               shuttle-6_vs_2-3  1.000000      1.000000           1.000000   
3                           bupa  0.764281      0.764262           0.762579   
4               cleveland-0_vs_4  0.971887      0.971914           0.973872   
5                         ecoli1  0.954372      0.954315           0.954413   
6                   poker-9_vs_7  0.985171      0.985666           0.986835   
7                         monk-2  1.000000      1.000000           0.999999   
8                      hepatitis  0.876796      0.875910           0.876334   
9           yeast-0-3-5-9_vs_7-8  0.794598      0.795476           0.794322   
10                  mammographic  0.867540      0.867568           0.868364   
11                       saheart  0.722119      0.72




In [4]:
results_pdf.to_csv("classification.csv")