In [1]:
import numpy as np 
import pandas as pd 

from src.utils import mkdir
from src.logger import init_logger
from src.builtin.estimators import TabNetClassifier
from sklearn.datasets import load_breast_cancer

## init logger

In [2]:
logger_dir = 'logs'
logger_name = 'TEST'
level = 'INFO'

logger = init_logger(logger_dir=logger_dir, logger_name=logger_name, level=level)

## load data

In [3]:
X, y = load_breast_cancer(return_X_y=True)

print(X.shape)
print(y.shape)

(569, 30)
(569,)


## Build Model

In [4]:
tabnet = TabNetClassifier(
    input_dims=30, output_dims=[1], logger=logger, is_cuda=True,
    reprs_dims=4, atten_dims=4, num_steps=3, num_indep=1, num_shared=1
)

In [5]:
tabnet.build(path=None)



TabNetClassifier(atten_dims=4, input_dims=30, is_cuda=True,
                 logger=<RootLogger root (INFO)>, num_indep=1, num_shared=1,
                 output_dims=[1], reprs_dims=4)

In [7]:
from torch.optim import Adam
from torch.optim import lr_scheduler

training_params = {
    'batch_size': 512,
    'max_epochs': 200,
    'metrics': ['mse'],
    'optimizer': Adam,
    'optimizer_params': {'lr': 0.1},
    'schedulers': [lr_scheduler.ExponentialLR],
    'scheduler_params': {'gamma': 0.99}
}


In [8]:
tabnet.fit(X, y.reshape(-1, 1), **training_params)

[2021-01-27 22:47:41,163][INFO][TabNet] start training.
[2021-01-27 22:47:41,164][INFO][TabNet] ******************** epoch : 1 ********************
[2021-01-27 22:47:44,642][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:47:44,644][INFO][TabNet] total_loss : 1.3006789684295654
[2021-01-27 22:47:44,644][INFO][TabNet] task_loss : 1.2993987798690796
[2021-01-27 22:47:44,644][INFO][TabNet] mask_loss : -1.2801618576049805
[2021-01-27 22:47:44,645][INFO][TabNet] time_cost : 0.8876254558563232
[2021-01-27 22:47:44,645][INFO][TabNet] mean_squared_error : 0.630859375
[2021-01-27 22:47:44,646][INFO][TabNet] ******************** epoch : 2 ********************
[2021-01-27 22:47:47,319][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:47:47,320][INFO][TabNet] total_loss : 0.7551138401031494
[2021-01-27 22:47:47,320][INFO][TabNet] task_loss : 0.7539939880371094
[2021-01-27 22:47:47,321][INFO][TabNet] mask_loss : -1.1198581457138062


[2021-01-27 22:48:21,254][INFO][TabNet] ******************** epoch : 16 ********************
[2021-01-27 22:48:23,891][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:48:23,891][INFO][TabNet] total_loss : 0.15169787406921387
[2021-01-27 22:48:23,892][INFO][TabNet] task_loss : 0.15111666917800903
[2021-01-27 22:48:23,892][INFO][TabNet] mask_loss : -0.5812097787857056
[2021-01-27 22:48:23,893][INFO][TabNet] time_cost : 0.040917396545410156
[2021-01-27 22:48:23,893][INFO][TabNet] mean_squared_error : 0.048828125
[2021-01-27 22:48:23,894][INFO][TabNet] ******************** epoch : 17 ********************
[2021-01-27 22:48:26,593][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:48:26,594][INFO][TabNet] total_loss : 0.15201318264007568
[2021-01-27 22:48:26,594][INFO][TabNet] task_loss : 0.15143974125385284
[2021-01-27 22:48:26,595][INFO][TabNet] mask_loss : -0.5734440088272095
[2021-01-27 22:48:26,595][INFO][TabNet] time_cos

