In [1]:
# !pip install -q numpy==1.19.5

In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)

In [28]:
import os 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.linear_model import LogisticRegressionCV, SGDClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GroupShuffleSplit
from sklearn import metrics

In [29]:
# setting to see all columns 
pd.set_option('display.max_columns', None)

# Load data and set input/output paths

In [30]:
from os.path import exists
# config: depends whether you're on Google Colab or local


# Get URL from github csv by clicking on Download > Copy Link Address

load_from_google_drive = False

if load_from_google_drive:
      # On google colab
      # Mount GDrive and attach it to the colab for data I/O
    from google.colab import drive
    drive.mount('/content/drive')
    input_dir = '/content/drive/My Drive/datum/vfp/data/input/'
    output_dir = '/content/drive/My Drive/datum/vfp/data/output/'
    os.makedirs(output_dir, exist_ok=True)

else:
  # If using jupyter-lab or jupyter notebook, load locally:
  input_dir = './data/input/'
  output_dir = './data/output/'
  output_dir = './data/output/bias_mitigation/'
  os.makedirs(output_dir, exist_ok = True)

import datetime
ts = datetime.datetime.utcnow().strftime('%y-%m-%dT%H-%M-%S')


In [31]:
df = pd.read_csv(input_dir + 'features/egemaps_vector_speech_duration.csv', index_col = 0)
df

