In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os
import random
from copy import deepcopy
import _pickle as pickle
import gc
from multiprocess import Pool
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from tensorflow.keras.preprocessing.text import Tokenizer
from sklearn.preprocessing import KBinsDiscretizer

from tensorflow.keras.optimizers import Adam, SGD
def save(file,name, folder = ""):
    if folder != "":
        outfile = open('./'+folder+'/'+name+'.pickle', 'wb')
    else:
        outfile = open(name+'.pickle', 'wb')
    pickle.dump(file, outfile, protocol=4)
    outfile.close
    
def load(name, folder = ""):
    
    if folder != "":
        outfile = open('./'+folder+'/'+name+'.pickle', 'rb')
    else:
        outfile = open(name+'.pickle', 'rb')
    file = pickle.load(outfile)
    outfile.close
    return file

class Discretiser:
    def __init__(self, nbins):
        self.nbins = nbins-1
        self.map_to = np.arange(self.nbins)/self.nbins
        
    def fit(self, X):
        ## X is a one dimension np array
        self.map_from = np.quantile(X, self.map_to)
        
    def transform(self, X):
        X1 = (np.interp(X, self.map_from, self.map_to, left=0, right=1, period=None) * self.nbins).astype(int)
        return X1
    
from tf_transformers2 import *
from tensorflow.keras.layers import Input, Dense, Dropout, TimeDistributed, LSTM
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

In [None]:
## features present
exercise to predict    ## cat
part of exercise       ## cat
gtag of exercise       ## cat
cluster of exercise    ## cat
qmean on user          ## num
qmean on content
container              ## num
timestamp              ## num

## content / user

## feature past
avg correctness
number explanation
number question
number lecture
avg elapsed time

## per part
avg correct
number explained
number lecture
number question
avg elapsed time
time since first question
time since last question
time since last lecture

## Per cluster
avg correct
number explained
number lecture
number question
avg elapsed time
time since first question
time since last question
time since last lecture

In [None]:
## Segments          28
user                 1
content              1
parts                6
cluster              20

## Time windows      17 possibles
history
# based on id
first five / ten /twenty / 100
last five / ten / twenty / 100
# Based on time
last 1/5/12/24/168/720 hours
first 1/5/12/24/168/720 hours

## Kpis              4 possibles 
interaction_type
time spent
time elapsed
explanation

## Kpis type         7 possibles
count
mean
hmean
std
min
max
slope - % improvement



#####################################
## Total 13328 features

In [None]:
def get_user_features(user_dic):
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    from tqdm.auto import tqdm
    from features_util import prepare, get_current_past, get_features
    import random
    user = prepare(user_dic)
    
    ##sampling
    seq_len = len(user['exercise_id'])
    r = random.randint(min(3, seq_len), min(10, seq_len))
    ids_to_fetch = np.random.choice(list(range(seq_len)), size = r, replace = False) 
               
    mat = np.zeros((r, 1049))
    for j, i in enumerate(tqdm(ids_to_fetch)):
        current, past = get_current_past(user, i)
        final = get_features(current, past)
        k = list(final.keys())
        v = list(final.values())
        mat[j] = v
    df = pd.DataFrame(mat, columns= k)
    return df

In [None]:
from multiprocess import Pool
import pandas as pd
from tqdm.notebook import tqdm
from features_util import load, save
import os

p = Pool(10)
for elt in tqdm(os.listdir('user_batch_2000')):
    if not(elt in os.listdir('user_batch_lgb_2000')):
        dico = load(elt.split('.')[0], 'user_batch_2000')
        list_dico = [dico[elt] for elt in dico]
        df = p.map(get_user_features, list_dico)
        df = pd.concat(df)
        save(df, elt.split('.')[0], 'user_batch_lgb_2000')
p.close()

In [2]:
df_train = []
for i in tqdm(range(1,30)):
    df_train.append(load('batch_'+str(i), 'user_batch_lgb_2000'))
df_train = pd.concat(df_train)  
df_test = []
for i in tqdm(range(40,50)):
    df_test.append(load('batch_'+str(i), 'user_batch_lgb_2000'))
df_test = pd.concat(df_test)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=29.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))




In [3]:
df_train = df_train[df_train['correctness']!=-1]
df_test = df_test[df_test['correctness']!=-1]
X_train, y_train = df_train[[elt for elt in df_train.columns if not(elt == 'correctness')]], df_train['correctness'].values
X_test, y_test = df_test[[elt for elt in df_test.columns if not(elt == 'correctness')]], df_test['correctness'].values

del df_train
del df_test
gc.collect()

38

In [4]:
import lightgbm as lgb
clf = lgb.LGBMClassifier(max_depth = -1, n_estimators = 10000, n_jobs = 12, silent = False)
clf.fit(X_train, y_train, eval_set =(X_test, y_test), eval_metric = 'auc')

[LightGBM] [Info] Number of positive: 201628, number of negative: 167369
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 207203
[LightGBM] [Info] Number of data points in the train set: 368997, number of used features: 995
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.546422 -> initscore=0.186223
[LightGBM] [Info] Start training from score 0.186223
[1]	valid_0's auc: 0.739731	valid_0's binary_logloss: 0.67236
[2]	valid_0's auc: 0.744644	valid_0's binary_logloss: 0.658666
[3]	valid_0's auc: 0.745662	valid_0's binary_logloss: 0.647414
[4]	valid_0's auc: 0.74731	valid_0's binary_logloss: 0.637959
[5]	valid_0's auc: 0.747904	valid_0's binary_logloss: 0.63007
[6]	valid_0's auc: 0.748403	valid_0's binary_logloss: 0.623437
[7]	valid_0's auc: 0.748834	valid_0's binary_logloss: 0.617879
[8]	valid_0's auc: 0.749293	valid_0's binary_logloss: 0.613158
[9]	valid_0's auc: 0.749573	valid_0's binary_logloss: 0.609208
[10]	valid_0's auc: 0.75001	valid_0's bina

