In [1]:
import numpy as np
import pandas as pd

import tqdm

from scipy.stats import wilcoxon

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.metrics import roc_auc_score

from flipping_random_forest import FlippingDecisionTreeClassifier

from datasets import binclas_datasets

In [2]:
data_loaders = binclas_datasets['data_loader_function'].values.tolist()

In [3]:
results = []

validator = RepeatedStratifiedKFold(n_splits=5, n_repeats=2000, random_state=5)

for data_loader in data_loaders:
    dataset = data_loader()
    X = dataset['data']
    y = dataset['target']
    
    aucs_orig = []
    aucs_flipped = []
    aucs_baseline = []
    aucs_baseline_flipped = []
    aucs_flipping_full = []
    aucs_flipping_coord = []
    
    for train, test in tqdm.tqdm(validator.split(X, y, y)):
        X_train = X[train]
        X_test = X[test]
        y_train = y[train]
        y_test = y[test]
        
        min_samples_leaf = 1#np.random.randint(1, 21)
        max_depth = None#np.random.choice([None, 1, 2, 3, 4, 5, 6, 7, 8, 9])
        
        params = {'min_samples_leaf': min_samples_leaf,
                    'max_depth': max_depth}
        
        
        pred = DecisionTreeClassifier(**params).fit(X_train, y_train).predict_proba(X_test)[:, 1]
        aucs_orig.append(roc_auc_score(y_test, pred))
        
        #pred = RandomForestClassifier(n_jobs=4).fit(-X_train, y_train).predict_proba(-X_test)[:, 1]
        #aucs_flipped.append(roc_auc_score(y_test, pred))
        aucs_flipped.append(0)
        
        pred = FlippingDecisionTreeClassifier(**params).fit(X_train, y_train).predict_proba(X_test)[:, 1]
        aucs_baseline.append(roc_auc_score(y_test, pred))
        
        #pred = FlippingRandomForestClassifier(n_jobs=4).fit(-X_train, y_train).predict_proba(-X_test)[:, 1]
        #aucs_baseline_flipped.append(roc_auc_score(y_test, pred))
        aucs_baseline_flipped.append(0)
        
        #pred = FlippingDecisionTreeClassifier(**params).fit(X_train, y_train).predict_proba(X_test)[:, 1]
        #aucs_flipping_full.append(roc_auc_score(y_test, pred))
        aucs_flipping_full.append(0)
        
        aucs_flipping_coord.append(0)
        
        #pred = FlippingDecisionTreeClassifier(flipping='coordinate', **params).fit(X_train, y_train).predict_proba(X_test)[:, 1]
        #aucs_flipping_coord.append(roc_auc_score(y_test, pred))
        
    
    print('aaa', len(aucs_baseline), len(aucs_flipping_full), len(aucs_flipping_coord))
    tmp = [dataset['name'], np.mean(aucs_orig), np.mean(aucs_flipped), np.mean(aucs_baseline), 
                            np.mean(aucs_baseline_flipped), np.mean(aucs_flipping_full), np.mean(aucs_flipping_coord), 
                            aucs_orig, aucs_flipped, aucs_baseline, aucs_baseline_flipped, aucs_flipping_full, aucs_flipping_coord]
    
    tmp = tmp + [wilcoxon(aucs_orig, aucs_baseline, alternative='less', zero_method='zsplit').pvalue,
                #wilcoxon(aucs_baseline, aucs_flipping_coord, alternative='less', zero_method='zsplit').pvalue
                ]
    
    results.append(tmp)
    
    results_pdf = pd.DataFrame(results, columns=['name', 'auc_orig', 'auc_flipped', 'auc_baseline', 'auc_baseline_flipped', 
                                'auc_flipping_full', 'auc_flipping_coord', 'aucs_orig', 'aucs_flipped', 'aucs_baseline', 
                                'aucs_baseline_flipped', 'aucs_flipping_full', 'aucs_flipping_coord', 'p_full'])
    results_pdf['auc_baseline_min'] = results_pdf[['auc_baseline', 'auc_baseline_flipped']].apply(lambda x: min(x), axis=1)
    results_pdf['auc_baseline_max'] = results_pdf[['auc_baseline', 'auc_baseline_flipped']].apply(lambda x: max(x), axis=1)
    print(results_pdf[['name', 
                        'auc_orig', 
                        #'auc_flipped', 
                        'auc_baseline', 
                        #'auc_baseline_flipped', 
                        #'auc_baseline_min', 
                        #'auc_baseline_max', 
                        'auc_flipping_full', 
                        'auc_flipping_coord', 
                        'p_full']])
        


