In [1]:
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.model_selection import KFold, StratifiedKFold

from pytorch_tabnet.pretraining import TabNetPretrainer
from pytorch_tabnet.tab_model import TabNetClassifier
import torch
from sklearn.preprocessing import StandardScaler

In [2]:
feat = "feat00"
train = pl.read_csv(f"feat/feat_train_{feat}.csv")
test = pl.read_csv(f"feat/feat_test_{feat}.csv")
train_origin = pl.read_csv("data/train.csv").rename({"": "idx"})

# 説明変数のカラム
cols_exp = [c for c in test.columns if c != "idx"]

# カテゴリ特徴量のカラム
cols_notcat = ['idx', 'created_at', 'tree_dbh']
cols_cat = [c for c in test.columns if not c in cols_notcat] # カテゴリ特徴量

for col in cols_cat:
    # もし欠損値があればその特徴量のユニーク数(欠損値を除く)で埋める（ordinal encoding前提）
    num_null = train[col].n_unique() - 1
    train = train.with_columns(train[col].fill_null(num_null))
    test = test.with_columns(test[col].fill_null(num_null))

# 100以上のユニーク数をもつ特徴量を削除（train/valid split時にordinal encodingの連番の整合性が崩れるため）
# ※ pretrain実施しない場合はコメントアウトする
cols_exp = [c for c in cols_exp if not c in ("boro_ct", "spc_latin", "nta")]
cols_cat = [c for c in cols_cat if not c in ("boro_ct", "spc_latin", "nta")]

### Tabnet

In [3]:
def train_tabnet(train, cols_exp, cols_cat, col_target, params=None):
    if params is None:
        params = {}

    params_add = {"device_name": "cpu", "seed": 0}
    params |= params_add

    x = train[cols_exp].to_numpy()
    y = train[col_target].to_numpy()

    # cols_expにおけるカテゴリ変数のインデックス
    cols_cat_idxs = [i for i, c in enumerate(cols_exp) if c in cols_cat]
    cols_num_idxs = [i for i, c in enumerate(cols_exp) if not c in cols_cat]
    cols_cat_dims = train.approx_n_unique()[cols_exp].to_numpy().ravel()[cols_cat_idxs]

    # K-fold
    kf = KFold(n_splits=5, shuffle=True, random_state=0)
    y_valid_pred_lst = []
    idx_valid_lst = []
    clf_lst = []

    # cross validation
    for fold, (idx_train, idx_valid) in enumerate(kf.split(x)):
        print("fold", fold)
        x_train = x[idx_train, :]
        x_valid = x[idx_valid, :]
        y_train = y[idx_train]
        y_valid = y[idx_valid]
        
        # normalization
        scaler = StandardScaler()
        x_train[:, cols_num_idxs] = scaler.fit_transform(x_train[:, cols_num_idxs])
        x_valid[:, cols_num_idxs] = scaler.transform(x_valid[:, cols_num_idxs])

        # modeling
        pretrainer = TabNetPretrainer(**params)
        pretrainer.fit(x_train, eval_set=[x_valid])
        clf = TabNetClassifier(**params, cat_idxs=cols_cat_idxs, cat_dims=cols_cat_dims)
        clf.fit(
            x_train, y_train,
            eval_set=[(x_train, y_train), (x_valid, y_valid)],
            eval_name=['train', 'valid'],
            eval_metric=["logloss"], 
            from_unsupervised=pretrainer
        )   
        
        # oof
        y_valid_pred = clf.predict_proba(x_valid)
        y_valid_pred_lst.append(y_valid_pred)
        idx_valid_lst.append(idx_valid)
        clf_lst.append(clf)

    idx_valid = np.hstack(idx_valid_lst)
    y_valid_pred = np.vstack(y_valid_pred_lst)
    oof_pred = y_valid_pred[np.argsort(idx_valid)]

    return clf_lst, oof_pred

In [4]:
def predict_test(x_test, clf_lst):
    y_test_pred_lst = []

    for clf in clf_lst:
        y_test_pred = clf.predict_proba(x_test)
        y_test_pred_lst.append(y_test_pred)

    y_test_pred = np.mean(y_test_pred_lst, axis=0)
    return y_test_pred