[123]	valid_0's auc: 0.758177	valid_0's binary_logloss: 0.580493
[124]	valid_0's auc: 0.758178	valid_0's binary_logloss: 0.580491
[125]	valid_0's auc: 0.758192	valid_0's binary_logloss: 0.580476
[126]	valid_0's auc: 0.758196	valid_0's binary_logloss: 0.580473
[127]	valid_0's auc: 0.758208	valid_0's binary_logloss: 0.580464
[128]	valid_0's auc: 0.758204	valid_0's binary_logloss: 0.580465
[129]	valid_0's auc: 0.758209	valid_0's binary_logloss: 0.580459
[130]	valid_0's auc: 0.758257	valid_0's binary_logloss: 0.580408
[131]	valid_0's auc: 0.758251	valid_0's binary_logloss: 0.580413
[132]	valid_0's auc: 0.758276	valid_0's binary_logloss: 0.580383
[133]	valid_0's auc: 0.758281	valid_0's binary_logloss: 0.580376
[134]	valid_0's auc: 0.758286	valid_0's binary_logloss: 0.580371
[135]	valid_0's auc: 0.758274	valid_0's binary_logloss: 0.580383
[136]	valid_0's auc: 0.758252	valid_0's binary_logloss: 0.5804
[137]	valid_0's auc: 0.758231	valid_0's binary_logloss: 0.580417
[138]	valid_0's auc: 0.7582

[250]	valid_0's auc: 0.758133	valid_0's binary_logloss: 0.580505
[251]	valid_0's auc: 0.758142	valid_0's binary_logloss: 0.580495
[252]	valid_0's auc: 0.758138	valid_0's binary_logloss: 0.580497
[253]	valid_0's auc: 0.758126	valid_0's binary_logloss: 0.580513
[254]	valid_0's auc: 0.758131	valid_0's binary_logloss: 0.580512
[255]	valid_0's auc: 0.758133	valid_0's binary_logloss: 0.580509
[256]	valid_0's auc: 0.758146	valid_0's binary_logloss: 0.580498
[257]	valid_0's auc: 0.758157	valid_0's binary_logloss: 0.580491
[258]	valid_0's auc: 0.758184	valid_0's binary_logloss: 0.580461
[259]	valid_0's auc: 0.758199	valid_0's binary_logloss: 0.580442
[260]	valid_0's auc: 0.758205	valid_0's binary_logloss: 0.580434
[261]	valid_0's auc: 0.758197	valid_0's binary_logloss: 0.580441
[262]	valid_0's auc: 0.758189	valid_0's binary_logloss: 0.580448
[263]	valid_0's auc: 0.75818	valid_0's binary_logloss: 0.580461
[264]	valid_0's auc: 0.758195	valid_0's binary_logloss: 0.580448
[265]	valid_0's auc: 0.758

[376]	valid_0's auc: 0.758104	valid_0's binary_logloss: 0.580555
[377]	valid_0's auc: 0.758097	valid_0's binary_logloss: 0.580559
[378]	valid_0's auc: 0.758094	valid_0's binary_logloss: 0.580561
[379]	valid_0's auc: 0.758076	valid_0's binary_logloss: 0.580579
[380]	valid_0's auc: 0.758066	valid_0's binary_logloss: 0.58059
[381]	valid_0's auc: 0.75808	valid_0's binary_logloss: 0.580578
[382]	valid_0's auc: 0.758085	valid_0's binary_logloss: 0.580575
[383]	valid_0's auc: 0.758065	valid_0's binary_logloss: 0.580593
[384]	valid_0's auc: 0.758079	valid_0's binary_logloss: 0.580586
[385]	valid_0's auc: 0.758062	valid_0's binary_logloss: 0.580597
[386]	valid_0's auc: 0.758062	valid_0's binary_logloss: 0.580597
[387]	valid_0's auc: 0.75805	valid_0's binary_logloss: 0.580607
[388]	valid_0's auc: 0.758043	valid_0's binary_logloss: 0.580614
[389]	valid_0's auc: 0.758027	valid_0's binary_logloss: 0.580627
[390]	valid_0's auc: 0.758018	valid_0's binary_logloss: 0.58063
[391]	valid_0's auc: 0.758025

[504]	valid_0's auc: 0.757578	valid_0's binary_logloss: 0.581017
[505]	valid_0's auc: 0.757579	valid_0's binary_logloss: 0.58102
[506]	valid_0's auc: 0.757578	valid_0's binary_logloss: 0.581025
[507]	valid_0's auc: 0.757576	valid_0's binary_logloss: 0.581029
[508]	valid_0's auc: 0.757571	valid_0's binary_logloss: 0.581033
[509]	valid_0's auc: 0.757566	valid_0's binary_logloss: 0.58104
[510]	valid_0's auc: 0.757555	valid_0's binary_logloss: 0.581047
[511]	valid_0's auc: 0.757552	valid_0's binary_logloss: 0.581055
[512]	valid_0's auc: 0.757555	valid_0's binary_logloss: 0.581049
[513]	valid_0's auc: 0.757551	valid_0's binary_logloss: 0.581053
[514]	valid_0's auc: 0.757552	valid_0's binary_logloss: 0.581054
[515]	valid_0's auc: 0.757556	valid_0's binary_logloss: 0.58105
[516]	valid_0's auc: 0.757551	valid_0's binary_logloss: 0.581052
[517]	valid_0's auc: 0.757534	valid_0's binary_logloss: 0.581069
[518]	valid_0's auc: 0.757542	valid_0's binary_logloss: 0.581063
[519]	valid_0's auc: 0.75753