10000it [01:26, 115.00it/s]


aaa 10000 10000 10000
       name  auc_orig  auc_baseline  auc_flipping_full  auc_flipping_coord  \
0  haberman  0.557615      0.561495                0.0                 0.0   

         p_full  
0  1.687998e-26  


10000it [00:42, 233.05it/s]


aaa 10000 10000 10000
           name  auc_orig  auc_baseline  auc_flipping_full  \
0      haberman  0.557615      0.561495                0.0   
1  new_thyroid1  0.930327      0.935953                0.0   

   auc_flipping_coord        p_full  
0                 0.0  1.687998e-26  
1                 0.0  5.336328e-34  


10000it [00:44, 225.50it/s]


aaa 10000 10000 10000
               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.557615      0.561495                0.0   
1      new_thyroid1  0.930327      0.935953                0.0   
2  shuttle-6_vs_2-3  1.000000      1.000000                0.0   

   auc_flipping_coord        p_full  
0                 0.0  1.687998e-26  
1                 0.0  5.336328e-34  
2                 0.0  5.000000e-01  


10000it [00:59, 166.86it/s]


aaa 10000 10000 10000
               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.557615      0.561495                0.0   
1      new_thyroid1  0.930327      0.935953                0.0   
2  shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3              bupa  0.624876      0.627844                0.0   

   auc_flipping_coord        p_full  
0                 0.0  1.687998e-26  
1                 0.0  5.336328e-34  
2                 0.0  5.000000e-01  
3                 0.0  5.535098e-17  


10000it [00:41, 241.68it/s]


aaa 10000 10000 10000
               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.557615      0.561495                0.0   
1      new_thyroid1  0.930327      0.935953                0.0   
2  shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3              bupa  0.624876      0.627844                0.0   
4  cleveland-0_vs_4  0.670715      0.674463                0.0   

   auc_flipping_coord        p_full  
0                 0.0  1.687998e-26  
1                 0.0  5.336328e-34  
2                 0.0  5.000000e-01  
3                 0.0  5.535098e-17  
4                 0.0  2.213045e-02  


10000it [00:47, 211.45it/s]


aaa 10000 10000 10000
               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.557615      0.561495                0.0   
1      new_thyroid1  0.930327      0.935953                0.0   
2  shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3              bupa  0.624876      0.627844                0.0   
4  cleveland-0_vs_4  0.670715      0.674463                0.0   
5            ecoli1  0.831813      0.833163                0.0   

   auc_flipping_coord        p_full  
0                 0.0  1.687998e-26  
1                 0.0  5.336328e-34  
2                 0.0  5.000000e-01  
3                 0.0  5.535098e-17  
4                 0.0  2.213045e-02  
5                 0.0  6.299310e-03  


10000it [00:39, 254.38it/s]


aaa 10000 10000 10000
               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.557615      0.561495                0.0   
1      new_thyroid1  0.930327      0.935953                0.0   
2  shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3              bupa  0.624876      0.627844                0.0   
4  cleveland-0_vs_4  0.670715      0.674463                0.0   
5            ecoli1  0.831813      0.833163                0.0   
6      poker-9_vs_7  0.607766      0.607903                0.0   

   auc_flipping_coord        p_full  
0                 0.0  1.687998e-26  
1                 0.0  5.336328e-34  
2                 0.0  5.000000e-01  
3                 0.0  5.535098e-17  
4                 0.0  2.213045e-02  
5                 0.0  6.299310e-03  
6                 0.0  6.830932e-08  


10000it [00:40, 249.80it/s]


aaa 10000 10000 10000
               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.557615      0.561495                0.0   
1      new_thyroid1  0.930327      0.935953                0.0   
2  shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3              bupa  0.624876      0.627844                0.0   
4  cleveland-0_vs_4  0.670715      0.674463                0.0   
5            ecoli1  0.831813      0.833163                0.0   
6      poker-9_vs_7  0.607766      0.607903                0.0   
7            monk-2  1.000000      1.000000                0.0   

   auc_flipping_coord        p_full  
0                 0.0  1.687998e-26  
1                 0.0  5.336328e-34  
2                 0.0  5.000000e-01  
3                 0.0  5.535098e-17  
4                 0.0  2.213045e-02  
5                 0.0  6.299310e-03  
6                 0.0  6.830932e-08  
7                 0.0  5.000000e-01  


10000it [00:37, 269.70it/s]