In [5]:
col_target = "health"

# train tabnet
clf_lst, oof_pred = train_tabnet(train, cols_exp, cols_cat, col_target)

# normalization for test
x_train = train[cols_exp].to_numpy()
x_test = test[cols_exp].to_numpy()
cols_num_idxs = [i for i, c in enumerate(cols_exp) if not c in cols_cat]
scaler = StandardScaler()
scaler.fit(x_train[:, cols_num_idxs])
x_test[:, cols_num_idxs] = scaler.transform(x_test[:, cols_num_idxs])

# predict test with CV ensemble
y_test_pred = predict_test(x_test, clf_lst)

# record
oof_pred_df = pl.DataFrame(oof_pred, schema=[f"health_is_{h}" for h in range(3)])
test_pred_df = pl.DataFrame(y_test_pred, schema=[f"health_is_{h}" for h in range(3)])

fold 0




epoch 0  | loss: 26.76198| val_0_unsup_loss_numpy: 20.42473030090332|  0:00:00s
epoch 1  | loss: 4.08857 | val_0_unsup_loss_numpy: 10.79423999786377|  0:00:01s
epoch 2  | loss: 1.95025 | val_0_unsup_loss_numpy: 5.092899799346924|  0:00:02s
epoch 3  | loss: 1.46901 | val_0_unsup_loss_numpy: 3.2443299293518066|  0:00:03s
epoch 4  | loss: 1.27537 | val_0_unsup_loss_numpy: 1.1792399883270264|  0:00:04s
epoch 5  | loss: 1.1298  | val_0_unsup_loss_numpy: 1.4030499458312988|  0:00:05s
epoch 6  | loss: 1.08237 | val_0_unsup_loss_numpy: 1.2402499914169312|  0:00:06s
epoch 7  | loss: 1.20298 | val_0_unsup_loss_numpy: 1.3411699533462524|  0:00:07s
epoch 8  | loss: 1.13437 | val_0_unsup_loss_numpy: 1.1023600101470947|  0:00:08s
epoch 9  | loss: 1.00452 | val_0_unsup_loss_numpy: 1.2569999694824219|  0:00:09s
epoch 10 | loss: 1.02615 | val_0_unsup_loss_numpy: 1.0763499736785889|  0:00:09s
epoch 11 | loss: 0.99021 | val_0_unsup_loss_numpy: 0.9975200295448303|  0:00:10s
epoch 12 | loss: 1.00271 | val_



epoch 0  | loss: 0.74412 | train_logloss: 0.66233 | valid_logloss: 0.67185 |  0:00:01s
epoch 1  | loss: 0.62741 | train_logloss: 0.63318 | valid_logloss: 0.64351 |  0:00:02s
epoch 2  | loss: 0.6168  | train_logloss: 0.61387 | valid_logloss: 0.61974 |  0:00:03s
epoch 3  | loss: 0.61381 | train_logloss: 0.61287 | valid_logloss: 0.61755 |  0:00:04s
epoch 4  | loss: 0.61177 | train_logloss: 0.61425 | valid_logloss: 0.62027 |  0:00:05s
epoch 5  | loss: 0.61214 | train_logloss: 0.61341 | valid_logloss: 0.61928 |  0:00:06s
epoch 6  | loss: 0.61252 | train_logloss: 0.61258 | valid_logloss: 0.61788 |  0:00:07s
epoch 7  | loss: 0.61112 | train_logloss: 0.60901 | valid_logloss: 0.6166  |  0:00:08s
epoch 8  | loss: 0.61022 | train_logloss: 0.60804 | valid_logloss: 0.61677 |  0:00:09s
epoch 9  | loss: 0.61043 | train_logloss: 0.60974 | valid_logloss: 0.61802 |  0:00:10s
epoch 10 | loss: 0.61256 | train_logloss: 0.61022 | valid_logloss: 0.61836 |  0:00:12s
epoch 11 | loss: 0.61066 | train_logloss: 0



fold 1