Unnamed: 0,F0semitoneFrom27.5Hz_sma3nz_amean,F0semitoneFrom27.5Hz_sma3nz_stddevNorm,F0semitoneFrom27.5Hz_sma3nz_percentile20.0,F0semitoneFrom27.5Hz_sma3nz_percentile50.0,F0semitoneFrom27.5Hz_sma3nz_percentile80.0,F0semitoneFrom27.5Hz_sma3nz_pctlrange0-2,F0semitoneFrom27.5Hz_sma3nz_meanRisingSlope,F0semitoneFrom27.5Hz_sma3nz_stddevRisingSlope,F0semitoneFrom27.5Hz_sma3nz_meanFallingSlope,F0semitoneFrom27.5Hz_sma3nz_stddevFallingSlope,loudness_sma3_amean,loudness_sma3_stddevNorm,loudness_sma3_percentile20.0,loudness_sma3_percentile50.0,loudness_sma3_percentile80.0,loudness_sma3_pctlrange0-2,loudness_sma3_meanRisingSlope,loudness_sma3_stddevRisingSlope,loudness_sma3_meanFallingSlope,loudness_sma3_stddevFallingSlope,spectralFlux_sma3_amean,spectralFlux_sma3_stddevNorm,mfcc1_sma3_amean,mfcc1_sma3_stddevNorm,mfcc2_sma3_amean,mfcc2_sma3_stddevNorm,mfcc3_sma3_amean,mfcc3_sma3_stddevNorm,mfcc4_sma3_amean,mfcc4_sma3_stddevNorm,jitterLocal_sma3nz_amean,jitterLocal_sma3nz_stddevNorm,shimmerLocaldB_sma3nz_amean,shimmerLocaldB_sma3nz_stddevNorm,HNRdBACF_sma3nz_amean,HNRdBACF_sma3nz_stddevNorm,logRelF0-H1-H2_sma3nz_amean,logRelF0-H1-H2_sma3nz_stddevNorm,logRelF0-H1-A3_sma3nz_amean,logRelF0-H1-A3_sma3nz_stddevNorm,F1frequency_sma3nz_amean,F1frequency_sma3nz_stddevNorm,F1bandwidth_sma3nz_amean,F1bandwidth_sma3nz_stddevNorm,F1amplitudeLogRelF0_sma3nz_amean,F1amplitudeLogRelF0_sma3nz_stddevNorm,F2frequency_sma3nz_amean,F2frequency_sma3nz_stddevNorm,F2bandwidth_sma3nz_amean,F2bandwidth_sma3nz_stddevNorm,F2amplitudeLogRelF0_sma3nz_amean,F2amplitudeLogRelF0_sma3nz_stddevNorm,F3frequency_sma3nz_amean,F3frequency_sma3nz_stddevNorm,F3bandwidth_sma3nz_amean,F3bandwidth_sma3nz_stddevNorm,F3amplitudeLogRelF0_sma3nz_amean,F3amplitudeLogRelF0_sma3nz_stddevNorm,alphaRatioV_sma3nz_amean,alphaRatioV_sma3nz_stddevNorm,hammarbergIndexV_sma3nz_amean,hammarbergIndexV_sma3nz_stddevNorm,slopeV0-500_sma3nz_amean,slopeV0-500_sma3nz_stddevNorm,slopeV500-1500_sma3nz_amean,slopeV500-1500_sma3nz_stddevNorm,spectralFluxV_sma3nz_amean,spectralFluxV_sma3nz_stddevNorm,mfcc1V_sma3nz_amean,mfcc1V_sma3nz_stddevNorm,mfcc2V_sma3nz_amean,mfcc2V_sma3nz_stddevNorm,mfcc3V_sma3nz_amean,mfcc3V_sma3nz_stddevNorm,mfcc4V_sma3nz_amean,mfcc4V_sma3nz_stddevNorm,alphaRatioUV_sma3nz_amean,hammarbergIndexUV_sma3nz_amean,slopeUV0-500_sma3nz_amean,slopeUV500-1500_sma3nz_amean,spectralFluxUV_sma3nz_amean,loudnessPeaksPerSec,VoicedSegmentsPerSec,MeanVoicedSegmentLengthSec,StddevVoicedSegmentLengthSec,MeanUnvoicedSegmentLength,StddevUnvoicedSegmentLength,equivalentSoundLevel_dBp,sid,token,target,filename,duration,task
0,38.72469,0.074773,37.18876,38.47970,40.44019,3.251423,37.94022,15.84932,206.45270,360.75950,0.409657,0.516380,0.211142,0.385909,0.588256,0.377114,5.487913,2.772103,5.275577,2.848066,0.183393,0.628117,18.73326,0.802712,10.520690,1.304518,7.572035,1.962388,-4.375561,-3.278681,0.023441,1.147700,0.908820,0.858329,10.061590,0.348174,19.314520,0.650321,21.89401,0.478417,530.5170,0.394628,1226.236,0.232456,-79.00346,-1.113559,1590.281,0.167123,784.5121,0.407070,-65.53321,-1.063740,2636.545,0.095774,703.0072,0.460240,-67.30154,-1.014580,-16.121190,-0.486312,24.37623,0.353210,0.041570,0.867909,-0.014989,-0.643631,0.202319,0.572868,24.18615,0.417486,8.214763,1.525709,7.750072,2.055495,-7.210148,-2.001652,-6.364767,13.05844,-0.020086,-0.002822,0.131253,6.017192,2.325582,0.323750,0.279148,0.086250,0.056111,-36.04536,VFP10,Speech1,1,VFP10_Speech1,3.503,Speech
1,41.11026,0.121323,37.82011,40.29837,42.86220,5.042088,65.27183,67.44999,107.02950,136.12460,0.352636,0.772619,0.091338,0.276297,0.643402,0.552064,5.773097,3.043491,5.736364,3.096833,0.165227,1.078465,16.27722,0.886090,4.175593,2.418372,4.164841,4.175746,-1.238956,-10.893710,0.033043,0.955943,1.110513,1.005039,8.546047,0.610735,10.095680,1.019345,16.71051,0.952560,609.8846,0.475136,1317.956,0.256875,-141.02610,-0.608951,1612.576,0.212125,900.7050,0.395662,-111.65260,-0.791487,2706.250,0.155800,773.9065,0.492648,-113.50150,-0.762943,-9.636793,-1.066963,20.83702,0.549675,0.031475,1.370808,-0.011556,-0.940793,0.281724,0.656881,21.47814,0.785475,2.729456,4.197899,-2.206290,-9.858141,-11.921610,-0.857818,-10.492630,18.57218,-0.030845,-0.001601,0.057944,3.724928,1.749271,0.271667,0.260411,0.276667,0.294430,-37.11591,VFP10,Speech2,1,VFP10_Speech2,3.503,Speech
2,40.64695,0.103110,38.43892,40.61772,43.04973,4.610809,70.18970,41.75834,143.79800,158.40300,0.455030,0.595214,0.193809,0.414761,0.771635,0.577826,5.880630,3.201431,4.344426,2.605690,0.192152,0.685843,16.00158,0.908747,1.617499,10.016700,2.134295,6.868225,0.189451,55.625250,0.029924,1.759617,1.051071,0.979267,9.214264,0.533516,9.660620,1.029315,16.85061,0.659495,631.3687,0.335223,1214.071,0.301923,-101.13850,-0.902200,1638.225,0.144500,934.1242,0.369319,-82.60666,-0.987637,2679.333,0.106968,783.9202,0.406726,-86.69166,-0.905688,-10.902040,-0.925076,20.16452,0.549932,0.026402,1.569013,-0.007047,-1.733416,0.232238,0.547707,22.32532,0.455623,-3.063484,-4.477240,-0.415351,-36.927860,-2.411690,-4.179638,-8.437313,16.16711,-0.026844,-0.003204,0.118944,5.157593,3.498543,0.175000,0.168201,0.118889,0.079365,-36.20234,VFP10,Speech3,1,VFP10_Speech3,3.503,Speech
3,30.43643,0.271136,21.39732,35.07611,37.99580,16.598470,520.01150,754.83490,113.23290,72.07719,0.361249,0.447226,0.200881,0.350636,0.494258,0.293377,3.505941,1.580694,2.991900,1.599806,0.152279,0.561194,17.09886,0.571667,9.499130,0.938797,15.013220,1.094950,-1.632905,-6.753133,0.061652,1.476092,1.323190,0.675948,4.797133,1.021562,7.106359,2.069486,17.30584,0.715088,495.4891,0.446891,1164.115,0.214102,-108.62230,-0.835690,1611.846,0.172599,685.3447,0.616692,-96.16886,-0.857065,2652.961,0.097656,725.8063,0.541421,-96.83107,-0.847054,-14.949060,-0.359614,23.91982,0.319472,-0.003510,-9.079530,-0.015711,-0.624240,0.175582,0.453718,21.25350,0.390166,12.484160,0.547064,13.903480,1.263885,-0.754947,-14.337070,-9.696174,18.68855,-0.035889,-0.006141,0.116599,3.724928,3.197675,0.174545,0.154176,0.116364,0.093738,-40.33198,VFP11,Speech1,1,VFP11_Speech1,3.503,Speech
4,32.12006,0.241961,22.38828,35.72247,37.69841,15.310130,1028.64100,954.45980,159.64700,112.48600,0.383630,0.413864,0.241884,0.355205,0.527708,0.285824,4.263408,3.172111,3.038845,1.653884,0.151369,0.475207,12.93016,0.957830,9.103159,0.968121,20.513380,0.633031,-2.292660,-4.604483,0.054972,1.229439,1.649192,0.817109,5.639021,0.829598,13.270850,1.035262,17.27562,0.684350,444.5576,0.284360,1154.101,0.208867,-115.89510,-0.761690,1595.626,0.137739,596.8990,0.641410,-108.24170,-0.744009,2653.475,0.071088,772.4484,0.535277,-107.27790,-0.759111,-14.976390,-0.407176,22.52378,0.311171,-0.010057,-2.957649,-0.014513,-0.663797,0.176212,0.397830,19.65244,0.415014,12.042640,0.726816,21.610900,0.671831,-2.279478,-4.844434,-7.212510,14.82589,-0.046613,-0.006085,0.118889,5.730659,4.081633,0.121429,0.107628,0.120833,0.178814,-39.28465,VFP11,Speech2,1,VFP11_Speech2,3.503,Speech
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
448,32.82069,0.073434,30.89363,32.34020,34.59331,3.699677,182.45580,326.94090,58.27354,55.20131,0.163460,0.522306,0.075665,0.156032,0.241769,0.166104,1.885137,1.111351,1.351178,0.590421,0.046189,0.667091,27.36822,0.450612,15.831880,0.795290,18.915380,0.760017,6.812958,1.755961,0.037639,1.190368,0.964755,0.991380,9.131144,0.310881,8.164756,1.006656,28.56432,0.322647,496.2999,0.316247,1274.801,0.171952,-66.06810,-1.246946,1621.714,0.136256,937.1945,0.537158,-76.53618,-0.946130,2685.701,0.108204,984.9135,0.431121,-80.27860,-0.874510,-20.851850,-0.339330,31.00552,0.284932,0.026031,1.334945,-0.015354,-0.629295,0.056202,0.534305,31.05387,0.374250,16.661340,0.687249,20.445920,0.756702,5.125762,2.416586,-17.455410,28.20164,-0.019634,-0.001798,0.020782,4.871060,2.332362,0.302500,0.161071,0.106250,0.154186,-46.03365,VFPNorm8,Speech2,0,VFPNorm8_Speech2,3.503,Speech
449,32.18818,0.070383,30.34963,31.72725,33.34668,2.997049,28.74552,11.09664,26.86969,23.55212,0.144799,0.503324,0.069550,0.141103,0.205773,0.136223,1.478385,0.844526,1.186362,0.551909,0.038738,0.706269,23.21600,0.767392,15.108680,0.809788,14.349510,1.121815,8.470038,1.412676,0.032173,1.535001,0.937512,1.005167,9.108112,0.285027,10.054340,0.843113,32.12694,0.272087,543.8298,0.384883,1265.332,0.213014,-89.37212,-0.993368,1560.604,0.176490,931.4860,0.385068,-95.93403,-0.847097,2698.299,0.133518,883.1241,0.425145,-101.72920,-0.751511,-18.480710,-0.416476,29.56695,0.284443,0.014625,1.903894,-0.013567,-0.875202,0.046592,0.616653,31.61001,0.421032,14.857740,0.835053,13.938850,1.313820,5.178216,2.521554,-14.754670,24.56603,-0.028107,-0.001130,0.027894,3.724928,3.206997,0.177273,0.180811,0.127000,0.129232,-48.11572,VFPNorm8,Speech3,0,VFPNorm8_Speech3,3.503,Speech
450,24.48411,0.180009,22.78093,25.40892,28.27801,5.497078,197.19170,130.46350,45.02444,25.83455,0.117911,0.483579,0.063525,0.108178,0.171664,0.108140,1.420177,0.629638,1.054868,0.484441,0.033011,0.519085,26.10903,0.543691,16.517930,0.593427,11.507460,1.250198,5.021239,2.448002,0.079452,2.069847,1.355051,0.954663,5.394252,0.598302,2.956565,1.731019,30.86301,0.207399,448.8878,0.313844,1220.155,0.122200,-85.09564,-1.085957,1461.210,0.187567,939.6982,0.327941,-95.18096,-0.886000,2521.296,0.101952,810.5140,0.457890,-100.74440,-0.788425,-19.690540,-0.306501,30.82067,0.229403,0.011569,2.887912,-0.020267,-0.655963,0.041008,0.408967,34.67593,0.267692,17.423910,0.488991,11.429120,1.455805,-0.466549,-26.340980,-17.850630,28.72795,-0.044269,-0.001048,0.022439,5.444126,2.616279,0.216667,0.127715,0.145556,0.118144,-51.09922,VFPNorm9,Speech1,0,VFPNorm9_Speech1,3.503,Speech
451,25.80380,0.099246,24.82298,26.69952,27.15441,2.331427,169.48410,131.27800,36.31673,24.31405,0.103108,0.577983,0.043224,0.085041,0.167701,0.124477,1.190860,0.662923,0.784700,0.549603,0.028304,0.692012,23.85000,0.536352,11.695800,0.827275,11.894940,0.918881,7.271324,1.505499,0.061378,2.061115,1.270060,1.010738,5.545829,0.638101,4.890247,1.119272,31.56176,0.243222,493.9183,0.290058,1274.643,0.178984,-114.18080,-0.816936,1501.698,0.181778,1011.0570,0.350117,-121.17270,-0.705554,2551.819,0.116107,915.6561,0.436638,-124.99080,-0.650363,-20.627660,-0.337539,30.83374,0.230944,0.006369,3.985361,-0.021191,-0.621737,0.041655,0.438832,35.60788,0.214996,15.542970,0.686406,12.781720,1.142406,2.036459,6.614008,-15.938160,27.39670,-0.039000,-0.000503,0.017274,4.011461,1.749271,0.248333,0.106053,0.303333,0.242051,-51.94005,VFPNorm9,Speech2,0,VFPNorm9_Speech2,3.503,Speech