[2021-01-27 22:49:00,824][INFO][TabNet] ******************** epoch : 31 ********************
[2021-01-27 22:49:03,510][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:49:03,510][INFO][TabNet] total_loss : 0.1039140447974205
[2021-01-27 22:49:03,511][INFO][TabNet] task_loss : 0.10339317470788956
[2021-01-27 22:49:03,512][INFO][TabNet] mask_loss : -0.5208722949028015
[2021-01-27 22:49:03,512][INFO][TabNet] time_cost : 0.053827762603759766
[2021-01-27 22:49:03,513][INFO][TabNet] mean_squared_error : 0.037109375
[2021-01-27 22:49:03,514][INFO][TabNet] ******************** epoch : 32 ********************
[2021-01-27 22:49:06,187][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:49:06,188][INFO][TabNet] total_loss : 0.11278778314590454
[2021-01-27 22:49:06,189][INFO][TabNet] task_loss : 0.1122625395655632
[2021-01-27 22:49:06,190][INFO][TabNet] mask_loss : -0.525246262550354
[2021-01-27 22:49:06,191][INFO][TabNet] time_cost :

[2021-01-27 22:49:40,776][INFO][TabNet] ******************** epoch : 46 ********************
[2021-01-27 22:49:43,468][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:49:43,469][INFO][TabNet] total_loss : 0.09190632402896881
[2021-01-27 22:49:43,469][INFO][TabNet] task_loss : 0.09137611836194992
[2021-01-27 22:49:43,470][INFO][TabNet] mask_loss : -0.5302072763442993
[2021-01-27 22:49:43,470][INFO][TabNet] time_cost : 0.059839725494384766
[2021-01-27 22:49:43,471][INFO][TabNet] mean_squared_error : 0.033203125
[2021-01-27 22:49:43,472][INFO][TabNet] ******************** epoch : 47 ********************
[2021-01-27 22:49:46,145][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:49:46,146][INFO][TabNet] total_loss : 0.08857086300849915
[2021-01-27 22:49:46,147][INFO][TabNet] task_loss : 0.08803239464759827
[2021-01-27 22:49:46,147][INFO][TabNet] mask_loss : -0.5384663939476013
[2021-01-27 22:49:46,148][INFO][TabNet] time_cos

[2021-01-27 22:50:21,247][INFO][TabNet] ******************** epoch : 61 ********************
[2021-01-27 22:50:23,921][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:50:23,922][INFO][TabNet] total_loss : 0.06693506985902786
[2021-01-27 22:50:23,923][INFO][TabNet] task_loss : 0.06637364625930786
[2021-01-27 22:50:23,923][INFO][TabNet] mask_loss : -0.561427116394043
[2021-01-27 22:50:23,925][INFO][TabNet] time_cost : 0.05687594413757324
[2021-01-27 22:50:23,926][INFO][TabNet] mean_squared_error : 0.021484375
[2021-01-27 22:50:23,927][INFO][TabNet] ******************** epoch : 62 ********************
[2021-01-27 22:50:26,588][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:50:26,589][INFO][TabNet] total_loss : 0.08769950270652771
[2021-01-27 22:50:26,590][INFO][TabNet] task_loss : 0.08714921772480011
[2021-01-27 22:50:26,590][INFO][TabNet] mask_loss : -0.5502815842628479
[2021-01-27 22:50:26,591][INFO][TabNet] time_cost 

[2021-01-27 22:51:01,643][INFO][TabNet] ******************** epoch : 76 ********************
[2021-01-27 22:51:04,334][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:51:04,336][INFO][TabNet] total_loss : 0.08728860318660736
[2021-01-27 22:51:04,337][INFO][TabNet] task_loss : 0.08676479756832123
[2021-01-27 22:51:04,337][INFO][TabNet] mask_loss : -0.5238022804260254
[2021-01-27 22:51:04,338][INFO][TabNet] time_cost : 0.060835838317871094
[2021-01-27 22:51:04,339][INFO][TabNet] mean_squared_error : 0.03125
[2021-01-27 22:51:04,340][INFO][TabNet] ******************** epoch : 77 ********************
[2021-01-27 22:51:07,008][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:51:07,009][INFO][TabNet] total_loss : 0.05764273181557655
[2021-01-27 22:51:07,009][INFO][TabNet] task_loss : 0.057118140161037445
[2021-01-27 22:51:07,010][INFO][TabNet] mask_loss : -0.5245932936668396
[2021-01-27 22:51:07,010][INFO][TabNet] time_cost :