epoch 0  | loss: 44.57027| val_0_unsup_loss_numpy: 39.23863983154297|  0:00:00s
epoch 1  | loss: 3.43539 | val_0_unsup_loss_numpy: 6.510300159454346|  0:00:01s
epoch 2  | loss: 1.66558 | val_0_unsup_loss_numpy: 15.489069938659668|  0:00:02s
epoch 3  | loss: 1.27852 | val_0_unsup_loss_numpy: 2.172830104827881|  0:00:03s
epoch 4  | loss: 1.24417 | val_0_unsup_loss_numpy: 1.311750054359436|  0:00:04s
epoch 5  | loss: 1.19865 | val_0_unsup_loss_numpy: 1.9211100339889526|  0:00:05s
epoch 6  | loss: 1.14775 | val_0_unsup_loss_numpy: 1.2566399574279785|  0:00:06s
epoch 7  | loss: 1.1991  | val_0_unsup_loss_numpy: 1.7197599411010742|  0:00:07s
epoch 8  | loss: 1.33253 | val_0_unsup_loss_numpy: 6.3994598388671875|  0:00:08s
epoch 9  | loss: 1.19769 | val_0_unsup_loss_numpy: 1.115339994430542|  0:00:08s
epoch 10 | loss: 0.99937 | val_0_unsup_loss_numpy: 1.0732500553131104|  0:00:09s
epoch 11 | loss: 1.00555 | val_0_unsup_loss_numpy: 1.053760051727295|  0:00:10s
epoch 12 | loss: 1.02236 | val_0_u



epoch 0  | loss: 0.70838 | train_logloss: 0.68116 | valid_logloss: 0.68349 |  0:00:01s
epoch 1  | loss: 0.6201  | train_logloss: 0.62628 | valid_logloss: 0.63706 |  0:00:02s
epoch 2  | loss: 0.6135  | train_logloss: 0.61671 | valid_logloss: 0.62663 |  0:00:03s
epoch 3  | loss: 0.61093 | train_logloss: 0.61271 | valid_logloss: 0.62107 |  0:00:04s
epoch 4  | loss: 0.61016 | train_logloss: 0.60978 | valid_logloss: 0.62184 |  0:00:05s
epoch 5  | loss: 0.60988 | train_logloss: 0.6104  | valid_logloss: 0.62135 |  0:00:06s
epoch 6  | loss: 0.61089 | train_logloss: 0.60916 | valid_logloss: 0.62    |  0:00:07s
epoch 7  | loss: 0.60917 | train_logloss: 0.60842 | valid_logloss: 0.61975 |  0:00:08s
epoch 8  | loss: 0.609   | train_logloss: 0.6087  | valid_logloss: 0.62054 |  0:00:09s
epoch 9  | loss: 0.61044 | train_logloss: 0.60788 | valid_logloss: 0.61993 |  0:00:11s
epoch 10 | loss: 0.60801 | train_logloss: 0.60836 | valid_logloss: 0.62118 |  0:00:12s
epoch 11 | loss: 0.60731 | train_logloss: 0



fold 2




epoch 0  | loss: 36.11608| val_0_unsup_loss_numpy: 75.12001037597656|  0:00:00s
epoch 1  | loss: 3.49095 | val_0_unsup_loss_numpy: 55.85118103027344|  0:00:01s
epoch 2  | loss: 1.65772 | val_0_unsup_loss_numpy: 7.563960075378418|  0:00:02s
epoch 3  | loss: 1.47043 | val_0_unsup_loss_numpy: 2.963409900665283|  0:00:03s
epoch 4  | loss: 1.34936 | val_0_unsup_loss_numpy: 3.136539936065674|  0:00:04s
epoch 5  | loss: 1.22492 | val_0_unsup_loss_numpy: 2.168299913406372|  0:00:05s
epoch 6  | loss: 1.13605 | val_0_unsup_loss_numpy: 1.23239004611969|  0:00:06s
epoch 7  | loss: 1.09811 | val_0_unsup_loss_numpy: 1.046180009841919|  0:00:07s
epoch 8  | loss: 1.09654 | val_0_unsup_loss_numpy: 1.6035300493240356|  0:00:08s
epoch 9  | loss: 1.09853 | val_0_unsup_loss_numpy: 1.6876599788665771|  0:00:09s
epoch 10 | loss: 1.08077 | val_0_unsup_loss_numpy: 1.2708200216293335|  0:00:09s
epoch 11 | loss: 1.36933 | val_0_unsup_loss_numpy: 6.931879997253418|  0:00:10s
epoch 12 | loss: 1.07511 | val_0_unsup



