In [1]:
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.model_selection import KFold, StratifiedKFold

from pytorch_tabnet.pretraining import TabNetPretrainer
from pytorch_tabnet.tab_model import TabNetClassifier
import torch
from sklearn.preprocessing import StandardScaler

In [2]:
feat = "feat00"
train = pl.read_csv(f"feat/feat_train_{feat}.csv")
test = pl.read_csv(f"feat/feat_test_{feat}.csv")
train_origin = pl.read_csv("data/train.csv").rename({"": "idx"})

# 説明変数のカラム
cols_exp = [c for c in test.columns if c != "idx"]

# カテゴリ特徴量のカラム
cols_notcat = ['idx', 'created_at', 'tree_dbh']
cols_cat = [c for c in test.columns if not c in cols_notcat] # カテゴリ特徴量

for col in cols_cat:
    # もし欠損値があればその特徴量のユニーク数(欠損値を除く)で埋める（ordinal encoding前提）
    num_null = train[col].n_unique() - 1
    train = train.with_columns(train[col].fill_null(num_null))
    test = test.with_columns(test[col].fill_null(num_null))

# 100以上のユニーク数をもつ特徴量を削除（train/valid split時にordinal encodingの連番の整合性が崩れるため）
# ※ pretrain実施しない場合はコメントアウトする
cols_exp = [c for c in cols_exp if not c in ("boro_ct", "spc_latin", "nta")]
cols_cat = [c for c in cols_cat if not c in ("boro_ct", "spc_latin", "nta")]

### Tabnet

In [3]:
def train_tabnet(train, cols_exp, cols_cat, col_target, params=None):
    if params is None:
        params = {}

    params_add = {"device_name": "cpu", "seed": 0}
    params |= params_add

    x = train[cols_exp].to_numpy()
    y = train[col_target].to_numpy()

    # cols_expにおけるカテゴリ変数のインデックス
    cols_cat_idxs = [i for i, c in enumerate(cols_exp) if c in cols_cat]
    cols_num_idxs = [i for i, c in enumerate(cols_exp) if not c in cols_cat]
    cols_cat_dims = train.approx_n_unique()[cols_exp].to_numpy().ravel()[cols_cat_idxs]

    # K-fold
    kf = KFold(n_splits=4, shuffle=True, random_state=1)
    y_valid_pred_lst = []
    idx_valid_lst = []
    clf_lst = []

    # cross validation
    for fold, (idx_train, idx_valid) in enumerate(kf.split(x)):
        print("fold", fold)
        x_train = x[idx_train, :]
        x_valid = x[idx_valid, :]
        y_train = y[idx_train]
        y_valid = y[idx_valid]
        
        # normalization
        scaler = StandardScaler()
        x_train[:, cols_num_idxs] = scaler.fit_transform(x_train[:, cols_num_idxs])
        x_valid[:, cols_num_idxs] = scaler.transform(x_valid[:, cols_num_idxs])

        # modeling
        pretrainer = TabNetPretrainer(**params)
        pretrainer.fit(x_train, eval_set=[x_valid])
        clf = TabNetClassifier(**params, cat_idxs=cols_cat_idxs, cat_dims=cols_cat_dims)
        clf.fit(
            x_train, y_train,
            eval_set=[(x_train, y_train), (x_valid, y_valid)],
            eval_name=['train', 'valid'],
            eval_metric=["logloss"], 
            from_unsupervised=pretrainer
        )   
        
        # oof
        y_valid_pred = clf.predict_proba(x_valid)
        y_valid_pred_lst.append(y_valid_pred)
        idx_valid_lst.append(idx_valid)
        clf_lst.append(clf)

    idx_valid = np.hstack(idx_valid_lst)
    y_valid_pred = np.vstack(y_valid_pred_lst)
    oof_pred = y_valid_pred[np.argsort(idx_valid)]

    return clf_lst, oof_pred

In [4]:
def predict_test(x_test, clf_lst):
    y_test_pred_lst = []

    for clf in clf_lst:
        y_test_pred = clf.predict_proba(x_test)
        y_test_pred_lst.append(y_test_pred)

    y_test_pred = np.mean(y_test_pred_lst, axis=0)
    return y_test_pred

In [5]:
col_target = "health"

# train tabnet
clf_lst, oof_pred = train_tabnet(train, cols_exp, cols_cat, col_target)