[632]	valid_0's auc: 0.757188	valid_0's binary_logloss: 0.581421
[633]	valid_0's auc: 0.757189	valid_0's binary_logloss: 0.581422
[634]	valid_0's auc: 0.75719	valid_0's binary_logloss: 0.581418
[635]	valid_0's auc: 0.757189	valid_0's binary_logloss: 0.58142
[636]	valid_0's auc: 0.757188	valid_0's binary_logloss: 0.58142
[637]	valid_0's auc: 0.757183	valid_0's binary_logloss: 0.581429
[638]	valid_0's auc: 0.757178	valid_0's binary_logloss: 0.581436
[639]	valid_0's auc: 0.757179	valid_0's binary_logloss: 0.581437
[640]	valid_0's auc: 0.757168	valid_0's binary_logloss: 0.581447
[641]	valid_0's auc: 0.757143	valid_0's binary_logloss: 0.581463
[642]	valid_0's auc: 0.757137	valid_0's binary_logloss: 0.581471
[643]	valid_0's auc: 0.757115	valid_0's binary_logloss: 0.581493
[644]	valid_0's auc: 0.757112	valid_0's binary_logloss: 0.581492
[645]	valid_0's auc: 0.757112	valid_0's binary_logloss: 0.581482
[646]	valid_0's auc: 0.757113	valid_0's binary_logloss: 0.581483
[647]	valid_0's auc: 0.75711

[760]	valid_0's auc: 0.756989	valid_0's binary_logloss: 0.581678
[761]	valid_0's auc: 0.756964	valid_0's binary_logloss: 0.581701
[762]	valid_0's auc: 0.756983	valid_0's binary_logloss: 0.581687
[763]	valid_0's auc: 0.756974	valid_0's binary_logloss: 0.581692
[764]	valid_0's auc: 0.75697	valid_0's binary_logloss: 0.581698
[765]	valid_0's auc: 0.756963	valid_0's binary_logloss: 0.581708
[766]	valid_0's auc: 0.756965	valid_0's binary_logloss: 0.581711
[767]	valid_0's auc: 0.756949	valid_0's binary_logloss: 0.581724
[768]	valid_0's auc: 0.756948	valid_0's binary_logloss: 0.581724
[769]	valid_0's auc: 0.756939	valid_0's binary_logloss: 0.58174
[770]	valid_0's auc: 0.756942	valid_0's binary_logloss: 0.58174
[771]	valid_0's auc: 0.756927	valid_0's binary_logloss: 0.581751
[772]	valid_0's auc: 0.756916	valid_0's binary_logloss: 0.581761
[773]	valid_0's auc: 0.756921	valid_0's binary_logloss: 0.58176
[774]	valid_0's auc: 0.756906	valid_0's binary_logloss: 0.581775
[775]	valid_0's auc: 0.756906

[887]	valid_0's auc: 0.756328	valid_0's binary_logloss: 0.582414
[888]	valid_0's auc: 0.756325	valid_0's binary_logloss: 0.582416
[889]	valid_0's auc: 0.756328	valid_0's binary_logloss: 0.582409
[890]	valid_0's auc: 0.756333	valid_0's binary_logloss: 0.582406
[891]	valid_0's auc: 0.756312	valid_0's binary_logloss: 0.582419
[892]	valid_0's auc: 0.75632	valid_0's binary_logloss: 0.582412
[893]	valid_0's auc: 0.75631	valid_0's binary_logloss: 0.582423
[894]	valid_0's auc: 0.756305	valid_0's binary_logloss: 0.582425
[895]	valid_0's auc: 0.756295	valid_0's binary_logloss: 0.582432
[896]	valid_0's auc: 0.756297	valid_0's binary_logloss: 0.582424
[897]	valid_0's auc: 0.756291	valid_0's binary_logloss: 0.582429
[898]	valid_0's auc: 0.756283	valid_0's binary_logloss: 0.582449
[899]	valid_0's auc: 0.75629	valid_0's binary_logloss: 0.582446
[900]	valid_0's auc: 0.7563	valid_0's binary_logloss: 0.582442
[901]	valid_0's auc: 0.756301	valid_0's binary_logloss: 0.582443
[902]	valid_0's auc: 0.756298	

[1013]	valid_0's auc: 0.75593	valid_0's binary_logloss: 0.582891
[1014]	valid_0's auc: 0.755927	valid_0's binary_logloss: 0.582893
[1015]	valid_0's auc: 0.755928	valid_0's binary_logloss: 0.582894
[1016]	valid_0's auc: 0.755919	valid_0's binary_logloss: 0.582903
[1017]	valid_0's auc: 0.755921	valid_0's binary_logloss: 0.582905
[1018]	valid_0's auc: 0.755928	valid_0's binary_logloss: 0.582898
[1019]	valid_0's auc: 0.755945	valid_0's binary_logloss: 0.582887
[1020]	valid_0's auc: 0.755943	valid_0's binary_logloss: 0.582889
[1021]	valid_0's auc: 0.755949	valid_0's binary_logloss: 0.582883
[1022]	valid_0's auc: 0.755935	valid_0's binary_logloss: 0.582897
[1023]	valid_0's auc: 0.75593	valid_0's binary_logloss: 0.5829
[1024]	valid_0's auc: 0.755927	valid_0's binary_logloss: 0.5829
[1025]	valid_0's auc: 0.755941	valid_0's binary_logloss: 0.582889
[1026]	valid_0's auc: 0.755959	valid_0's binary_logloss: 0.582879
[1027]	valid_0's auc: 0.755957	valid_0's binary_logloss: 0.582884
[1028]	valid_0's

