In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)

In [2]:
import os 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score
from sklearn.linear_model import LassoCV, Lasso
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split

In [3]:
# setting to see all columns 
pd.set_option('display.max_columns', None)

# Load data and set input/output paths

In [4]:
from os.path import exists
# config: depends whether you're on Google Colab or local


# Get URL from github csv by clicking on Download > Copy Link Address

load_from_google_drive = False

if load_from_google_drive:
      # On google colab
      # Mount GDrive and attach it to the colab for data I/O
    from google.colab import drive
    drive.mount('/content/drive')
    input_dir = '/content/drive/My Drive/datum/vfp/data/input/'
    output_dir = '/content/drive/My Drive/datum/vfp/data/output/'
    os.makedirs(output_dir, exist_ok=True)

else:
  # If using jupyter-lab or jupyter notebook, load locally:
  input_dir = './data/input/'
  output_dir = './data/output/'



In [5]:
df = pd.read_csv(input_dir + 'features/egemaps_vector_speech_cpp.csv', index_col = 0)
df

Unnamed: 0,F0semitoneFrom27.5Hz_sma3nz_amean,F0semitoneFrom27.5Hz_sma3nz_stddevNorm,F0semitoneFrom27.5Hz_sma3nz_percentile20.0,F0semitoneFrom27.5Hz_sma3nz_percentile50.0,F0semitoneFrom27.5Hz_sma3nz_percentile80.0,F0semitoneFrom27.5Hz_sma3nz_pctlrange0-2,F0semitoneFrom27.5Hz_sma3nz_meanRisingSlope,F0semitoneFrom27.5Hz_sma3nz_stddevRisingSlope,F0semitoneFrom27.5Hz_sma3nz_meanFallingSlope,F0semitoneFrom27.5Hz_sma3nz_stddevFallingSlope,loudness_sma3_amean,loudness_sma3_stddevNorm,loudness_sma3_percentile20.0,loudness_sma3_percentile50.0,loudness_sma3_percentile80.0,loudness_sma3_pctlrange0-2,loudness_sma3_meanRisingSlope,loudness_sma3_stddevRisingSlope,loudness_sma3_meanFallingSlope,loudness_sma3_stddevFallingSlope,spectralFlux_sma3_amean,spectralFlux_sma3_stddevNorm,mfcc1_sma3_amean,mfcc1_sma3_stddevNorm,mfcc2_sma3_amean,mfcc2_sma3_stddevNorm,mfcc3_sma3_amean,mfcc3_sma3_stddevNorm,mfcc4_sma3_amean,mfcc4_sma3_stddevNorm,jitterLocal_sma3nz_amean,jitterLocal_sma3nz_stddevNorm,shimmerLocaldB_sma3nz_amean,shimmerLocaldB_sma3nz_stddevNorm,HNRdBACF_sma3nz_amean,HNRdBACF_sma3nz_stddevNorm,logRelF0-H1-H2_sma3nz_amean,logRelF0-H1-H2_sma3nz_stddevNorm,logRelF0-H1-A3_sma3nz_amean,logRelF0-H1-A3_sma3nz_stddevNorm,F1frequency_sma3nz_amean,F1frequency_sma3nz_stddevNorm,F1bandwidth_sma3nz_amean,F1bandwidth_sma3nz_stddevNorm,F1amplitudeLogRelF0_sma3nz_amean,F1amplitudeLogRelF0_sma3nz_stddevNorm,F2frequency_sma3nz_amean,F2frequency_sma3nz_stddevNorm,F2bandwidth_sma3nz_amean,F2bandwidth_sma3nz_stddevNorm,F2amplitudeLogRelF0_sma3nz_amean,F2amplitudeLogRelF0_sma3nz_stddevNorm,F3frequency_sma3nz_amean,F3frequency_sma3nz_stddevNorm,F3bandwidth_sma3nz_amean,F3bandwidth_sma3nz_stddevNorm,F3amplitudeLogRelF0_sma3nz_amean,F3amplitudeLogRelF0_sma3nz_stddevNorm,alphaRatioV_sma3nz_amean,alphaRatioV_sma3nz_stddevNorm,hammarbergIndexV_sma3nz_amean,hammarbergIndexV_sma3nz_stddevNorm,slopeV0-500_sma3nz_amean,slopeV0-500_sma3nz_stddevNorm,slopeV500-1500_sma3nz_amean,slopeV500-1500_sma3nz_stddevNorm,spectralFluxV_sma3nz_amean,spectralFluxV_sma3nz_stddevNorm,mfcc1V_sma3nz_amean,mfcc1V_sma3nz_stddevNorm,mfcc2V_sma3nz_amean,mfcc2V_sma3nz_stddevNorm,mfcc3V_sma3nz_amean,mfcc3V_sma3nz_stddevNorm,mfcc4V_sma3nz_amean,mfcc4V_sma3nz_stddevNorm,alphaRatioUV_sma3nz_amean,hammarbergIndexUV_sma3nz_amean,slopeUV0-500_sma3nz_amean,slopeUV500-1500_sma3nz_amean,spectralFluxUV_sma3nz_amean,loudnessPeaksPerSec,VoicedSegmentsPerSec,MeanVoicedSegmentLengthSec,StddevVoicedSegmentLengthSec,MeanUnvoicedSegmentLength,StddevUnvoicedSegmentLength,equivalentSoundLevel_dBp,sid,token,target,filename,cpp_amean,cpp_stddevNorm,cpp_percentile20,cpp_percentile80
0,38.72469,0.074773,37.18876,38.47970,40.44019,3.251423,37.94022,15.84932,206.45270,360.75950,0.409657,0.516380,0.211142,0.385909,0.588256,0.377114,5.487913,2.772103,5.275577,2.848066,0.183393,0.628117,18.73326,0.802712,10.520690,1.304518,7.572035,1.962388,-4.375561,-3.278681,0.023441,1.147700,0.908820,0.858329,10.061590,0.348174,19.314520,0.650321,21.89401,0.478417,530.5170,0.394628,1226.236,0.232456,-79.00346,-1.113559,1590.281,0.167123,784.5121,0.407070,-65.53321,-1.063740,2636.545,0.095774,703.0072,0.460240,-67.30154,-1.014580,-16.121190,-0.486312,24.37623,0.353210,0.041570,0.867909,-0.014989,-0.643631,0.202319,0.572868,24.18615,0.417486,8.214763,1.525709,7.750072,2.055495,-7.210148,-2.001652,-6.364767,13.05844,-0.020086,-0.002822,0.131253,6.017192,2.325582,0.323750,0.279148,0.086250,0.056111,-36.04536,VFP10,Speech1,1,VFP10_Speech1,18.435575,0.228489,14.486988,22.229799
1,41.11026,0.121323,37.82011,40.29837,42.86220,5.042088,65.27183,67.44999,107.02950,136.12460,0.352636,0.772619,0.091338,0.276297,0.643402,0.552064,5.773097,3.043491,5.736364,3.096833,0.165227,1.078465,16.27722,0.886090,4.175593,2.418372,4.164841,4.175746,-1.238956,-10.893710,0.033043,0.955943,1.110513,1.005039,8.546047,0.610735,10.095680,1.019345,16.71051,0.952560,609.8846,0.475136,1317.956,0.256875,-141.02610,-0.608951,1612.576,0.212125,900.7050,0.395662,-111.65260,-0.791487,2706.250,0.155800,773.9065,0.492648,-113.50150,-0.762943,-9.636793,-1.066963,20.83702,0.549675,0.031475,1.370808,-0.011556,-0.940793,0.281724,0.656881,21.47814,0.785475,2.729456,4.197899,-2.206290,-9.858141,-11.921610,-0.857818,-10.492630,18.57218,-0.030845,-0.001601,0.057944,3.724928,1.749271,0.271667,0.260411,0.276667,0.294430,-37.11591,VFP10,Speech2,1,VFP10_Speech2,17.512815,0.266492,13.536888,22.317354
2,40.64695,0.103110,38.43892,40.61772,43.04973,4.610809,70.18970,41.75834,143.79800,158.40300,0.455030,0.595214,0.193809,0.414761,0.771635,0.577826,5.880630,3.201431,4.344426,2.605690,0.192152,0.685843,16.00158,0.908747,1.617499,10.016700,2.134295,6.868225,0.189451,55.625250,0.029924,1.759617,1.051071,0.979267,9.214264,0.533516,9.660620,1.029315,16.85061,0.659495,631.3687,0.335223,1214.071,0.301923,-101.13850,-0.902200,1638.225,0.144500,934.1242,0.369319,-82.60666,-0.987637,2679.333,0.106968,783.9202,0.406726,-86.69166,-0.905688,-10.902040,-0.925076,20.16452,0.549932,0.026402,1.569013,-0.007047,-1.733416,0.232238,0.547707,22.32532,0.455623,-3.063484,-4.477240,-0.415351,-36.927860,-2.411690,-4.179638,-8.437313,16.16711,-0.026844,-0.003204,0.118944,5.157593,3.498543,0.175000,0.168201,0.118889,0.079365,-36.20234,VFP10,Speech3,1,VFP10_Speech3,18.321217,0.245712,14.002137,23.058219
3,30.43643,0.271136,21.39732,35.07611,37.99580,16.598470,520.01150,754.83490,113.23290,72.07719,0.361249,0.447226,0.200881,0.350636,0.494258,0.293377,3.505941,1.580694,2.991900,1.599806,0.152279,0.561194,17.09886,0.571667,9.499130,0.938797,15.013220,1.094950,-1.632905,-6.753133,0.061652,1.476092,1.323190,0.675948,4.797133,1.021562,7.106359,2.069486,17.30584,0.715088,495.4891,0.446891,1164.115,0.214102,-108.62230,-0.835690,1611.846,0.172599,685.3447,0.616692,-96.16886,-0.857065,2652.961,0.097656,725.8063,0.541421,-96.83107,-0.847054,-14.949060,-0.359614,23.91982,0.319472,-0.003510,-9.079530,-0.015711,-0.624240,0.175582,0.453718,21.25350,0.390166,12.484160,0.547064,13.903480,1.263885,-0.754947,-14.337070,-9.696174,18.68855,-0.035889,-0.006141,0.116599,3.724928,3.197675,0.174545,0.154176,0.116364,0.093738,-40.33198,VFP11,Speech1,1,VFP11_Speech1,14.886433,0.150245,13.171427,16.589531
4,32.12006,0.241961,22.38828,35.72247,37.69841,15.310130,1028.64100,954.45980,159.64700,112.48600,0.383630,0.413864,0.241884,0.355205,0.527708,0.285824,4.263408,3.172111,3.038845,1.653884,0.151369,0.475207,12.93016,0.957830,9.103159,0.968121,20.513380,0.633031,-2.292660,-4.604483,0.054972,1.229439,1.649192,0.817109,5.639021,0.829598,13.270850,1.035262,17.27562,0.684350,444.5576,0.284360,1154.101,0.208867,-115.89510,-0.761690,1595.626,0.137739,596.8990,0.641410,-108.24170,-0.744009,2653.475,0.071088,772.4484,0.535277,-107.27790,-0.759111,-14.976390,-0.407176,22.52378,0.311171,-0.010057,-2.957649,-0.014513,-0.663797,0.176212,0.397830,19.65244,0.415014,12.042640,0.726816,21.610900,0.671831,-2.279478,-4.844434,-7.212510,14.82589,-0.046613,-0.006085,0.118889,5.730659,4.081633,0.121429,0.107628,0.120833,0.178814,-39.28465,VFP11,Speech2,1,VFP11_Speech2,15.198459,0.144177,13.505874,16.946842
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
448,32.82069,0.073434,30.89363,32.34020,34.59331,3.699677,182.45580,326.94090,58.27354,55.20131,0.163460,0.522306,0.075665,0.156032,0.241769,0.166104,1.885137,1.111351,1.351178,0.590421,0.046189,0.667091,27.36822,0.450612,15.831880,0.795290,18.915380,0.760017,6.812958,1.755961,0.037639,1.190368,0.964755,0.991380,9.131144,0.310881,8.164756,1.006656,28.56432,0.322647,496.2999,0.316247,1274.801,0.171952,-66.06810,-1.246946,1621.714,0.136256,937.1945,0.537158,-76.53618,-0.946130,2685.701,0.108204,984.9135,0.431121,-80.27860,-0.874510,-20.851850,-0.339330,31.00552,0.284932,0.026031,1.334945,-0.015354,-0.629295,0.056202,0.534305,31.05387,0.374250,16.661340,0.687249,20.445920,0.756702,5.125762,2.416586,-17.455410,28.20164,-0.019634,-0.001798,0.020782,4.871060,2.332362,0.302500,0.161071,0.106250,0.154186,-46.03365,VFPNorm8,Speech2,0,VFPNorm8_Speech2,17.564739,0.196077,14.454194,20.907656
449,32.18818,0.070383,30.34963,31.72725,33.34668,2.997049,28.74552,11.09664,26.86969,23.55212,0.144799,0.503324,0.069550,0.141103,0.205773,0.136223,1.478385,0.844526,1.186362,0.551909,0.038738,0.706269,23.21600,0.767392,15.108680,0.809788,14.349510,1.121815,8.470038,1.412676,0.032173,1.535001,0.937512,1.005167,9.108112,0.285027,10.054340,0.843113,32.12694,0.272087,543.8298,0.384883,1265.332,0.213014,-89.37212,-0.993368,1560.604,0.176490,931.4860,0.385068,-95.93403,-0.847097,2698.299,0.133518,883.1241,0.425145,-101.72920,-0.751511,-18.480710,-0.416476,29.56695,0.284443,0.014625,1.903894,-0.013567,-0.875202,0.046592,0.616653,31.61001,0.421032,14.857740,0.835053,13.938850,1.313820,5.178216,2.521554,-14.754670,24.56603,-0.028107,-0.001130,0.027894,3.724928,3.206997,0.177273,0.180811,0.127000,0.129232,-48.11572,VFPNorm8,Speech3,0,VFPNorm8_Speech3,17.078172,0.204099,13.790553,20.426678
450,24.48411,0.180009,22.78093,25.40892,28.27801,5.497078,197.19170,130.46350,45.02444,25.83455,0.117911,0.483579,0.063525,0.108178,0.171664,0.108140,1.420177,0.629638,1.054868,0.484441,0.033011,0.519085,26.10903,0.543691,16.517930,0.593427,11.507460,1.250198,5.021239,2.448002,0.079452,2.069847,1.355051,0.954663,5.394252,0.598302,2.956565,1.731019,30.86301,0.207399,448.8878,0.313844,1220.155,0.122200,-85.09564,-1.085957,1461.210,0.187567,939.6982,0.327941,-95.18096,-0.886000,2521.296,0.101952,810.5140,0.457890,-100.74440,-0.788425,-19.690540,-0.306501,30.82067,0.229403,0.011569,2.887912,-0.020267,-0.655963,0.041008,0.408967,34.67593,0.267692,17.423910,0.488991,11.429120,1.455805,-0.466549,-26.340980,-17.850630,28.72795,-0.044269,-0.001048,0.022439,5.444126,2.616279,0.216667,0.127715,0.145556,0.118144,-51.09922,VFPNorm9,Speech1,0,VFPNorm9_Speech1,16.552551,0.203620,13.647662,19.809361
451,25.80380,0.099246,24.82298,26.69952,27.15441,2.331427,169.48410,131.27800,36.31673,24.31405,0.103108,0.577983,0.043224,0.085041,0.167701,0.124477,1.190860,0.662923,0.784700,0.549603,0.028304,0.692012,23.85000,0.536352,11.695800,0.827275,11.894940,0.918881,7.271324,1.505499,0.061378,2.061115,1.270060,1.010738,5.545829,0.638101,4.890247,1.119272,31.56176,0.243222,493.9183,0.290058,1274.643,0.178984,-114.18080,-0.816936,1501.698,0.181778,1011.0570,0.350117,-121.17270,-0.705554,2551.819,0.116107,915.6561,0.436638,-124.99080,-0.650363,-20.627660,-0.337539,30.83374,0.230944,0.006369,3.985361,-0.021191,-0.621737,0.041655,0.438832,35.60788,0.214996,15.542970,0.686406,12.781720,1.142406,2.036459,6.614008,-15.938160,27.39670,-0.039000,-0.000503,0.017274,4.011461,1.749271,0.248333,0.106053,0.303333,0.242051,-51.94005,VFPNorm9,Speech2,0,VFPNorm9_Speech2,16.389359,0.211420,13.639057,19.368173