epoch 0  | loss: 0.7308  | train_logloss: 0.6739  | valid_logloss: 0.64184 |  0:00:01s
epoch 1  | loss: 0.63446 | train_logloss: 0.63321 | valid_logloss: 0.60459 |  0:00:02s
epoch 2  | loss: 0.6247  | train_logloss: 0.62651 | valid_logloss: 0.60049 |  0:00:03s
epoch 3  | loss: 0.61764 | train_logloss: 0.6194  | valid_logloss: 0.59891 |  0:00:04s
epoch 4  | loss: 0.61806 | train_logloss: 0.61714 | valid_logloss: 0.59522 |  0:00:05s
epoch 5  | loss: 0.61509 | train_logloss: 0.61794 | valid_logloss: 0.59646 |  0:00:06s
epoch 6  | loss: 0.61586 | train_logloss: 0.6158  | valid_logloss: 0.59511 |  0:00:07s
epoch 7  | loss: 0.61695 | train_logloss: 0.61839 | valid_logloss: 0.59599 |  0:00:08s
epoch 8  | loss: 0.61742 | train_logloss: 0.61597 | valid_logloss: 0.59829 |  0:00:09s
epoch 9  | loss: 0.61764 | train_logloss: 0.6147  | valid_logloss: 0.59459 |  0:00:11s
epoch 10 | loss: 0.61501 | train_logloss: 0.61584 | valid_logloss: 0.59514 |  0:00:12s
epoch 11 | loss: 0.61613 | train_logloss: 0



fold 3




epoch 0  | loss: 35.71334| val_0_unsup_loss_numpy: 177.01922607421875|  0:00:00s
epoch 1  | loss: 3.20178 | val_0_unsup_loss_numpy: 14.371740341186523|  0:00:01s
epoch 2  | loss: 1.74427 | val_0_unsup_loss_numpy: 24.39349937438965|  0:00:02s
epoch 3  | loss: 1.31387 | val_0_unsup_loss_numpy: 3.773940086364746|  0:00:03s
epoch 4  | loss: 1.20053 | val_0_unsup_loss_numpy: 4.367929935455322|  0:00:04s
epoch 5  | loss: 1.14815 | val_0_unsup_loss_numpy: 3.998610019683838|  0:00:05s
epoch 6  | loss: 1.16136 | val_0_unsup_loss_numpy: 1.8066400289535522|  0:00:06s
epoch 7  | loss: 1.08228 | val_0_unsup_loss_numpy: 1.2938300371170044|  0:00:07s
epoch 8  | loss: 1.0377  | val_0_unsup_loss_numpy: 2.125309944152832|  0:00:08s
epoch 9  | loss: 1.13309 | val_0_unsup_loss_numpy: 5.820350170135498|  0:00:08s
epoch 10 | loss: 1.2986  | val_0_unsup_loss_numpy: 2.0673599243164062|  0:00:09s
epoch 11 | loss: 1.02789 | val_0_unsup_loss_numpy: 0.9876000285148621|  0:00:10s
epoch 12 | loss: 0.9812  | val_0_u



