In [1]:
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.model_selection import KFold, StratifiedKFold

from pytorch_tabnet.tab_model import TabNetClassifier
import torch
from sklearn.preprocessing import StandardScaler

In [2]:
feat = "feat03"
train = pl.read_csv(f"feat/feat_train_{feat}.csv")
test = pl.read_csv(f"feat/feat_test_{feat}.csv")
train_origin = pl.read_csv("data/train.csv").rename({"": "idx"})

cols_exp = [c for c in test.columns if c != "idx"]

### Tabnet

In [3]:
def train_tabnet(train, cols_exp, col_target, params=None):
    if params is None:
        params = {}
        
    params_add = {"device_name": "cpu", "seed": 0}
    params |= params_add

    x = train[cols_exp].to_numpy()
    y = train[col_target].to_numpy()
    
    # K-fold
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    y_valid_pred_lst = []
    idx_valid_lst = []
    clf_lst = []

    # cross validation
    for fold, (idx_train, idx_valid) in enumerate(kf.split(x)):
        print("fold", fold)
        x_train = x[idx_train, :]
        x_valid = x[idx_valid, :]
        y_train = y[idx_train]
        y_valid = y[idx_valid]
        
        # normalization
        scaler = StandardScaler()
        x_train = scaler.fit_transform(x_train)
        x_valid = scaler.transform(x_valid)

        # modeling
        clf = TabNetClassifier(**params)
        clf.fit(
            x_train, y_train,
            eval_set=[(x_train, y_train), (x_valid, y_valid)],
            eval_name=['train', 'valid'],
            eval_metric=["logloss"]
        )   

        # oof
        y_valid_pred = clf.predict_proba(x_valid)
        y_valid_pred_lst.append(y_valid_pred)
        idx_valid_lst.append(idx_valid)
        clf_lst.append(clf)

    idx_valid = np.hstack(idx_valid_lst)
    y_valid_pred = np.vstack(y_valid_pred_lst)
    oof_pred = y_valid_pred[np.argsort(idx_valid)]

    return clf_lst, oof_pred

In [4]:
def predict_test(x_test, clf_lst):
    y_test_pred_lst = []

    for clf in clf_lst:
        y_test_pred = clf.predict_proba(x_test)
        y_test_pred_lst.append(y_test_pred)

    y_test_pred = np.mean(y_test_pred_lst, axis=0)
    return y_test_pred

In [5]:
col_target = "health"

# train tabnet
clf_lst, oof_pred = train_tabnet(train, cols_exp, col_target)

# normalization for test
scaler = StandardScaler()
scaler.fit(train[cols_exp].to_numpy())
x_test = scaler.transform(test[cols_exp].to_numpy())

# predict test with CV ensemble
y_test_pred = predict_test(x_test, clf_lst)

# record
oof_pred_df = pl.DataFrame(oof_pred, schema=[f"health_is_{h}" for h in range(3)])
test_pred_df = pl.DataFrame(y_test_pred, schema=[f"health_is_{h}" for h in range(3)])

fold 0




epoch 0  | loss: 0.81184 | train_logloss: 0.66039 | valid_logloss: 0.66231 |  0:00:01s
epoch 1  | loss: 0.64572 | train_logloss: 0.62549 | valid_logloss: 0.62412 |  0:00:02s
epoch 2  | loss: 0.6218  | train_logloss: 0.61621 | valid_logloss: 0.62384 |  0:00:03s
epoch 3  | loss: 0.62015 | train_logloss: 0.61344 | valid_logloss: 0.62138 |  0:00:05s
epoch 4  | loss: 0.61501 | train_logloss: 0.61009 | valid_logloss: 0.61645 |  0:00:06s
epoch 5  | loss: 0.61478 | train_logloss: 0.60977 | valid_logloss: 0.61783 |  0:00:07s
epoch 6  | loss: 0.60968 | train_logloss: 0.6086  | valid_logloss: 0.61463 |  0:00:08s
epoch 7  | loss: 0.6117  | train_logloss: 0.60705 | valid_logloss: 0.61662 |  0:00:10s
epoch 8  | loss: 0.61151 | train_logloss: 0.60884 | valid_logloss: 0.61782 |  0:00:11s
epoch 9  | loss: 0.6113  | train_logloss: 0.60814 | valid_logloss: 0.61307 |  0:00:12s
epoch 10 | loss: 0.61086 | train_logloss: 0.60836 | valid_logloss: 0.61404 |  0:00:13s
epoch 11 | loss: 0.61184 | train_logloss: 0



fold 1




epoch 0  | loss: 0.7947  | train_logloss: 0.66145 | valid_logloss: 0.66717 |  0:00:01s
epoch 1  | loss: 0.64153 | train_logloss: 0.62822 | valid_logloss: 0.63213 |  0:00:02s
epoch 2  | loss: 0.62133 | train_logloss: 0.61512 | valid_logloss: 0.62471 |  0:00:03s
epoch 3  | loss: 0.61774 | train_logloss: 0.61117 | valid_logloss: 0.62156 |  0:00:05s
epoch 4  | loss: 0.6157  | train_logloss: 0.60967 | valid_logloss: 0.61815 |  0:00:06s
epoch 5  | loss: 0.61376 | train_logloss: 0.60844 | valid_logloss: 0.61878 |  0:00:07s
epoch 6  | loss: 0.6123  | train_logloss: 0.60752 | valid_logloss: 0.61753 |  0:00:09s
epoch 7  | loss: 0.61061 | train_logloss: 0.60751 | valid_logloss: 0.61598 |  0:00:10s
epoch 8  | loss: 0.61211 | train_logloss: 0.60775 | valid_logloss: 0.6169  |  0:00:11s
epoch 9  | loss: 0.60904 | train_logloss: 0.60682 | valid_logloss: 0.61714 |  0:00:13s
epoch 10 | loss: 0.60927 | train_logloss: 0.60547 | valid_logloss: 0.6171  |  0:00:14s
epoch 11 | loss: 0.60562 | train_logloss: 0