[2021-01-27 22:51:41,987][INFO][TabNet] ******************** epoch : 91 ********************
[2021-01-27 22:51:44,665][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:51:44,666][INFO][TabNet] total_loss : 0.04498175531625748
[2021-01-27 22:51:44,667][INFO][TabNet] task_loss : 0.04443535581231117
[2021-01-27 22:51:44,668][INFO][TabNet] mask_loss : -0.5464009046554565
[2021-01-27 22:51:44,668][INFO][TabNet] time_cost : 0.057845115661621094
[2021-01-27 22:51:44,669][INFO][TabNet] mean_squared_error : 0.021484375
[2021-01-27 22:51:44,670][INFO][TabNet] ******************** epoch : 92 ********************
[2021-01-27 22:51:47,376][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:51:47,377][INFO][TabNet] total_loss : 0.05940573662519455
[2021-01-27 22:51:47,378][INFO][TabNet] task_loss : 0.05885749310255051
[2021-01-27 22:51:47,379][INFO][TabNet] mask_loss : -0.5482423305511475
[2021-01-27 22:51:47,379][INFO][TabNet] time_cos

[2021-01-27 22:52:22,215][INFO][TabNet] ******************** epoch : 106 ********************
[2021-01-27 22:52:24,898][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:52:24,899][INFO][TabNet] total_loss : 0.045789968222379684
[2021-01-27 22:52:24,900][INFO][TabNet] task_loss : 0.04525812715291977
[2021-01-27 22:52:24,900][INFO][TabNet] mask_loss : -0.5318412184715271
[2021-01-27 22:52:24,901][INFO][TabNet] time_cost : 0.057845354080200195
[2021-01-27 22:52:24,901][INFO][TabNet] mean_squared_error : 0.017578125
[2021-01-27 22:52:24,902][INFO][TabNet] ******************** epoch : 107 ********************
[2021-01-27 22:52:27,575][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:52:27,576][INFO][TabNet] total_loss : 0.03795431926846504
[2021-01-27 22:52:27,576][INFO][TabNet] task_loss : 0.0374356284737587
[2021-01-27 22:52:27,577][INFO][TabNet] mask_loss : -0.5186896920204163
[2021-01-27 22:52:27,578][INFO][TabNet] time_c

[2021-01-27 22:53:02,829][INFO][TabNet] ******************** epoch : 121 ********************
[2021-01-27 22:53:05,506][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:53:05,507][INFO][TabNet] total_loss : 0.05057975649833679
[2021-01-27 22:53:05,508][INFO][TabNet] task_loss : 0.05007443577051163
[2021-01-27 22:53:05,509][INFO][TabNet] mask_loss : -0.5053219199180603
[2021-01-27 22:53:05,510][INFO][TabNet] time_cost : 0.047843217849731445
[2021-01-27 22:53:05,511][INFO][TabNet] mean_squared_error : 0.017578125
[2021-01-27 22:53:05,512][INFO][TabNet] ******************** epoch : 122 ********************
[2021-01-27 22:53:08,212][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:53:08,213][INFO][TabNet] total_loss : 0.044274359941482544
[2021-01-27 22:53:08,213][INFO][TabNet] task_loss : 0.043770089745521545
[2021-01-27 22:53:08,214][INFO][TabNet] mask_loss : -0.5042701959609985
[2021-01-27 22:53:08,214][INFO][TabNet] time

[2021-01-27 22:53:43,158][INFO][TabNet] ******************** epoch : 136 ********************
[2021-01-27 22:53:45,868][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:53:45,869][INFO][TabNet] total_loss : 0.061623286455869675
[2021-01-27 22:53:45,870][INFO][TabNet] task_loss : 0.0611075721681118
[2021-01-27 22:53:45,870][INFO][TabNet] mask_loss : -0.5157128572463989
[2021-01-27 22:53:45,871][INFO][TabNet] time_cost : 0.04590630531311035
[2021-01-27 22:53:45,872][INFO][TabNet] mean_squared_error : 0.025390625
[2021-01-27 22:53:45,873][INFO][TabNet] ******************** epoch : 137 ********************
[2021-01-27 22:53:48,582][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:53:48,583][INFO][TabNet] total_loss : 0.043721869587898254
[2021-01-27 22:53:48,584][INFO][TabNet] task_loss : 0.04320311173796654
[2021-01-27 22:53:48,584][INFO][TabNet] mask_loss : -0.5187592506408691
[2021-01-27 22:53:48,585][INFO][TabNet] time_c