### Find distance correlation between intensity features and other features. 



In [33]:
import dcor
correlated_features_d = {}


features = ['duration']

other_features = list(set(df.columns)-set(features)-set(['sid', 'token', 'target', 'filename', 'task']))


for x1 in features:
    for x2 in other_features:
        # if x2 in correlated_features.keys():
        #     continue #skip
        # else:
            dcor_result = float(dcor.distance_correlation(df[x1], df[x2]))
            if x2 not in correlated_features_d.keys():
                correlated_features_d[x2] = dcor_result
            elif x2 in correlated_features_d.keys():
                if dcor_result>correlated_features_d.get(x2):
                    # save the highest correlation between intensity features and nonintensity ones
                    correlated_features_d[x2] = dcor_result
                else:
                    # save the prior correlation
                    continue
            
            
    


In [34]:
correlated_features = pd.DataFrame(correlated_features_d, index = ['dcor']).T.reset_index().sort_values('dcor')[::-1]
correlated_features.columns = ['other_features', 'dcor']
correlated_features


Unnamed: 0,other_features,dcor
57,equivalentSoundLevel_dBp,0.755525
60,spectralFlux_sma3_amean,0.724038
17,loudness_sma3_stddevRisingSlope,0.707983
18,loudness_sma3_amean,0.702836
25,spectralFluxV_sma3nz_amean,0.701224
...,...,...
22,mfcc3V_sma3nz_stddevNorm,0.079990
39,hammarbergIndexV_sma3nz_stddevNorm,0.075945
49,mfcc4_sma3_stddevNorm,0.073853
65,mfcc4V_sma3nz_amean,0.073674


