In [10]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification, make_regression
from sklearn.metrics import accuracy_score, root_mean_squared_error
from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor
import xgboost as xgb

In [11]:
def compare_classification(X_train, y_train, X_valid, y_valid, X_test, y_test, num_classes):
    if num_classes == 2:
        clf_tabnet_vanila = TabNetClassifier(verbose=0)
        clf_xgboost_vanila = xgb.XGBClassifier()
        type = "Binary Classification"
    else:
        clf_tabnet_vanila = TabNetClassifier(verbose=0)
        clf_xgboost_vanila = xgb.XGBClassifier()
        type = "Multiclass Classification"


    clf_tabnet_vanila.fit(
        X_train, y_train,
        eval_set=[(X_valid, y_valid)],
    )

    clf_xgboost_vanila.fit(
        X_train, y_train,
        eval_set=[(X_valid, y_valid)],
        verbose=False
    )

    pred_tabnet = clf_tabnet_vanila.predict(X_test)
    pred_xgboost = clf_xgboost_vanila.predict(X_test)

    accuracy_tabnet = accuracy_score(y_test, pred_tabnet)
    accuracy_xgboost = accuracy_score(y_test, pred_xgboost)

    print(f"\n === {type} ===\n")
    print(f"TabNet Accuracy with vanilla parameters: {accuracy_tabnet}\n")
    print(f"XGBoost Accuracy with vanilla parameters: {accuracy_xgboost}\n")

In [12]:
def compare_regression(X_train, y_train, X_valid, y_valid, X_test, y_test):
    reg_tabnet_vanila = TabNetRegressor(verbose=0)
    reg_xgboost_vanila = xgb.XGBRegressor()
    y_train = y_train.reshape(-1, 1)
    y_valid = y_valid.reshape(-1, 1)

    reg_tabnet_vanila.fit(
        X_train, y_train,
        eval_set=[(X_valid, y_valid)],
    )

    reg_xgboost_vanila.fit(
        X_train, y_train,
        eval_set=[(X_valid, y_valid)],
        verbose=False
    )

    pred_tabnet = reg_tabnet_vanila.predict(X_test)
    pred_xgboost = reg_xgboost_vanila.predict(X_test)

    mse_tabnet = root_mean_squared_error(y_test, pred_tabnet)
    mse_xgboost = root_mean_squared_error(y_test, pred_xgboost)

    print("\n=== Regression ===\n")
    print(f"TabNet RMSE with vanilla parameters: {mse_tabnet}\n")
    print(f"XGBoost RMSE with vanilla parameters: {mse_xgboost}\n")


In [13]:
# syntetic data for BINARY CLASSIFICATION
X_binary, y_binary = make_classification(n_samples=10000, n_features=100, n_informative=8, n_classes=2, random_state=42)
X_train_bin, X_temp_bin, y_train_bin, y_temp_bin = train_test_split(X_binary, y_binary, test_size=0.3, random_state=42)
X_valid_bin, X_test_bin, y_valid_bin, y_test_bin = train_test_split(X_temp_bin, y_temp_bin, test_size=0.5, random_state=42)

In [14]:
# syntetic data for MULTICLASS CLASSIFICATION
X_multi, y_multi = make_classification(n_samples=10000, n_features=100, n_informative=8, n_classes=3, n_clusters_per_class=1, random_state=42)
X_train_multi, X_temp_multi, y_train_multi, y_temp_multi = train_test_split(X_multi, y_multi, test_size=0.3, random_state=42)
X_valid_multi, X_test_multi, y_valid_multi, y_test_multi = train_test_split(X_temp_multi, y_temp_multi, test_size=0.5, random_state=42)

In [15]:
# syntetic data for REGRESSION
X_reg, y_reg = make_regression(n_samples=10000, n_features=100, n_informative=8, noise=0.1, random_state=42)
X_train_reg, X_temp_reg, y_train_reg, y_temp_reg = train_test_split(X_reg, y_reg, test_size=0.3, random_state=42)
X_valid_reg, X_test_reg, y_valid_reg, y_test_reg = train_test_split(X_temp_reg, y_temp_reg, test_size=0.5, random_state=42)

In [16]:
# Binary Classification
compare_classification(X_train_bin, y_train_bin, X_valid_bin, y_valid_bin, X_test_bin, y_test_bin, 2)


Early stopping occurred at epoch 79 with best_epoch = 69 and best_val_0_auc = 0.95724





 === Binary Classification ===

TabNet Accuracy with vanilla parameters: 0.896

XGBoost Accuracy with vanilla parameters: 0.912



In [17]:
# MultiClass Classification
compare_classification(X_train_multi, y_train_multi, X_valid_multi, y_valid_multi, X_test_multi, y_test_multi, 3)


Early stopping occurred at epoch 90 with best_epoch = 80 and best_val_0_accuracy = 0.91733





 === Multiclass Classification ===

TabNet Accuracy with vanilla parameters: 0.928

XGBoost Accuracy with vanilla parameters: 0.944



In [18]:
# Regression
compare_regression(X_train_reg, y_train_reg, X_valid_reg, y_valid_reg, X_test_reg, y_test_reg)


Early stopping occurred at epoch 69 with best_epoch = 59 and best_val_0_mse = 101.57551





=== Regression ===

TabNet RMSE with vanilla parameters: 11.138994901732492

XGBoost RMSE with vanilla parameters: 43.010909116965536

