# Build Customized Estimator for Classification Task

In [1]:
import sys
sys.path.append('../')

In [2]:
import torch
from tabnet.utils.logger import init_logger
from tabnet.estimator import CustomizedEstimator
from sklearn.datasets import load_breast_cancer

In [3]:
logger_dir = 'logs'
logger_name = 'Customized Estimator DEMO'
level = 'INFO'

logger = init_logger(logger_dir=logger_dir, logger_name=logger_name, level=level)

## Step 1
### Init a `CustomizedEstimator` object

In [4]:
my_model = CustomizedEstimator(
    input_dims=30, output_dims=[1], logger=logger, is_cuda=True
)

## Step 2
### Defined your customized `loss function` and `post-processor`

***Import base classes***

In [5]:
from tabnet.base import Loss, BaseTabNet, BasePostProcessor

***Define customized loss function***

In [6]:
class MyLoss(Loss):
    def __init__(self):
        super(MyLoss, self).__init__()
        self._loss_fn = torch.nn.BCEWithLogitsLoss()
        
    def score_func(self, preds, targets):
        targets = targets.to(preds.device)
        return self._loss_fn(preds, targets)

***Define customized post-processor***

In [7]:
class MyPostProcessor(BasePostProcessor):
    def __init__(self, is_cuda):            
        super(MyPostProcessor, self).__init__(
            num_tasks=1, is_cuda=is_cuda, output_dims=[1]
        )
        
    def _build(self, num_tasks, output_dims):
        self._processors.append(torch.nn.Sigmoid())
        return 
    
    def forward(self, x, is_return_proba=False):
        proba = self._processors[0](x[0])
        label = (proba > 0.5) * 1
        
        if is_return_proba:
            return [label], [proba]
        else:
            return [label]

## Step 3
### Register customized `objects`.

In [8]:
my_model.register_loss(MyLoss()).register_postprocessor(MyPostProcessor(is_cuda=True))

CustomizedEstimator(input_dims=30, output_dims=[1])

# Step 4
### Train your own model !

***Load data***

In [9]:
X, y = load_breast_cancer(return_X_y=True)

print(X.shape)
print(y.shape)

(569, 30)
(569,)


***Build network architecture***

In [10]:
my_model.build(path=None)



CustomizedEstimator(input_dims=30, output_dims=[1])

In [11]:
my_model.show_model()

[2021-02-23 21:48:37,368][INFO][TabNet] Show model architecture.