[2021-01-27 22:54:23,289][INFO][TabNet] ******************** epoch : 151 ********************
[2021-01-27 22:54:25,987][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:54:25,988][INFO][TabNet] total_loss : 0.039077721536159515
[2021-01-27 22:54:25,988][INFO][TabNet] task_loss : 0.038585640490055084
[2021-01-27 22:54:25,989][INFO][TabNet] mask_loss : -0.4920797348022461
[2021-01-27 22:54:25,990][INFO][TabNet] time_cost : 0.04986572265625
[2021-01-27 22:54:25,990][INFO][TabNet] mean_squared_error : 0.021484375
[2021-01-27 22:54:25,991][INFO][TabNet] ******************** epoch : 152 ********************
[2021-01-27 22:54:28,720][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:54:28,721][INFO][TabNet] total_loss : 0.04102548211812973
[2021-01-27 22:54:28,722][INFO][TabNet] task_loss : 0.04053111746907234
[2021-01-27 22:54:28,723][INFO][TabNet] mask_loss : -0.494363397359848
[2021-01-27 22:54:28,723][INFO][TabNet] time_cost

[2021-01-27 22:55:03,794][INFO][TabNet] ******************** epoch : 166 ********************
[2021-01-27 22:55:06,473][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:55:06,474][INFO][TabNet] total_loss : 0.03189460188150406
[2021-01-27 22:55:06,475][INFO][TabNet] task_loss : 0.03137994557619095
[2021-01-27 22:55:06,476][INFO][TabNet] mask_loss : -0.5146576166152954
[2021-01-27 22:55:06,476][INFO][TabNet] time_cost : 0.05482602119445801
[2021-01-27 22:55:06,477][INFO][TabNet] mean_squared_error : 0.013671875
[2021-01-27 22:55:06,478][INFO][TabNet] ******************** epoch : 167 ********************
[2021-01-27 22:55:09,134][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:55:09,135][INFO][TabNet] total_loss : 0.033883415162563324
[2021-01-27 22:55:09,136][INFO][TabNet] task_loss : 0.03337927162647247
[2021-01-27 22:55:09,136][INFO][TabNet] mask_loss : -0.5041443705558777
[2021-01-27 22:55:09,137][INFO][TabNet] time_c

[2021-01-27 22:55:43,945][INFO][TabNet] ******************** epoch : 181 ********************
[2021-01-27 22:55:46,629][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:55:46,631][INFO][TabNet] total_loss : 0.04940275847911835
[2021-01-27 22:55:46,632][INFO][TabNet] task_loss : 0.048902228474617004
[2021-01-27 22:55:46,632][INFO][TabNet] mask_loss : -0.5005303025245667
[2021-01-27 22:55:46,633][INFO][TabNet] time_cost : 0.050863027572631836
[2021-01-27 22:55:46,634][INFO][TabNet] mean_squared_error : 0.01953125
[2021-01-27 22:55:46,635][INFO][TabNet] ******************** epoch : 182 ********************
[2021-01-27 22:55:49,357][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:55:49,358][INFO][TabNet] total_loss : 0.03571540117263794
[2021-01-27 22:55:49,359][INFO][TabNet] task_loss : 0.03520404174923897
[2021-01-27 22:55:49,359][INFO][TabNet] mask_loss : -0.5113601684570312
[2021-01-27 22:55:49,360][INFO][TabNet] time_c