In [35]:
print(correlated_features[correlated_features['dcor']>0.3].shape)
print(correlated_features[correlated_features['dcor']>0.4].shape)


(35, 2)
(26, 2)


In [36]:
print('remaining vars', 88-35)
print('remaining vars', 88-26)

remaining vars 53
remaining vars 62


In [37]:
correlated_with_duration = correlated_features[correlated_features['dcor']>0.3].round(2)
correlated_with_duration = correlated_with_duration.reset_index(drop = True)

correlated_with_duration.to_csv(output_dir+'uncorrelated_dcor-030_duration.csv')
correlated_with_duration

Unnamed: 0,other_features,dcor
0,equivalentSoundLevel_dBp,0.76
1,spectralFlux_sma3_amean,0.72
2,loudness_sma3_stddevRisingSlope,0.71
3,loudness_sma3_amean,0.7
4,spectralFluxV_sma3nz_amean,0.7
5,spectralFluxUV_sma3nz_amean,0.7
6,loudness_sma3_percentile80.0,0.69
7,slopeUV500-1500_sma3nz_amean,0.67
8,loudness_sma3_percentile50.0,0.67
9,loudness_sma3_meanRisingSlope,0.65


In [38]:



correlated_with_intensity = pd.read_csv(output_dir+'uncorrelated_dcor-030_intensity.csv', index_col = 0)