# Preprocessing

# Descriptive statistics

In [6]:

def corrdot(*args, **kwargs):
    corr_r = args[0].corr(args[1], 'spearman')
    corr_text = f"{corr_r:2.2f}".replace("0.", ".")
    ax = plt.gca()
    ax.set_axis_off()
#     print(abs(corr_r) *10000)
    
    marker_size = abs(corr_r) * 10000
    
    ax.scatter([.5], [.5], marker_size, [corr_r], alpha=0.6, cmap="coolwarm",
               vmin=-1, vmax=1, transform=ax.transAxes)
    font_size = 40 #abs(corr_r) * 40 + 5
    ax.annotate(corr_text, [.5, .5,],  xycoords="axes fraction",
                ha='center', va='center', fontsize=font_size)

In [7]:
run_this = False

if run_this:
    # Pairwise correlation plot
    sns.set(style='white', font_scale=1.6)
    # iris = sns.load_dataset('iris')
    if run_toy:
      g = sns.PairGrid(df.sample(frac=0.1), aspect=1.4, diag_sharey=False)
    else:
      g = sns.PairGrid(df, aspect=1.4, diag_sharey=False)
    g.map_lower(sns.regplot, lowess=True, ci=True, line_kws={'color': 'black'}, fit_reg=True,
              x_jitter=.1, y_jitter=.1, 
                scatter_kws={"s": 1, "alpha":0.1}
                )
    g.map_diag(sns.distplot, kde_kws={'color': 'black'})
    g.map_upper(corrdot)
    plt.show()

