In [14]:
import pandas as pd
import pandas_profiling
import random

import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

import pickle

from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder, OneHotEncoder
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix

import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch import nn
from pytorch_tabnet.pretraining import TabNetPretrainer
from pytorch_tabnet.tab_model import TabNetClassifier
import torch

pd.options.display.max_columns = 200
random_state = 123

In [15]:
#pytorchのランダムシード固定
def torch_fix_seed(seed=123):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True
    
torch_fix_seed()

## データ修正

In [16]:
#データ読み込み
train = pd.read_csv("data/train_df.csv")
train2 = pd.read_csv("labeling_data/lgb_2nd.csv")

test = pd.read_csv("data/test_df.csv")
sample= pd.read_csv("data/submission.csv",header = None)

In [18]:
train = train2.drop('sum_glasgow', axis=1)

In [19]:
#カテゴリと考えられる変数を変更
ids = ['id', 'personal_id_1', 'personal_id_2',]
# cats = ['facility_id', 'icu_5', 'icu_7', 'icu_8', 'icu_id', 'situation_1', 'situation_2', 'glasgow_coma_scale_3']
cats = ['facility_id', 'icu_5', 'icu_7', 'icu_8', 'icu_id', 'situation_1', 'situation_2', 'glasgow_coma_scale_3', 'blood_oxy']
#cats = ['facility_id', 'icu_5', 'icu_id'] #best_cats
dis_name = ['aids', 'cirrhosis', 'diabetes', 'hepatic_issue', 'immunosuppression', 'leukemia', 'lymphoma', 'carcinoma'] 
for c in ids:
    train[c] = train[c].astype(object)
    test[c] = test[c].astype(object)

for c in cats:
    train[c] = train[c].astype(object)
    test[c] = test[c].astype(object)
    
for c in dis_name:
    train[c] = train[c].astype(object)
    test[c] = test[c].astype(object)

In [20]:
# データ確認
y_train=train[['target_label']]
x_train=train.drop('target_label',axis=1)
col_num = x_train.columns[x_train.dtypes!='object'].values.tolist()
print('数値データ')
print(col_num)
print('-'*100)
col_cat = x_train.columns[x_train.dtypes=='object'].values.tolist()
print('カテゴリ変数')
print(col_cat)