correlated_with_intensity



correlated_with_both = correlated_with_duration.merge(correlated_with_intensity, on='other_features', suffixes=('_duration', '_intensity'), how='outer')

correlated_with_both


Unnamed: 0,other_features,dcor_duration,dcor_intensity
0,equivalentSoundLevel_dBp,0.76,
1,spectralFlux_sma3_amean,0.72,0.9
2,loudness_sma3_stddevRisingSlope,0.71,
3,loudness_sma3_amean,0.7,
4,spectralFluxV_sma3nz_amean,0.7,0.88
5,spectralFluxUV_sma3nz_amean,0.7,0.8
6,loudness_sma3_percentile80.0,0.69,
7,slopeUV500-1500_sma3nz_amean,0.67,0.58
8,loudness_sma3_percentile50.0,0.67,
9,loudness_sma3_meanRisingSlope,0.65,


In [39]:
# insert 1 in dcor with intensity features
correlated_with_both = correlated_with_both.reset_index(drop = True)
correlated_with_both.index = correlated_with_both['other_features'].values
correlated_with_both =  correlated_with_both.drop('other_features', axis=1)
intensity_features = ['loudness_sma3_amean',
       'loudness_sma3_stddevNorm', 'loudness_sma3_percentile20.0',
       'loudness_sma3_percentile50.0', 'loudness_sma3_percentile80.0',
       'loudness_sma3_pctlrange0-2', 'loudness_sma3_meanRisingSlope',
       'loudness_sma3_stddevRisingSlope', 'loudness_sma3_meanFallingSlope',
       'loudness_sma3_stddevFallingSlope','loudnessPeaksPerSec','equivalentSoundLevel_dBp',
                      'HNRdBACF_sma3nz_amean','HNRdBACF_sma3nz_stddevNorm',]