[1138]	valid_0's auc: 0.755492	valid_0's binary_logloss: 0.583434
[1139]	valid_0's auc: 0.755492	valid_0's binary_logloss: 0.583432
[1140]	valid_0's auc: 0.755485	valid_0's binary_logloss: 0.583437
[1141]	valid_0's auc: 0.755469	valid_0's binary_logloss: 0.583455
[1142]	valid_0's auc: 0.755461	valid_0's binary_logloss: 0.583462
[1143]	valid_0's auc: 0.755441	valid_0's binary_logloss: 0.583481
[1144]	valid_0's auc: 0.755446	valid_0's binary_logloss: 0.583479
[1145]	valid_0's auc: 0.755444	valid_0's binary_logloss: 0.58348
[1146]	valid_0's auc: 0.755453	valid_0's binary_logloss: 0.583475
[1147]	valid_0's auc: 0.75545	valid_0's binary_logloss: 0.583479
[1148]	valid_0's auc: 0.755447	valid_0's binary_logloss: 0.583484
[1149]	valid_0's auc: 0.755446	valid_0's binary_logloss: 0.58349
[1150]	valid_0's auc: 0.755449	valid_0's binary_logloss: 0.583491
[1151]	valid_0's auc: 0.75543	valid_0's binary_logloss: 0.583507
[1152]	valid_0's auc: 0.755429	valid_0's binary_logloss: 0.58351
[1153]	valid_0'

[1263]	valid_0's auc: 0.755102	valid_0's binary_logloss: 0.583965
[1264]	valid_0's auc: 0.755107	valid_0's binary_logloss: 0.583964
[1265]	valid_0's auc: 0.755097	valid_0's binary_logloss: 0.583974
[1266]	valid_0's auc: 0.75509	valid_0's binary_logloss: 0.583983
[1267]	valid_0's auc: 0.755091	valid_0's binary_logloss: 0.583978
[1268]	valid_0's auc: 0.755087	valid_0's binary_logloss: 0.583982
[1269]	valid_0's auc: 0.755085	valid_0's binary_logloss: 0.583986
[1270]	valid_0's auc: 0.755099	valid_0's binary_logloss: 0.583974
[1271]	valid_0's auc: 0.755092	valid_0's binary_logloss: 0.583984
[1272]	valid_0's auc: 0.75508	valid_0's binary_logloss: 0.583999
[1273]	valid_0's auc: 0.755068	valid_0's binary_logloss: 0.584011
[1274]	valid_0's auc: 0.755066	valid_0's binary_logloss: 0.584015
[1275]	valid_0's auc: 0.75506	valid_0's binary_logloss: 0.584024
[1276]	valid_0's auc: 0.75506	valid_0's binary_logloss: 0.584023
[1277]	valid_0's auc: 0.755056	valid_0's binary_logloss: 0.584026
[1278]	valid_0

[1388]	valid_0's auc: 0.754761	valid_0's binary_logloss: 0.584471
[1389]	valid_0's auc: 0.754762	valid_0's binary_logloss: 0.584471
[1390]	valid_0's auc: 0.754751	valid_0's binary_logloss: 0.584482
[1391]	valid_0's auc: 0.754756	valid_0's binary_logloss: 0.584482
[1392]	valid_0's auc: 0.754747	valid_0's binary_logloss: 0.584494
[1393]	valid_0's auc: 0.754744	valid_0's binary_logloss: 0.584496
[1394]	valid_0's auc: 0.754738	valid_0's binary_logloss: 0.584503
[1395]	valid_0's auc: 0.754736	valid_0's binary_logloss: 0.584504
[1396]	valid_0's auc: 0.754731	valid_0's binary_logloss: 0.584508
[1397]	valid_0's auc: 0.754735	valid_0's binary_logloss: 0.584503
[1398]	valid_0's auc: 0.75474	valid_0's binary_logloss: 0.584503
[1399]	valid_0's auc: 0.754738	valid_0's binary_logloss: 0.584509
[1400]	valid_0's auc: 0.754725	valid_0's binary_logloss: 0.58452
[1401]	valid_0's auc: 0.754726	valid_0's binary_logloss: 0.584521
[1402]	valid_0's auc: 0.754725	valid_0's binary_logloss: 0.58452
[1403]	valid_

[1513]	valid_0's auc: 0.754337	valid_0's binary_logloss: 0.58501
[1514]	valid_0's auc: 0.754332	valid_0's binary_logloss: 0.585019
[1515]	valid_0's auc: 0.754333	valid_0's binary_logloss: 0.585021
[1516]	valid_0's auc: 0.754321	valid_0's binary_logloss: 0.585035
[1517]	valid_0's auc: 0.754325	valid_0's binary_logloss: 0.585034
[1518]	valid_0's auc: 0.754324	valid_0's binary_logloss: 0.585035
[1519]	valid_0's auc: 0.754319	valid_0's binary_logloss: 0.585039
[1520]	valid_0's auc: 0.754325	valid_0's binary_logloss: 0.585038
[1521]	valid_0's auc: 0.75431	valid_0's binary_logloss: 0.585048
[1522]	valid_0's auc: 0.754303	valid_0's binary_logloss: 0.585059
[1523]	valid_0's auc: 0.754297	valid_0's binary_logloss: 0.585067
[1524]	valid_0's auc: 0.754292	valid_0's binary_logloss: 0.585075
[1525]	valid_0's auc: 0.754287	valid_0's binary_logloss: 0.585085
[1526]	valid_0's auc: 0.754291	valid_0's binary_logloss: 0.585082
[1527]	valid_0's auc: 0.754291	valid_0's binary_logloss: 0.585082
[1528]	valid

