In [1]:
import sys
sys.path.append('../')

In [2]:
from tabnet.utils.logger import init_logger
from tabnet.estimator import TabNetRegressor
from sklearn.datasets import load_boston

In [3]:
logger_dir = 'logs'
logger_name = 'TestRegression'
level = 'INFO'

logger = init_logger(logger_dir=logger_dir, logger_name=logger_name, level=level)

In [4]:
X, y = load_boston(return_X_y=True)

print(X.shape)
print(y.shape)

(506, 13)
(506,)


In [5]:
tabnet = TabNetRegressor(
    input_dims=13, output_dims=[1], logger=None, is_cuda=False,
    reprs_dims=4, atten_dims=4, num_steps=3, num_indep=1, num_shared=1
)

In [6]:
tabnet.build(path=None)



TabNetRegressor(atten_dims=4, batch_size=1024, cate_dims=None,
                cate_embed_dims=1, cate_indices=None, criterions=['mse'],
                gamma=1.3, input_dims=13, is_cuda=False, is_shuffle=True,
                logger=<RootLogger root (INFO)>, mask_type='sparsemax',
                momentum=0.03, num_indep=1, num_shared=1, num_steps=3,
                num_workers=4, output_dims=[1], pin_memory=True, reprs_dims=4,
                task_weights=1, virtual_batch_size=128)

In [7]:
from torch.optim import Adam
from torch.optim import lr_scheduler


training_params = {
    'batch_size': 256,
    'max_epochs': 200,
    'metrics': ['mse'],
    'optimizer': Adam,
    'optimizer_params': {'lr': 0.1},
    'schedulers': [lr_scheduler.ExponentialLR],
    'scheduler_params': {'gamma': 0.99}
}


In [8]:
tabnet.fit(X, y.reshape(-1, 1), **training_params)

[2021-02-03 13:49:54,805][INFO][TabNet] start training.
[2021-02-03 13:49:54,807][INFO][TabNet] ******************** epoch : 0 ********************
[2021-02-03 13:49:58,629][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:49:58,636][INFO][TabNet] total_loss : 632.794189453125
[2021-02-03 13:49:58,641][INFO][TabNet] task_loss : 632.7933349609375
[2021-02-03 13:49:58,644][INFO][TabNet] mask_loss : -0.8424485325813293
[2021-02-03 13:49:58,647][INFO][TabNet] time_cost : 0.12095332145690918
[2021-02-03 13:49:58,650][INFO][TabNet] mean_squared_error : 23.378442764282227
[2021-02-03 13:49:58,652][INFO][TabNet] ******************** epoch : 1 ********************
[2021-02-03 13:50:02,344][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:50:02,347][INFO][TabNet] total_loss : 584.9020385742188
[2021-02-03 13:50:02,349][INFO][TabNet] task_loss : 584.9011840820312
[2021-02-03 13:50:02,351][INFO][TabNet] mask_loss : -0.82804435491561

[2021-02-03 13:50:49,216][INFO][TabNet] mean_squared_error : 11.399986267089844
[2021-02-03 13:50:49,218][INFO][TabNet] ******************** epoch : 15 ********************
[2021-02-03 13:50:53,088][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:50:53,091][INFO][TabNet] total_loss : 163.3161163330078
[2021-02-03 13:50:53,093][INFO][TabNet] task_loss : 163.31588745117188
[2021-02-03 13:50:53,094][INFO][TabNet] mask_loss : -0.23289121687412262
[2021-02-03 13:50:53,096][INFO][TabNet] time_cost : 0.029971837997436523
[2021-02-03 13:50:53,097][INFO][TabNet] mean_squared_error : 9.76811408996582
[2021-02-03 13:50:53,098][INFO][TabNet] ******************** epoch : 16 ********************
[2021-02-03 13:50:57,023][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:50:57,025][INFO][TabNet] total_loss : 153.5235595703125
[2021-02-03 13:50:57,027][INFO][TabNet] task_loss : 153.5233154296875
[2021-02-03 13:50:57,029][INFO][TabNet] m

[2021-02-03 13:51:45,372][INFO][TabNet] time_cost : 0.022998332977294922
[2021-02-03 13:51:45,373][INFO][TabNet] mean_squared_error : 5.6942458152771
[2021-02-03 13:51:45,375][INFO][TabNet] ******************** epoch : 30 ********************
[2021-02-03 13:51:49,164][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:51:49,167][INFO][TabNet] total_loss : 56.99422073364258
[2021-02-03 13:51:49,169][INFO][TabNet] task_loss : 56.99397659301758
[2021-02-03 13:51:49,171][INFO][TabNet] mask_loss : -0.24253641068935394
[2021-02-03 13:51:49,173][INFO][TabNet] time_cost : 0.018001556396484375
[2021-02-03 13:51:49,175][INFO][TabNet] mean_squared_error : 5.944016456604004
[2021-02-03 13:51:49,176][INFO][TabNet] ******************** epoch : 31 ********************
[2021-02-03 13:51:52,834][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:51:52,835][INFO][TabNet] total_loss : 76.60246276855469
[2021-02-03 13:51:52,836][INFO][TabNet] t

