# TabNet

TabNetは、Sercan Ö. ArıkとTomas Pfisterによって開発された、高性能かつ解釈可能な深層タブラーデータ学習アーキテクチャです。
- Github:https://github.com/dreamquark-ai/tabnet


<a href="https://colab.research.google.com/github/fuyu-quant/data-science-wiki/blob/develop/tabledata/regression/tabnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install pytorch-tabnet

In [6]:
from pytorch_tabnet.tab_model import TabNetRegressor

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.datasets import fetch_california_housing

### データの用意

In [7]:
california = fetch_california_housing()
df = pd.DataFrame(california.data, columns=california.feature_names)
df['target'] = pd.Series(california.target, name='MedHouseVal')
df.head()

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,target
0,8.3252,41.0,6.984127,1.02381,322.0,2.555556,37.88,-122.23,4.526
1,8.3014,21.0,6.238137,0.97188,2401.0,2.109842,37.86,-122.22,3.585
2,7.2574,52.0,8.288136,1.073446,496.0,2.80226,37.85,-122.24,3.521
3,5.6431,52.0,5.817352,1.073059,558.0,2.547945,37.85,-122.25,3.413
4,3.8462,52.0,6.281853,1.081081,565.0,2.181467,37.85,-122.25,3.422


In [30]:
x = df.drop('target', axis = 1)
y = df['target']


X_train,X_valid,y_train,y_valid = train_test_split(x,y,test_size = 0.2, random_state=4)

### TabNetの学習

In [31]:
y_train_reshaped = y_train.values.reshape(-1, 1)

In [33]:
model = TabNetRegressor()
model.fit(
    X_train.values, y_train_reshaped,
    eval_set=[(X_valid.values, y_valid.values.reshape(-1, 1))]
    #eval_metric=['rmse']
)



epoch 0  | loss: 2.71265 | val_0_mse: 211.47251|  0:00:01s
epoch 1  | loss: 0.71385 | val_0_mse: 39.78079|  0:00:02s
epoch 2  | loss: 0.54929 | val_0_mse: 4.27326 |  0:00:04s
epoch 3  | loss: 0.47337 | val_0_mse: 1.3551  |  0:00:05s
epoch 4  | loss: 0.44223 | val_0_mse: 1.34245 |  0:00:07s
epoch 5  | loss: 0.4196  | val_0_mse: 4.27685 |  0:00:08s
epoch 6  | loss: 0.40035 | val_0_mse: 3.30382 |  0:00:10s
epoch 7  | loss: 0.39117 | val_0_mse: 63.64466|  0:00:11s
epoch 8  | loss: 0.37853 | val_0_mse: 117.64524|  0:00:12s
epoch 9  | loss: 0.36998 | val_0_mse: 77.13009|  0:00:14s
epoch 10 | loss: 0.38171 | val_0_mse: 61.39398|  0:00:15s
epoch 11 | loss: 0.35656 | val_0_mse: 47.18377|  0:00:17s
epoch 12 | loss: 0.35908 | val_0_mse: 49.20412|  0:00:18s
epoch 13 | loss: 0.35002 | val_0_mse: 28.94247|  0:00:20s
epoch 14 | loss: 0.36427 | val_0_mse: 38.68645|  0:00:21s

Early stopping occurred at epoch 14 with best_epoch = 4 and best_val_0_mse = 1.34245




### 予測

In [39]:
y_pred = model.predict(X_train.values)

In [35]:
y_pred = y_pred.flatten()

In [None]:
# 平均二乗誤差を計算します
mse = mean_squared_error(y_valid, y_pred)
print(f'Mean Squared Error: {mse}')