[1639]	valid_0's auc: 0.753837	valid_0's binary_logloss: 0.585678
[1640]	valid_0's auc: 0.753841	valid_0's binary_logloss: 0.585675
[1641]	valid_0's auc: 0.753852	valid_0's binary_logloss: 0.585665
[1642]	valid_0's auc: 0.75387	valid_0's binary_logloss: 0.58565
[1643]	valid_0's auc: 0.753868	valid_0's binary_logloss: 0.585648
[1644]	valid_0's auc: 0.753876	valid_0's binary_logloss: 0.585642
[1645]	valid_0's auc: 0.753865	valid_0's binary_logloss: 0.585658
[1646]	valid_0's auc: 0.753862	valid_0's binary_logloss: 0.585662
[1647]	valid_0's auc: 0.753864	valid_0's binary_logloss: 0.58566
[1648]	valid_0's auc: 0.753866	valid_0's binary_logloss: 0.585665
[1649]	valid_0's auc: 0.753869	valid_0's binary_logloss: 0.585664
[1650]	valid_0's auc: 0.753863	valid_0's binary_logloss: 0.585678
[1651]	valid_0's auc: 0.753873	valid_0's binary_logloss: 0.585673
[1652]	valid_0's auc: 0.753859	valid_0's binary_logloss: 0.585687
[1653]	valid_0's auc: 0.753857	valid_0's binary_logloss: 0.585692
[1654]	valid_

[1764]	valid_0's auc: 0.753472	valid_0's binary_logloss: 0.586244
[1765]	valid_0's auc: 0.753474	valid_0's binary_logloss: 0.586243
[1766]	valid_0's auc: 0.75347	valid_0's binary_logloss: 0.58625
[1767]	valid_0's auc: 0.75347	valid_0's binary_logloss: 0.586252
[1768]	valid_0's auc: 0.753464	valid_0's binary_logloss: 0.586258
[1769]	valid_0's auc: 0.753459	valid_0's binary_logloss: 0.586262
[1770]	valid_0's auc: 0.753466	valid_0's binary_logloss: 0.586262
[1771]	valid_0's auc: 0.753466	valid_0's binary_logloss: 0.586262
[1772]	valid_0's auc: 0.753467	valid_0's binary_logloss: 0.586264
[1773]	valid_0's auc: 0.753465	valid_0's binary_logloss: 0.586268
[1774]	valid_0's auc: 0.753454	valid_0's binary_logloss: 0.586279
[1775]	valid_0's auc: 0.753456	valid_0's binary_logloss: 0.58629
[1776]	valid_0's auc: 0.753458	valid_0's binary_logloss: 0.586292
[1777]	valid_0's auc: 0.753452	valid_0's binary_logloss: 0.586298
[1778]	valid_0's auc: 0.753452	valid_0's binary_logloss: 0.586298
[1779]	valid_0

[1889]	valid_0's auc: 0.753212	valid_0's binary_logloss: 0.586703
[1890]	valid_0's auc: 0.753197	valid_0's binary_logloss: 0.586715
[1891]	valid_0's auc: 0.7532	valid_0's binary_logloss: 0.586711
[1892]	valid_0's auc: 0.753204	valid_0's binary_logloss: 0.58671
[1893]	valid_0's auc: 0.753212	valid_0's binary_logloss: 0.586698
[1894]	valid_0's auc: 0.753208	valid_0's binary_logloss: 0.586705
[1895]	valid_0's auc: 0.753223	valid_0's binary_logloss: 0.586692
[1896]	valid_0's auc: 0.753224	valid_0's binary_logloss: 0.586699
[1897]	valid_0's auc: 0.753221	valid_0's binary_logloss: 0.586701
[1898]	valid_0's auc: 0.753221	valid_0's binary_logloss: 0.586698
[1899]	valid_0's auc: 0.753229	valid_0's binary_logloss: 0.586691
[1900]	valid_0's auc: 0.753222	valid_0's binary_logloss: 0.586701
[1901]	valid_0's auc: 0.75321	valid_0's binary_logloss: 0.586708
[1902]	valid_0's auc: 0.753201	valid_0's binary_logloss: 0.586719
[1903]	valid_0's auc: 0.7532	valid_0's binary_logloss: 0.586719
[1904]	valid_0's

[2014]	valid_0's auc: 0.752771	valid_0's binary_logloss: 0.587291
[2015]	valid_0's auc: 0.752765	valid_0's binary_logloss: 0.587301
[2016]	valid_0's auc: 0.752757	valid_0's binary_logloss: 0.587307
[2017]	valid_0's auc: 0.752759	valid_0's binary_logloss: 0.587311
[2018]	valid_0's auc: 0.752754	valid_0's binary_logloss: 0.587313
[2019]	valid_0's auc: 0.752753	valid_0's binary_logloss: 0.587317
[2020]	valid_0's auc: 0.752756	valid_0's binary_logloss: 0.587316
[2021]	valid_0's auc: 0.75275	valid_0's binary_logloss: 0.587325
[2022]	valid_0's auc: 0.752764	valid_0's binary_logloss: 0.587319
[2023]	valid_0's auc: 0.752762	valid_0's binary_logloss: 0.587321
[2024]	valid_0's auc: 0.752768	valid_0's binary_logloss: 0.587321
[2025]	valid_0's auc: 0.752761	valid_0's binary_logloss: 0.587334
[2026]	valid_0's auc: 0.752762	valid_0's binary_logloss: 0.587334
[2027]	valid_0's auc: 0.752756	valid_0's binary_logloss: 0.58734
[2028]	valid_0's auc: 0.75275	valid_0's binary_logloss: 0.587346
[2029]	valid_