for feature in correlated_with_both.index:
    if feature in intensity_features:
        correlated_with_both.loc[feature, 'dcor_intensity'] = 1

correlated_with_both

Unnamed: 0,dcor_duration,dcor_intensity
equivalentSoundLevel_dBp,0.76,1.0
spectralFlux_sma3_amean,0.72,0.9
loudness_sma3_stddevRisingSlope,0.71,1.0
loudness_sma3_amean,0.7,1.0
spectralFluxV_sma3nz_amean,0.7,0.88
spectralFluxUV_sma3nz_amean,0.7,0.8
loudness_sma3_percentile80.0,0.69,1.0
slopeUV500-1500_sma3nz_amean,0.67,0.58
loudness_sma3_percentile50.0,0.67,1.0
loudness_sma3_meanRisingSlope,0.65,1.0


In [40]:
correlated_with_both.to_csv(output_dir+'uncorrelated_dcor-030_both.csv')

In [41]:
all_features = correlated_features['other_features']

In [42]:
threshold = 0.3
filename = 'uncorrelated_dcor-030_both'


correlated_both_030 = correlated_with_both[(correlated_with_both['dcor_duration']>=threshold) | (correlated_with_both['dcor_intensity']>=threshold)].index

# all_features
# uncorrelated_features 
len(correlated_both_030)

58

In [43]:
correlated_with_both.shape

(58, 2)

In [44]:
88-45

43

In [45]:
threshold = 0.4
filename = 'uncorrelated_dcor-040_both'


correlated_both_040 = correlated_with_both[(correlated_with_both['dcor_duration']>=threshold) | (correlated_with_both['dcor_intensity']>=threshold)].index

# all_features
# uncorrelated_features 
len(correlated_both_040)


45

In [46]:
output_dir

'./data/output/bias_mitigation/'

In [24]:
egemaps_features = correlated_features['other_features'].values
uncorrelated_above_03 = set(egemaps_features) - set(correlated_both_030)
len(uncorrelated_above_03)

30

In [48]:
%%time


models = [
    LogisticRegressionCV(solver='liblinear', penalty = 'l1', max_iter = 100),
    MLPClassifier(alpha = 1, max_iter= 1000),
    RandomForestClassifier(n_estimators= 100),
    SGDClassifier(loss='log', penalty="elasticnet", early_stopping=True, max_iter = 5000),
]