fold 2




epoch 0  | loss: 0.81861 | train_logloss: 0.66041 | valid_logloss: 0.66651 |  0:00:01s
epoch 1  | loss: 0.64704 | train_logloss: 0.63803 | valid_logloss: 0.64444 |  0:00:02s
epoch 2  | loss: 0.63148 | train_logloss: 0.62515 | valid_logloss: 0.61653 |  0:00:03s
epoch 3  | loss: 0.62415 | train_logloss: 0.62087 | valid_logloss: 0.60562 |  0:00:05s
epoch 4  | loss: 0.6195  | train_logloss: 0.6162  | valid_logloss: 0.60021 |  0:00:06s
epoch 5  | loss: 0.62027 | train_logloss: 0.61491 | valid_logloss: 0.59122 |  0:00:07s
epoch 6  | loss: 0.61502 | train_logloss: 0.61498 | valid_logloss: 0.5911  |  0:00:09s
epoch 7  | loss: 0.61598 | train_logloss: 0.61461 | valid_logloss: 0.59223 |  0:00:10s
epoch 8  | loss: 0.61596 | train_logloss: 0.61244 | valid_logloss: 0.59124 |  0:00:11s
epoch 9  | loss: 0.6139  | train_logloss: 0.61275 | valid_logloss: 0.59166 |  0:00:12s
epoch 10 | loss: 0.61495 | train_logloss: 0.61222 | valid_logloss: 0.59123 |  0:00:14s
epoch 11 | loss: 0.61463 | train_logloss: 0



fold 3




epoch 0  | loss: 0.8045  | train_logloss: 0.65694 | valid_logloss: 0.68137 |  0:00:01s
epoch 1  | loss: 0.64184 | train_logloss: 0.62712 | valid_logloss: 0.6394  |  0:00:02s
epoch 2  | loss: 0.62054 | train_logloss: 0.61196 | valid_logloss: 0.61955 |  0:00:03s
epoch 3  | loss: 0.61498 | train_logloss: 0.61043 | valid_logloss: 0.61909 |  0:00:05s
epoch 4  | loss: 0.61383 | train_logloss: 0.60919 | valid_logloss: 0.6192  |  0:00:06s
epoch 5  | loss: 0.61053 | train_logloss: 0.60761 | valid_logloss: 0.61824 |  0:00:07s
epoch 6  | loss: 0.60788 | train_logloss: 0.60647 | valid_logloss: 0.61584 |  0:00:08s
epoch 7  | loss: 0.60781 | train_logloss: 0.60706 | valid_logloss: 0.62026 |  0:00:10s
epoch 8  | loss: 0.60982 | train_logloss: 0.60687 | valid_logloss: 0.61493 |  0:00:11s
epoch 9  | loss: 0.61128 | train_logloss: 0.60696 | valid_logloss: 0.61281 |  0:00:12s
epoch 10 | loss: 0.60969 | train_logloss: 0.60633 | valid_logloss: 0.6176  |  0:00:14s
epoch 11 | loss: 0.60586 | train_logloss: 0



fold 4




epoch 0  | loss: 0.8152  | train_logloss: 0.67626 | valid_logloss: 0.67423 |  0:00:01s
epoch 1  | loss: 0.64326 | train_logloss: 0.62548 | valid_logloss: 0.63865 |  0:00:02s
epoch 2  | loss: 0.62159 | train_logloss: 0.61459 | valid_logloss: 0.62492 |  0:00:03s
epoch 3  | loss: 0.61521 | train_logloss: 0.61119 | valid_logloss: 0.6234  |  0:00:05s
epoch 4  | loss: 0.61453 | train_logloss: 0.61076 | valid_logloss: 0.62538 |  0:00:06s
epoch 5  | loss: 0.61265 | train_logloss: 0.60912 | valid_logloss: 0.62315 |  0:00:07s
epoch 6  | loss: 0.61035 | train_logloss: 0.60847 | valid_logloss: 0.62224 |  0:00:09s
epoch 7  | loss: 0.61114 | train_logloss: 0.608   | valid_logloss: 0.62081 |  0:00:10s
epoch 8  | loss: 0.60923 | train_logloss: 0.60744 | valid_logloss: 0.61992 |  0:00:11s
epoch 9  | loss: 0.60807 | train_logloss: 0.6064  | valid_logloss: 0.62032 |  0:00:12s
epoch 10 | loss: 0.61035 | train_logloss: 0.60551 | valid_logloss: 0.62048 |  0:00:14s
epoch 11 | loss: 0.60769 | train_logloss: 0



In [6]:
# save
oof_pred_df.write_csv(f"pred/oof_pred_tabnet_{feat}.csv")
test_pred_df.write_csv(f"pred/test_pred_tabnet_{feat}.csv")