aaa 10000 10000 10000
               name  auc_orig  auc_baseline  auc_flipping_full  \
0          haberman  0.557615      0.561495                0.0   
1      new_thyroid1  0.930327      0.935953                0.0   
2  shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3              bupa  0.624876      0.627844                0.0   
4  cleveland-0_vs_4  0.670715      0.674463                0.0   
5            ecoli1  0.831813      0.833163                0.0   
6      poker-9_vs_7  0.607766      0.607903                0.0   
7            monk-2  1.000000      1.000000                0.0   
8         hepatitis  0.667475      0.668629                0.0   

   auc_flipping_coord        p_full  
0                 0.0  1.687998e-26  
1                 0.0  5.336328e-34  
2                 0.0  5.000000e-01  
3                 0.0  5.535098e-17  
4                 0.0  2.213045e-02  
5                 0.0  6.299310e-03  
6                 0.0  6.830932e-08  
7            

10000it [00:50, 197.26it/s]


aaa 10000 10000 10000
                   name  auc_orig  auc_baseline  auc_flipping_full  \
0              haberman  0.557615      0.561495                0.0   
1          new_thyroid1  0.930327      0.935953                0.0   
2      shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                  bupa  0.624876      0.627844                0.0   
4      cleveland-0_vs_4  0.670715      0.674463                0.0   
5                ecoli1  0.831813      0.833163                0.0   
6          poker-9_vs_7  0.607766      0.607903                0.0   
7                monk-2  1.000000      1.000000                0.0   
8             hepatitis  0.667475      0.668629                0.0   
9  yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   

   auc_flipping_coord        p_full  
0                 0.0  1.687998e-26  
1                 0.0  5.336328e-34  
2                 0.0  5.000000e-01  
3                 0.0  5.535098e-17  
4                

10000it [01:02, 160.21it/s]


aaa 10000 10000 10000
                    name  auc_orig  auc_baseline  auc_flipping_full  \
0               haberman  0.557615      0.561495                0.0   
1           new_thyroid1  0.930327      0.935953                0.0   
2       shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                   bupa  0.624876      0.627844                0.0   
4       cleveland-0_vs_4  0.670715      0.674463                0.0   
5                 ecoli1  0.831813      0.833163                0.0   
6           poker-9_vs_7  0.607766      0.607903                0.0   
7                 monk-2  1.000000      1.000000                0.0   
8              hepatitis  0.667475      0.668629                0.0   
9   yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   
10          mammographic  0.787445      0.791948                0.0   

    auc_flipping_coord        p_full  
0                  0.0  1.687998e-26  
1                  0.0  5.336328e-34  
2       

10000it [01:05, 153.43it/s]


aaa 10000 10000 10000
                    name  auc_orig  auc_baseline  auc_flipping_full  \
0               haberman  0.557615      0.561495                0.0   
1           new_thyroid1  0.930327      0.935953                0.0   
2       shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                   bupa  0.624876      0.627844                0.0   
4       cleveland-0_vs_4  0.670715      0.674463                0.0   
5                 ecoli1  0.831813      0.833163                0.0   
6           poker-9_vs_7  0.607766      0.607903                0.0   
7                 monk-2  1.000000      1.000000                0.0   
8              hepatitis  0.667475      0.668629                0.0   
9   yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   
10          mammographic  0.787445      0.791948                0.0   
11               saheart  0.584617      0.585352                0.0   

    auc_flipping_coord        p_full  
0              

10000it [00:31, 318.58it/s]


aaa 10000 10000 10000
                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.557615      0.561495                0.0   
1                   new_thyroid1  0.930327      0.935953                0.0   
2               shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                           bupa  0.624876      0.627844                0.0   
4               cleveland-0_vs_4  0.670715      0.674463                0.0   
5                         ecoli1  0.831813      0.833163                0.0   
6                   poker-9_vs_7  0.607766      0.607903                0.0   
7                         monk-2  1.000000      1.000000                0.0   
8                      hepatitis  0.667475      0.668629                0.0   
9           yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   
10                  mammographic  0.787445      0.791948                0.0   
11                       sahea

10000it [01:20, 123.56it/s]


aaa 10000 10000 10000
                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.557615      0.561495                0.0   
1                   new_thyroid1  0.930327      0.935953                0.0   
2               shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                           bupa  0.624876      0.627844                0.0   
4               cleveland-0_vs_4  0.670715      0.674463                0.0   
5                         ecoli1  0.831813      0.833163                0.0   
6                   poker-9_vs_7  0.607766      0.607903                0.0   
7                         monk-2  1.000000      1.000000                0.0   
8                      hepatitis  0.667475      0.668629                0.0   
9           yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   
10                  mammographic  0.787445      0.791948                0.0   
11                       sahea