names = ['LogisticRegressionCV', "MLPClassifier","RandomForestClassifier",'SGDClassifier']


from sklearn.model_selection import train_test_split




toy = False


for threshold in [0.3, 0.4]:

  for null_model in [True, False]:
      print('\npermute', null_model)
      # Create DFs for each independent variable

      if threshold == 0.3:
        variables = correlated_both_030
        filename = 'uncorrelated_below_dcor-030_both'
      elif threshold == 0.4:
        variables = correlated_both_040
        filename = 'uncorrelated_below_dcor-040_both'

          

      X = df[variables].values
      y = df['target'].values


      if toy:
        n_bootstraps = 3
      else:
        n_bootstraps = 50

      if null_model:
          y = np.random.permutation(y) #CHECK


      y_pred_all = {}
      roc_auc_all = {}
      for model, name in zip(models, names):
        y_pred_all[name] = []
        roc_auc_all[name] = []
        pipe = Pipeline(steps=[
                ('scaler', StandardScaler()), 
                ('model', model)])

        ## Performing bootstrapping
        for i in range(n_bootstraps):
            #Split the data into training and testing set

            # Chaning the seed value for each iteration
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42+i)

        # for train_index, test_index in bs:
        #   X_train, X_test, y_train, y_test = X[train_index], X[test_index], y[train_index], y[test_index]
            pipe.fit(X_train,y_train)

            # # Evaluate
            # y_proba = pipe.predict_proba(X_test)       # Get predicted probabilities
            # y_proba_1 = y_proba[:,1]
            # y_pred = np.argmax(y_proba, axis=1) 
            # roc_auc = metrics.roc_auc_score(y_test, y_proba_1)  
                      
            y_pred = pipe.predict(X_test) 
            roc_auc = metrics.roc_auc_score(y_test, y_pred)  # ROC AUC takes probabilities but here we match what pydra-ml does: https://github.com/nipype/pydra-ml/issues/56

            y_pred_all[name].append(y_pred)
            roc_auc_all[name].append(roc_auc)

      results_i = []
      for name in ['LogisticRegressionCV','MLPClassifier','RandomForestClassifier','SGDClassifier']:
        scores = roc_auc_all.get(name)
        roc_auc_median = np.round(np.median(scores),2)
        roc_auc_5 = np.round(np.percentile(scores, 5),2)
        roc_auc_95 = np.round(np.percentile(scores, 95),2)
        results_str = f'{roc_auc_median} ({roc_auc_5}–{roc_auc_95}; )'
        results_str = results_str.replace('0.', '.')
        results_i.append([name, results_str])

        if null_model:
          print(name, str(roc_auc_median).replace('0.', '.'))
      if not null_model:
          results_i_df = pd.DataFrame(results_i, ).T
          display(results_i_df)
          results_i_df.to_csv(output_dir+f'results_{filename}_permute-{null_model}_duration_{ts}.csv')




      # pd.DataFrame(y_pred_all)



permute True




LogisticRegressionCV .5
MLPClassifier .51
RandomForestClassifier .52
SGDClassifier .5

permute False




Unnamed: 0,0,1,2,3
0,LogisticRegressionCV,MLPClassifier,RandomForestClassifier,SGDClassifier
1,.88 (.83–.91; ),.89 (.86–.93; ),.89 (.84–.92; ),.86 (.78–.9; )



permute True




LogisticRegressionCV .5
MLPClassifier .52
RandomForestClassifier .54
SGDClassifier .51

permute False


Unnamed: 0,0,1,2,3
0,LogisticRegressionCV,MLPClassifier,RandomForestClassifier,SGDClassifier
1,.85 (.79–.9; ),.9 (.86–.93; ),.88 (.82–.93; ),.8 (.7–.88; )


CPU times: user 25min 21s, sys: 7min 52s, total: 33min 14s
Wall time: 23min
