In [1]:
import json

from pathlib import Path

import pandas as pd
import numpy as np

import biopsykit as bp
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.impute import SimpleImputer


#Feature Selection
from sklearn.feature_selection import SelectKBest, RFE

#Classification
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC

# Regression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import AdaBoostRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import HistGradientBoostingRegressor

# Cross-Validation
from sklearn.model_selection import GroupKFold

from biopsykit.classification.model_selection import SklearnPipelinePermuter

import matplotlib.pyplot as plt

%matplotlib widget
%load_ext autoreload
%autoreload 2

In [2]:
save_results = False

In [3]:
data_path = Path("../../results/data")
data_path

WindowsPath('../../results/data')

In [4]:
models_path = Path("../../results/models")

In [6]:
input_data = pd.read_csv(data_path.joinpath("train_data.csv"), index_col=[0,1,2,3])
input_data

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,arbol2017-isoelectric-crossings_forouzanfar2018,arbol2017-isoelectric-crossings_linear-interpolation,arbol2017-isoelectric-crossings_none,arbol2017-second-derivative_forouzanfar2018,arbol2017-second-derivative_linear-interpolation,arbol2017-second-derivative_none,arbol2017-third-derivative_forouzanfar2018,arbol2017-third-derivative_linear-interpolation,arbol2017-third-derivative_none,debski1993-second-derivative_forouzanfar2018,...,lozano2007-linear-regression_none,lozano2007-quadratic-regression_forouzanfar2018,lozano2007-quadratic-regression_linear-interpolation,lozano2007-quadratic-regression_none,sherwood1990_forouzanfar2018,sherwood1990_linear-interpolation,sherwood1990_none,stern1985_forouzanfar2018,stern1985_linear-interpolation,stern1985_none
participant,phase,b_point_sample_reference,heartbeat_id_reference,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
GDN0005,HoldingBreath,388.0,0,438.0,438.0,438.0,398.0,398.0,398.0,394.0,394.0,394.0,452.0,...,412.0,384.0,384.0,384.0,442.0,442.0,442.0,388.0,388.0,388.0
GDN0005,HoldingBreath,404.0,1,420.0,422.0,340.0,350.0,350.0,350.0,396.0,386.0,244.0,426.0,...,404.0,384.0,384.0,384.0,422.0,420.0,330.0,402.0,402.0,402.0
GDN0005,HoldingBreath,376.0,3,382.0,382.0,382.0,296.0,296.0,296.0,386.0,386.0,386.0,366.0,...,366.0,348.0,348.0,348.0,382.0,382.0,382.0,374.0,374.0,374.0
GDN0005,HoldingBreath,390.0,4,394.0,394.0,394.0,344.0,344.0,344.0,396.0,396.0,396.0,376.0,...,372.0,354.0,358.0,348.0,394.0,394.0,394.0,388.0,388.0,388.0
GDN0005,HoldingBreath,386.0,5,398.0,398.0,398.0,312.0,312.0,312.0,388.0,388.0,388.0,418.0,...,378.0,360.0,366.0,354.0,400.0,400.0,400.0,384.0,384.0,384.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
VP_032,Talk,310.0,39,335.0,335.0,335.0,276.0,276.0,276.0,324.0,324.0,324.0,309.0,...,305.0,294.0,294.0,294.0,337.0,337.0,337.0,306.0,306.0,306.0
VP_032,Talk,322.0,40,331.0,331.0,331.0,298.0,298.0,298.0,342.0,345.0,330.0,318.0,...,321.0,316.0,317.0,303.0,329.0,329.0,329.0,287.0,287.0,287.0
VP_032,Talk,340.0,41,317.0,317.0,317.0,300.0,300.0,300.0,348.0,348.0,348.0,330.0,...,332.0,322.0,322.0,322.0,311.0,311.0,311.0,287.0,287.0,287.0
VP_032,Talk,311.0,42,365.0,365.0,365.0,324.0,324.0,324.0,366.0,366.0,366.0,347.0,...,351.0,337.0,337.0,337.0,368.0,368.0,368.0,305.0,305.0,305.0


In [7]:
X, y, groups, group_keys = bp.classification.utils.prepare_df_sklearn(data=input_data, label_col="b_point_sample_reference", subject_col="participant", print_summary=True)

Shape of X: (10385, 30); shape of y: (10385,); number of groups: 39, class prevalence: [ 2  1  1  4  1  1  1  1  1  2  1  1  2  1  1  2  3  2  6  3  9  4  8  7
 10  4  5 11  7 13  9 11 14 11 15  9 12 15 17 21 22 20 28 27 29 28 34 35
 44 37 37 39 30 34 46 40 50 47 41 38 23 21 26 34 34 27 33 37 26 40 20 26
 33 23 23 29 31 18 28 24 19 29 17 20 24 23 23 11 18 20 16 23 26 24 22 10
 15 14 17 14 17 20 16 13 17 17 25 16 12 19 21 18 19 19 22 17 22 19 21 24
 24 18 27 21 17 23 13 14 22 21 25 12 21 18 27 14 26 17 20 23 19  9 22 16
 24 11 18 21 25 16 31 20 23 25 26 24 27 10 36 15 36 12 36 21 29 16 32 23
 55 16 48 14 38 28 31 11 52  8 56 20 54 23 50 14 53 18 63 18 56 18 55 24
 48 21 59 14 57 11 73 19 44 18 56 12 72 17 56 14 53 16 63 18 77 12 84  9
 60 16 69  9 59  4 66 16 64  7 66  9 82  9 83  6 73  8 71 10 81 10 90  9
 69  7 68  4 74  5 71  3 61  7 78  5 77  2 82  7 73  8 64  8 61  1 81  2
 76  3 63 73  2 68  3 64  2 86  6 66  2 84  2 78  1 66  2 65 72 77  2 56
  1 71 69  1 59  1 76 49 67  1 68 63 

In [8]:
group_keys

Index(['GDN0005', 'GDN0006', 'GDN0007', 'GDN0008', 'GDN0009', 'GDN0010',
       'GDN0011', 'GDN0012', 'GDN0013', 'GDN0014', 'GDN0016', 'GDN0017',
       'GDN0018', 'GDN0019', 'GDN0020', 'GDN0021', 'GDN0022', 'GDN0023',
       'GDN0024', 'GDN0025', 'GDN0027', 'GDN0028', 'GDN0029', 'GDN0030',
       'VP_001', 'VP_002', 'VP_003', 'VP_004', 'VP_005', 'VP_020', 'VP_022',
       'VP_023', 'VP_026', 'VP_027', 'VP_028', 'VP_029', 'VP_030', 'VP_031',
       'VP_032'],
      dtype='object')