[2021-01-27 22:56:24,508][INFO][TabNet] ******************** epoch : 196 ********************
[2021-01-27 22:56:27,191][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:56:27,192][INFO][TabNet] total_loss : 0.04163944348692894
[2021-01-27 22:56:27,193][INFO][TabNet] task_loss : 0.04112957790493965
[2021-01-27 22:56:27,194][INFO][TabNet] mask_loss : -0.5098642110824585
[2021-01-27 22:56:27,195][INFO][TabNet] time_cost : 0.05186128616333008
[2021-01-27 22:56:27,196][INFO][TabNet] mean_squared_error : 0.01953125
[2021-01-27 22:56:27,197][INFO][TabNet] ******************** epoch : 197 ********************
[2021-01-27 22:56:29,847][INFO][TabNet] -------------------- train info --------------------
[2021-01-27 22:56:29,848][INFO][TabNet] total_loss : 0.03986820578575134
[2021-01-27 22:56:29,848][INFO][TabNet] task_loss : 0.03936133161187172
[2021-01-27 22:56:29,849][INFO][TabNet] mask_loss : -0.5068746209144592
[2021-01-27 22:56:29,850][INFO][TabNet] time_cos

In [9]:
m_explain, masks = tabnet.explain(X)

In [10]:
import matplotlib.pyplot as plt

In [None]:
fig, axs = plt.subplots(4, 1, figsize=(15,15))

for i in range(3):
    axs[i].imshow(masks[i].cpu().numpy()[:30])
    axs[i].set_xlabel('features')
    axs[i].set_ylabel('samples')
    axs[i].set_title(f"mask {i}")

axs[3].imshow(m_explain.cpu().numpy()[:30, :])
axs[3].set_xlabel('features')
axs[3].set_ylabel('samples')
axs[3].set_title('importance')

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(25,25))

for i in range(3):
    axs[i].imshow(masks[i].cpu().numpy()[:20])
    axs[i].set_xlabel('features')
    axs[i].set_ylabel('samples')
    axs[i].set_title(f"mask {i}")
    axs[i].set_yticks(range(20))
    axs[i].set_xticks(range(30))
    axs[i].set_xticklabels(feature_names, rotation=90)

axs[3].imshow(importance.cpu().numpy()[:20, :])
axs[3].set_xlabel('features')
axs[3].set_ylabel('samples')
axs[3].set_title('importance')
axs[3].set_yticks(range(20))
axs[3].set_xticks(range(30))
axs[3].set_xticklabels(feature_names, rotation=90)
plt.show()

In [None]:
from scipy.sparse import csc_matrix
import scipy

In [None]:
create_explain_matrix(30, [], [], 30)

In [None]:
def create_explain_matrix(input_dim, cat_emb_dim, cat_idxs, post_embed_dim):
    """
    This is a computational trick.
    In order to rapidly sum importances from same embeddings
    to the initial index.
    Parameters
    ----------
    input_dim : int
        Initial input dim
    cat_emb_dim : int or list of int
        if int : size of embedding for all categorical feature
        if list of int : size of embedding for each categorical feature
    cat_idxs : list of int
        Initial position of categorical features
    post_embed_dim : int
        Post embedding inputs dimension
    Returns
    -------
    reducing_matrix : np.array
        Matrix of dim (post_embed_dim, input_dim)  to performe reduce
    """

    if isinstance(cat_emb_dim, int):
        all_emb_impact = [cat_emb_dim - 1] * len(cat_idxs)
    else:
        all_emb_impact = [emb_dim - 1 for emb_dim in cat_emb_dim]

    acc_emb = 0
    nb_emb = 0
    indices_trick = []
    for i in range(input_dim):
        if i not in cat_idxs:
            indices_trick.append([i + acc_emb])
        else:
            indices_trick.append(
                range(i + acc_emb, i + acc_emb + all_emb_impact[nb_emb] + 1)
            )
            acc_emb += all_emb_impact[nb_emb]
            nb_emb += 1

    reducing_matrix = np.zeros((post_embed_dim, input_dim))
    for i, cols in enumerate(indices_trick):
        reducing_matrix[cols, i] = 1

    return scipy.sparse.csc_matrix(reducing_matrix)

In [None]:
for key, value in masks.items():
                masks[key] = csc_matrix.dot(
                    value.cpu().detach().numpy(), self.reducing_matrix
                )

            res_explain.append(
                csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
            )


In [None]:
tabnet.predict(X)[0].reshape(-1)

In [None]:
y

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score 