# normalization for test
x_train = train[cols_exp].to_numpy()
x_test = test[cols_exp].to_numpy()
cols_num_idxs = [i for i, c in enumerate(cols_exp) if not c in cols_cat]
scaler = StandardScaler()
scaler.fit(x_train[:, cols_num_idxs])
x_test[:, cols_num_idxs] = scaler.transform(x_test[:, cols_num_idxs])

# predict test with CV ensemble
y_test_pred = predict_test(x_test, clf_lst)

# record
oof_pred_df = pl.DataFrame(oof_pred, schema=[f"health_is_{h}" for h in range(3)])
test_pred_df = pl.DataFrame(y_test_pred, schema=[f"health_is_{h}" for h in range(3)])

fold 0




epoch 0  | loss: 28.8181 | val_0_unsup_loss_numpy: 918.6209716796875|  0:00:00s
epoch 1  | loss: 3.33351 | val_0_unsup_loss_numpy: 71.92803955078125|  0:00:01s
epoch 2  | loss: 1.72215 | val_0_unsup_loss_numpy: 15.211400032043457|  0:00:02s
epoch 3  | loss: 1.3846  | val_0_unsup_loss_numpy: 7.321239948272705|  0:00:03s
epoch 4  | loss: 1.22386 | val_0_unsup_loss_numpy: 1.7450499534606934|  0:00:04s
epoch 5  | loss: 1.19924 | val_0_unsup_loss_numpy: 2.343100070953369|  0:00:05s
epoch 6  | loss: 1.18146 | val_0_unsup_loss_numpy: 2.298029899597168|  0:00:07s
epoch 7  | loss: 1.12498 | val_0_unsup_loss_numpy: 1.262120008468628|  0:00:07s
epoch 8  | loss: 1.03204 | val_0_unsup_loss_numpy: 2.7106399536132812|  0:00:08s
epoch 9  | loss: 1.22404 | val_0_unsup_loss_numpy: 2.265470027923584|  0:00:09s
epoch 10 | loss: 1.08398 | val_0_unsup_loss_numpy: 1.0026099681854248|  0:00:10s
epoch 11 | loss: 1.08066 | val_0_unsup_loss_numpy: 3.1536900997161865|  0:00:11s
epoch 12 | loss: 1.04316 | val_0_un



epoch 0  | loss: 0.74198 | train_logloss: 0.89261 | valid_logloss: 0.8958  |  0:00:01s
epoch 1  | loss: 0.6266  | train_logloss: 0.75247 | valid_logloss: 0.76859 |  0:00:02s
epoch 2  | loss: 0.6147  | train_logloss: 0.68914 | valid_logloss: 0.70327 |  0:00:03s
epoch 3  | loss: 0.61167 | train_logloss: 0.65762 | valid_logloss: 0.67227 |  0:00:04s
epoch 4  | loss: 0.61163 | train_logloss: 0.64534 | valid_logloss: 0.65692 |  0:00:05s
epoch 5  | loss: 0.61257 | train_logloss: 0.63846 | valid_logloss: 0.6497  |  0:00:06s
epoch 6  | loss: 0.61259 | train_logloss: 0.62375 | valid_logloss: 0.63746 |  0:00:07s
epoch 7  | loss: 0.6118  | train_logloss: 0.60888 | valid_logloss: 0.62084 |  0:00:08s
epoch 8  | loss: 0.6124  | train_logloss: 0.60846 | valid_logloss: 0.61898 |  0:00:09s
epoch 9  | loss: 0.61026 | train_logloss: 0.60796 | valid_logloss: 0.61975 |  0:00:10s
epoch 10 | loss: 0.61045 | train_logloss: 0.60714 | valid_logloss: 0.61905 |  0:00:11s
epoch 11 | loss: 0.60769 | train_logloss: 0



fold 1