InferenceModel(
  (embedding_encoder): EmbeddingEncoder()
  (tabnet_encoder): TabNetEncoder(
    (input_bn): BatchNorm1d(30, eps=1e-05, momentum=0.03, affine=True, track_running_stats=True)
    (input_splitter): FeatureTransformer(
      (shared_block): FeatureBlock(
        (shared_layers): ModuleList(
          (0): Linear(in_features=30, out_features=32, bias=False)
          (1): Linear(in_features=16, out_features=32, bias=False)
        )
        (glu_blocks): ModuleList(
          (0): GLUBlock(
            (fc): Linear(in_features=30, out_features=32, bias=False)
            (gbn): GhostBatchNorm(
              (bn): BatchNorm1d(32, eps=1e-05, momentum=0.03, affine=True, track_running_stats=True)
            )
          )
          (1): GLUBlock(
            (fc): Linear(in_features=16, out_features=32, bias=False)
            (gbn): GhostBatchNorm(
              (bn): BatchNorm1d(32, eps=1e-05, momentum=0.03, affine=True, track_running_stats=True)
            )
          )
   

***Start training !!***

In [12]:
from torch.optim import Adam
from torch.optim import lr_scheduler

training_params = {
    'batch_size': 512,
    'max_epochs': 100,
    'optimizer': Adam,
    'optimizer_params': {'lr': 0.1},
    'schedulers': [lr_scheduler.ExponentialLR],
    'scheduler_params': {'gamma': 0.99}
}

In [13]:
my_model.fit(X, y.reshape(-1, 1), **training_params)

[2021-02-23 21:48:37,401][INFO][TabNet] Convert to inference model.
[2021-02-23 21:48:37,403][INFO][TabNet] start training.
[2021-02-23 21:48:37,404][INFO][TabNet] ******************** epoch : 0 ********************
[2021-02-23 21:48:42,238][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:48:42,239][INFO][TabNet] total_loss : 0.8457855582237244
[2021-02-23 21:48:42,239][INFO][TabNet] task_loss : 0.844632625579834
[2021-02-23 21:48:42,240][INFO][TabNet] mask_loss : -1.152915596961975
[2021-02-23 21:48:42,241][INFO][TabNet] time_cost : 0.945627
[2021-02-23 21:48:42,241][INFO][TabNet] ******************** epoch : 1 ********************
[2021-02-23 21:48:46,175][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:48:46,176][INFO][TabNet] total_loss : 0.62274169921875
[2021-02-23 21:48:46,177][INFO][TabNet] task_loss : 0.6216962337493896
[2021-02-23 21:48:46,178][INFO][TabNet] mask_loss : -1.045465111732483
[2021-02-23 21:48:46

[2021-02-23 21:49:49,172][INFO][TabNet] task_loss : 0.11796693503856659
[2021-02-23 21:49:49,172][INFO][TabNet] mask_loss : -0.5092649459838867
[2021-02-23 21:49:49,173][INFO][TabNet] time_cost : 0.053834
[2021-02-23 21:49:49,174][INFO][TabNet] ******************** epoch : 18 ********************
[2021-02-23 21:49:53,108][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:49:53,109][INFO][TabNet] total_loss : 0.1167355552315712
[2021-02-23 21:49:53,110][INFO][TabNet] task_loss : 0.11623270064592361
[2021-02-23 21:49:53,111][INFO][TabNet] mask_loss : -0.5028560161590576
[2021-02-23 21:49:53,111][INFO][TabNet] time_cost : 0.055845
[2021-02-23 21:49:53,112][INFO][TabNet] ******************** epoch : 19 ********************
[2021-02-23 21:49:57,034][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:49:57,035][INFO][TabNet] total_loss : 0.11471997201442719
[2021-02-23 21:49:57,036][INFO][TabNet] task_loss : 0.11424538493156433
[

[2021-02-23 21:51:00,078][INFO][TabNet] total_loss : 0.0869387537240982
[2021-02-23 21:51:00,079][INFO][TabNet] task_loss : 0.0866905003786087
[2021-02-23 21:51:00,080][INFO][TabNet] mask_loss : -0.24825194478034973
[2021-02-23 21:51:00,080][INFO][TabNet] time_cost : 0.053805
[2021-02-23 21:51:00,081][INFO][TabNet] ******************** epoch : 36 ********************
[2021-02-23 21:51:04,013][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:51:04,014][INFO][TabNet] total_loss : 0.0804620161652565
[2021-02-23 21:51:04,015][INFO][TabNet] task_loss : 0.08019043505191803
[2021-02-23 21:51:04,016][INFO][TabNet] mask_loss : -0.27158308029174805
[2021-02-23 21:51:04,016][INFO][TabNet] time_cost : 0.058864
[2021-02-23 21:51:04,018][INFO][TabNet] ******************** epoch : 37 ********************
[2021-02-23 21:51:07,951][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:51:07,952][INFO][TabNet] total_loss : 0.0815630778670311
[

[2021-02-23 21:52:10,970][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:52:10,971][INFO][TabNet] total_loss : 0.06692496687173843
[2021-02-23 21:52:10,972][INFO][TabNet] task_loss : 0.06660980731248856
[2021-02-23 21:52:10,973][INFO][TabNet] mask_loss : -0.3151610493659973
[2021-02-23 21:52:10,974][INFO][TabNet] time_cost : 0.060703
[2021-02-23 21:52:10,975][INFO][TabNet] ******************** epoch : 54 ********************
[2021-02-23 21:52:14,907][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:52:14,908][INFO][TabNet] total_loss : 0.06917443871498108
[2021-02-23 21:52:14,909][INFO][TabNet] task_loss : 0.06884706020355225
[2021-02-23 21:52:14,909][INFO][TabNet] mask_loss : -0.32737916707992554
[2021-02-23 21:52:14,910][INFO][TabNet] time_cost : 0.053853
[2021-02-23 21:52:14,910][INFO][TabNet] ******************** epoch : 55 ********************
[2021-02-23 21:52:18,826][INFO][TabNet] -------------------- train info

[2021-02-23 21:53:17,751][INFO][TabNet] time_cost : 0.053881
[2021-02-23 21:53:17,752][INFO][TabNet] ******************** epoch : 71 ********************
[2021-02-23 21:53:21,668][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:53:21,669][INFO][TabNet] total_loss : 0.05028965696692467
[2021-02-23 21:53:21,669][INFO][TabNet] task_loss : 0.050007425248622894
[2021-02-23 21:53:21,670][INFO][TabNet] mask_loss : -0.2822331190109253
[2021-02-23 21:53:21,670][INFO][TabNet] time_cost : 0.053858
[2021-02-23 21:53:21,671][INFO][TabNet] ******************** epoch : 72 ********************
[2021-02-23 21:53:25,602][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:53:25,603][INFO][TabNet] total_loss : 0.04899831861257553
[2021-02-23 21:53:25,604][INFO][TabNet] task_loss : 0.04872569441795349
[2021-02-23 21:53:25,605][INFO][TabNet] mask_loss : -0.272623747587204
[2021-02-23 21:53:25,606][INFO][TabNet] time_cost : 0.05448
[2021-02-23 

[2021-02-23 21:54:28,648][INFO][TabNet] task_loss : 0.056136827915906906
[2021-02-23 21:54:28,648][INFO][TabNet] mask_loss : -0.2769050896167755
[2021-02-23 21:54:28,649][INFO][TabNet] time_cost : 0.054883
[2021-02-23 21:54:28,649][INFO][TabNet] ******************** epoch : 89 ********************
[2021-02-23 21:54:32,578][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:54:32,579][INFO][TabNet] total_loss : 0.053862135857343674
[2021-02-23 21:54:32,580][INFO][TabNet] task_loss : 0.05357220768928528
[2021-02-23 21:54:32,580][INFO][TabNet] mask_loss : -0.2899284362792969
[2021-02-23 21:54:32,581][INFO][TabNet] time_cost : 0.054875
[2021-02-23 21:54:32,581][INFO][TabNet] ******************** epoch : 90 ********************
[2021-02-23 21:54:36,509][INFO][TabNet] -------------------- train info --------------------
[2021-02-23 21:54:36,510][INFO][TabNet] total_loss : 0.041292257606983185
[2021-02-23 21:54:36,511][INFO][TabNet] task_loss : 0.041006498038768

CustomizedEstimator(input_dims=30, output_dims=[1])

In [14]:
from sklearn.metrics import accuracy_score

y_pred = my_model.predict(X)
accuracy_score(y_pred[0], y)

0.9595782073813708

In [15]:
y_pred[0].shape

(569, 1)