[2140]	valid_0's auc: 0.752597	valid_0's binary_logloss: 0.58767
[2141]	valid_0's auc: 0.752601	valid_0's binary_logloss: 0.587669
[2142]	valid_0's auc: 0.75261	valid_0's binary_logloss: 0.587666
[2143]	valid_0's auc: 0.752612	valid_0's binary_logloss: 0.587669
[2144]	valid_0's auc: 0.752609	valid_0's binary_logloss: 0.587671
[2145]	valid_0's auc: 0.752613	valid_0's binary_logloss: 0.587664
[2146]	valid_0's auc: 0.752608	valid_0's binary_logloss: 0.587675
[2147]	valid_0's auc: 0.752608	valid_0's binary_logloss: 0.587675
[2148]	valid_0's auc: 0.75261	valid_0's binary_logloss: 0.587676
[2149]	valid_0's auc: 0.752611	valid_0's binary_logloss: 0.58768
[2150]	valid_0's auc: 0.752613	valid_0's binary_logloss: 0.58768
[2151]	valid_0's auc: 0.75261	valid_0's binary_logloss: 0.587681
[2152]	valid_0's auc: 0.752607	valid_0's binary_logloss: 0.587685
[2153]	valid_0's auc: 0.752606	valid_0's binary_logloss: 0.587679
[2154]	valid_0's auc: 0.752612	valid_0's binary_logloss: 0.587676
[2155]	valid_0's

[2266]	valid_0's auc: 0.752343	valid_0's binary_logloss: 0.588159
[2267]	valid_0's auc: 0.752337	valid_0's binary_logloss: 0.588162
[2268]	valid_0's auc: 0.752327	valid_0's binary_logloss: 0.588177
[2269]	valid_0's auc: 0.752333	valid_0's binary_logloss: 0.588175
[2270]	valid_0's auc: 0.752337	valid_0's binary_logloss: 0.588177
[2271]	valid_0's auc: 0.752326	valid_0's binary_logloss: 0.588188
[2272]	valid_0's auc: 0.752325	valid_0's binary_logloss: 0.588193
[2273]	valid_0's auc: 0.752324	valid_0's binary_logloss: 0.588196
[2274]	valid_0's auc: 0.752324	valid_0's binary_logloss: 0.588193
[2275]	valid_0's auc: 0.752333	valid_0's binary_logloss: 0.588184
[2276]	valid_0's auc: 0.752336	valid_0's binary_logloss: 0.588185
[2277]	valid_0's auc: 0.75233	valid_0's binary_logloss: 0.588191
[2278]	valid_0's auc: 0.752336	valid_0's binary_logloss: 0.588182
[2279]	valid_0's auc: 0.752344	valid_0's binary_logloss: 0.588176
[2280]	valid_0's auc: 0.752333	valid_0's binary_logloss: 0.588185
[2281]	vali

[2391]	valid_0's auc: 0.752063	valid_0's binary_logloss: 0.588633
[2392]	valid_0's auc: 0.752049	valid_0's binary_logloss: 0.588649
[2393]	valid_0's auc: 0.752053	valid_0's binary_logloss: 0.588658
[2394]	valid_0's auc: 0.752055	valid_0's binary_logloss: 0.588657
[2395]	valid_0's auc: 0.752058	valid_0's binary_logloss: 0.588659
[2396]	valid_0's auc: 0.752058	valid_0's binary_logloss: 0.58866
[2397]	valid_0's auc: 0.752049	valid_0's binary_logloss: 0.588665
[2398]	valid_0's auc: 0.752049	valid_0's binary_logloss: 0.588663
[2399]	valid_0's auc: 0.752043	valid_0's binary_logloss: 0.588668
[2400]	valid_0's auc: 0.752033	valid_0's binary_logloss: 0.588679
[2401]	valid_0's auc: 0.752033	valid_0's binary_logloss: 0.588683
[2402]	valid_0's auc: 0.752029	valid_0's binary_logloss: 0.588689
[2403]	valid_0's auc: 0.752023	valid_0's binary_logloss: 0.588695
[2404]	valid_0's auc: 0.752023	valid_0's binary_logloss: 0.588698
[2405]	valid_0's auc: 0.752017	valid_0's binary_logloss: 0.588707
[2406]	vali

[2516]	valid_0's auc: 0.751744	valid_0's binary_logloss: 0.589154
[2517]	valid_0's auc: 0.751746	valid_0's binary_logloss: 0.589155
[2518]	valid_0's auc: 0.75175	valid_0's binary_logloss: 0.58915
[2519]	valid_0's auc: 0.751749	valid_0's binary_logloss: 0.58915
[2520]	valid_0's auc: 0.751747	valid_0's binary_logloss: 0.589153
[2521]	valid_0's auc: 0.751751	valid_0's binary_logloss: 0.589155
[2522]	valid_0's auc: 0.751739	valid_0's binary_logloss: 0.589174
[2523]	valid_0's auc: 0.751737	valid_0's binary_logloss: 0.589179
[2524]	valid_0's auc: 0.751744	valid_0's binary_logloss: 0.589173
[2525]	valid_0's auc: 0.751741	valid_0's binary_logloss: 0.589179
[2526]	valid_0's auc: 0.751746	valid_0's binary_logloss: 0.589174
[2527]	valid_0's auc: 0.751741	valid_0's binary_logloss: 0.589187
[2528]	valid_0's auc: 0.751748	valid_0's binary_logloss: 0.589181
[2529]	valid_0's auc: 0.751739	valid_0's binary_logloss: 0.589191
[2530]	valid_0's auc: 0.751731	valid_0's binary_logloss: 0.589196
[2531]	valid_