epoch 0  | loss: 31.16613| val_0_unsup_loss_numpy: 13.487449645996094|  0:00:00s
epoch 1  | loss: 3.44688 | val_0_unsup_loss_numpy: 6.697360038757324|  0:00:01s
epoch 2  | loss: 1.89554 | val_0_unsup_loss_numpy: 4.012030124664307|  0:00:02s
epoch 3  | loss: 1.43937 | val_0_unsup_loss_numpy: 1.5916199684143066|  0:00:03s
epoch 4  | loss: 1.25375 | val_0_unsup_loss_numpy: 1.3181899785995483|  0:00:04s
epoch 5  | loss: 1.18857 | val_0_unsup_loss_numpy: 1.3203599452972412|  0:00:05s
epoch 6  | loss: 1.12348 | val_0_unsup_loss_numpy: 1.1567200422286987|  0:00:06s
epoch 7  | loss: 1.06686 | val_0_unsup_loss_numpy: 1.0848100185394287|  0:00:07s
epoch 8  | loss: 1.05477 | val_0_unsup_loss_numpy: 1.1174800395965576|  0:00:08s
epoch 9  | loss: 1.03106 | val_0_unsup_loss_numpy: 1.2991700172424316|  0:00:08s
epoch 10 | loss: 1.03883 | val_0_unsup_loss_numpy: 1.2236100435256958|  0:00:09s
epoch 11 | loss: 1.01266 | val_0_unsup_loss_numpy: 1.1615400314331055|  0:00:10s
epoch 12 | loss: 1.02552 | val



epoch 0  | loss: 0.6875  | train_logloss: 0.67518 | valid_logloss: 0.67746 |  0:00:01s
epoch 1  | loss: 0.62248 | train_logloss: 0.63321 | valid_logloss: 0.64365 |  0:00:02s
epoch 2  | loss: 0.61886 | train_logloss: 0.62167 | valid_logloss: 0.63158 |  0:00:03s
epoch 3  | loss: 0.61476 | train_logloss: 0.61495 | valid_logloss: 0.62449 |  0:00:04s
epoch 4  | loss: 0.61046 | train_logloss: 0.61515 | valid_logloss: 0.62565 |  0:00:05s
epoch 5  | loss: 0.60681 | train_logloss: 0.60999 | valid_logloss: 0.61884 |  0:00:06s
epoch 6  | loss: 0.61036 | train_logloss: 0.60823 | valid_logloss: 0.61683 |  0:00:07s
epoch 7  | loss: 0.60674 | train_logloss: 0.60791 | valid_logloss: 0.61657 |  0:00:09s
epoch 8  | loss: 0.60942 | train_logloss: 0.60723 | valid_logloss: 0.61532 |  0:00:10s
epoch 9  | loss: 0.6068  | train_logloss: 0.60714 | valid_logloss: 0.61636 |  0:00:11s
epoch 10 | loss: 0.61044 | train_logloss: 0.60807 | valid_logloss: 0.61513 |  0:00:12s
epoch 11 | loss: 0.60643 | train_logloss: 0



fold 2




epoch 0  | loss: 33.50326| val_0_unsup_loss_numpy: 77.64251708984375|  0:00:00s
epoch 1  | loss: 4.38511 | val_0_unsup_loss_numpy: 6.991390228271484|  0:00:01s
epoch 2  | loss: 1.8272  | val_0_unsup_loss_numpy: 110.569091796875|  0:00:02s
epoch 3  | loss: 1.36532 | val_0_unsup_loss_numpy: 5.118649959564209|  0:00:03s
epoch 4  | loss: 1.16673 | val_0_unsup_loss_numpy: 3.9049999713897705|  0:00:04s
epoch 5  | loss: 1.20438 | val_0_unsup_loss_numpy: 1.3443700075149536|  0:00:05s
epoch 6  | loss: 1.15774 | val_0_unsup_loss_numpy: 1.1534700393676758|  0:00:06s
epoch 7  | loss: 1.08048 | val_0_unsup_loss_numpy: 1.1506799459457397|  0:00:07s
epoch 8  | loss: 1.00594 | val_0_unsup_loss_numpy: 1.1089199781417847|  0:00:08s
epoch 9  | loss: 1.01416 | val_0_unsup_loss_numpy: 1.2882399559020996|  0:00:09s
epoch 10 | loss: 1.37113 | val_0_unsup_loss_numpy: 9.371789932250977|  0:00:10s
epoch 11 | loss: 1.10733 | val_0_unsup_loss_numpy: 1.0227199792861938|  0:00:11s
epoch 12 | loss: 1.00037 | val_0_u



