In [4]:
import os
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import LabelEncoder
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from pytorch_tabnet.tab_model import TabNetClassifier

### 데이터를 불러오자

In [5]:
dir = os.getcwd()
train = pd.read_csv(dir + "/data/train_data.csv")
test = pd.read_csv(dir + "/data/test_data.csv")
valid = pd.read_csv(dir + "/data/valid_data.csv")

### 데이터 확인

In [6]:
train.head()

Unnamed: 0,cust_no,label,E1,E2,E3,E4,E5,E6,E10,E14,...,농축업,무직,사무원,생산직,서비스직,은퇴,전문직,정치인,판매원,프리랜서
0,0xb2d283b6,1.0,965,965.0,965.0,1,0,209.0,210.0,18169,...,0,0,0,0,0,0,0,0,0,1
1,0xb2d62fab,1.0,368,368.0,368.0,0,1,173.0,275.0,18260,...,0,0,0,0,0,0,0,0,1,0
2,0xb2d69cdb,1.0,199,199.0,199.0,0,1,41.0,6.0,18168,...,0,0,0,0,0,0,0,0,1,0
3,0xb2d942e8,-1.0,120,120.0,120.0,0,1,11.0,1.0,18261,...,0,0,0,0,0,0,0,0,1,0
4,0xb2d9156f,1.0,40,40.0,40.0,0,0,18.0,1.0,18169,...,0,0,0,0,0,0,0,0,1,0


In [8]:
pd.set_option('display.max_rows',None)
train.isnull().sum()

cust_no     0
label       0
E1          0
E2          0
E3          0
E4          0
E5          0
E6          0
E10         0
E14         0
E15         0
E16         0
E17         0
E18         0
I1         54
I2          0
I3          0
I4          0
I6          0
I7          0
I11         0
I15         0
I16         0
I17         0
I18         0
I19         0
I20         0
X1_m1       0
X2_m1       0
X3_m1       0
X4_m1       0
X5_m1       0
X6_m1       0
X7_m1       0
X8_m1       0
B1_m1       0
B2_m1       0
B3_m1       0
B4_m1       0
B5_m1       0
C1_m1       0
C2_m1       0
X1_m2       0
X2_m2       0
X3_m2       0
X4_m2       0
X5_m2       0
X6_m2       0
X7_m2       0
X8_m2       0
B1_m2       0
B2_m2       0
B3_m2       0
B4_m2       0
B5_m2       0
C1_m2       0
C2_m2       0
X1_m3       0
X2_m3       0
X3_m3       0
X4_m3       0
X5_m3       0
X6_m3       0
X7_m3       0
X8_m3       0
B1_m3       0
B2_m3       0
B3_m3       0
B4_m3       0
B5_m3       0
B6_m3       0
B7_m3 

In [9]:
test.isnull().sum()

cust_no    0
label      0
E1         0
E2         0
E3         0
E4         0
E5         0
E6         0
E10        0
E14        0
E15        0
E16        0
E17        0
E18        0
I1         7
I2         0
I3         0
I4         0
I6         0
I7         0
I11        0
I15        0
I16        0
I17        0
I18        0
I19        0
I20        0
X1_m1      0
X2_m1      0
X3_m1      0
X4_m1      0
X5_m1      0
X6_m1      0
X7_m1      0
X8_m1      0
B1_m1      0
B2_m1      0
B3_m1      0
B4_m1      0
B5_m1      0
C1_m1      0
C2_m1      0
X1_m2      0
X2_m2      0
X3_m2      0
X4_m2      0
X5_m2      0
X6_m2      0
X7_m2      0
X8_m2      0
B1_m2      0
B2_m2      0
B3_m2      0
B4_m2      0
B5_m2      0
C1_m2      0
C2_m2      0
X1_m3      0
X2_m3      0
X3_m3      0
X4_m3      0
X5_m3      0
X6_m3      0
X7_m3      0
X8_m3      0
B1_m3      0
B2_m3      0
B3_m3      0
B4_m3      0
B5_m3      0
B6_m3      0
B7_m3      0
C1_m3      0
C2_m3      0
군인         0
농축업        0
무직         0