10000it [00:40, 244.13it/s]


aaa 10000 10000 10000
                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.557615      0.561495                0.0   
1                   new_thyroid1  0.930327      0.935953                0.0   
2               shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                           bupa  0.624876      0.627844                0.0   
4               cleveland-0_vs_4  0.670715      0.674463                0.0   
5                         ecoli1  0.831813      0.833163                0.0   
6                   poker-9_vs_7  0.607766      0.607903                0.0   
7                         monk-2  1.000000      1.000000                0.0   
8                      hepatitis  0.667475      0.668629                0.0   
9           yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   
10                  mammographic  0.787445      0.791948                0.0   
11                       sahea

10000it [01:09, 144.84it/s]


aaa 10000 10000 10000
                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.557615      0.561495                0.0   
1                   new_thyroid1  0.930327      0.935953                0.0   
2               shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                           bupa  0.624876      0.627844                0.0   
4               cleveland-0_vs_4  0.670715      0.674463                0.0   
5                         ecoli1  0.831813      0.833163                0.0   
6                   poker-9_vs_7  0.607766      0.607903                0.0   
7                         monk-2  1.000000      1.000000                0.0   
8                      hepatitis  0.667475      0.668629                0.0   
9           yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   
10                  mammographic  0.787445      0.791948                0.0   
11                       sahea

10000it [01:04, 155.69it/s]


aaa 10000 10000 10000
                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.557615      0.561495                0.0   
1                   new_thyroid1  0.930327      0.935953                0.0   
2               shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                           bupa  0.624876      0.627844                0.0   
4               cleveland-0_vs_4  0.670715      0.674463                0.0   
5                         ecoli1  0.831813      0.833163                0.0   
6                   poker-9_vs_7  0.607766      0.607903                0.0   
7                         monk-2  1.000000      1.000000                0.0   
8                      hepatitis  0.667475      0.668629                0.0   
9           yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   
10                  mammographic  0.787445      0.791948                0.0   
11                       sahea

10000it [01:23, 119.95it/s]


aaa 10000 10000 10000
                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.557615      0.561495                0.0   
1                   new_thyroid1  0.930327      0.935953                0.0   
2               shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                           bupa  0.624876      0.627844                0.0   
4               cleveland-0_vs_4  0.670715      0.674463                0.0   
5                         ecoli1  0.831813      0.833163                0.0   
6                   poker-9_vs_7  0.607766      0.607903                0.0   
7                         monk-2  1.000000      1.000000                0.0   
8                      hepatitis  0.667475      0.668629                0.0   
9           yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   
10                  mammographic  0.787445      0.791948                0.0   
11                       sahea

10000it [01:12, 137.10it/s]


aaa 10000 10000 10000
                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.557615      0.561495                0.0   
1                   new_thyroid1  0.930327      0.935953                0.0   
2               shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                           bupa  0.624876      0.627844                0.0   
4               cleveland-0_vs_4  0.670715      0.674463                0.0   
5                         ecoli1  0.831813      0.833163                0.0   
6                   poker-9_vs_7  0.607766      0.607903                0.0   
7                         monk-2  1.000000      1.000000                0.0   
8                      hepatitis  0.667475      0.668629                0.0   
9           yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   
10                  mammographic  0.787445      0.791948                0.0   
11                       sahea

10000it [01:19, 126.07it/s]

aaa 10000 10000 10000
                            name  auc_orig  auc_baseline  auc_flipping_full  \
0                       haberman  0.557615      0.561495                0.0   
1                   new_thyroid1  0.930327      0.935953                0.0   
2               shuttle-6_vs_2-3  1.000000      1.000000                0.0   
3                           bupa  0.624876      0.627844                0.0   
4               cleveland-0_vs_4  0.670715      0.674463                0.0   
5                         ecoli1  0.831813      0.833163                0.0   
6                   poker-9_vs_7  0.607766      0.607903                0.0   
7                         monk-2  1.000000      1.000000                0.0   
8                      hepatitis  0.667475      0.668629                0.0   
9           yeast-0-3-5-9_vs_7-8  0.640446      0.641646                0.0   
10                  mammographic  0.787445      0.791948                0.0   
11                       sahea




In [4]:
results_pdf.to_csv("classification-tree.csv")