[2021-02-03 13:52:43,215][INFO][TabNet] mask_loss : -0.24501705169677734
[2021-02-03 13:52:43,217][INFO][TabNet] time_cost : 0.02295231819152832
[2021-02-03 13:52:43,219][INFO][TabNet] mean_squared_error : 4.117396831512451
[2021-02-03 13:52:43,220][INFO][TabNet] ******************** epoch : 45 ********************
[2021-02-03 13:52:48,100][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:52:48,102][INFO][TabNet] total_loss : 31.375768661499023
[2021-02-03 13:52:48,104][INFO][TabNet] task_loss : 31.375524520874023
[2021-02-03 13:52:48,107][INFO][TabNet] mask_loss : -0.2441369742155075
[2021-02-03 13:52:48,109][INFO][TabNet] time_cost : 0.019999980926513672
[2021-02-03 13:52:48,113][INFO][TabNet] mean_squared_error : 4.04672384262085
[2021-02-03 13:52:48,118][INFO][TabNet] ******************** epoch : 46 ********************
[2021-02-03 13:52:54,454][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:52:54,456][INFO][TabNet

[2021-02-03 13:53:46,277][INFO][TabNet] task_loss : 21.520273208618164
[2021-02-03 13:53:46,278][INFO][TabNet] mask_loss : -0.21806275844573975
[2021-02-03 13:53:46,280][INFO][TabNet] time_cost : 0.02799534797668457
[2021-02-03 13:53:46,281][INFO][TabNet] mean_squared_error : 3.3651249408721924
[2021-02-03 13:53:46,282][INFO][TabNet] ******************** epoch : 60 ********************
[2021-02-03 13:53:50,021][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:53:50,023][INFO][TabNet] total_loss : 23.88224220275879
[2021-02-03 13:53:50,026][INFO][TabNet] task_loss : 23.882028579711914
[2021-02-03 13:53:50,027][INFO][TabNet] mask_loss : -0.21427935361862183
[2021-02-03 13:53:50,029][INFO][TabNet] time_cost : 0.026004791259765625
[2021-02-03 13:53:50,031][INFO][TabNet] mean_squared_error : 3.6638693809509277
[2021-02-03 13:53:50,032][INFO][TabNet] ******************** epoch : 61 ********************
[2021-02-03 13:53:53,793][INFO][TabNet] -----------------

[2021-02-03 13:54:43,718][INFO][TabNet] total_loss : 17.445505142211914
[2021-02-03 13:54:43,719][INFO][TabNet] task_loss : 17.445255279541016
[2021-02-03 13:54:43,721][INFO][TabNet] mask_loss : -0.24970334768295288
[2021-02-03 13:54:43,723][INFO][TabNet] time_cost : 0.024960041046142578
[2021-02-03 13:54:43,724][INFO][TabNet] mean_squared_error : 3.0862293243408203
[2021-02-03 13:54:43,726][INFO][TabNet] ******************** epoch : 75 ********************
[2021-02-03 13:54:47,455][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:54:47,456][INFO][TabNet] total_loss : 14.411949157714844
[2021-02-03 13:54:47,458][INFO][TabNet] task_loss : 14.411722183227539
[2021-02-03 13:54:47,460][INFO][TabNet] mask_loss : -0.22668789327144623
[2021-02-03 13:54:47,462][INFO][TabNet] time_cost : 0.027993202209472656
[2021-02-03 13:54:47,464][INFO][TabNet] mean_squared_error : 2.8484411239624023
[2021-02-03 13:54:47,467][INFO][TabNet] ******************** epoch : 76 ****

[2021-02-03 13:55:43,133][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:55:43,137][INFO][TabNet] total_loss : 16.905546188354492
[2021-02-03 13:55:43,139][INFO][TabNet] task_loss : 16.90530014038086
[2021-02-03 13:55:43,141][INFO][TabNet] mask_loss : -0.24674147367477417
[2021-02-03 13:55:43,142][INFO][TabNet] time_cost : 0.021999835968017578
[2021-02-03 13:55:43,144][INFO][TabNet] mean_squared_error : 2.9090943336486816
[2021-02-03 13:55:43,146][INFO][TabNet] ******************** epoch : 90 ********************
[2021-02-03 13:55:46,880][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:55:46,882][INFO][TabNet] total_loss : 17.463722229003906
[2021-02-03 13:55:46,884][INFO][TabNet] task_loss : 17.46347427368164
[2021-02-03 13:55:46,886][INFO][TabNet] mask_loss : -0.2481185644865036
[2021-02-03 13:55:46,888][INFO][TabNet] time_cost : 0.02499556541442871
[2021-02-03 13:55:46,890][INFO][TabNet] mean_squared_error : 2.9528

[2021-02-03 13:56:35,691][INFO][TabNet] ******************** epoch : 104 ********************
[2021-02-03 13:56:39,400][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:56:39,402][INFO][TabNet] total_loss : 15.460806846618652
[2021-02-03 13:56:39,404][INFO][TabNet] task_loss : 15.460545539855957
[2021-02-03 13:56:39,406][INFO][TabNet] mask_loss : -0.2609458565711975
[2021-02-03 13:56:39,407][INFO][TabNet] time_cost : 0.02095317840576172
[2021-02-03 13:56:39,409][INFO][TabNet] mean_squared_error : 2.8335843086242676
[2021-02-03 13:56:39,411][INFO][TabNet] ******************** epoch : 105 ********************
[2021-02-03 13:56:43,104][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:56:43,107][INFO][TabNet] total_loss : 16.49391746520996
[2021-02-03 13:56:43,109][INFO][TabNet] task_loss : 16.49365997314453
[2021-02-03 13:56:43,111][INFO][TabNet] mask_loss : -0.25751611590385437
[2021-02-03 13:56:43,113][INFO][TabNet] time_

[2021-02-03 13:57:32,875][INFO][TabNet] mean_squared_error : 2.931251287460327
[2021-02-03 13:57:32,877][INFO][TabNet] ******************** epoch : 119 ********************
[2021-02-03 13:57:36,620][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:57:36,622][INFO][TabNet] total_loss : 14.832707405090332
[2021-02-03 13:57:36,624][INFO][TabNet] task_loss : 14.832435607910156
[2021-02-03 13:57:36,626][INFO][TabNet] mask_loss : -0.27163681387901306
[2021-02-03 13:57:36,627][INFO][TabNet] time_cost : 0.021959781646728516
[2021-02-03 13:57:36,629][INFO][TabNet] mean_squared_error : 2.8603992462158203
[2021-02-03 13:57:36,631][INFO][TabNet] ******************** epoch : 120 ********************
[2021-02-03 13:57:40,312][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:57:40,314][INFO][TabNet] total_loss : 15.651037216186523
[2021-02-03 13:57:40,316][INFO][TabNet] task_loss : 15.650775909423828
[2021-02-03 13:57:40,318][INFO][Tab

[2021-02-03 13:58:30,315][INFO][TabNet] time_cost : 0.02500176429748535
[2021-02-03 13:58:30,317][INFO][TabNet] mean_squared_error : 2.664052963256836
[2021-02-03 13:58:30,319][INFO][TabNet] ******************** epoch : 134 ********************
[2021-02-03 13:58:34,202][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:58:34,205][INFO][TabNet] total_loss : 22.309329986572266
[2021-02-03 13:58:34,210][INFO][TabNet] task_loss : 22.30907440185547
[2021-02-03 13:58:34,213][INFO][TabNet] mask_loss : -0.25511041283607483
[2021-02-03 13:58:34,216][INFO][TabNet] time_cost : 0.02399277687072754
[2021-02-03 13:58:34,218][INFO][TabNet] mean_squared_error : 3.3356738090515137
[2021-02-03 13:58:34,221][INFO][TabNet] ******************** epoch : 135 ********************
[2021-02-03 13:58:38,424][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:58:38,425][INFO][TabNet] total_loss : 17.701335906982422
[2021-02-03 13:58:38,427][INFO][TabN

[2021-02-03 13:59:28,535][INFO][TabNet] mask_loss : -0.2629603445529938
[2021-02-03 13:59:28,537][INFO][TabNet] time_cost : 0.029003143310546875
[2021-02-03 13:59:28,538][INFO][TabNet] mean_squared_error : 2.769136905670166
[2021-02-03 13:59:28,540][INFO][TabNet] ******************** epoch : 149 ********************
[2021-02-03 13:59:32,611][INFO][TabNet] -------------------- train info --------------------
[2021-02-03 13:59:32,613][INFO][TabNet] total_loss : 15.585423469543457
[2021-02-03 13:59:32,615][INFO][TabNet] task_loss : 15.58518123626709
[2021-02-03 13:59:32,617][INFO][TabNet] mask_loss : -0.24261583387851715
[2021-02-03 13:59:32,621][INFO][TabNet] time_cost : 0.023003816604614258
[2021-02-03 13:59:32,622][INFO][TabNet] mean_squared_error : 2.848806142807007
[2021-02-03 13:59:32,624][INFO][TabNet] ******************** epoch : 150 ********************


KeyboardInterrupt: 

In [None]:
importance, masks = tabnet.explain(X)

In [None]:
import matplotlib.pyplot as plt

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig, axs = plt.subplots(4, 1, figsize=(5,15))

for i in range(3):
    axs[i].imshow(masks[i].cpu().numpy()[:20])
#     axs[i].set_xlabel('features')
    axs[i].set_ylabel('samples')
    axs[i].set_title(f"mask {i}")
    axs[i].set_yticks(range(20))
#     axs[i].set_xticks(range(30))
#     axs[i].set_xticklabels(feature_names, rotation=90)

axs[3].imshow(importance.cpu().numpy()[:20, :])
axs[3].set_xlabel('features')
axs[3].set_ylabel('samples')
axs[3].set_title('importance')
axs[3].set_yticks(range(20))
axs[3].set_xticks(range(13))
axs[3].set_xticklabels(feature_names, rotation=90)
plt.show()
axs[3].set_title('importance')