In [8]:
def add_top_column(df, top_col, inplace=True):
    if not inplace:
        df = df.copy()
    
    df.columns = pd.MultiIndex.from_product([[top_col], df.columns])
    return df

In [9]:
df.columns

Index(['F0semitoneFrom27.5Hz_sma3nz_amean',
       'F0semitoneFrom27.5Hz_sma3nz_stddevNorm',
       'F0semitoneFrom27.5Hz_sma3nz_percentile20.0',
       'F0semitoneFrom27.5Hz_sma3nz_percentile50.0',
       'F0semitoneFrom27.5Hz_sma3nz_percentile80.0',
       'F0semitoneFrom27.5Hz_sma3nz_pctlrange0-2',
       'F0semitoneFrom27.5Hz_sma3nz_meanRisingSlope',
       'F0semitoneFrom27.5Hz_sma3nz_stddevRisingSlope',
       'F0semitoneFrom27.5Hz_sma3nz_meanFallingSlope',
       'F0semitoneFrom27.5Hz_sma3nz_stddevFallingSlope', 'loudness_sma3_amean',
       'loudness_sma3_stddevNorm', 'loudness_sma3_percentile20.0',
       'loudness_sma3_percentile50.0', 'loudness_sma3_percentile80.0',
       'loudness_sma3_pctlrange0-2', 'loudness_sma3_meanRisingSlope',
       'loudness_sma3_stddevRisingSlope', 'loudness_sma3_meanFallingSlope',
       'loudness_sma3_stddevFallingSlope', 'spectralFlux_sma3_amean',
       'spectralFlux_sma3_stddevNorm', 'mfcc1_sma3_amean',
       'mfcc1_sma3_stddevNorm', 'mfcc2_

In [10]:
# Create DFs for each independent variable



variables = ["cpp_amean","cpp_stddevNorm","cpp_percentile20", "cpp_percentile80",]

X = df[variables].values
y = df['target'].values


In [11]:
# Observe the range of the covariates
df.describe()

Unnamed: 0,F0semitoneFrom27.5Hz_sma3nz_amean,F0semitoneFrom27.5Hz_sma3nz_stddevNorm,F0semitoneFrom27.5Hz_sma3nz_percentile20.0,F0semitoneFrom27.5Hz_sma3nz_percentile50.0,F0semitoneFrom27.5Hz_sma3nz_percentile80.0,F0semitoneFrom27.5Hz_sma3nz_pctlrange0-2,F0semitoneFrom27.5Hz_sma3nz_meanRisingSlope,F0semitoneFrom27.5Hz_sma3nz_stddevRisingSlope,F0semitoneFrom27.5Hz_sma3nz_meanFallingSlope,F0semitoneFrom27.5Hz_sma3nz_stddevFallingSlope,loudness_sma3_amean,loudness_sma3_stddevNorm,loudness_sma3_percentile20.0,loudness_sma3_percentile50.0,loudness_sma3_percentile80.0,loudness_sma3_pctlrange0-2,loudness_sma3_meanRisingSlope,loudness_sma3_stddevRisingSlope,loudness_sma3_meanFallingSlope,loudness_sma3_stddevFallingSlope,spectralFlux_sma3_amean,spectralFlux_sma3_stddevNorm,mfcc1_sma3_amean,mfcc1_sma3_stddevNorm,mfcc2_sma3_amean,mfcc2_sma3_stddevNorm,mfcc3_sma3_amean,mfcc3_sma3_stddevNorm,mfcc4_sma3_amean,mfcc4_sma3_stddevNorm,jitterLocal_sma3nz_amean,jitterLocal_sma3nz_stddevNorm,shimmerLocaldB_sma3nz_amean,shimmerLocaldB_sma3nz_stddevNorm,HNRdBACF_sma3nz_amean,HNRdBACF_sma3nz_stddevNorm,logRelF0-H1-H2_sma3nz_amean,logRelF0-H1-H2_sma3nz_stddevNorm,logRelF0-H1-A3_sma3nz_amean,logRelF0-H1-A3_sma3nz_stddevNorm,F1frequency_sma3nz_amean,F1frequency_sma3nz_stddevNorm,F1bandwidth_sma3nz_amean,F1bandwidth_sma3nz_stddevNorm,F1amplitudeLogRelF0_sma3nz_amean,F1amplitudeLogRelF0_sma3nz_stddevNorm,F2frequency_sma3nz_amean,F2frequency_sma3nz_stddevNorm,F2bandwidth_sma3nz_amean,F2bandwidth_sma3nz_stddevNorm,F2amplitudeLogRelF0_sma3nz_amean,F2amplitudeLogRelF0_sma3nz_stddevNorm,F3frequency_sma3nz_amean,F3frequency_sma3nz_stddevNorm,F3bandwidth_sma3nz_amean,F3bandwidth_sma3nz_stddevNorm,F3amplitudeLogRelF0_sma3nz_amean,F3amplitudeLogRelF0_sma3nz_stddevNorm,alphaRatioV_sma3nz_amean,alphaRatioV_sma3nz_stddevNorm,hammarbergIndexV_sma3nz_amean,hammarbergIndexV_sma3nz_stddevNorm,slopeV0-500_sma3nz_amean,slopeV0-500_sma3nz_stddevNorm,slopeV500-1500_sma3nz_amean,slopeV500-1500_sma3nz_stddevNorm,spectralFluxV_sma3nz_amean,spectralFluxV_sma3nz_stddevNorm,mfcc1V_sma3nz_amean,mfcc1V_sma3nz_stddevNorm,mfcc2V_sma3nz_amean,mfcc2V_sma3nz_stddevNorm,mfcc3V_sma3nz_amean,mfcc3V_sma3nz_stddevNorm,mfcc4V_sma3nz_amean,mfcc4V_sma3nz_stddevNorm,alphaRatioUV_sma3nz_amean,hammarbergIndexUV_sma3nz_amean,slopeUV0-500_sma3nz_amean,slopeUV500-1500_sma3nz_amean,spectralFluxUV_sma3nz_amean,loudnessPeaksPerSec,VoicedSegmentsPerSec,MeanVoicedSegmentLengthSec,StddevVoicedSegmentLengthSec,MeanUnvoicedSegmentLength,StddevUnvoicedSegmentLength,equivalentSoundLevel_dBp,target,cpp_amean,cpp_stddevNorm,cpp_percentile20,cpp_percentile80
count,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0,453.0
mean,27.773086,0.180886,24.501736,28.227031,30.741257,6.239521,259.302151,293.475657,95.229031,112.744889,0.334331,0.643889,0.122457,0.283488,0.528974,0.406517,4.68477,2.935852,3.735202,2.353543,0.193536,0.876727,20.37686,0.731166,10.776994,0.853266,12.454899,1.560939,5.576572,1.476789,0.055273,1.487167,1.313254,0.872286,4.889798,0.78495,5.72722,0.913835,21.730246,0.601766,524.183745,0.384546,1207.86384,0.196818,-93.12478,-1.018919,1555.705461,0.175466,876.139507,0.407901,-97.22107,-0.85503,2598.852155,0.112206,805.464425,0.444826,-99.577595,-0.806093,-15.878447,-0.492537,25.11642,0.41997,0.005709,9.00528,-0.015427,-0.798278,0.250961,0.678841,27.278677,0.388581,11.152781,0.680789,12.93603,3.31691,3.407402,2.746749,-11.896896,21.266978,-0.03468,-0.002237,0.116521,3.685114,2.302538,0.26819,0.232763,0.216193,0.194593,-40.218405,0.490066,17.003533,0.221112,13.624324,20.631712
std,5.72734,0.108714,7.109427,6.33457,6.119611,5.740371,233.846926,308.099106,145.585508,167.002546,0.217935,0.129142,0.073159,0.191295,0.363352,0.314665,4.071452,3.639947,3.159179,2.37898,0.289213,0.411603,4.704048,0.377598,4.995327,5.726374,5.577447,2.408659,5.41381,26.055441,0.031972,0.452772,0.371451,0.179732,3.257428,4.182756,5.099448,52.39921,6.281371,0.649668,90.863766,0.081074,134.822094,0.045086,34.446862,0.341749,177.671276,0.032387,128.618558,0.108751,31.189685,0.208976,268.971056,0.022042,131.697392,0.11511,30.458584,0.186828,3.920372,0.247867,4.555246,1.345527,0.025963,210.652117,0.006658,2.081297,0.387468,0.354488,5.547402,0.185532,6.544443,11.980173,8.108624,30.224763,8.121887,44.540852,3.95383,4.881126,0.015876,0.004791,0.185645,0.829925,0.666932,0.159867,0.171055,0.312,0.125715,9.304772,0.500454,1.801019,0.057476,0.533749,3.403467
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-441.4793,0.0,0.04189,0.103423,0.029156,0.040626,0.044929,0.00627,0.175077,0.229224,0.112962,0.15254,0.00908,0.351456,1.821961,0.188744,-2.866461,-78.58298,-7.121749,-9.103332,-14.7408,-435.1182,0.0,0.0,0.0,0.0,-6.178151,-56.27238,-30.87626,-1088.493,-28.148,-8.293115,0.0,0.0,0.0,0.0,-201.0,-2.255436,0.0,0.0,0.0,0.0,-201.0,-1.328574,0.0,0.0,0.0,0.0,-201.0,-1.238518,-24.30405,-1.880323,0.0,0.0,-0.113657,-487.7057,-0.036467,-12.44503,0.0,0.0,-16.32981,-1.100615,-9.374919,-133.3555,-16.90892,-44.49627,-26.08918,-275.7294,-26.14068,10.02642,-0.083885,-0.019222,0.009083,1.712483,0.0,0.0,0.0,0.0275,0.0,-69.17478,0.0,14.056408,0.110981,12.580905,15.355588
25%,24.17795,0.098649,19.46681,24.27154,26.72837,2.588263,112.0131,94.01752,33.35004,25.98172,0.167986,0.564292,0.06955,0.141103,0.261552,0.179364,2.174997,1.1452,1.808901,1.040513,0.046269,0.63601,17.17762,0.560696,7.846159,0.706274,8.546057,0.803981,2.235135,0.864667,0.032075,1.169063,1.082655,0.764868,3.081957,0.427521,2.75191,1.071203,18.19321,0.362436,478.0557,0.334315,1176.238,0.171949,-113.3196,-1.213154,1496.903,0.15783,819.6248,0.342813,-114.9155,-0.990652,2543.167,0.10017,742.8032,0.36559,-116.5816,-0.925613,-18.70609,-0.553224,22.56676,0.292049,-0.01033,-1.344008,-0.019224,-1.024222,0.062908,0.490611,24.73734,0.309626,7.333704,0.593549,7.861609,0.708069,-0.862195,-1.362136,-14.21167,17.83577,-0.044269,-0.004882,0.02836,3.151862,1.816118,0.179167,0.134536,0.124444,0.115515,-47.6288,0.0,15.449793,0.172552,13.264555,17.318851
50%,27.63382,0.150999,24.5925,28.71903,30.98429,3.799084,188.5835,184.9389,69.79765,70.1663,0.264598,0.643075,0.102379,0.225703,0.414051,0.301807,3.592666,1.980574,2.902293,1.722978,0.105108,0.760844,20.22096,0.681601,10.73927,0.968121,12.06348,1.09495,5.69286,1.461549,0.046282,1.476092,1.269155,0.881893,5.151922,0.674065,5.395566,1.873109,22.12645,0.493931,508.3525,0.395157,1222.283,0.192762,-92.46409,-0.985871,1557.735,0.176463,890.2863,0.399022,-96.75571,-0.861373,2614.511,0.112229,804.1733,0.436676,-99.58066,-0.813755,-16.14501,-0.463387,25.61967,0.349932,0.006369,0.957554,-0.015043,-0.797808,0.129728,0.592609,27.71767,0.365915,11.7006,0.88956,13.15574,1.06928,3.18086,1.139261,-11.69449,21.08169,-0.034342,-0.001455,0.057008,3.724928,2.332362,0.237778,0.193776,0.176667,0.175619,-40.93425,0.0,16.990854,0.234159,13.615555,20.674922
75%,31.61681,0.248335,30.04253,32.4717,34.99295,7.419657,330.8329,390.6895,127.1552,131.3925,0.455637,0.721746,0.161954,0.385909,0.72981,0.540669,6.04021,3.526557,4.72824,2.843472,0.226346,0.95609,23.69091,0.832952,14.44822,1.442621,16.1141,1.58251,9.348991,2.787834,0.072565,1.784315,1.50077,0.991683,7.287317,1.046216,8.748993,3.239242,25.78078,0.703408,560.1079,0.435684,1264.474,0.216718,-72.93134,-0.810123,1618.664,0.194366,953.9945,0.461946,-79.59825,-0.744695,2679.301,0.126046,879.2843,0.505434,-81.87444,-0.705348,-13.56808,-0.386336,27.9745,0.40882,0.022644,2.339598,-0.011386,-0.641194,0.292596,0.746487,30.40254,0.438311,15.55342,1.403709,18.61148,1.767408,8.725277,2.865714,-9.195477,24.26268,-0.025549,0.000828,0.130639,4.297994,2.647059,0.309429,0.284144,0.227778,0.25734,-33.38766,1.0,18.250852,0.266162,13.896243,23.320696
max,49.16956,0.744891,51.95441,53.8151,56.81369,32.29908,1802.76,2025.72,2434.991,2427.714,1.216163,1.117512,0.444798,1.178534,2.035199,1.771311,33.74147,41.01346,19.99471,23.7948,2.258952,3.221594,36.32645,6.606424,23.59145,25.35354,28.93164,27.72505,18.4247,160.64,0.159942,2.946238,2.914299,1.240887,11.74264,33.25824,25.60875,118.9058,37.17638,3.317974,998.6246,0.645759,1399.877,0.560359,-14.21301,0.0,2053.354,0.283605,1174.822,0.895936,-28.44161,0.0,3145.756,0.183249,1429.579,0.961035,-32.27789,0.0,5.805379,2.596702,33.77319,28.91362,0.081327,4443.285,0.007087,26.24671,3.283502,3.593985,41.8151,2.751229,25.3492,106.3512,37.25217,607.3204,24.054,743.7161,-1.811306,37.2087,0.015021,0.009007,1.686988,6.590258,4.081633,1.275555,1.759411,3.43,1.157648,-14.6006,1.0,22.406372,0.33621,16.307431,28.654261


In [20]:
import random
from sklearn.inspection import permutation_importance
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV 
from sklearn.naive_bayes import MultinomialNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, roc_auc_score

### Bootstrapping

In [21]:
from sklearn.linear_model import LogisticRegressionCV, SGDClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier


models = [
    LogisticRegressionCV(solver='liblinear', penalty = 'l1', max_iter = 100),
    MLPClassifier(alpha = 1, max_iter= 1000),
    RandomForestClassifier(n_estimators= 100),
    SGDClassifier(loss='log', penalty="elasticnet", early_stopping=True, max_iter = 5000),
]


names = ['LogisticRegressionCV', "MLPClassifier","RandomForestClassifier",'SGDClassifier']

In [23]:
import random


In [27]:
y_train_shuffled = random.shuffle(y_train)

In [32]:
%%time

from sklearn.model_selection import train_test_split

toy = False

filenames = ['speech_cpp', 'vowel_cpp', 'both_cpp']



for filename in filenames:

    for null_model in [True, False]:
        df = pd.read_csv(input_dir + f'features/egemaps_vector_{filename}.csv', index_col = 0)
        variables = ["cpp_amean","cpp_stddevNorm","cpp_percentile20", "cpp_percentile80",]
        X = df[variables].values
        y = df['target'].values
        print('\npermute', null_model)

        if toy:
          n_bootstraps = 3
        else:
          n_bootstraps = 50

        if null_model:
            y = np.random.permutation(y) #CHECK


        y_pred_all = {}
        roc_auc_all = {}
        for model, name in zip(models, names):
          y_pred_all[name] = []
          roc_auc_all[name] = []
          pipe = Pipeline(steps=[
                  ('scaler', StandardScaler()), 
                  ('model', model)])

          ## Performing bootstrapping
          for i in range(n_bootstraps):
              #Split the data into training and testing set

              # Chaning the seed value for each iteration
              X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42+i)

          # for train_index, test_index in bs:
          #   X_train, X_test, y_train, y_test = X[train_index], X[test_index], y[train_index], y[test_index]
              pipe.fit(X_train,y_train)

              # Test     
              y_pred = pipe.predict(X_test)
              roc_auc = roc_auc_score(y_test,y_pred)
              # print(name, roc_auc)
              y_pred_all[name].append(y_pred)
              roc_auc_all[name].append(roc_auc)

        results_i = []
        for name in ['LogisticRegressionCV','MLPClassifier','RandomForestClassifier','SGDClassifier']:
          scores = roc_auc_all.get(name)
          roc_auc_median = np.round(np.median(scores),2)
          roc_auc_5 = np.round(np.percentile(scores, 5),2)
          roc_auc_95 = np.round(np.percentile(scores, 95),2)
          results_str = f'{roc_auc_median} ({roc_auc_5}–{roc_auc_95}; )'
          results_str = results_str.replace('0.', '.')
          results_i.append([name, results_str])

          if null_model:
            print(name, str(roc_auc_median).replace('0.', '.'))
        if not null_model:
            results_i_df = pd.DataFrame(results_i, ).T
            display(results_i_df)
            results_i_df.to_csv(output_dir+f'results_{filename}_permute-{null_model}.csv')




        # pd.DataFrame(y_pred_all)



permute True
LogisticRegressionCV .5
MLPClassifier .5
RandomForestClassifier .52
SGDClassifier .5

permute False


Unnamed: 0,0,1,2,3
0,LogisticRegressionCV,MLPClassifier,RandomForestClassifier,SGDClassifier
1,.77 (.69–.81; ),.76 (.69–.82; ),.73 (.66–.8; ),.74 (.33–.81; )



permute True




LogisticRegressionCV .5
MLPClassifier .49
RandomForestClassifier .51
SGDClassifier .5

permute False


Unnamed: 0,0,1,2,3
0,LogisticRegressionCV,MLPClassifier,RandomForestClassifier,SGDClassifier
1,.81 (.75–.87; ),.8 (.74–.87; ),.77 (.7–.81; ),.76 (.32–.86; )



permute True
LogisticRegressionCV .5
MLPClassifier .5
RandomForestClassifier .53
SGDClassifier .5

permute False


Unnamed: 0,0,1,2,3
0,LogisticRegressionCV,MLPClassifier,RandomForestClassifier,SGDClassifier
1,.73 (.69–.78; ),.77 (.72–.81; ),.75 (.69–.79; ),.7 (.5–.76; )


CPU times: user 3min 11s, sys: 3min 11s, total: 6min 22s
Wall time: 2min 38s
