In [2]:
import torch
import numpy as np
from pytorch_tabnet.tab_model import TabNetClassifier
import pandas as pd
import kagglehub
from kagglehub import KaggleDatasetAdapter
from sklearn.model_selection import train_test_split,cross_validate
from sklearn.metrics import roc_auc_score

In [3]:
df = kagglehub.load_dataset(adapter=KaggleDatasetAdapter.PANDAS,handle="fedesoriano/stroke-prediction-dataset",path="healthcare-dataset-stroke-data.csv")

In [4]:
df.head()

Unnamed: 0,id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,9046,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
1,51676,Female,61.0,0,0,Yes,Self-employed,Rural,202.21,,never smoked,1
2,31112,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
3,60182,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
4,1665,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1


In [5]:
display(df.dtypes)
print(f"Shape: {df.shape}")


id                     int64
gender                object
age                  float64
hypertension           int64
heart_disease          int64
ever_married          object
work_type             object
Residence_type        object
avg_glucose_level    float64
bmi                  float64
smoking_status        object
stroke                 int64
dtype: object

Shape: (5110, 12)


In [6]:
df.smoking_status.unique()

array(['formerly smoked', 'never smoked', 'smokes', 'Unknown'],
      dtype=object)

In [7]:
#Binary Categorical Features to encode
categorical_binary = ['ever_married', 'Residence_type']
categorical_multi = ['work_type', 'gender','smoking_status']
numerical_double = ['age', 'avg_glucose_level', 'bmi',]

#create numerical encodings for binary categorical features
df['ever_married'] = df['ever_married'].apply(lambda x: 1 if x == 'Yes' else 0)
df['Residence_type'] = df['Residence_type'].apply(lambda x: 1 if x == 'Urban' else 0)

#one hot encode multi categorical features
df = pd.concat([df, pd.get_dummies(data= df[categorical_multi],dtype=int)], axis=1)

In [8]:
df = df.drop(columns=categorical_multi)

In [9]:
#check for missing values
df.isnull().sum()

id                                  0
age                                 0
hypertension                        0
heart_disease                       0
ever_married                        0
Residence_type                      0
avg_glucose_level                   0
bmi                               201
stroke                              0
work_type_Govt_job                  0
work_type_Never_worked              0
work_type_Private                   0
work_type_Self-employed             0
work_type_children                  0
gender_Female                       0
gender_Male                         0
gender_Other                        0
smoking_status_Unknown              0
smoking_status_formerly smoked      0
smoking_status_never smoked         0
smoking_status_smokes               0
dtype: int64

In [10]:
#remove rows whith no bmi data
df = df.dropna()
X = df.drop(columns='stroke')
y = df['stroke']

#split data into train, test, and validation sets stratified by stroke
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42, stratify=y_train)


In [11]:
#split data into train, test, and validation sets stratified by stroke
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42, stratify=y_train)
X_train.shape

(3141, 20)

In [12]:
tabnet_model = TabNetClassifier(optimizer_fn=torch.optim.Adamax, optimizer_params=dict(lr=1e-3),n_d=16,n_a=16,n_steps=3) 

tabnet_model.fit(X_train=X_train.values, y_train=y_train.values,
                  eval_set=[(X_val.values, y_val.values)],
                  eval_metric=['logloss'], num_workers=8, max_epochs=15,virtual_batch_size=64,
                  batch_size=256)



epoch 0  | loss: 0.716   | val_0_logloss: 0.42658 |  0:01:16s
epoch 1  | loss: 0.57266 | val_0_logloss: 0.37631 |  0:02:32s
epoch 2  | loss: 0.47032 | val_0_logloss: 0.35578 |  0:03:48s
epoch 3  | loss: 0.40502 | val_0_logloss: 0.3336  |  0:05:05s
epoch 4  | loss: 0.34663 | val_0_logloss: 0.30335 |  0:06:20s
epoch 5  | loss: 0.31236 | val_0_logloss: 0.29246 |  0:07:37s
epoch 6  | loss: 0.28363 | val_0_logloss: 0.28225 |  0:08:53s
epoch 7  | loss: 0.26622 | val_0_logloss: 0.27634 |  0:10:09s
epoch 8  | loss: 0.24951 | val_0_logloss: 0.27397 |  0:11:25s
epoch 9  | loss: 0.23705 | val_0_logloss: 0.25783 |  0:12:41s
epoch 10 | loss: 0.22569 | val_0_logloss: 0.25077 |  0:13:56s
epoch 11 | loss: 0.21983 | val_0_logloss: 0.23779 |  0:15:12s
epoch 12 | loss: 0.21503 | val_0_logloss: 0.22773 |  0:16:28s
epoch 13 | loss: 0.20335 | val_0_logloss: 0.22204 |  0:17:43s
epoch 14 | loss: 0.20604 | val_0_logloss: 0.21632 |  0:18:59s
Stop training because you reached max_epochs = 15 with best_epoch = 14



In [13]:
# Convert model to PyTorch format
pytorch_model = tabnet_model.network
pytorch_model.eval()

TabNet(
  (embedder): EmbeddingGenerator()
  (tabnet): TabNetNoEmbeddings(
    (initial_bn): BatchNorm1d(20, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (encoder): TabNetEncoder(
      (initial_bn): BatchNorm1d(20, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      (initial_splitter): FeatTransformer(
        (shared): GLU_Block(
          (shared_layers): ModuleList(
            (0): Linear(in_features=20, out_features=64, bias=False)
            (1): Linear(in_features=32, out_features=64, bias=False)
          )
          (glu_layers): ModuleList(
            (0): GLU_Layer(
              (fc): Linear(in_features=20, out_features=64, bias=False)
              (bn): GBN(
                (bn): BatchNorm1d(64, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)
              )
            )
            (1): GLU_Layer(
              (fc): Linear(in_features=32, out_features=64, bias=False)
              (bn): GBN(
               

In [14]:
dummy_input = torch.randn(3141, 20).float()  # 3141 sample, 20 features

# Export model to ONNX
torch.onnx.export(
    pytorch_model, dummy_input, "tabnet.onnx",
    export_params=True,    # Store the trained parameters
    opset_version=11,      # Use ONNX opset 11+
    do_constant_folding=True,  # Optimize graph by folding constants
    input_names=["input"], output_names=["output"]
)

print("TabNet model has been converted to ONNX format!")

  chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0)


TabNet model has been converted to ONNX format!


In [27]:
#Use onnx model to make predictions
import onnxruntime as ort


# Load the ONNX model
onnx_model = ort.InferenceSession("tabnet.onnx")

# Get the input name for the ONNX model
input_name = onnx_model.get_inputs()[0].name

# Get the output name for the ONNX model
output_name = onnx_model.get_outputs()[0].name

# Convert the input data to ONNX format
input_data = X_train.values

# Make predictions using ONNX model
predictions_onnx = onnx_model.run([output_name], {input_name: input_data.astype(np.float32)})[0]

#show stats
print(f"ONNX model predictions: {predictions_onnx}")

ONNX model predictions: [[ 3.3532372  -0.7668635 ]
 [ 2.6243238  -0.8884481 ]
 [ 0.9245977  -1.1461244 ]
 ...
 [ 1.3486643  -0.48778722]
 [ 0.7718145  -0.98826957]
 [ 0.88292843 -0.64041334]]