epoch 0  | loss: 0.71659 | train_logloss: 0.68033 | valid_logloss: 0.69825 |  0:00:07s
epoch 1  | loss: 0.62568 | train_logloss: 0.62483 | valid_logloss: 0.62634 |  0:00:13s
epoch 2  | loss: 0.61738 | train_logloss: 0.61572 | valid_logloss: 0.61875 |  0:00:20s
epoch 3  | loss: 0.61249 | train_logloss: 0.61579 | valid_logloss: 0.61703 |  0:00:24s
epoch 4  | loss: 0.61265 | train_logloss: 0.61118 | valid_logloss: 0.61414 |  0:00:27s
epoch 5  | loss: 0.61146 | train_logloss: 0.61122 | valid_logloss: 0.61612 |  0:00:31s
epoch 6  | loss: 0.60922 | train_logloss: 0.60869 | valid_logloss: 0.6135  |  0:00:34s
epoch 7  | loss: 0.60761 | train_logloss: 0.60876 | valid_logloss: 0.61541 |  0:00:37s
epoch 8  | loss: 0.60924 | train_logloss: 0.60797 | valid_logloss: 0.61147 |  0:00:40s
epoch 9  | loss: 0.61042 | train_logloss: 0.6079  | valid_logloss: 0.61107 |  0:00:43s
epoch 10 | loss: 0.6086  | train_logloss: 0.60657 | valid_logloss: 0.61256 |  0:00:47s
epoch 11 | loss: 0.6085  | train_logloss: 0



fold 3




epoch 0  | loss: 23.10862| val_0_unsup_loss_numpy: 111.81297302246094|  0:00:05s
epoch 1  | loss: 3.57288 | val_0_unsup_loss_numpy: 2.5347800254821777|  0:00:08s
epoch 2  | loss: 1.72308 | val_0_unsup_loss_numpy: 2.92438006401062|  0:00:11s
epoch 3  | loss: 1.33641 | val_0_unsup_loss_numpy: 1.649899959564209|  0:00:14s
epoch 4  | loss: 1.23878 | val_0_unsup_loss_numpy: 1.967919945716858|  0:00:17s
epoch 5  | loss: 1.19382 | val_0_unsup_loss_numpy: 1.538100004196167|  0:00:19s
epoch 6  | loss: 1.08358 | val_0_unsup_loss_numpy: 2.203579902648926|  0:00:22s
epoch 7  | loss: 1.08237 | val_0_unsup_loss_numpy: 3.623990058898926|  0:00:25s
epoch 8  | loss: 1.19017 | val_0_unsup_loss_numpy: 2.926690101623535|  0:00:28s
epoch 9  | loss: 1.33867 | val_0_unsup_loss_numpy: 1.0580300092697144|  0:00:31s
epoch 10 | loss: 1.16182 | val_0_unsup_loss_numpy: 1.158560037612915|  0:00:34s
epoch 11 | loss: 1.0244  | val_0_unsup_loss_numpy: 1.731279969215393|  0:00:37s
epoch 12 | loss: 1.03484 | val_0_unsup



epoch 0  | loss: 0.71304 | train_logloss: 0.72079 | valid_logloss: 0.70695 |  0:00:03s
epoch 1  | loss: 0.62946 | train_logloss: 0.63984 | valid_logloss: 0.63003 |  0:00:06s
epoch 2  | loss: 0.62045 | train_logloss: 0.64003 | valid_logloss: 0.62979 |  0:00:10s
epoch 3  | loss: 0.61566 | train_logloss: 0.63207 | valid_logloss: 0.62059 |  0:00:14s
epoch 4  | loss: 0.61522 | train_logloss: 0.62878 | valid_logloss: 0.61804 |  0:00:19s
epoch 5  | loss: 0.61425 | train_logloss: 0.61946 | valid_logloss: 0.61124 |  0:00:23s
epoch 6  | loss: 0.61482 | train_logloss: 0.62101 | valid_logloss: 0.61543 |  0:00:27s
epoch 7  | loss: 0.61292 | train_logloss: 0.61572 | valid_logloss: 0.60796 |  0:00:30s
epoch 8  | loss: 0.61479 | train_logloss: 0.61301 | valid_logloss: 0.60448 |  0:00:34s
epoch 9  | loss: 0.61481 | train_logloss: 0.61695 | valid_logloss: 0.60831 |  0:00:38s
epoch 10 | loss: 0.61623 | train_logloss: 0.61388 | valid_logloss: 0.60613 |  0:00:41s
epoch 11 | loss: 0.61168 | train_logloss: 0



In [None]:
# save
oof_pred_df.write_csv(f"pred/oof_pred_tabnet_{feat}.csv")
test_pred_df.write_csv(f"pred/test_pred_tabnet_{feat}.csv")