epoch 0  | loss: 0.72688 | train_logloss: 0.69214 | valid_logloss: 0.69    |  0:00:01s
epoch 1  | loss: 0.62343 | train_logloss: 0.65199 | valid_logloss: 0.65389 |  0:00:02s
epoch 2  | loss: 0.61114 | train_logloss: 0.64245 | valid_logloss: 0.64615 |  0:00:03s
epoch 3  | loss: 0.61264 | train_logloss: 0.62725 | valid_logloss: 0.63039 |  0:00:04s
epoch 4  | loss: 0.61227 | train_logloss: 0.62245 | valid_logloss: 0.62698 |  0:00:05s
epoch 5  | loss: 0.61029 | train_logloss: 0.61817 | valid_logloss: 0.62253 |  0:00:06s
epoch 6  | loss: 0.60948 | train_logloss: 0.61227 | valid_logloss: 0.61977 |  0:00:07s
epoch 7  | loss: 0.61072 | train_logloss: 0.61194 | valid_logloss: 0.61828 |  0:00:08s
epoch 8  | loss: 0.612   | train_logloss: 0.60993 | valid_logloss: 0.61871 |  0:00:09s
epoch 9  | loss: 0.61341 | train_logloss: 0.60933 | valid_logloss: 0.61686 |  0:00:11s
epoch 10 | loss: 0.61248 | train_logloss: 0.60933 | valid_logloss: 0.61491 |  0:00:12s
epoch 11 | loss: 0.61175 | train_logloss: 0



fold 4




epoch 0  | loss: 28.38081| val_0_unsup_loss_numpy: 116.68115234375|  0:00:00s
epoch 1  | loss: 3.30085 | val_0_unsup_loss_numpy: 6.321189880371094|  0:00:01s
epoch 2  | loss: 1.67137 | val_0_unsup_loss_numpy: 3.409219980239868|  0:00:02s
epoch 3  | loss: 1.36336 | val_0_unsup_loss_numpy: 2.843679904937744|  0:00:03s
epoch 4  | loss: 1.2589  | val_0_unsup_loss_numpy: 2.1197500228881836|  0:00:04s
epoch 5  | loss: 1.27764 | val_0_unsup_loss_numpy: 1.4433799982070923|  0:00:05s
epoch 6  | loss: 1.12332 | val_0_unsup_loss_numpy: 2.4402499198913574|  0:00:06s
epoch 7  | loss: 1.06985 | val_0_unsup_loss_numpy: 1.2499099969863892|  0:00:07s
epoch 8  | loss: 1.15865 | val_0_unsup_loss_numpy: 2.115540027618408|  0:00:08s
epoch 9  | loss: 1.01649 | val_0_unsup_loss_numpy: 1.3405499458312988|  0:00:08s
epoch 10 | loss: 1.03331 | val_0_unsup_loss_numpy: 1.1784499883651733|  0:00:09s
epoch 11 | loss: 1.02818 | val_0_unsup_loss_numpy: 1.2469899654388428|  0:00:10s
epoch 12 | loss: 0.98553 | val_0_un



epoch 0  | loss: 0.73888 | train_logloss: 0.66243 | valid_logloss: 0.68047 |  0:00:01s
epoch 1  | loss: 0.62156 | train_logloss: 0.62205 | valid_logloss: 0.64172 |  0:00:02s
epoch 2  | loss: 0.61597 | train_logloss: 0.61653 | valid_logloss: 0.63132 |  0:00:03s
epoch 3  | loss: 0.61316 | train_logloss: 0.61374 | valid_logloss: 0.6285  |  0:00:04s
epoch 4  | loss: 0.61226 | train_logloss: 0.61406 | valid_logloss: 0.63053 |  0:00:05s
epoch 5  | loss: 0.612   | train_logloss: 0.612   | valid_logloss: 0.6232  |  0:00:06s
epoch 6  | loss: 0.61168 | train_logloss: 0.61004 | valid_logloss: 0.628   |  0:00:07s
epoch 7  | loss: 0.6106  | train_logloss: 0.60701 | valid_logloss: 0.62141 |  0:00:08s
epoch 8  | loss: 0.60685 | train_logloss: 0.60753 | valid_logloss: 0.62426 |  0:00:09s
epoch 9  | loss: 0.60839 | train_logloss: 0.60518 | valid_logloss: 0.62114 |  0:00:11s
epoch 10 | loss: 0.60828 | train_logloss: 0.60609 | valid_logloss: 0.62484 |  0:00:12s
epoch 11 | loss: 0.60835 | train_logloss: 0



In [6]:
# save
oof_pred_df.write_csv(f"pred/oof_pred_tabnet_{feat}.csv")
test_pred_df.write_csv(f"pred/test_pred_tabnet_{feat}.csv")