数値データ
['age', 'bmi', 'height', 'weight', 'icu_4', 'icu_6', 'glasgow_coma_scale_1', 'glasgow_coma_scale_2', 'glasgow_coma_scale_4', 'heart_rate', 'arterial_pressure', 'respiratory_rate', 'temp', 'blood_pressure_1', 'blood_pressure_2', 'blood_pressure_3', 'blood_pressure_4', 'v1_heartrate_max', 'v2', 'v3', 'v4', 'v5', 'v6', 'v7', 'v8', 'v9', 'v10', 'v11', 'v12', 'v13', 'v14', 'v15', 'v16', 'w1', 'w2', 'w3', 'w4', 'w5', 'w6', 'w7', 'w8', 'w9', 'w10', 'w11', 'w12', 'w13', 'w14', 'w15', 'w16', 'w17', 'w18', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6']
----------------------------------------------------------------------------------------------------
カテゴリ変数
['id', 'personal_id_1', 'personal_id_2', 'facility_id', 'situation_1', 'situation_2', 'ethnicity', 'gender', 'icu_id', 'icu_1', 'icu_2', 'icu_3', 'icu_5', 'icu_7', 'icu_8', 'glasgow_coma_scale_3', 'blood_oxy', 'aids', 'cirrhosis', 'diabetes', 'hepatic_issue', 'immunosuppression', 'leukemia', 'lymphoma', 'carcinoma', 'body_system_1', 'body_system_

In [21]:
"""
エンコーディングの方針
nunique数が少ない変数はOne Hot Encoding
nunique数が多い変数はLabel encoeding してから　tabnet内でembeding処理

body_systemsは同一診療科のものがあるがリスクが上がるのか疑問に思い、合算してOne Hot Encording
"""

low_cat_cols = [] # nunique が少ない変数
high_cat_cols = [] # nunique が多い変数

for c in col_cat:
    nunq = x_train[c].nunique()
    if nunq>30:
        high_cat_cols.append(c)
    else:
        low_cat_cols.append(c)
        
body = ['body_system_1', 'body_system_2']
for r in body:
    low_cat_cols.remove(r)
        
print('-'*20, 'unique_low', '-'*20)
print(low_cat_cols)
print('-'*20, 'unique_high', '-'*20)
print(high_cat_cols)
print('-'*20, 'body_systems', '-'*20)
print(body)

-------------------- unique_low --------------------
['situation_1', 'situation_2', 'ethnicity', 'gender', 'icu_1', 'icu_2', 'icu_3', 'icu_7', 'icu_8', 'glasgow_coma_scale_3', 'blood_oxy', 'aids', 'cirrhosis', 'diabetes', 'hepatic_issue', 'immunosuppression', 'leukemia', 'lymphoma', 'carcinoma']
-------------------- unique_high --------------------
['id', 'personal_id_1', 'personal_id_2', 'facility_id', 'icu_id', 'icu_5']
-------------------- body_systems --------------------
['body_system_1', 'body_system_2']


In [22]:
#少ない要素は確認してみる
for c in low_cat_cols:
    print(f"{c}:\t{x_train[c].unique()}")

situation_1:	[0 1]
situation_2:	[1.0 0.0 nan]
ethnicity:	['Caucasian' 'African American' 'Other/Unknown' 'Hispanic' nan 'Asian'
 'Native American']
gender:	['M' 'F' nan]
icu_1:	['Floor' 'Accident & Emergency' 'Operating Room / Recovery'
 'Other Hospital' 'Other ICU' nan]
icu_2:	['admit' 'readmit' 'transfer']
icu_3:	['MICU' 'CCU-CTICU' 'Med-Surg ICU' 'Neuro ICU' 'CSICU' 'SICU' 'CTICU'
 'Cardiac ICU']
icu_7:	[0 1]
icu_8:	[0.0 1.0 nan]
glasgow_coma_scale_3:	[0.0 1.0 nan]
blood_oxy:	[0.0 1.0 nan]
aids:	[0.0 nan 1.0]
cirrhosis:	[0.0 1.0 nan]
diabetes:	[1.0 0.0 nan]
hepatic_issue:	[0.0 1.0 nan]
immunosuppression:	[0.0 1.0 nan]
leukemia:	[0.0 1.0 nan]
lymphoma:	[0.0 nan 1.0]
carcinoma:	[0.0 1.0 nan]


In [23]:
#nunique が少ない変数はOne Hot

dict_low_cat = {}
for col in low_cat_cols:
    print(col)
    value_fillna = 'unknown'
    x_train[col] = x_train[col].fillna(value_fillna)
    x_train[col] = x_train[col].astype(str)
    
    dict_low_cat[col] = {}
    dict_low_cat[col]['fillna'] = value_fillna
    
x_train = pd.get_dummies(x_train, dummy_na=False, drop_first=False, columns=low_cat_cols)
print('Done')

situation_1
situation_2
ethnicity
gender
icu_1
icu_2
icu_3
icu_7
icu_8
glasgow_coma_scale_3
blood_oxy
aids
cirrhosis
diabetes
hepatic_issue
immunosuppression
leukemia
lymphoma
carcinoma
Done


In [24]:
# nunique が多い変数はlabel encoding

ids = ['id', 'personal_id_1', 'personal_id_2']
for i in ids:
    high_cat_cols.remove(i)
    
dict_high_cat = {}
for col in high_cat_cols:
    print(col)
    value_fillna = 'unknown'
    x_train[col] = x_train[col].fillna(value_fillna)
    x_train[col] = x_train[col].astype(str)
    # strに変換
    le = LabelEncoder()
    le.fit(x_train[col])
    list_label = sorted(list(set(le.classes_) | set(['unknown'])))
    # print(list_label)
    map_label = {j:i for i,j in enumerate(list_label)}
    x_train[col] = x_train[col].map(map_label)
    # categorical_dims = len(le.classes_)

    
    dict_high_cat[col] = {}
    dict_high_cat[col]['fillna'] = value_fillna
    dict_high_cat[col]['map_label'] = map_label
    dict_high_cat[col]['num_label'] = len(list_label)
    # dict_high_cat[col]['categorical_dims'] = categorical_dims

print('Done')

facility_id
icu_id
icu_5
Done


In [25]:
# 数値データは標準化
dict_num = {}
for col in col_num:
    print(col)
    # 欠損値を0へ
    value_fillna = 0 
    x_train[col] = x_train[col].fillna(value_fillna)
    # 正規化
    value_min = x_train[col].min()
    value_max = x_train[col].max()
    value_mean = x_train[col].mean()
    value_std = x_train[col].std()
    #train[col] = (train[col] - value_min) / (value_max - value_min)
    x_train[col] = (x_train[col] - value_mean) / value_std
    
    dict_num[col] = {}
    dict_num[col]['fillna'] = value_fillna
    dict_num[col]['min'] = value_min
    dict_num[col]['max'] = value_max
    dict_num[col]['mean'] = value_mean    
    dict_num[col]['std'] = value_std    
    
print('Done')

age
bmi
height
weight
icu_4
icu_6
glasgow_coma_scale_1
glasgow_coma_scale_2
glasgow_coma_scale_4
heart_rate
arterial_pressure
respiratory_rate
temp
blood_pressure_1
blood_pressure_2
blood_pressure_3
blood_pressure_4
v1_heartrate_max
v2
v3
v4
v5
v6
v7
v8
v9
v10
v11
v12
v13
v14
v15
v16
w1
w2
w3
w4
w5
w6
w7
w8
w9
w10
w11
w12
w13
w14
w15
w16
w17
w18
x1
x2
x3
x4
x5
x6
Done


In [26]:
def transform_data(input_x):
    output_x = input_x.copy()
    
    for col in col_num:
        value_fillna = dict_num[col]['fillna']
        output_x[col] = output_x[col].fillna(value_fillna)
        
        value_min = dict_num[col]['min']
        value_max = dict_num[col]['max']
        value_mean = dict_num[col]['mean']
        value_std = dict_num[col]['std']
        
        # output_x[col]  = (output_x[col] - value_min ) / (value_max - value_min)
        output_x[col]  = (output_x[col] - value_mean ) / (value_std)
        
        
    for col in low_cat_cols:
        value_fillna = dict_low_cat[col]['fillna']
        output_x[col] = output_x[col].fillna(value_fillna)
        
        output_x[col] = output_x[col].astype(str)
        
    for col in high_cat_cols:
        value_fillna = dict_high_cat[col]['fillna']
        output_x[col] = output_x[col].fillna(value_fillna)
        
        output_x[col] = output_x[col].astype(str)
        
        map_label = dict_high_cat[col]['map_label']#辞書からlabel_encordの辞書を取り出す
        output_x[col] = output_x[col].map(map_label)
        
        #対応するものがない場合はunkoumn
        output_x[col] = output_x[col].fillna(map_label['unknown'])
        
    output_x = pd.get_dummies(output_x, dummy_na=False, drop_first=False, columns=low_cat_cols)
        
    return output_x

x_test = transform_data(test)

In [27]:
#OneHotencorder(doby_system)
#body_system1と2の診療科を被りを除いて、One Hot Encoding

data = [x_train, x_test]
#被っているものはbody_system_2をNanに
for d in data:   
    d['body_system_2'] = d['body_system_2'].where(d['body_system_1'] != d['body_system_2'])
#Nanはunknoun
    d.fillna('unknown', inplace=True)
    d['body_system_2'] = d['body_system_2'].replace('Undefined diagnoses', 'Undefined Diagnoses')
print(x_train.shape, x_test.shape)

body_train = pd.get_dummies(x_train[body].stack(), dummy_na=False, prefix='ohe').groupby(level=0).sum()
body_test = pd.get_dummies(x_test[body].stack(), dummy_na=False, prefix='ohe').groupby(level=0).sum()

x_train = pd.concat([x_train.drop(body, axis=1), body_train], axis=1).drop('ohe_unknown', axis=1)
x_test = pd.concat([x_test.drop(body, axis=1), body_test], axis=1).drop('ohe_unknown', axis=1)

(60628, 132) (12840, 132)


In [28]:
# idを別のdfに、
id_train = x_train[['id']]
x_train = x_train.drop(['id'], axis=1)

id_test = x_test[id_train.columns]
x_test = x_test.drop(['id'], axis=1)

print('学習用データ: {},   学習ラベル：{},  テストデータ：{}、学習IDデータ:{}, テストIDデータ:{}'.format(
    x_train.shape, y_train.shape, x_test.shape, id_train.shape, id_test.shape))

print('陽性ラベルの割合: {}'.format(
    y_train.value_counts()[1]  / len(train)))

学習用データ: (60628, 144),   学習ラベル：(60628, 1),  テストデータ：(12840, 144)、学習IDデータ:(60628, 1), テストIDデータ:(12840, 1)
陽性ラベルの割合: 0.07625189681335356


In [68]:
# テストデータをtargetencording
def tranform_data_TE(cat_cols, input_x, train_label):#catslist, testdata(df), labels(series) =>df
    output_x = input_x.copy()
    
    for c in cat_cols :
        data_tmp = pd.DataFrame({c: output_x[c], 'target': train_label})
        target_mean = data_tmp.groupby(c)['target'].mean()
        output_x.loc[:, 'TE_'+c] = output_x[c].map(target_mean)
        output_x = output_x.fillna(0)
        
        
    return  output_x

# target_encordするカテゴリ変数を指定
cat_cols = ['facility_id', 'icu_id']
x_test = tranform_data_TE(cat_cols, x_test, y_train['target_label'])

In [30]:
# 学習データをtargetencordingする関数、実際にはCV内で実行
def target_encoding(cat_cols, tr_x, tr_y, va_x):#list_cat df*3 => df*2
    # クロスバリデーションの中で実行し、出力されたデータでモデル学習する
    list_nfold=[0,1,2,3,4]
    for c in cat_cols:
        # 学習データ全体で各カテゴリにおけるtargetの平均を計算
        data_tmp = pd.DataFrame({c: tr_x[c], 'target': tr_y['target_label']})
        target_mean = data_tmp.groupby(c)['target'].mean()
        # バリデーションデータのカテゴリを置換
        va_x.loc[:, 'TE_'+c] = va_x[c].map(target_mean)
        va_x = va_x.fillna(0)
        

        # 学習データの変換後の値を格納する配列を準備
        tmp = np.repeat(np.nan, tr_x.shape[0])
        
        cv_encoding = list(StratifiedKFold(n_splits=5, shuffle=True, random_state=123).split(tr_x,tr_y['target_label']))
        for  nfold in list_nfold:
            idx_1, idx_2 = cv_encoding[nfold][0], cv_encoding[nfold][1]
            # out-of-foldで各カテゴリにおける目的変数の平均を計算
            target_mean = data_tmp.iloc[idx_1].groupby(c)['target'].mean()
            # 変換後の値を一時配列に格納
            tmp[idx_2] = tr_x[c].iloc[idx_2].map(target_mean)
            tmp[idx_2] = np.nan_to_num(tmp[idx_2])
            
        
        tr_x.loc[:, 'TE_'+c] = tmp
        
    return tr_x, va_x

In [31]:
#初期データ
display(train)
display(test)

Unnamed: 0,id,personal_id_1,personal_id_2,facility_id,age,bmi,situation_1,situation_2,ethnicity,gender,height,weight,icu_id,icu_1,icu_2,icu_3,icu_4,icu_5,icu_6,icu_7,icu_8,glasgow_coma_scale_1,glasgow_coma_scale_2,glasgow_coma_scale_3,glasgow_coma_scale_4,heart_rate,blood_oxy,arterial_pressure,respiratory_rate,temp,blood_pressure_1,blood_pressure_2,blood_pressure_3,blood_pressure_4,v1_heartrate_max,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14,v15,v16,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13,w14,w15,w16,w17,w18,x1,x2,x3,x4,x5,x6,aids,cirrhosis,diabetes,hepatic_issue,immunosuppression,leukemia,lymphoma,carcinoma,body_system_1,body_system_2,target_label
0,0,114501,58009,51,69.0,24.731460,0,1.0,Caucasian,M,175.30,76.0,698,Floor,admit,MICU,25.801389,302.0,109.09,0,0.0,3.0,6.0,0.0,3.0,100.0,0.0,50.0,33.0,,59.0,46.0,59.0,46.0,96.0,91.0,84.0,53.0,84.0,53.0,16.0,14.0,100.0,97.0,124.0,67.0,124.0,67.0,37.10,36.80,46.0,46.0,46.0,46.0,96.0,96.0,53.0,53.0,53.0,53.0,16.0,16.0,100.0,100.0,75.0,67.0,75.0,67.0,243.0,76.0,3.5,3.5,0.25,0.07,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,Cardiovascular,Cardiovascular,0.0
1,1,44353,112590,19,64.0,28.666129,0,1.0,Caucasian,M,183.00,96.0,657,Floor,admit,CCU-CTICU,3.639583,,0.19,0,0.0,1.0,1.0,0.0,1.0,117.0,0.0,145.0,4.0,36.72,73.0,48.0,73.0,48.0,111.0,62.0,100.0,59.0,100.0,59.0,30.0,0.0,97.0,87.0,178.0,99.0,178.0,99.0,37.38,36.72,,,,,83.0,80.0,,,,,17.0,8.0,94.0,93.0,,,,,158.0,109.0,4.2,4.2,0.42,0.25,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,,0.0
2,2,8023,1677,16,74.0,18.144869,0,0.0,Caucasian,F,166.00,50.0,482,Accident & Emergency,admit,MICU,0.059028,304.0,307.01,0,0.0,4.0,6.0,0.0,5.0,53.0,0.0,50.0,6.0,36.10,65.0,39.0,65.0,39.0,71.0,55.0,73.0,50.0,73.0,50.0,18.0,11.0,100.0,100.0,122.0,69.0,122.0,69.0,37.00,36.10,65.0,65.0,65.0,65.0,62.0,59.0,73.0,73.0,73.0,73.0,18.0,13.0,100.0,100.0,100.0,100.0,100.0,100.0,73.0,62.0,4.2,4.1,0.07,0.03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Gastrointestinal,Gastrointestinal,0.0
3,3,106340,74166,188,60.0,23.047667,0,0.0,Caucasian,M,182.90,77.1,855,Accident & Emergency,admit,CCU-CTICU,0.014583,123.0,702.01,0,0.0,4.0,6.0,0.0,5.0,102.0,0.0,127.0,4.0,37.00,87.0,69.0,87.0,69.0,99.0,75.0,114.0,86.0,114.0,86.0,21.0,15.0,100.0,96.0,153.0,123.0,153.0,123.0,37.10,36.60,80.0,73.0,80.0,73.0,99.0,96.0,97.0,86.0,97.0,86.0,18.0,17.0,98.0,97.0,124.0,123.0,124.0,123.0,373.0,46.0,4.2,3.2,0.01,0.00,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,Metabolic,Metabolic,0.0
4,4,118467,52717,168,75.0,20.190265,0,0.0,Caucasian,F,160.02,51.7,136,Accident & Emergency,admit,Med-Surg ICU,0.004861,304.0,308.01,0,0.0,4.0,6.0,0.0,5.0,90.0,0.0,68.0,28.0,36.80,87.0,87.0,87.0,87.0,90.0,90.0,105.0,105.0,105.0,105.0,14.0,11.0,97.0,92.0,147.0,147.0,147.0,147.0,36.80,36.80,100.0,56.0,100.0,56.0,67.0,61.0,104.0,68.0,104.0,68.0,28.0,24.0,94.0,90.0,142.0,114.0,142.0,114.0,,,,,0.08,0.02,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,Gastrointestinal,Gastrointestinal,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
60623,64191,121586,101483,139,57.0,49.170233,0,0.0,Caucasian,F,162.60,130.0,684,Accident & Emergency,admit,Cardiac ICU,0.063889,307.0,704.01,0,0.0,4.0,6.0,0.0,5.0,63.0,0.0,55.0,4.0,37.10,53.0,52.0,53.0,52.0,77.0,67.0,76.0,64.0,76.0,64.0,19.0,14.0,98.0,97.0,112.0,90.0,112.0,90.0,37.10,37.10,53.0,53.0,53.0,53.0,77.0,77.0,76.0,76.0,76.0,76.0,19.0,19.0,97.0,97.0,112.0,112.0,112.0,112.0,125.0,119.0,3.8,3.7,0.02,0.01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Metabolic,Metabolic,0.0
60624,64194,79880,56511,19,39.0,23.147277,0,0.0,Caucasian,F,163.00,61.5,657,Accident & Emergency,admit,CCU-CTICU,0.155556,304.0,301.01,0,0.0,4.0,6.0,0.0,5.0,113.0,0.0,62.0,8.0,36.33,98.0,52.0,98.0,52.0,111.0,84.0,106.0,62.0,106.0,62.0,21.0,10.0,97.0,88.0,159.0,85.0,159.0,85.0,37.38,36.33,63.0,54.0,63.0,54.0,103.0,100.0,75.0,68.0,75.0,68.0,17.0,15.0,94.0,90.0,120.0,109.0,120.0,109.0,117.0,89.0,6.4,5.5,0.03,0.02,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Gastrointestinal,Gastrointestinal,0.0
60625,64195,97405,32055,136,79.0,27.759515,0,0.0,Caucasian,M,175.50,85.5,374,Accident & Emergency,admit,Med-Surg ICU,0.195139,112.0,107.01,0,0.0,4.0,6.0,0.0,5.0,103.0,0.0,46.0,13.0,36.30,77.0,43.0,77.0,43.0,99.0,69.0,82.0,51.0,82.0,51.0,19.0,14.0,97.0,89.0,108.0,77.0,108.0,77.0,37.70,36.30,67.0,59.0,67.0,59.0,74.0,69.0,74.0,67.0,74.0,67.0,16.0,14.0,97.0,95.0,97.0,90.0,97.0,90.0,180.0,180.0,4.5,4.5,0.04,0.02,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Cardiovascular,Cardiovascular,0.0
60626,64196,31970,117733,70,56.0,28.331661,0,1.0,Caucasian,F,172.70,84.5,451,Other Hospital,admit,Neuro ICU,0.012500,301.0,403.01,0,0.0,3.0,6.0,0.0,4.0,93.0,0.0,70.0,27.0,36.60,83.0,55.0,83.0,55.0,89.0,71.0,98.0,70.0,98.0,70.0,26.0,13.0,100.0,92.0,145.0,113.0,145.0,113.0,37.10,36.60,78.0,78.0,78.0,78.0,81.0,81.0,94.0,94.0,94.0,94.0,26.0,26.0,96.0,96.0,145.0,145.0,145.0,145.0,161.0,146.0,4.1,4.1,0.08,0.03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Neurological,Neurologic,0.0


Unnamed: 0,id,personal_id_1,personal_id_2,facility_id,age,bmi,situation_1,situation_2,ethnicity,gender,height,weight,icu_id,icu_1,icu_2,icu_3,icu_4,icu_5,icu_6,icu_7,icu_8,glasgow_coma_scale_1,glasgow_coma_scale_2,glasgow_coma_scale_3,glasgow_coma_scale_4,heart_rate,blood_oxy,arterial_pressure,respiratory_rate,temp,blood_pressure_1,blood_pressure_2,blood_pressure_3,blood_pressure_4,v1_heartrate_max,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14,v15,v16,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13,w14,w15,w16,w17,w18,x1,x2,x3,x4,x5,x6,aids,cirrhosis,diabetes,hepatic_issue,immunosuppression,leukemia,lymphoma,carcinoma,body_system_1,body_system_2
0,51359,12058,66446,83,37.0,,0,0.0,Caucasian,M,182.9,,95,Floor,readmit,Med-Surg ICU,0.902778,113.0,501.02,0,0.0,4.0,6.0,0.0,5.0,123.0,0.0,76.0,4.0,37.00,74.0,56.0,74.0,56.0,120.0,103.0,93.0,74.0,93.0,74.0,19.0,6.0,99.0,90.0,127.0,106.0,127.0,106.0,37.30,36.90,65.0,60.0,65.0,60.0,112.0,104.0,84.0,84.0,84.0,84.0,19.0,16.0,99.0,92.0,115.0,111.0,115.0,111.0,160.0,122.0,3.5,3.5,-1.00,0.03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Sepsis,Cardiovascular
1,51360,92348,32311,185,60.0,32.961764,1,0.0,Caucasian,M,185.4,113.30,679,Operating Room / Recovery,admit,Neuro ICU,0.406944,,0.25,0,0.0,3.0,6.0,0.0,4.0,60.0,0.0,151.0,5.0,37.20,85.0,81.0,85.0,81.0,83.0,68.0,132.0,107.0,132.0,107.0,14.0,11.0,97.0,92.0,191.0,146.0,191.0,146.0,37.40,37.20,85.0,85.0,85.0,85.0,68.0,68.0,132.0,132.0,132.0,132.0,11.0,11.0,97.0,97.0,191.0,191.0,191.0,191.0,259.0,184.0,4.4,4.4,0.05,0.01,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,,
2,51361,68371,20639,157,70.0,19.295957,0,1.0,Caucasian,M,177.8,61.00,697,Floor,admit,SICU,0.977083,303.0,211.09,0,0.0,3.0,5.0,0.0,1.0,106.0,0.0,58.0,39.0,36.50,78.0,51.0,78.0,51.0,87.0,69.0,98.0,73.0,98.0,73.0,20.0,14.0,100.0,100.0,127.0,103.0,127.0,103.0,36.60,36.50,78.0,78.0,78.0,78.0,87.0,87.0,98.0,98.0,98.0,98.0,14.0,14.0,100.0,100.0,127.0,127.0,127.0,127.0,113.0,93.0,4.1,4.1,0.13,0.06,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Respiratory,Respiratory
3,51362,19544,116026,60,54.0,27.900747,0,1.0,Caucasian,M,180.3,90.70,538,Accident & Emergency,admit,Med-Surg ICU,0.172917,122.0,703.03,0,0.0,4.0,6.0,0.0,4.0,118.0,0.0,189.0,53.0,,144.0,73.0,144.0,73.0,118.0,59.0,184.0,81.0,181.0,81.0,53.0,0.0,100.0,94.0,232.0,101.0,232.0,101.0,35.10,34.50,112.0,95.0,112.0,95.0,102.0,88.0,136.0,116.0,136.0,116.0,29.0,15.0,98.0,97.0,183.0,156.0,183.0,156.0,101.0,101.0,3.7,3.7,0.03,0.02,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,Metabolic,Metabolic
4,51363,85588,102404,196,85.0,39.414062,0,1.0,Caucasian,F,160.0,100.90,809,Accident & Emergency,admit,CSICU,0.031944,117.0,106.01,0,0.0,4.0,6.0,0.0,5.0,165.0,0.0,63.0,37.0,36.70,84.0,52.0,84.0,52.0,163.0,114.0,91.0,63.0,91.0,63.0,32.0,18.0,98.0,89.0,119.0,90.0,119.0,90.0,37.10,36.70,66.0,66.0,66.0,66.0,160.0,144.0,77.0,77.0,77.0,77.0,27.0,27.0,96.0,96.0,119.0,119.0,119.0,119.0,110.0,110.0,3.9,3.9,0.15,0.07,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Cardiovascular,Cardiovascular
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12835,64194,79880,56511,19,39.0,23.147277,0,0.0,Caucasian,F,163.0,61.50,657,Accident & Emergency,admit,CCU-CTICU,0.155556,304.0,301.01,0,0.0,4.0,6.0,0.0,5.0,113.0,0.0,62.0,8.0,36.33,98.0,52.0,98.0,52.0,111.0,84.0,106.0,62.0,106.0,62.0,21.0,10.0,97.0,88.0,159.0,85.0,159.0,85.0,37.38,36.33,63.0,54.0,63.0,54.0,103.0,100.0,75.0,68.0,75.0,68.0,17.0,15.0,94.0,90.0,120.0,109.0,120.0,109.0,117.0,89.0,6.4,5.5,0.03,0.02,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Gastrointestinal,Gastrointestinal
12836,64195,97405,32055,136,79.0,27.759515,0,0.0,Caucasian,M,175.5,85.50,374,Accident & Emergency,admit,Med-Surg ICU,0.195139,112.0,107.01,0,0.0,4.0,6.0,0.0,5.0,103.0,0.0,46.0,13.0,36.30,77.0,43.0,77.0,43.0,99.0,69.0,82.0,51.0,82.0,51.0,19.0,14.0,97.0,89.0,108.0,77.0,108.0,77.0,37.70,36.30,67.0,59.0,67.0,59.0,74.0,69.0,74.0,67.0,74.0,67.0,16.0,14.0,97.0,95.0,97.0,90.0,97.0,90.0,180.0,180.0,4.5,4.5,0.04,0.02,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Cardiovascular,Cardiovascular
12837,64196,31970,117733,70,56.0,28.331661,0,1.0,Caucasian,F,172.7,84.50,451,Other Hospital,admit,Neuro ICU,0.012500,301.0,403.01,0,0.0,3.0,6.0,0.0,4.0,93.0,0.0,70.0,27.0,36.60,83.0,55.0,83.0,55.0,89.0,71.0,98.0,70.0,98.0,70.0,26.0,13.0,100.0,92.0,145.0,113.0,145.0,113.0,37.10,36.60,78.0,78.0,78.0,78.0,81.0,81.0,94.0,94.0,94.0,94.0,26.0,26.0,96.0,96.0,145.0,145.0,145.0,145.0,161.0,146.0,4.1,4.1,0.08,0.03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Neurological,Neurologic
12838,64197,76051,93359,189,79.0,24.578812,0,1.0,African American,F,170.2,71.20,543,Operating Room / Recovery,admit,Med-Surg ICU,0.331944,213.0,1405.05,1,0.0,1.0,1.0,0.0,1.0,164.0,0.0,41.0,12.0,35.80,101.0,13.0,101.0,13.0,162.0,102.0,104.0,22.0,104.0,22.0,16.0,12.0,100.0,7.0,135.0,43.0,135.0,43.0,36.60,35.80,101.0,79.0,101.0,79.0,125.0,102.0,104.0,89.0,104.0,89.0,16.0,12.0,99.0,56.0,135.0,116.0,135.0,116.0,134.0,114.0,5.0,4.5,0.71,0.53,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Gastrointestinal,Gastrointestinal


In [32]:
len(x_train.columns)

144

In [33]:
len(x_test.columns)

146

In [34]:
# 処理後
display(x_train)
display(x_test)

Unnamed: 0,personal_id_1,personal_id_2,facility_id,age,bmi,height,weight,icu_id,icu_4,icu_5,icu_6,glasgow_coma_scale_1,glasgow_coma_scale_2,glasgow_coma_scale_4,heart_rate,arterial_pressure,respiratory_rate,temp,blood_pressure_1,blood_pressure_2,blood_pressure_3,blood_pressure_4,v1_heartrate_max,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14,v15,v16,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13,w14,w15,w16,w17,w18,x1,x2,x3,x4,x5,x6,situation_1_0,situation_1_1,situation_2_0.0,situation_2_1.0,situation_2_unknown,ethnicity_African American,ethnicity_Asian,ethnicity_Caucasian,ethnicity_Hispanic,ethnicity_Native American,ethnicity_Other/Unknown,ethnicity_unknown,gender_F,gender_M,gender_unknown,icu_1_Accident & Emergency,icu_1_Floor,icu_1_Operating Room / Recovery,icu_1_Other Hospital,icu_1_Other ICU,icu_1_unknown,icu_2_admit,icu_2_readmit,icu_2_transfer,icu_3_CCU-CTICU,icu_3_CSICU,icu_3_CTICU,icu_3_Cardiac ICU,icu_3_MICU,icu_3_Med-Surg ICU,icu_3_Neuro ICU,icu_3_SICU,icu_7_0,icu_7_1,icu_8_0.0,icu_8_1.0,icu_8_unknown,glasgow_coma_scale_3_0.0,glasgow_coma_scale_3_1.0,glasgow_coma_scale_3_unknown,blood_oxy_0.0,blood_oxy_1.0,blood_oxy_unknown,aids_0.0,aids_1.0,aids_unknown,cirrhosis_0.0,cirrhosis_1.0,cirrhosis_unknown,diabetes_0.0,diabetes_1.0,diabetes_unknown,hepatic_issue_0.0,hepatic_issue_1.0,hepatic_issue_unknown,immunosuppression_0.0,immunosuppression_1.0,immunosuppression_unknown,leukemia_0.0,leukemia_1.0,leukemia_unknown,lymphoma_0.0,lymphoma_1.0,lymphoma_unknown,carcinoma_0.0,carcinoma_1.0,carcinoma_unknown,ohe_Cardiovascular,ohe_Gastrointestinal,ohe_Genitourinary,ohe_Gynecological,ohe_Haematologic,ohe_Hematological,ohe_Metabolic,ohe_Musculoskeletal/Skin,ohe_Neurologic,ohe_Neurological,ohe_Renal/Genitourinary,ohe_Respiratory,ohe_Sepsis,ohe_Trauma,ohe_Undefined Diagnoses
0,114501,58009,108,0.470442,-0.350090,0.353541,-0.202549,176,10.195277,37,-0.958632,-0.404470,0.415078,-0.596473,0.061541,-0.880434,0.510495,-4.634794,-1.471486,-0.328340,-1.318881,-0.280109,-0.280637,1.229265,-0.964250,-0.764635,-0.782049,-0.645723,-1.166409,0.245212,0.184046,0.599524,-0.925298,-1.456925,-0.757391,-1.288413,0.126135,0.248973,-1.146986,-0.730204,-0.871995,-0.530322,0.267034,0.625034,-1.209381,-0.893798,-0.874574,-0.630496,-0.620241,-0.042250,0.309311,0.434988,-1.434449,-1.324479,-1.055653,-0.994568,0.856075,-0.664261,-0.198955,-0.009586,0.765213,0.171629,1,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,44353,112590,69,0.230309,0.052854,0.692325,0.503980,160,1.155726,44,-1.191409,-2.322288,-2.995229,-1.832223,0.592493,1.350692,-1.386594,0.253124,-0.769379,-0.179744,-0.673137,-0.140400,0.398688,-0.470148,-0.211417,-0.387485,-0.122957,-0.304093,0.132933,-2.528611,-0.313065,-0.303441,1.137908,0.077014,1.048046,0.108906,0.174620,0.234759,-3.135770,-3.017362,-2.559895,-2.494001,-0.226017,-0.031969,-3.083341,-2.973526,-2.442478,-2.384239,-0.505072,-1.205932,0.017305,0.097317,-3.450720,-3.279206,-2.722877,-2.635169,-0.048067,0.049529,0.283477,0.516369,1.491374,1.045558,1,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,8023,1677,45,0.710575,-1.024614,-0.055640,-1.121037,97,-0.304742,39,-0.535572,0.554440,0.415078,0.639277,-1.406387,-0.880434,-1.255760,0.170594,-1.170583,-0.848423,-1.042134,-0.769091,-1.412847,-0.880352,-1.481823,-0.953210,-1.235175,-0.816538,-0.980788,-0.349179,0.184046,0.870413,-1.001713,-1.361054,-0.824259,-1.201081,0.108819,0.124596,-0.325532,0.214491,-0.174819,0.280763,-1.022483,-0.894286,-0.502226,-0.108995,-0.282911,0.031294,-0.389903,-0.478631,0.309311,0.434988,-0.762359,-0.361703,-0.499911,-0.186512,-0.952208,-0.967081,0.283477,0.441232,-0.003664,-0.022577,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
3,106340,74166,67,0.038203,-0.522525,0.687925,-0.163690,206,-0.322870,21,0.308754,0.554440,0.415078,0.639277,0.124006,0.927952,-1.386594,0.290396,-0.067272,1.380507,-0.027394,1.326544,-0.144772,0.291658,0.447312,1.309688,0.453749,1.233243,-0.702358,0.443342,0.184046,0.509228,0.182720,1.227469,0.212196,1.156896,0.126135,0.213437,0.322984,0.612258,0.375583,0.622273,0.380815,0.625034,0.346361,0.401127,0.427083,0.461457,-0.389903,0.103210,0.211976,0.290272,-0.117152,0.309323,0.033600,0.376680,2.238879,-1.313161,0.283477,-0.234995,-0.259956,-0.168232,1,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0
4,118467,52717,50,0.758601,-0.815148,-0.318748,-1.060982,14,-0.326836,39,-0.533434,0.554440,0.415078,0.639277,-0.250784,-0.457695,0.183411,0.263773,-0.067272,2.717865,-0.027394,2.583925,-0.552368,1.170665,0.023843,2.503995,0.083010,2.315071,-1.352029,-0.349179,-0.313065,0.148042,-0.046525,2.377923,0.011592,2.204885,0.074186,0.248973,1.187672,-0.232996,1.109452,-0.103435,-0.832848,-0.812161,0.593865,-0.305196,0.634165,-0.134154,0.761788,1.121431,0.017305,-0.047398,0.366753,0.046748,0.433734,0.156300,-1.728706,-2.308141,-2.611118,-2.639360,0.039052,-0.071129,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
60623,121586,101483,30,-0.105877,2.152652,-0.205233,1.705080,170,-0.302759,42,0.313029,0.554440,0.415078,0.639277,-1.094062,-0.763007,-1.386594,0.303707,-1.772390,0.117446,-1.595628,0.139018,-1.141116,-0.177146,-1.340666,-0.073194,-1.111595,-0.019401,-0.887978,0.245212,-0.147362,0.599524,-1.383788,-0.354406,-1.158599,-0.284090,0.126135,0.302278,-0.844346,-0.382158,-0.615141,-0.231501,-0.453578,-0.155158,-0.396152,0.008725,-0.194162,0.130562,-0.274734,0.394130,0.163308,0.290272,-0.439755,-0.011603,-0.233155,0.107327,-0.399086,0.265829,0.007802,0.140687,-0.217240,-0.119680,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0
60624,79880,56511,69,-0.970355,-0.512324,-0.187634,-0.714783,160,-0.265369,39,-0.548397,0.554440,0.415078,0.639277,0.467563,-0.598608,-1.124927,0.201210,0.484384,0.117446,0.479976,0.139018,0.398688,0.819062,0.070896,-0.198911,0.124203,-0.133277,-0.702358,-0.547309,-0.313065,-0.213144,0.411965,-0.594084,0.412800,-0.502421,0.174620,0.165463,-0.412001,-0.332438,-0.248206,-0.188812,0.532523,0.789285,-0.431510,-0.305196,-0.223745,-0.134154,-0.505072,-0.187711,0.017305,-0.047398,-0.224686,-0.099128,-0.055318,0.033868,-0.484182,-0.383071,1.799694,1.493142,-0.174525,-0.071129,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
60625,97405,32055,27,0.950708,-0.039991,0.362341,0.133052,59,-0.249224,10,-0.963078,0.554440,0.415078,0.639277,0.155238,-0.974376,-0.797842,0.197216,-0.568777,-0.551233,-0.488639,-0.489673,-0.144772,-0.059945,-1.058354,-0.890352,-0.864436,-0.759599,-0.887978,0.245212,-0.313065,-0.122848,-1.536618,-0.977569,-1.292335,-0.851751,0.230033,0.160132,-0.239064,-0.083834,-0.101432,0.024631,-0.567359,-0.483659,-0.466868,-0.344436,-0.253328,-0.167243,-0.620241,-0.333171,0.163308,0.193795,-0.843009,-0.653453,-0.566600,-0.431377,0.185946,1.585259,0.490234,0.741778,-0.131810,-0.071129,1,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0
60626,31970,117733,121,-0.153903,0.018602,0.239146,0.097726,86,-0.323720,36,-0.330369,-0.404470,0.415078,0.021402,-0.157087,-0.410723,0.117994,0.237150,-0.267874,0.340339,-0.211892,0.348581,-0.597656,0.057256,-0.305521,0.303956,-0.205343,0.322229,-0.238308,0.047082,0.184046,0.148042,-0.122940,0.748113,-0.055276,0.720233,0.126135,0.213437,0.236515,0.860862,0.302196,0.835716,-0.301871,0.009093,0.240287,0.715048,0.338334,0.726173,0.531450,1.412351,0.114641,0.242033,0.447404,0.951174,0.500423,0.915384,-0.016156,0.849839,0.214558,0.441232,0.039052,-0.022577,1,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0


Unnamed: 0,personal_id_1,personal_id_2,facility_id,age,bmi,height,weight,icu_id,icu_4,icu_5,icu_6,glasgow_coma_scale_1,glasgow_coma_scale_2,glasgow_coma_scale_4,heart_rate,arterial_pressure,respiratory_rate,temp,blood_pressure_1,blood_pressure_2,blood_pressure_3,blood_pressure_4,v1_heartrate_max,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14,v15,v16,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13,w14,w15,w16,w17,w18,x1,x2,x3,x4,x5,x6,situation_1_0,situation_1_1,situation_2_0.0,situation_2_1.0,situation_2_unknown,ethnicity_African American,ethnicity_Asian,ethnicity_Caucasian,ethnicity_Hispanic,ethnicity_Native American,ethnicity_Other/Unknown,ethnicity_unknown,gender_F,gender_M,gender_unknown,icu_1_Accident & Emergency,icu_1_Floor,icu_1_Operating Room / Recovery,icu_1_Other Hospital,icu_1_Other ICU,icu_1_unknown,icu_2_admit,icu_2_readmit,icu_2_transfer,icu_3_CCU-CTICU,icu_3_CSICU,icu_3_CTICU,icu_3_Cardiac ICU,icu_3_MICU,icu_3_Med-Surg ICU,icu_3_Neuro ICU,icu_3_SICU,icu_7_0,icu_7_1,icu_8_0.0,icu_8_1.0,icu_8_unknown,glasgow_coma_scale_3_0.0,glasgow_coma_scale_3_1.0,glasgow_coma_scale_3_unknown,blood_oxy_0.0,blood_oxy_1.0,blood_oxy_unknown,aids_0.0,aids_1.0,aids_unknown,cirrhosis_0.0,cirrhosis_1.0,cirrhosis_unknown,diabetes_0.0,diabetes_1.0,diabetes_unknown,hepatic_issue_0.0,hepatic_issue_1.0,hepatic_issue_unknown,immunosuppression_0.0,immunosuppression_1.0,immunosuppression_unknown,leukemia_0.0,leukemia_1.0,leukemia_unknown,lymphoma_0.0,lymphoma_1.0,lymphoma_unknown,carcinoma_0.0,carcinoma_1.0,carcinoma_unknown,ohe_Cardiovascular,ohe_Gastrointestinal,ohe_Genitourinary,ohe_Gynecological,ohe_Haematologic,ohe_Hematological,ohe_Metabolic,ohe_Musculoskeletal/Skin,ohe_Neurologic,ohe_Neurological,ohe_Renal/Genitourinary,ohe_Respiratory,ohe_Sepsis,ohe_Trauma,ohe_Undefined Diagnoses,TE_facility_id,TE_icu_id
0,12058,66446,132,-1.066408,-2.882806,0.687925,-2.887360,236.0,0.039414,11,-0.120869,0.554440,0.415078,0.639277,0.779889,-0.269810,-1.386594,0.290396,-0.719229,0.414637,-0.627013,0.418436,0.806284,1.932471,-0.540781,0.555389,-0.411310,0.549983,-0.887978,-1.339830,0.018342,-0.032551,-0.810675,0.412563,-0.657089,0.414570,0.160767,0.266741,-0.325532,-0.034113,-0.174819,0.067320,0.873865,0.953535,-0.113290,0.322647,0.042503,0.395278,-0.274734,-0.042250,0.260644,0.049079,-0.359104,-0.040778,-0.166466,0.082841,-0.026793,0.330719,-0.198955,-0.009586,-4.574208,-0.022577,1,0,1,0,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0.094488,0.094488
1,92348,32311,64,0.038203,0.492764,0.797920,1.115128,167.0,-0.162831,44,-1.191281,-0.404470,0.415078,0.021402,-1.187760,1.491605,-1.321177,0.317018,-0.167573,2.272079,-0.119643,2.164798,-0.869386,-0.118545,1.294249,2.629712,1.195228,2.428948,-1.352029,-0.349179,-0.313065,0.148042,1.634606,2.329988,1.482688,2.161219,0.178084,0.320046,0.539156,1.208908,0.559050,1.134537,-0.794921,-0.524722,1.583882,2.206174,1.462492,1.983574,-1.196087,-0.769551,0.163308,0.290272,1.684050,2.293225,1.522987,2.041766,1.026266,1.671779,0.421315,0.666641,-0.089094,-0.119680,0,1,1,0,0,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.137339,0.212121
2,68371,20639,42,0.518469,-0.906733,0.463536,-0.732446,175.0,0.069723,38,-0.740604,-0.404470,-0.266984,-1.832223,0.248936,-0.692550,0.902997,0.223839,-0.518626,0.043149,-0.442515,0.069163,-0.688233,-0.059945,-0.305521,0.492530,-0.205343,0.493044,-0.795168,0.245212,0.184046,0.870413,-0.810675,0.268757,-0.657089,0.283571,0.039554,0.195668,0.236515,0.860862,0.302196,0.835716,-0.074309,0.255469,0.381718,0.872009,0.456666,0.858531,-0.850580,-0.333171,0.309311,0.434988,-0.036501,0.426023,0.100289,0.474626,-0.526730,-0.296551,0.214558,0.441232,0.252629,0.123077,1,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0.085859,0.100000
3,19544,116026,114,-0.249956,-0.025528,0.573531,0.316750,122.0,-0.258288,20,0.310934,0.554440,0.415078,0.021402,0.623726,2.384056,1.818833,-4.634794,2.791309,1.677698,2.601704,1.605962,0.715707,-0.645950,3.740957,0.995397,3.213699,0.948551,2.267565,-2.528611,0.184046,0.328635,3.201114,0.172885,2.853483,0.196239,-0.220191,-0.159696,1.706485,1.706116,1.549774,1.561423,0.494596,0.296532,1.725313,1.578331,1.580825,1.454142,0.876957,-0.187711,0.211976,0.290272,1.468981,1.272099,1.345150,1.184736,-0.654373,-0.123511,-0.061117,0.140687,-0.174525,-0.071129,1,0,0,1,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0.050847,0.050847
4,85588,102404,73,1.238867,1.153536,-0.319628,0.677080,191.0,-0.315789,15,-0.965216,0.554440,0.415078,0.639277,2.091654,-0.575122,0.772163,0.250462,-0.217723,0.117446,-0.165768,0.139018,2.753684,2.577076,-0.634885,-0.136052,-0.493696,-0.076339,0.318553,1.037733,-0.147362,-0.122848,-1.116335,-0.354406,-0.924561,-0.284090,0.126135,0.231205,-0.282298,0.264212,-0.138126,0.323452,2.694359,2.596043,-0.360795,0.047965,-0.164579,0.163652,0.646619,1.557812,0.114641,0.242033,-0.251570,0.192623,-0.077548,0.278733,-0.558641,0.071159,0.076720,0.290959,0.338059,0.171629,1,0,0,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.058824,0.082569
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12835,79880,56511,69,-0.970355,-0.512324,-0.187634,-0.714783,160.0,-0.265369,39,-0.548397,0.554440,0.415078,0.639277,0.467563,-0.598608,-1.124927,0.201210,0.484384,0.117446,0.479976,0.139018,0.398688,0.819062,0.070896,-0.198911,0.124203,-0.133277,-0.702358,-0.547309,-0.313065,-0.213144,0.411965,-0.594084,0.412800,-0.502421,0.174620,0.165463,-0.412001,-0.332438,-0.248206,-0.188812,0.532523,0.789285,-0.431510,-0.305196,-0.223745,-0.134154,-0.505072,-0.187711,0.017305,-0.047398,-0.224686,-0.099128,-0.055318,0.033868,-0.484182,-0.383071,1.799694,1.493142,-0.174525,-0.071129,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0.104673,0.123077
12836,97405,32055,27,0.950708,-0.039991,0.362341,0.133052,59.0,-0.249224,10,-0.963078,0.554440,0.415078,0.639277,0.155238,-0.974376,-0.797842,0.197216,-0.568777,-0.551233,-0.488639,-0.489673,-0.144772,-0.059945,-1.058354,-0.890352,-0.864436,-0.759599,-0.887978,0.245212,-0.313065,-0.122848,-1.536618,-0.977569,-1.292335,-0.851751,0.230033,0.160132,-0.239064,-0.083834,-0.101432,0.024631,-0.567359,-0.483659,-0.466868,-0.344436,-0.253328,-0.167243,-0.620241,-0.333171,0.163308,0.193795,-0.843009,-0.653453,-0.566600,-0.431377,0.185946,1.585259,0.490234,0.741778,-0.131810,-0.071129,1,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.096386,0.096386
12837,31970,117733,121,-0.153903,0.018602,0.239146,0.097726,86.0,-0.323720,36,-0.330369,-0.404470,0.415078,0.021402,-0.157087,-0.410723,0.117994,0.237150,-0.267874,0.340339,-0.211892,0.348581,-0.597656,0.057256,-0.305521,0.303956,-0.205343,0.322229,-0.238308,0.047082,0.184046,0.148042,-0.122940,0.748113,-0.055276,0.720233,0.126135,0.213437,0.236515,0.860862,0.302196,0.835716,-0.301871,0.009093,0.240287,0.715048,0.338334,0.726173,0.531450,1.412351,0.114641,0.242033,0.447404,0.951174,0.500423,0.915384,-0.016156,0.849839,0.214558,0.441232,0.039052,-0.022577,1,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0.108527,0.126437
12838,76051,93359,68,0.950708,-0.365723,0.129151,-0.372116,123.0,-0.193422,29,1.811525,-2.322288,-2.995229,-1.832223,2.060421,-1.091804,-0.863259,0.130660,0.634836,-2.780163,0.618349,-2.585308,2.708395,1.873870,-0.023209,-2.713242,0.041816,-2.410811,-1.166409,-0.151049,0.184046,-7.527159,-0.505015,-2.607380,-0.389616,-2.336403,0.039554,0.071291,1.230907,0.910583,1.146146,0.878405,1.366916,0.871410,0.593865,0.518847,0.634165,0.560726,-0.620241,-0.624091,0.260644,-1.687511,0.178568,0.105098,0.278127,0.205274,-0.303354,0.157679,0.834829,0.741778,2.730120,2.405002,1,0,0,1,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,1,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0.111111,0.111111


##　重み

In [35]:
#不均衡データ用の重み
from sklearn.utils import class_weight

class_weights = list(class_weight.compute_class_weight('balanced', 
                                                        classes=np.unique(y_train['target_label']),
                                                        y=y_train['target_label'])
                        )
weights = torch.from_numpy(np.array(class_weights)).float()

# lossを指定し重みを加える 重み無しならLossはNone
cross_entropy_loss_wight = CrossEntropyLoss(weight=weights)

In [36]:
# tabnetのカテゴリカル変数の登録用に
# カテゴリカル変数はcat_idxに登録することでembedding処理されると考えています。
# 事前にhigh_cat_colsはlabel encoding済み
cat_idxs = [ i for i, f in enumerate(x_train) if f in high_cat_cols]
cat_dims = [dict_high_cat[i]['num_label'] for i in high_cat_cols]

In [37]:
import os
max_epochs = 100 if not os.getenv("CI", False) else 2 # 1000
max_epochs

100

In [38]:
# TabNetPretrainer
unsupervised_model = TabNetPretrainer(
    cat_idxs=cat_idxs,
    cat_dims=cat_dims,
    cat_emb_dim=3,
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    mask_type='entmax', # "sparsemax",
    n_shared_decoder=1, # nb shared glu for decoding
    n_indep_decoder=1, # nb independent glu for decoding
#     grouped_features=[[0, 1]], # you can group features together here
    verbose=5,
)

# PreTrain

In [39]:
x_tr, x_va, y_tr, y_va = train_test_split(x_train,
                                           y_train,
                                           test_size=0.2,
                                           shuffle=True,
                                           stratify=y_train,
                                           random_state=random_state)
x_tr, x_va = target_encoding(cat_cols, x_tr, y_tr, x_va)
# print(x_tr.isnull().any())
# print(x_va.isnull().any())

y_tr=np.squeeze(y_tr.values)
y_va=np.squeeze(y_va.values)
x_tr=x_tr.values
x_va=x_va.values

print('訓練データ: ',x_tr.shape, y_tr.shape)
print('検証データ: ',x_va.shape, y_va.shape)

訓練データ:  (48502, 146) (48502,)
検証データ:  (12126, 146) (12126,)


In [40]:
unsupervised_model.fit(
    X_train=x_tr,
    eval_set=[x_va],
    max_epochs=max_epochs , patience=5,
    batch_size=2048, virtual_batch_size=128,
    num_workers=0,
    drop_last=False,
    pretraining_ratio=0.5,
) 


unsupervised_model.save_model('./model/tabnet/pretrain')


epoch 0  | loss: 50.15713| val_0_unsup_loss_numpy: 3.050729990005493|  0:00:02s
epoch 5  | loss: 0.95165 | val_0_unsup_loss_numpy: 0.9874600172042847|  0:00:13s
epoch 10 | loss: 0.89627 | val_0_unsup_loss_numpy: 0.8735499978065491|  0:00:24s
epoch 15 | loss: 0.84563 | val_0_unsup_loss_numpy: 0.8115699887275696|  0:00:35s
epoch 20 | loss: 0.83212 | val_0_unsup_loss_numpy: 0.788129985332489|  0:00:46s
epoch 25 | loss: 0.78928 | val_0_unsup_loss_numpy: 0.7581800222396851|  0:00:57s
epoch 30 | loss: 0.77313 | val_0_unsup_loss_numpy: 0.7488200068473816|  0:01:08s
epoch 35 | loss: 0.77367 | val_0_unsup_loss_numpy: 0.7456799745559692|  0:01:19s
epoch 40 | loss: 0.76707 | val_0_unsup_loss_numpy: 0.7400400042533875|  0:01:30s
epoch 45 | loss: 0.76352 | val_0_unsup_loss_numpy: 0.7373300194740295|  0:01:41s
epoch 50 | loss: 0.76483 | val_0_unsup_loss_numpy: 0.7331200242042542|  0:01:52s
epoch 55 | loss: 0.7507  | val_0_unsup_loss_numpy: 0.7183499932289124|  0:02:03s
epoch 60 | loss: 0.74705 | val

'./model/tabnet/pretrain.zip'

In [41]:
loaded_pretrain = TabNetPretrainer()
loaded_pretrain.load_model('./model/tabnet/pretrain.zip')

## validation方法（ベースライン作成へ）

In [42]:
random_state = 123
params = {'n_d': 47, #値が大きいほど表現力と過学習のリスクがあがる
          'n_a': 24, # n_dと同じ値にしておくのが良いらしい
          'n_steps': 3,#TabNetEncoderのstepを何回繰り返すか
          'gamma': 1.3,
          'n_independent': 2,
          'n_shared': 2,
          'seed':random_state,
          'lambda_sparse': 1e-3,
          'optimizer_fn': torch.optim.Adam, 
          'optimizer_params': {'lr':2e-2},
          'mask_type': "entmax",#AttentiveTransformerでマスク作るのにどっちの関数を使うか'sparsemax'or'entmax'
          'scheduler_params':{'mode': "min",'patience': 5,'min_lr': 1e-5,'factor': 0.9},
          'scheduler_fn': torch.optim.lr_scheduler.ReduceLROnPlateau,
          'verbose':10
         }

In [49]:
# cvでの評価用 iuput_yはsereis
def train_tabnet(
    input_x,
    input_y,
    input_id,
    params,
    list_nfold=[0,1,2,3,4],
    n_splits=5,
    random_state=123
            ):
    train_oof = np.zeros(len(input_x))
    # foldごとの推論値
    metrics = []
    imp = pd.DataFrame()

    cv = list(StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state).split(input_x, input_y ))
    for  nfold in list_nfold :
        print('-'*20, nfold, '-'*20)
        
        idx_tr, idx_va = cv[nfold][0], cv[nfold][1]
        x_tr, y_tr = input_x.loc[idx_tr, :], input_y.loc[idx_tr, :]
        x_va, y_va = input_x.loc[idx_va, :], input_y.loc[idx_va, :]
        #CV内でCVしてtargetencording
        x_tr, x_va = target_encoding(cat_cols, x_tr, y_tr, x_va)
        column_name = x_tr.columns
        
        # print(x_tr.isnull().any())
        # print(x_va.isnull().any())
        
        print(x_tr.shape, y_tr.shape)
        print(x_va.shape, y_va.shape)
        print('y_train:{:.3f}, y_tr:{:.3f}, y_va{:.3f}'.
              format(y_train['target_label'].mean(), y_tr['target_label'].mean(), y_va['target_label'].mean(),))
        
        y_tr=np.squeeze(y_tr.values)
        y_va=np.squeeze(y_va.values)
        x_tr=x_tr.values
        x_va=x_va.values
        
        model = TabNetClassifier(**params)
        model.fit(
            X_train=x_tr,
            y_train=y_tr,
            eval_set=[(x_va, y_va)],
            eval_name = ["valid"],
            eval_metric = ["auc"],
            loss_fn = cross_entropy_loss_wight,
            max_epochs=200,
            patience=20, 
            batch_size=256,
            virtual_batch_size=128,
            num_workers=0, 
            drop_last=False,
            from_unsupervised=loaded_pretrain,
        )
        
        # モデルの保存
        fname_tabnet = 'model/tabnet/model_tabnet_fold{}.pickle'.format(nfold)
        with open(fname_tabnet, 'wb')as f:
            pickle.dump(model, f, protocol=4)
            
            
        # 評価
        y_tr_pred = model.predict_proba(x_tr)[:,1]
        y_va_pred = model.predict_proba(x_va)[:,1]
        metric_tr = roc_auc_score(y_tr, y_tr_pred)
        metric_va = roc_auc_score(y_va, y_va_pred)
        print('[auc] tr: {:.2f}, va: {:2f}'.
             format(metric_tr, metric_va))
        metrics.append([nfold, metric_tr, metric_va])
        
        # oof
        train_oof[idx_va] = y_va_pred
        
        _imp = pd.DataFrame({'col':column_name, 'imp':model.feature_importances_,'nfold':nfold})
        imp = pd.concat([imp, _imp], axis=0, ignore_index=False)

    
    print('-'*20, 'result', '-'*20)
    
    # metrix出力
    metrics = np.array(metrics)
    print(metrics)
    print('[cv] tr: {:.2f}+-{:.2f}, va: {:.2f}'.format(
        metrics[:,1].mean(), metrics[:,1].std(),
        metrics[:,2].mean(), metrics[:,2].std()
    ))
    print('[oof] {:.4f}'.format(
        roc_auc_score(input_y, train_oof)))
    # oof出力  
    train_oof = pd.concat([
        input_id,
        pd.DataFrame({'pred':train_oof})]
        ,axis=1)
    
        # imp出力
    imp = imp.groupby('col')['imp'].agg(['mean', 'std']).reset_index(drop=False)
    imp.columns = ['col', 'imp', 'imp_std']


    print('Done')
    
    return train_oof, imp, metrics

In [50]:
train_oof, imp, metrics = train_tabnet(x_train, y_train, id_train, params,list_nfold=[0,1,2,3,4], n_splits=5, random_state=123)

-------------------- 0 --------------------
(48502, 146) (48502, 1)
(12126, 146) (12126, 1)
y_train:0.076, y_tr:0.076, y_va0.076
epoch 0  | loss: 0.49827 | valid_auc: 0.87847 |  0:00:03s
epoch 10 | loss: 0.36278 | valid_auc: 0.88349 |  0:00:33s
epoch 20 | loss: 0.30831 | valid_auc: 0.87648 |  0:01:04s

Early stopping occurred at epoch 27 with best_epoch = 7 and best_valid_auc = 0.89219
[auc] tr: 0.92, va: 0.892192
-------------------- 1 --------------------
(48502, 146) (48502, 1)
(12126, 146) (12126, 1)
y_train:0.076, y_tr:0.076, y_va0.076
epoch 0  | loss: 0.49881 | valid_auc: 0.88467 |  0:00:03s
epoch 10 | loss: 0.36286 | valid_auc: 0.88477 |  0:00:33s
epoch 20 | loss: 0.30829 | valid_auc: 0.87572 |  0:01:02s

Early stopping occurred at epoch 25 with best_epoch = 5 and best_valid_auc = 0.89485
[auc] tr: 0.91, va: 0.894854
-------------------- 2 --------------------
(48502, 146) (48502, 1)
(12126, 146) (12126, 1)
y_train:0.076, y_tr:0.076, y_va0.076
epoch 0  | loss: 0.49781 | valid_au

## 推論

In [51]:
display(x_test)

Unnamed: 0,personal_id_1,personal_id_2,facility_id,age,bmi,height,weight,icu_id,icu_4,icu_5,icu_6,glasgow_coma_scale_1,glasgow_coma_scale_2,glasgow_coma_scale_4,heart_rate,arterial_pressure,respiratory_rate,temp,blood_pressure_1,blood_pressure_2,blood_pressure_3,blood_pressure_4,v1_heartrate_max,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14,v15,v16,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13,w14,w15,w16,w17,w18,x1,x2,x3,x4,x5,x6,situation_1_0,situation_1_1,situation_2_0.0,situation_2_1.0,situation_2_unknown,ethnicity_African American,ethnicity_Asian,ethnicity_Caucasian,ethnicity_Hispanic,ethnicity_Native American,ethnicity_Other/Unknown,ethnicity_unknown,gender_F,gender_M,gender_unknown,icu_1_Accident & Emergency,icu_1_Floor,icu_1_Operating Room / Recovery,icu_1_Other Hospital,icu_1_Other ICU,icu_1_unknown,icu_2_admit,icu_2_readmit,icu_2_transfer,icu_3_CCU-CTICU,icu_3_CSICU,icu_3_CTICU,icu_3_Cardiac ICU,icu_3_MICU,icu_3_Med-Surg ICU,icu_3_Neuro ICU,icu_3_SICU,icu_7_0,icu_7_1,icu_8_0.0,icu_8_1.0,icu_8_unknown,glasgow_coma_scale_3_0.0,glasgow_coma_scale_3_1.0,glasgow_coma_scale_3_unknown,blood_oxy_0.0,blood_oxy_1.0,blood_oxy_unknown,aids_0.0,aids_1.0,aids_unknown,cirrhosis_0.0,cirrhosis_1.0,cirrhosis_unknown,diabetes_0.0,diabetes_1.0,diabetes_unknown,hepatic_issue_0.0,hepatic_issue_1.0,hepatic_issue_unknown,immunosuppression_0.0,immunosuppression_1.0,immunosuppression_unknown,leukemia_0.0,leukemia_1.0,leukemia_unknown,lymphoma_0.0,lymphoma_1.0,lymphoma_unknown,carcinoma_0.0,carcinoma_1.0,carcinoma_unknown,ohe_Cardiovascular,ohe_Gastrointestinal,ohe_Genitourinary,ohe_Gynecological,ohe_Haematologic,ohe_Hematological,ohe_Metabolic,ohe_Musculoskeletal/Skin,ohe_Neurologic,ohe_Neurological,ohe_Renal/Genitourinary,ohe_Respiratory,ohe_Sepsis,ohe_Trauma,ohe_Undefined Diagnoses,TE_facility_id,TE_icu_id
0,12058,66446,132,-1.066408,-2.882806,0.687925,-2.887360,236.0,0.039414,11,-0.120869,0.554440,0.415078,0.639277,0.779889,-0.269810,-1.386594,0.290396,-0.719229,0.414637,-0.627013,0.418436,0.806284,1.932471,-0.540781,0.555389,-0.411310,0.549983,-0.887978,-1.339830,0.018342,-0.032551,-0.810675,0.412563,-0.657089,0.414570,0.160767,0.266741,-0.325532,-0.034113,-0.174819,0.067320,0.873865,0.953535,-0.113290,0.322647,0.042503,0.395278,-0.274734,-0.042250,0.260644,0.049079,-0.359104,-0.040778,-0.166466,0.082841,-0.026793,0.330719,-0.198955,-0.009586,-4.574208,-0.022577,1,0,1,0,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0.094488,0.094488
1,92348,32311,64,0.038203,0.492764,0.797920,1.115128,167.0,-0.162831,44,-1.191281,-0.404470,0.415078,0.021402,-1.187760,1.491605,-1.321177,0.317018,-0.167573,2.272079,-0.119643,2.164798,-0.869386,-0.118545,1.294249,2.629712,1.195228,2.428948,-1.352029,-0.349179,-0.313065,0.148042,1.634606,2.329988,1.482688,2.161219,0.178084,0.320046,0.539156,1.208908,0.559050,1.134537,-0.794921,-0.524722,1.583882,2.206174,1.462492,1.983574,-1.196087,-0.769551,0.163308,0.290272,1.684050,2.293225,1.522987,2.041766,1.026266,1.671779,0.421315,0.666641,-0.089094,-0.119680,0,1,1,0,0,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.137339,0.212121
2,68371,20639,42,0.518469,-0.906733,0.463536,-0.732446,175.0,0.069723,38,-0.740604,-0.404470,-0.266984,-1.832223,0.248936,-0.692550,0.902997,0.223839,-0.518626,0.043149,-0.442515,0.069163,-0.688233,-0.059945,-0.305521,0.492530,-0.205343,0.493044,-0.795168,0.245212,0.184046,0.870413,-0.810675,0.268757,-0.657089,0.283571,0.039554,0.195668,0.236515,0.860862,0.302196,0.835716,-0.074309,0.255469,0.381718,0.872009,0.456666,0.858531,-0.850580,-0.333171,0.309311,0.434988,-0.036501,0.426023,0.100289,0.474626,-0.526730,-0.296551,0.214558,0.441232,0.252629,0.123077,1,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0.085859,0.100000
3,19544,116026,114,-0.249956,-0.025528,0.573531,0.316750,122.0,-0.258288,20,0.310934,0.554440,0.415078,0.021402,0.623726,2.384056,1.818833,-4.634794,2.791309,1.677698,2.601704,1.605962,0.715707,-0.645950,3.740957,0.995397,3.213699,0.948551,2.267565,-2.528611,0.184046,0.328635,3.201114,0.172885,2.853483,0.196239,-0.220191,-0.159696,1.706485,1.706116,1.549774,1.561423,0.494596,0.296532,1.725313,1.578331,1.580825,1.454142,0.876957,-0.187711,0.211976,0.290272,1.468981,1.272099,1.345150,1.184736,-0.654373,-0.123511,-0.061117,0.140687,-0.174525,-0.071129,1,0,0,1,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0.050847,0.050847
4,85588,102404,73,1.238867,1.153536,-0.319628,0.677080,191.0,-0.315789,15,-0.965216,0.554440,0.415078,0.639277,2.091654,-0.575122,0.772163,0.250462,-0.217723,0.117446,-0.165768,0.139018,2.753684,2.577076,-0.634885,-0.136052,-0.493696,-0.076339,0.318553,1.037733,-0.147362,-0.122848,-1.116335,-0.354406,-0.924561,-0.284090,0.126135,0.231205,-0.282298,0.264212,-0.138126,0.323452,2.694359,2.596043,-0.360795,0.047965,-0.164579,0.163652,0.646619,1.557812,0.114641,0.242033,-0.251570,0.192623,-0.077548,0.278733,-0.558641,0.071159,0.076720,0.290959,0.338059,0.171629,1,0,0,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.058824,0.082569
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12835,79880,56511,69,-0.970355,-0.512324,-0.187634,-0.714783,160.0,-0.265369,39,-0.548397,0.554440,0.415078,0.639277,0.467563,-0.598608,-1.124927,0.201210,0.484384,0.117446,0.479976,0.139018,0.398688,0.819062,0.070896,-0.198911,0.124203,-0.133277,-0.702358,-0.547309,-0.313065,-0.213144,0.411965,-0.594084,0.412800,-0.502421,0.174620,0.165463,-0.412001,-0.332438,-0.248206,-0.188812,0.532523,0.789285,-0.431510,-0.305196,-0.223745,-0.134154,-0.505072,-0.187711,0.017305,-0.047398,-0.224686,-0.099128,-0.055318,0.033868,-0.484182,-0.383071,1.799694,1.493142,-0.174525,-0.071129,1,0,1,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0.104673,0.123077
12836,97405,32055,27,0.950708,-0.039991,0.362341,0.133052,59.0,-0.249224,10,-0.963078,0.554440,0.415078,0.639277,0.155238,-0.974376,-0.797842,0.197216,-0.568777,-0.551233,-0.488639,-0.489673,-0.144772,-0.059945,-1.058354,-0.890352,-0.864436,-0.759599,-0.887978,0.245212,-0.313065,-0.122848,-1.536618,-0.977569,-1.292335,-0.851751,0.230033,0.160132,-0.239064,-0.083834,-0.101432,0.024631,-0.567359,-0.483659,-0.466868,-0.344436,-0.253328,-0.167243,-0.620241,-0.333171,0.163308,0.193795,-0.843009,-0.653453,-0.566600,-0.431377,0.185946,1.585259,0.490234,0.741778,-0.131810,-0.071129,1,0,1,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.096386,0.096386
12837,31970,117733,121,-0.153903,0.018602,0.239146,0.097726,86.0,-0.323720,36,-0.330369,-0.404470,0.415078,0.021402,-0.157087,-0.410723,0.117994,0.237150,-0.267874,0.340339,-0.211892,0.348581,-0.597656,0.057256,-0.305521,0.303956,-0.205343,0.322229,-0.238308,0.047082,0.184046,0.148042,-0.122940,0.748113,-0.055276,0.720233,0.126135,0.213437,0.236515,0.860862,0.302196,0.835716,-0.301871,0.009093,0.240287,0.715048,0.338334,0.726173,0.531450,1.412351,0.114641,0.242033,0.447404,0.951174,0.500423,0.915384,-0.016156,0.849839,0.214558,0.441232,0.039052,-0.022577,1,0,0,1,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0.108527,0.126437
12838,76051,93359,68,0.950708,-0.365723,0.129151,-0.372116,123.0,-0.193422,29,1.811525,-2.322288,-2.995229,-1.832223,2.060421,-1.091804,-0.863259,0.130660,0.634836,-2.780163,0.618349,-2.585308,2.708395,1.873870,-0.023209,-2.713242,0.041816,-2.410811,-1.166409,-0.151049,0.184046,-7.527159,-0.505015,-2.607380,-0.389616,-2.336403,0.039554,0.071291,1.230907,0.910583,1.146146,0.878405,1.366916,0.871410,0.593865,0.518847,0.634165,0.560726,-0.620241,-0.624091,0.260644,-1.687511,0.178568,0.105098,0.278127,0.205274,-0.303354,0.157679,0.834829,0.741778,2.730120,2.405002,1,0,0,1,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,1,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0.111111,0.111111


In [52]:
def predict_tabnet(input_x,
               input_id,
               list_nfold=[0,1,2,3,4],
               ):
    pred = np.zeros((len(input_x), len(list_nfold)))
    for nfold in list_nfold:
        print('-'*20, nfold, '-'*20)
        fname_tabnet = 'model/tabnet/model_tabnet_fold{}.pickle'.format(nfold)
        with open(fname_tabnet, 'rb')as f:
            model = pickle.load(f)
        pred[:,nfold] = model.predict_proba(input_x.values)[:,1]
        
    pred = pd.concat([
        input_id,
        pd.DataFrame({'target_label':pred.mean(axis=1)}),], axis=1)
    
    print('Done')
    
    return pred

In [53]:
test_pred_proba = predict_tabnet(
        x_test,
        id_test,
        list_nfold=[0,1,2,3,4],
        )

-------------------- 0 --------------------
-------------------- 1 --------------------
-------------------- 2 --------------------
-------------------- 3 --------------------
-------------------- 4 --------------------
Done


In [54]:
test_pred_proba

Unnamed: 0,id,target_label
0,51359,0.287734
1,51360,0.049686
2,51361,0.513013
3,51362,0.131906
4,51363,0.656096
...,...,...
12835,64194,0.157725
12836,64195,0.306261
12837,64196,0.432315
12838,64197,0.985214


In [55]:
test_pred_proba.to_csv('sub/submission_tabnet.csv', index=None, header=True,)

In [56]:
sample= pd.read_csv("data/submission.csv")

## アンサンブル用データ

In [57]:
import pickle
 
with open('ensamble/tabnet_train.pickle', mode='wb') as fo:
    pickle.dump(train_oof, fo)
    
with open('ensamble/tabnet_test.pickle', mode='wb') as fo:
    pickle.dump(test_pred_proba, fo)
    

## ベースライン

In [69]:
x_tr, x_va2, y_tr, y_va2 = train_test_split(x_train,
                                           y_train,
                                           test_size=0.2,
                                           shuffle=True,
                                           stratify=y_train,
                                           random_state=random_state)
print('検証データ: ',x_tr.shape, y_tr.shape)
print('ベースライン検証データ: ',x_va2.shape, y_va2.shape)

x_tr1, x_va1, y_tr1, y_va1 = train_test_split(x_tr,
                                              y_tr,
                                              test_size=0.2,
                                              shuffle=True,
                                              stratify=y_tr,
                                              random_state=random_state)
print('検証データ(train): ',x_tr1.shape, y_tr1.shape)
print('検証データ(test): ',x_va1.shape, y_va1.shape)


cat_cols = ['facility_id', 'icu_id']
x_tr1, x_va1 = target_encoding(cat_cols, x_tr1, y_tr1, x_va1)
# x_va1 = tranform_data_TE(cat_cols, x_va1, y_tr1['target_label'])
x_va2 = tranform_data_TE(cat_cols, x_va2, y_tr['target_label'])
print(x_tr1.isnull().any())
print(x_va1.isnull().any())
print(x_va2.isnull().any())


y_tr1=np.squeeze(y_tr1.values)
y_va1=np.squeeze(y_va1.values)
y_va2=np.squeeze(y_va2.values)

x_tr1=x_tr1.values
x_va1=x_va1.values
x_va2=x_va2.values




print('-'*20,'tabnet用にunsqueze','-'*20)
print('ベースライン検証データ: ',x_va2.shape, y_va2.shape)
print('検証データ(train): ',x_tr1.shape, y_tr1.shape)
print('検証データ(test): ',x_va1.shape, y_va1.shape)


検証データ:  (48502, 144) (48502, 1)
ベースライン検証データ:  (12126, 144) (12126, 1)
検証データ(train):  (38801, 144) (38801, 1)
検証データ(test):  (9701, 144) (9701, 1)
personal_id_1              False
personal_id_2              False
facility_id                False
age                        False
bmi                        False
                           ...  
ohe_Sepsis                 False
ohe_Trauma                 False
ohe_Undefined Diagnoses    False
TE_facility_id             False
TE_icu_id                  False
Length: 146, dtype: bool
personal_id_1              False
personal_id_2              False
facility_id                False
age                        False
bmi                        False
                           ...  
ohe_Sepsis                 False
ohe_Trauma                 False
ohe_Undefined Diagnoses    False
TE_facility_id             False
TE_icu_id                  False
Length: 146, dtype: bool
personal_id_1              False
personal_id_2              False
facility_id  

In [None]:


#validation結果
loaded_pretrain = TabNetPretrainer()
loaded_pretrain.load_model('./model/tabnet/pretrain.zip')

model = TabNetClassifier(**params)
model.fit(
    X_train=x_tr1,
    y_train=y_tr1,
    eval_set=[(x_va1, y_va1)],
    eval_name = ["valid"],
    eval_metric = ["auc"],
    loss_fn = cross_entropy_loss_wight,
    max_epochs=200,
    patience=20, 
    batch_size=256,
    virtual_batch_size=128,
    num_workers=0, 
    drop_last=False,
    from_unsupervised=loaded_pretrain
)

epoch 0  | loss: 0.50509 | valid_auc: 0.88264 |  0:00:11s


In [None]:
#評価指標の差
y_va1_pred = model.predict(x_va1)
y_va2_pred = model.predict(x_va2)

print('[検証データ] auc: {:.4f}'.format(roc_auc_score(y_va1, y_va1_pred)))
print('[ベースライン検証データ] auc: {:.4f}'.format(roc_auc_score(y_va2, y_va2_pred)))

y_va1_pred_proba = model.predict_proba(x_va1)
y_va2_pred_proba = model.predict_proba(x_va2)
print('[検証データ] auc: {:.4f}'.format(roc_auc_score(y_va1, y_va1_pred_proba[:,1])))
print('[ベースライン検証データ] auc: {:.4f}'.format(roc_auc_score(y_va2, y_va2_pred_proba[:,1])))


for param in ['loss', 'valid_auc']:
    plt.plot(model.history[param], label=param)
    plt.xlabel('epoch')
    plt.grid()
    plt.legend()
    plt.show()

In [None]:
#誤分類の分布
print('検証データ')
print(confusion_matrix(y_va1, np.where(y_va1_pred>0.5,1,0)))
print(confusion_matrix(y_va1, np.where(y_va1_pred>0.5,1,0), normalize='all'))

print('ベースライン検証データ')
print(confusion_matrix(y_va2, np.where(y_va2_pred>0.5,1,0)))
print(confusion_matrix(y_va2, np.where(y_va2_pred>0.5,1,0), normalize='all'))

In [None]:
# 予測値の分布
y_va1_pred_prob = model.predict_proba(x_va1)[:,1]
y_va2_pred_prob = model.predict_proba(x_va2)[:,1]

fig = plt.figure(figsize=(10,8))


fig.add_subplot(2,1,1)
plt.title('validation_data')
plt.hist(y_va1_pred_prob[np.array(y_va1).reshape(-1)==1], bins=10, alpha=0.5, label='1')
plt.hist(y_va1_pred_prob[np.array(y_va1).reshape(-1)==0], bins=10, alpha=0.5, label='0')
plt.axis([0, 1, 0, 2000])

plt.grid()
plt.legend()

fig.add_subplot(2,1,2)
plt.title('basreline_validation_data')
plt.hist(y_va2_pred_prob[np.array(y_va2).reshape(-1)==1], bins=10, alpha=0.5, label='1')
plt.hist(y_va2_pred_prob[np.array(y_va2).reshape(-1)==0], bins=10, alpha=0.5, label='0')
plt.axis([0, 1, 0, 2000])

plt.grid()
plt.legend()

## チューニング

In [60]:
import optuna

params = {'n_d': 8, #値が大きいほど表現力と過学習のリスクがあがる
          'n_a': 8, # n_dと同じ値にしておくのが良いらしい
          'n_steps': 3,#TabNetEncoderのstepを何回繰り返すか
          'gamma': 1.3,
          'n_independent': 2,
          'n_shared': 2,
          'seed':random_state,
          'lambda_sparse': 1e-3,
          'optimizer_fn': torch.optim.Adam, 
          'optimizer_params': {'lr':2e-2},
          'mask_type': "entmax",#AttentiveTransformerでマスク作るのにどっちの関数を使うか'sparsemax'or'entmax'
          'scheduler_params':{'mode': "min",'patience': 5,'min_lr': 1e-5,'factor': 0.9},
          'scheduler_fn': torch.optim.lr_scheduler.ReduceLROnPlateau,
          'verbose':10
         }
         
         n_d, n_a	8-64	8
n_steps	1-10	3
gamma	1.0-2.0	1.3
mask_type	"entmatx" or "sparsemax"

In [29]:
# 探索するパラメータ
# 探索しないパラメータ

params_base = {
   'optimizer_fn': torch.optim.Adam,
   'optimizer_params': {'lr':2e-2,'weight_decay':1e-5},
   'mask_type': "entmax",#AttentiveTransformerでマスク作るのにどっちの関数を使うか'sparsemax'or'entmax'
   'scheduler_params':{'mode': "min",'patience': 5,'min_lr': 1e-5,'factor': 0.9, 'scheduler_fn': torch.optim.lr_scheduler.ReduceLROnPlateau,},
   'verbose':10,
   'seed': 123,
}

def objective(trial):
    # 探索するパラメータ
    params_tuning = {
        'n_d': trial.suggest_int('n_d',8,64),
        'n_a': trial.suggest_int('n_a',8,64),
        'n_steps': trial.suggest_int('n_steps', 1, 10),
        'gamma': trial.suggest_float('gamma', 1.0, 2.0),
        'mask_type': trial.suggest_categorical('mask_type', ['entmatx','sparsemax']),
    }
    params_tuning.update(params_base)
    
    # モデル学習・評価
    list_metrics = []
    cv = list(StratifiedKFold(n_splits=4, shuffle=True, random_state=random_state).split(X_train, y_train))
    for nfold in np.arange(4):
        idx_tr, idx_va = cv[nfold][0], cv[nfold][1]
        x_tr, y_tr = X_train.loc[idx_tr, :], y_train.loc[idx_tr, :]
        x_va, y_va = X_train.loc[idx_va, :], y_train.loc[idx_va, :]
        y_tr=np.squeeze(y_tr.values)
        y_va=np.squeeze(y_va.values)
        x_tr=x_tr.values
        x_va=x_va.values
        pretrainer = TabNetPretrainer(**params)
        pretrainer.fit(
            X_train=x_tr,
            eval_set=[x_va],
            max_epochs=200,
            patience=20, batch_size=256, virtual_batch_size=128,
            num_workers=1, drop_last=True)
        model = TabNetClassifier(**params)
        model.fit(
            X_train=x_tr,
            y_train=y_tr,
            eval_set=[(x_va, y_va)],
            eval_name = ["valid"],
            eval_metric = ["auc"],
            max_epochs=200,
            patience=20, 
            batch_size=256,
            virtual_batch_size=128,
            num_workers=0, 
            drop_last=False,
            from_unsupervised=pretrainer
        )
        y_va_pred = model.predict_proba(x_va)[:,1]
        metric_va = accuracy_score(y_va, np.where(y_va_pred>0.5, 1, 0))
        list_metrics.append(metric_va)
        
    # 評価値の計算
    metrics = np.mean(list_metrics)
    
    return metrics

In [30]:
sampler = optuna.samplers.TPESampler(seed=random_state)
study = optuna.create_study(sampler=sampler, direction='maximize')
study.optimize(objective, n_trials=30)

[32m[I 2023-03-27 08:10:52,385][0m A new study created in memory with name: no-name-6d507e69-96d2-4c1b-a0eb-d5eb0aa38957[0m




[32m[I 2023-03-27 08:12:12,219][0m Trial 0 finished with value: 0.881111394495306 and parameters: {'num_leaves': 181, 'min_data_in_leaf': 61, 'min_sum_hessian_in_leaf': 4.792414358623587e-05, 'feature_fraction': 0.7756573845414456, 'bagging_fraction': 0.8597344848927815, 'reg_alpha': 0.492522233779106, 'reg_lambda': 83.76388146302445}. Best is trial 0 with value: 0.881111394495306.[0m




[32m[I 2023-03-27 08:12:56,070][0m Trial 1 finished with value: 0.9033858426149493 and parameters: {'num_leaves': 178, 'min_data_in_leaf': 99, 'min_sum_hessian_in_leaf': 0.00015009027543233888, 'feature_fraction': 0.6715890080754348, 'bagging_fraction': 0.8645248536920208, 'reg_alpha': 0.567922374174008, 'reg_lambda': 0.01732652966363563}. Best is trial 1 with value: 0.9033858426149493.[0m




[32m[I 2023-03-27 08:13:31,604][0m Trial 2 finished with value: 0.8843433513033084 and parameters: {'num_leaves': 107, 'min_data_in_leaf': 149, 'min_sum_hessian_in_leaf': 3.52756635172055e-05, 'feature_fraction': 0.5877258780737462, 'bagging_fraction': 0.7657756869209191, 'reg_alpha': 1.3406343673102123, 'reg_lambda': 3.4482904089131434}. Best is trial 1 with value: 0.9033858426149493.[0m




[32m[I 2023-03-27 08:14:09,266][0m Trial 3 finished with value: 0.8842266213801941 and parameters: {'num_leaves': 219, 'min_data_in_leaf': 146, 'min_sum_hessian_in_leaf': 0.0006808799287054756, 'feature_fraction': 0.8612216912851107, 'bagging_fraction': 0.6614794569265892, 'reg_alpha': 0.2799978022399009, 'reg_lambda': 0.08185645330667264}. Best is trial 1 with value: 0.9033858426149493.[0m




[32m[I 2023-03-27 08:14:42,903][0m Trial 4 finished with value: 0.8770422102729732 and parameters: {'num_leaves': 81, 'min_data_in_leaf': 128, 'min_sum_hessian_in_leaf': 1.889360449174926e-05, 'feature_fraction': 0.7168505863397641, 'bagging_fraction': 0.7154313816648219, 'reg_alpha': 0.9434967110751797, 'reg_lambda': 0.5050346330980694}. Best is trial 1 with value: 0.9033858426149493.[0m




[32m[I 2023-03-27 08:15:31,907][0m Trial 5 finished with value: 0.8778987276282181 and parameters: {'num_leaves': 85, 'min_data_in_leaf': 88, 'min_sum_hessian_in_leaf': 0.004788147156768277, 'feature_fraction': 0.9720800091019398, 'bagging_fraction': 0.7509183379421682, 'reg_alpha': 3.1319282717196035, 'reg_lambda': 0.029005047452739414}. Best is trial 1 with value: 0.9033858426149493.[0m




[32m[I 2023-03-27 08:16:24,251][0m Trial 6 finished with value: 0.8301758702735702 and parameters: {'num_leaves': 87, 'min_data_in_leaf': 86, 'min_sum_hessian_in_leaf': 0.003971252247766701, 'feature_fraction': 0.6252276826982534, 'bagging_fraction': 0.7415171321313522, 'reg_alpha': 87.54657140659076, 'reg_lambda': 1.1965765212602313}. Best is trial 1 with value: 0.9033858426149493.[0m




[32m[I 2023-03-27 08:17:17,600][0m Trial 7 finished with value: 0.9035224235857501 and parameters: {'num_leaves': 160, 'min_data_in_leaf': 28, 'min_sum_hessian_in_leaf': 0.0030131614432849746, 'feature_fraction': 0.8015300642054637, 'bagging_fraction': 0.7725340032332324, 'reg_alpha': 0.23499322154972468, 'reg_lambda': 0.1646202117975735}. Best is trial 7 with value: 0.9035224235857501.[0m




[32m[I 2023-03-27 08:17:53,432][0m Trial 8 finished with value: 0.8767498084351151 and parameters: {'num_leaves': 111, 'min_data_in_leaf': 138, 'min_sum_hessian_in_leaf': 0.00423029374725911, 'feature_fraction': 0.7552111687390055, 'bagging_fraction': 0.8346568914811361, 'reg_alpha': 2.206714812711709, 'reg_lambda': 3.1594683442464033}. Best is trial 7 with value: 0.9035224235857501.[0m




[32m[I 2023-03-27 08:18:30,038][0m Trial 9 finished with value: 0.8768086033257799 and parameters: {'num_leaves': 175, 'min_data_in_leaf': 170, 'min_sum_hessian_in_leaf': 1.7765808030254076e-05, 'feature_fraction': 0.8818414207216692, 'bagging_fraction': 0.6218331872684371, 'reg_alpha': 0.05982625838323253, 'reg_lambda': 1.9490717640641542}. Best is trial 7 with value: 0.9035224235857501.[0m




[32m[I 2023-03-27 08:18:56,459][0m Trial 10 finished with value: 0.8473101602909165 and parameters: {'num_leaves': 32, 'min_data_in_leaf': 6, 'min_sum_hessian_in_leaf': 0.0010167214653943027, 'feature_fraction': 0.5040305717020102, 'bagging_fraction': 0.9940542446575642, 'reg_alpha': 0.010612397212799423, 'reg_lambda': 0.1661409929489422}. Best is trial 7 with value: 0.9035224235857501.[0m




[32m[I 2023-03-27 08:22:12,364][0m Trial 11 finished with value: 0.9028212812451548 and parameters: {'num_leaves': 165, 'min_data_in_leaf': 23, 'min_sum_hessian_in_leaf': 0.0002546304993969339, 'feature_fraction': 0.6893072883706839, 'bagging_fraction': 0.5643816257078462, 'reg_alpha': 0.10108607276304732, 'reg_lambda': 0.010211649165953098}. Best is trial 7 with value: 0.9035224235857501.[0m




[32m[I 2023-03-27 08:25:11,608][0m Trial 12 finished with value: 0.8963765711292914 and parameters: {'num_leaves': 255, 'min_data_in_leaf': 41, 'min_sum_hessian_in_leaf': 0.000153402164637483, 'feature_fraction': 0.8175314403750262, 'bagging_fraction': 0.8626569725135499, 'reg_alpha': 8.37591328058755, 'reg_lambda': 0.01045117234553353}. Best is trial 7 with value: 0.9035224235857501.[0m




[32m[I 2023-03-27 08:27:38,684][0m Trial 13 finished with value: 0.8689615751656204 and parameters: {'num_leaves': 160, 'min_data_in_leaf': 190, 'min_sum_hessian_in_leaf': 0.008932219618102614, 'feature_fraction': 0.6800062881915976, 'bagging_fraction': 0.5111428766066355, 'reg_alpha': 0.17052736553266273, 'reg_lambda': 0.0988056140193565}. Best is trial 7 with value: 0.9035224235857501.[0m




[32m[I 2023-03-27 08:30:09,640][0m Trial 14 finished with value: 0.9137250356014663 and parameters: {'num_leaves': 215, 'min_data_in_leaf': 70, 'min_sum_hessian_in_leaf': 0.00011414918234944389, 'feature_fraction': 0.8041563305513639, 'bagging_fraction': 0.9449350856793115, 'reg_alpha': 0.045693362410957825, 'reg_lambda': 0.038934140387364174}. Best is trial 14 with value: 0.9137250356014663.[0m




[32m[I 2023-03-27 08:31:17,831][0m Trial 15 finished with value: 0.9143091750480853 and parameters: {'num_leaves': 219, 'min_data_in_leaf': 57, 'min_sum_hessian_in_leaf': 0.001215432268892957, 'feature_fraction': 0.798973943231285, 'bagging_fraction': 0.9883163331495053, 'reg_alpha': 0.03988616968038969, 'reg_lambda': 0.04384568349950474}. Best is trial 15 with value: 0.9143091750480853.[0m




[32m[I 2023-03-27 08:32:24,169][0m Trial 16 finished with value: 0.9105513853695867 and parameters: {'num_leaves': 225, 'min_data_in_leaf': 65, 'min_sum_hessian_in_leaf': 0.0007547212191327465, 'feature_fraction': 0.8974152759983636, 'bagging_fraction': 0.9915442564871022, 'reg_alpha': 0.020436805378618597, 'reg_lambda': 0.03433559455816001}. Best is trial 15 with value: 0.9143091750480853.[0m




[32m[I 2023-03-27 08:33:27,286][0m Trial 17 finished with value: 0.9135693330217827 and parameters: {'num_leaves': 215, 'min_data_in_leaf': 51, 'min_sum_hessian_in_leaf': 8.935790929091197e-05, 'feature_fraction': 0.8330744197268897, 'bagging_fraction': 0.9308943929336889, 'reg_alpha': 0.04250632026048154, 'reg_lambda': 0.042677457059612324}. Best is trial 15 with value: 0.9143091750480853.[0m




[32m[I 2023-03-27 08:34:23,871][0m Trial 18 finished with value: 0.9061118388312092 and parameters: {'num_leaves': 253, 'min_data_in_leaf': 113, 'min_sum_hessian_in_leaf': 0.0004380338011497303, 'feature_fraction': 0.7474499041095187, 'bagging_fraction': 0.9422115556402889, 'reg_alpha': 0.029505797157008604, 'reg_lambda': 0.3613816693235936}. Best is trial 15 with value: 0.9143091750480853.[0m




[32m[I 2023-03-27 08:35:26,251][0m Trial 19 finished with value: 0.9095580989969473 and parameters: {'num_leaves': 207, 'min_data_in_leaf': 73, 'min_sum_hessian_in_leaf': 0.0014854275432688256, 'feature_fraction': 0.9170263295138465, 'bagging_fraction': 0.9252732251354118, 'reg_alpha': 0.010027969419668254, 'reg_lambda': 0.044612610488335765}. Best is trial 15 with value: 0.9143091750480853.[0m




[32m[I 2023-03-27 08:36:06,717][0m Trial 20 finished with value: 0.8284818206865326 and parameters: {'num_leaves': 8, 'min_data_in_leaf': 113, 'min_sum_hessian_in_leaf': 0.0003715334734853126, 'feature_fraction': 0.8309878210063437, 'bagging_fraction': 0.9921935443390247, 'reg_alpha': 0.07760481503455585, 'reg_lambda': 0.023199708121192097}. Best is trial 15 with value: 0.9143091750480853.[0m




[32m[I 2023-03-27 08:37:14,049][0m Trial 21 finished with value: 0.915029580032509 and parameters: {'num_leaves': 205, 'min_data_in_leaf': 45, 'min_sum_hessian_in_leaf': 6.893857790440731e-05, 'feature_fraction': 0.8091273410784455, 'bagging_fraction': 0.9273977096274312, 'reg_alpha': 0.04667317333314764, 'reg_lambda': 0.050582355537352366}. Best is trial 21 with value: 0.915029580032509.[0m




[32m[I 2023-03-27 08:38:41,768][0m Trial 22 finished with value: 0.9142896485324238 and parameters: {'num_leaves': 234, 'min_data_in_leaf': 40, 'min_sum_hessian_in_leaf': 7.391470668175935e-05, 'feature_fraction': 0.7995472472441396, 'bagging_fraction': 0.9104388655687001, 'reg_alpha': 0.026721044041056047, 'reg_lambda': 0.07451032272050791}. Best is trial 21 with value: 0.915029580032509.[0m




[32m[I 2023-03-27 08:39:42,581][0m Trial 23 finished with value: 0.9076305880471762 and parameters: {'num_leaves': 240, 'min_data_in_leaf': 35, 'min_sum_hessian_in_leaf': 1.0257826229045205e-05, 'feature_fraction': 0.7731462798990519, 'bagging_fraction': 0.9006517574979365, 'reg_alpha': 0.02312152610279702, 'reg_lambda': 0.07467868849739885}. Best is trial 21 with value: 0.915029580032509.[0m




[32m[I 2023-03-27 08:40:44,720][0m Trial 24 finished with value: 0.9112910438988101 and parameters: {'num_leaves': 198, 'min_data_in_leaf': 11, 'min_sum_hessian_in_leaf': 6.294576802299786e-05, 'feature_fraction': 0.8554704582704863, 'bagging_fraction': 0.8157420611929699, 'reg_alpha': 0.1000586051654101, 'reg_lambda': 0.19997413091609148}. Best is trial 21 with value: 0.915029580032509.[0m




[32m[I 2023-03-27 08:41:28,745][0m Trial 25 finished with value: 0.8993166690081501 and parameters: {'num_leaves': 142, 'min_data_in_leaf': 50, 'min_sum_hessian_in_leaf': 0.00027797661617380033, 'feature_fraction': 0.7373471207171156, 'bagging_fraction': 0.8967208581588109, 'reg_alpha': 0.017634090175269762, 'reg_lambda': 0.055486302297669474}. Best is trial 21 with value: 0.915029580032509.[0m




[32m[I 2023-03-27 08:42:40,161][0m Trial 26 finished with value: 0.9139977198440596 and parameters: {'num_leaves': 234, 'min_data_in_leaf': 23, 'min_sum_hessian_in_leaf': 6.356298381849394e-05, 'feature_fraction': 0.9359202818582051, 'bagging_fraction': 0.9637477392254032, 'reg_alpha': 0.03150389776898005, 'reg_lambda': 0.09144538998867943}. Best is trial 21 with value: 0.915029580032509.[0m




[32m[I 2023-03-27 08:43:39,462][0m Trial 27 finished with value: 0.9119336885836791 and parameters: {'num_leaves': 192, 'min_data_in_leaf': 49, 'min_sum_hessian_in_leaf': 0.00015169247729063443, 'feature_fraction': 0.7880733307854664, 'bagging_fraction': 0.8995181385779787, 'reg_alpha': 0.12484292645725906, 'reg_lambda': 0.016607645835344912}. Best is trial 21 with value: 0.915029580032509.[0m




[32m[I 2023-03-27 08:45:17,879][0m Trial 28 finished with value: 0.9118166705246548 and parameters: {'num_leaves': 235, 'min_data_in_leaf': 85, 'min_sum_hessian_in_leaf': 0.00021981389524911843, 'feature_fraction': 0.8586951177330127, 'bagging_fraction': 0.9538875680745383, 'reg_alpha': 0.05392514580590439, 'reg_lambda': 0.020621949248373524}. Best is trial 21 with value: 0.915029580032509.[0m




[32m[I 2023-03-27 08:46:16,098][0m Trial 29 finished with value: 0.913102046335166 and parameters: {'num_leaves': 198, 'min_data_in_leaf': 61, 'min_sum_hessian_in_leaf': 4.073757643365435e-05, 'feature_fraction': 0.7777756551083584, 'bagging_fraction': 0.8182312770946693, 'reg_alpha': 0.33236916530073274, 'reg_lambda': 0.34068251293058605}. Best is trial 21 with value: 0.915029580032509.[0m


In [31]:
trial = study.best_trial
print('acc(best)={:.4f}'.format(trial.value))
display(trial.params)

acc(best)=0.9150


{'num_leaves': 205,
 'min_data_in_leaf': 45,
 'min_sum_hessian_in_leaf': 6.893857790440731e-05,
 'feature_fraction': 0.8091273410784455,
 'bagging_fraction': 0.9273977096274312,
 'reg_alpha': 0.04667317333314764,
 'reg_lambda': 0.050582355537352366}

In [32]:
params_best = trial.params
params_best.update(params_base)
display(params_best)

{'num_leaves': 205,
 'min_data_in_leaf': 45,
 'min_sum_hessian_in_leaf': 6.893857790440731e-05,
 'feature_fraction': 0.8091273410784455,
 'bagging_fraction': 0.9273977096274312,
 'reg_alpha': 0.04667317333314764,
 'reg_lambda': 0.050582355537352366,
 'boosting_type': 'gbdt',
 'objective': 'binary',
 'metrics': 'auc',
 'learning_rate': 0.02,
 'n_estimators': 100000,
 'bagging_freq': 1,
 'seed': 123}