[2641]	valid_0's auc: 0.75161	valid_0's binary_logloss: 0.589571
[2642]	valid_0's auc: 0.751614	valid_0's binary_logloss: 0.58957
[2643]	valid_0's auc: 0.75161	valid_0's binary_logloss: 0.58958
[2644]	valid_0's auc: 0.751606	valid_0's binary_logloss: 0.589589
[2645]	valid_0's auc: 0.751604	valid_0's binary_logloss: 0.589591
[2646]	valid_0's auc: 0.751602	valid_0's binary_logloss: 0.589591
[2647]	valid_0's auc: 0.751604	valid_0's binary_logloss: 0.589585
[2648]	valid_0's auc: 0.751608	valid_0's binary_logloss: 0.589586
[2649]	valid_0's auc: 0.751608	valid_0's binary_logloss: 0.589589
[2650]	valid_0's auc: 0.751617	valid_0's binary_logloss: 0.58958
[2651]	valid_0's auc: 0.751614	valid_0's binary_logloss: 0.589589
[2652]	valid_0's auc: 0.751626	valid_0's binary_logloss: 0.589581
[2653]	valid_0's auc: 0.751636	valid_0's binary_logloss: 0.589572
[2654]	valid_0's auc: 0.751629	valid_0's binary_logloss: 0.589587
[2655]	valid_0's auc: 0.75163	valid_0's binary_logloss: 0.589584
[2656]	valid_0's

[2767]	valid_0's auc: 0.751408	valid_0's binary_logloss: 0.590069
[2768]	valid_0's auc: 0.751411	valid_0's binary_logloss: 0.59007
[2769]	valid_0's auc: 0.751408	valid_0's binary_logloss: 0.590072
[2770]	valid_0's auc: 0.751407	valid_0's binary_logloss: 0.590073
[2771]	valid_0's auc: 0.751397	valid_0's binary_logloss: 0.590082
[2772]	valid_0's auc: 0.751394	valid_0's binary_logloss: 0.59009
[2773]	valid_0's auc: 0.751383	valid_0's binary_logloss: 0.5901
[2774]	valid_0's auc: 0.751384	valid_0's binary_logloss: 0.590103
[2775]	valid_0's auc: 0.751384	valid_0's binary_logloss: 0.590107
[2776]	valid_0's auc: 0.751388	valid_0's binary_logloss: 0.590102
[2777]	valid_0's auc: 0.75139	valid_0's binary_logloss: 0.590099
[2778]	valid_0's auc: 0.751389	valid_0's binary_logloss: 0.590104
[2779]	valid_0's auc: 0.751377	valid_0's binary_logloss: 0.590118
[2780]	valid_0's auc: 0.751379	valid_0's binary_logloss: 0.590119
[2781]	valid_0's auc: 0.751392	valid_0's binary_logloss: 0.590105
[2782]	valid_0'

[2891]	valid_0's auc: 0.751173	valid_0's binary_logloss: 0.590556
[2892]	valid_0's auc: 0.751174	valid_0's binary_logloss: 0.590561
[2893]	valid_0's auc: 0.751183	valid_0's binary_logloss: 0.590549
[2894]	valid_0's auc: 0.751179	valid_0's binary_logloss: 0.590556
[2895]	valid_0's auc: 0.751181	valid_0's binary_logloss: 0.590555
[2896]	valid_0's auc: 0.751186	valid_0's binary_logloss: 0.590552
[2897]	valid_0's auc: 0.751185	valid_0's binary_logloss: 0.590559
[2898]	valid_0's auc: 0.751179	valid_0's binary_logloss: 0.590565
[2899]	valid_0's auc: 0.751185	valid_0's binary_logloss: 0.590563
[2900]	valid_0's auc: 0.751181	valid_0's binary_logloss: 0.590564
[2901]	valid_0's auc: 0.751189	valid_0's binary_logloss: 0.590557
[2902]	valid_0's auc: 0.751186	valid_0's binary_logloss: 0.590562
[2903]	valid_0's auc: 0.751165	valid_0's binary_logloss: 0.590582
[2904]	valid_0's auc: 0.751157	valid_0's binary_logloss: 0.590594
[2905]	valid_0's auc: 0.751162	valid_0's binary_logloss: 0.590587
[2906]	val

[3016]	valid_0's auc: 0.751008	valid_0's binary_logloss: 0.591009
[3017]	valid_0's auc: 0.751007	valid_0's binary_logloss: 0.591012
[3018]	valid_0's auc: 0.751006	valid_0's binary_logloss: 0.591017
[3019]	valid_0's auc: 0.750994	valid_0's binary_logloss: 0.591027
[3020]	valid_0's auc: 0.750995	valid_0's binary_logloss: 0.591033
[3021]	valid_0's auc: 0.750985	valid_0's binary_logloss: 0.591044
[3022]	valid_0's auc: 0.75098	valid_0's binary_logloss: 0.591051
[3023]	valid_0's auc: 0.750974	valid_0's binary_logloss: 0.591059
[3024]	valid_0's auc: 0.750972	valid_0's binary_logloss: 0.591062
[3025]	valid_0's auc: 0.750974	valid_0's binary_logloss: 0.591065
[3026]	valid_0's auc: 0.750973	valid_0's binary_logloss: 0.591067
[3027]	valid_0's auc: 0.750966	valid_0's binary_logloss: 0.591076
[3028]	valid_0's auc: 0.750962	valid_0's binary_logloss: 0.591082
[3029]	valid_0's auc: 0.750972	valid_0's binary_logloss: 0.591071
[3030]	valid_0's auc: 0.750971	valid_0's binary_logloss: 0.591077
[3031]	vali