In [10]:
valid.isnull().sum()

cust_no    0
label      0
E1         0
E2         0
E3         0
E4         0
E5         0
E6         0
E10        0
E14        0
E15        0
E16        0
E17        0
E18        0
I1         3
I2         0
I3         0
I4         0
I6         0
I7         0
I11        0
I15        0
I16        0
I17        0
I18        0
I19        0
I20        0
X1_m1      0
X2_m1      0
X3_m1      0
X4_m1      0
X5_m1      0
X6_m1      0
X7_m1      0
X8_m1      0
B1_m1      0
B2_m1      0
B3_m1      0
B4_m1      0
B5_m1      0
C1_m1      0
C2_m1      0
X1_m2      0
X2_m2      0
X3_m2      0
X4_m2      0
X5_m2      0
X6_m2      0
X7_m2      0
X8_m2      0
B1_m2      0
B2_m2      0
B3_m2      0
B4_m2      0
B5_m2      0
C1_m2      0
C2_m2      0
X1_m3      0
X2_m3      0
X3_m3      0
X4_m3      0
X5_m3      0
X6_m3      0
X7_m3      0
X8_m3      0
B1_m3      0
B2_m3      0
B3_m3      0
B4_m3      0
B5_m3      0
B6_m3      0
B7_m3      0
C1_m3      0
C2_m3      0
군인         0
농축업        0
무직         0

In [14]:
pd.reset_option('max_rows')

##### 전부 I1열에만 결측치가 있다. 일단 모든 결측치를 0.5로 대체하자.

In [15]:
pd.reset_option('max_rows')
train_fil = train.fillna(0.5)
test_fil = test.fillna(0.5)
valid_fil = valid.fillna(0.5)

##### data split

In [16]:
x_train, y_train = train_fil.drop(['label','cust_no'], axis=1), train_fil["label"]
x_valid, y_valid = valid_fil.drop(['label','cust_no'],axis=1), valid_fil['label']
x_test, y_test = test_fil.drop(['label','cust_no'], axis=1), test_fil['label']

### 모델

In [17]:
clf = TabNetClassifier(optimizer_fn=torch.optim.Adam,
                        mask_type='sparsemax')
clf.fit(x_train.values, y_train.values, 
        eval_set=[(x_train.values, y_train.values), (x_valid.values, y_valid.values)],
        eval_name=['train', 'valid'],
        eval_metric=['accuracy'],
        drop_last=False)



epoch 0  | loss: 0.88959 | train_accuracy: 0.63802 | valid_accuracy: 0.6433  |  0:00:11s
epoch 1  | loss: 0.83149 | train_accuracy: 0.62542 | valid_accuracy: 0.6307  |  0:00:23s
epoch 2  | loss: 0.81024 | train_accuracy: 0.63866 | valid_accuracy: 0.6438  |  0:00:34s
epoch 3  | loss: 0.79305 | train_accuracy: 0.63754 | valid_accuracy: 0.6435  |  0:00:46s
epoch 4  | loss: 0.78455 | train_accuracy: 0.63855 | valid_accuracy: 0.6437  |  0:00:58s
epoch 5  | loss: 0.77649 | train_accuracy: 0.63861 | valid_accuracy: 0.6438  |  0:01:10s
epoch 6  | loss: 0.77274 | train_accuracy: 0.63946 | valid_accuracy: 0.6467  |  0:01:21s
epoch 7  | loss: 0.75979 | train_accuracy: 0.63959 | valid_accuracy: 0.6451  |  0:01:34s
epoch 8  | loss: 0.75394 | train_accuracy: 0.64109 | valid_accuracy: 0.6476  |  0:01:46s
epoch 9  | loss: 0.75742 | train_accuracy: 0.63869 | valid_accuracy: 0.6439  |  0:01:58s


KeyboardInterrupt: 