[3141]	valid_0's auc: 0.750777	valid_0's binary_logloss: 0.591511
[3142]	valid_0's auc: 0.750779	valid_0's binary_logloss: 0.591513
[3143]	valid_0's auc: 0.750784	valid_0's binary_logloss: 0.591514
[3144]	valid_0's auc: 0.750789	valid_0's binary_logloss: 0.591512
[3145]	valid_0's auc: 0.750793	valid_0's binary_logloss: 0.59151
[3146]	valid_0's auc: 0.750772	valid_0's binary_logloss: 0.591536
[3147]	valid_0's auc: 0.750768	valid_0's binary_logloss: 0.591548
[3148]	valid_0's auc: 0.750775	valid_0's binary_logloss: 0.591546
[3149]	valid_0's auc: 0.750765	valid_0's binary_logloss: 0.591558
[3150]	valid_0's auc: 0.750757	valid_0's binary_logloss: 0.591569
[3151]	valid_0's auc: 0.750758	valid_0's binary_logloss: 0.591569
[3152]	valid_0's auc: 0.750741	valid_0's binary_logloss: 0.591585
[3153]	valid_0's auc: 0.750749	valid_0's binary_logloss: 0.591579
[3154]	valid_0's auc: 0.750755	valid_0's binary_logloss: 0.59158
[3155]	valid_0's auc: 0.750746	valid_0's binary_logloss: 0.591592
[3156]	valid

[3267]	valid_0's auc: 0.750584	valid_0's binary_logloss: 0.592035
[3268]	valid_0's auc: 0.75058	valid_0's binary_logloss: 0.592037
[3269]	valid_0's auc: 0.75059	valid_0's binary_logloss: 0.592029
[3270]	valid_0's auc: 0.75058	valid_0's binary_logloss: 0.592043
[3271]	valid_0's auc: 0.750564	valid_0's binary_logloss: 0.592058
[3272]	valid_0's auc: 0.750545	valid_0's binary_logloss: 0.59208
[3273]	valid_0's auc: 0.750534	valid_0's binary_logloss: 0.592094
[3274]	valid_0's auc: 0.750526	valid_0's binary_logloss: 0.592109
[3275]	valid_0's auc: 0.750526	valid_0's binary_logloss: 0.592111
[3276]	valid_0's auc: 0.75053	valid_0's binary_logloss: 0.592105
[3277]	valid_0's auc: 0.75052	valid_0's binary_logloss: 0.592115
[3278]	valid_0's auc: 0.750514	valid_0's binary_logloss: 0.592122
[3279]	valid_0's auc: 0.750506	valid_0's binary_logloss: 0.592133
[3280]	valid_0's auc: 0.750502	valid_0's binary_logloss: 0.592136
[3281]	valid_0's auc: 0.750493	valid_0's binary_logloss: 0.592145
[3282]	valid_0's

KeyboardInterrupt: 

In [None]:
pred = clf.predict_proba(X_test)[:,1]
roc_auc_score(y_test, pred)

In [None]:
sorted_f = np.argsort(clf.feature_importances_)

for elt in range(len(sorted_f)):
    print(clf.feature_importances_[sorted_f[-elt-1]], clf.feature_name_[sorted_f[-elt-1]])

In [None]:
import tensorflow as tf
from tabnet import *

tabnet_encoder = TabNet(
        num_features = X_test.shape[1],
        feature_dim = 256,
        output_dim = 128,
        feature_columns = None,
        n_step = 6,
        n_total = 4,
        n_shared = 2,
        relaxation_factor = 1.5,
        bn_epsilon = 1e-5,
        bn_momentum = 0.7,
        bn_virtual_divider = 10,
    )


# inputs = [tf.keras.Input(shape=(1,)) for i in range(4)] + [tf.keras.Input(shape=(327,))]

# in_size = [14000,8,100,20]
# out_size = [128,8,30,20]

# agg = [tf.squeeze(tf.keras.layers.Embedding(in_size[i], out_size[i])(inputs[i]), axis = 1) for i in range(len(in_size))] + [inputs[-1]]
# agg = tf.keras.layers.Concatenate(axis = 1)(agg)

inputs = tf.keras.Input(shape = (X_test.shape[1]))

enc, masks = tabnet_encoder(inputs)

out = tf.keras.layers.Dense(1, activation = 'sigmoid')(enc)
model = tf.keras.Model(inputs, out)

In [None]:
model.summary()

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.02, beta_1=0.9, beta_2=0.999)

model.compile(
        loss = 'binary_crossentropy',
        optimizer = optimizer,
        metrics = ['accuracy', 'AUC'])

In [None]:
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ReduceLROnPlateau

early = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=9, verbose=1, 
                                                mode='auto', restore_best_weights=True)

reduce = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=4, verbose=1, 
                           mode='auto', min_delta=0.0001, cooldown=0, min_lr=0)

callbacks =[early, reduce]

epochs = 1000
batch_size = 5000

ls = X_train.shape[0]//batch_size * batch_size
lt = X_test.shape[0]//batch_size * batch_size

# model.fit([elt[:ls] for elt in X_train], y_train[:ls], 
#           validation_data = ([elt[:lt] for elt in X_test], y_test[:lt]), 
#           batch_size = batch_size, epochs = epochs, callbacks=callbacks)

model.fit(X_train.values[:ls], y_train[:ls], 
          validation_data = (X_test.values[:lt], y_test[:lt]), 
          batch_size = batch_size, epochs = epochs, callbacks=callbacks)