In [1]:
import torch
from sklearn.datasets import fetch_openml

from neural_blueprints.utils import Trainer, accuracy
from neural_blueprints.config.architectures import TabularBERTConfig
from neural_blueprints.config.utils import TrainerConfig
from neural_blueprints.architectures import TabularBERT
from neural_blueprints.datasets import MaskedTabularDataset
from neural_blueprints.preprocess import TabularPreprocessor

import logging
logging.basicConfig(
    level=logging.DEBUG,  # or DEBUG if you want even more detail
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

In [2]:
# Fetch Adult Income dataset from OpenML
adult = fetch_openml("adult", version=2, as_frame=True)

# X = features, y = target
X = adult.data
y = adult.target

data = X.copy()
data[y.name] = y

In [3]:
preprocessor = TabularPreprocessor()
data, discrete_features, continuous_features = preprocessor.run(data)

# Create dataset
dataset = MaskedTabularDataset(
    data, 
    discrete_features, 
    continuous_features,
    mask_prob=0.35
)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

2025-12-22 00:13:32,707 - neural_blueprints.preprocess.tabular_preprocess - INFO - Identified 10 discrete features: ['workclass', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'class']
2025-12-22 00:13:32,707 - neural_blueprints.preprocess.tabular_preprocess - INFO - Identified 5 continuous features: ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']


In [6]:
bert_config = TabularBERTConfig(
    cardinalities = dataset.cardinalities,   
    latent_dim = 64,
    encoder_layers=8,
    dropout_p = 0.1,
    normalization = "batchnorm1d",
    activation = "gelu",
    final_activation = None
)

model = TabularBERT(bert_config)
model.blueprint()

TabularBERT(
  (input_projection): TabularInputProjection(
    (input_projections): ModuleList(
      (0): FeedForwardNetwork(
        (network): Sequential(
          (0): DenseLayer(
            (layer): Sequential(
              (0): Linear(in_features=1, out_features=256, bias=True)
              (1): NormalizationLayer(
                (network): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (2): GELU(approximate='none')
              (3): DropoutLayer(
                (dropout): Dropout(p=0.0, inplace=False)
              )
            )
          )
          (1): DenseLayer(
            (layer): Sequential(
              (0): Linear(in_features=256, out_features=128, bias=True)
              (1): NormalizationLayer(
                (network): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (2): GELU(approximate='none')
              (3): DropoutLayer(
 

TabularBERTConfig(cardinalities=[1, 9, 1, 16, 16, 7, 15, 6, 5, 2, 1, 1, 1, 42, 2], latent_dim=64, encoder_layers=8, dropout_p=0.1, normalization='batchnorm1d', activation='gelu', final_activation=None)

In [7]:
trainer = Trainer(
    model=model,
    config=TrainerConfig(
        criterion='mixed_type_reconstruction_loss',
        optimizer='adam',
        early_stopping_patience=3,
        learning_rate=1e-3,
        weight_decay=1e-5,
        save_weights_path="./models/bert_adult.pth",
        batch_size=128
    )
)
trainer.train(train_dataset, val_dataset, epochs=5, visualize=True)

2025-12-22 00:13:44,800 - neural_blueprints.utils.trainer - INFO - Trainer initialized on device: cpu


Directory ./models already exists. Existing weights are overwritten.


Training Epochs:  20%|██        | 1/5 [00:34<02:19, 34.94s/epoch]

Epoch 1/5, Training Loss: 33.5590, Validation Loss: 25.5916


Training Epochs:  40%|████      | 2/5 [01:09<01:44, 34.91s/epoch]

Epoch 2/5, Training Loss: 25.1405, Validation Loss: 24.6938


Training Epochs:  60%|██████    | 3/5 [01:44<01:09, 34.95s/epoch]

Epoch 3/5, Training Loss: 24.5489, Validation Loss: 24.2016


Training Epochs:  80%|████████  | 4/5 [02:19<00:34, 34.98s/epoch]

Epoch 4/5, Training Loss: 24.2124, Validation Loss: 24.0425


Training Epochs: 100%|██████████| 5/5 [02:54<00:00, 34.97s/epoch]
2025-12-22 00:16:39,649 - neural_blueprints.utils.trainer - INFO - Training completed in 174.85 seconds.
2025-12-22 00:16:39,649 - neural_blueprints.utils.trainer - INFO - Best validation loss: 2.3718e+01


Epoch 5/5, Training Loss: 23.9635, Validation Loss: 23.7179


In [None]:
X = val_dataset[:][0]
y = val_dataset[:][1]
mask = val_dataset[:][2]
with torch.no_grad():
    y_pred = model(x=X)

dis_accuracy = 0
cont_accuracy = 0
for column_idx, column_name in enumerate(data.columns):
    print(f"\nFeature Column {column_name}:")
    predicted_attributes = y_pred[column_idx]      # shape: (batch_size, num_classes)
    targets = y[:, column_idx]                     # shape: (batch_size,)

    feature_mask = mask[:, column_idx]                  # shape: (batch_size,)
    predicted_attributes = predicted_attributes[feature_mask]
    if predicted_attributes.size(1) > 1:
        predicted_attributes = predicted_attributes.softmax(dim=-1).argmax(dim=-1).cpu().numpy()
    else:
        predicted_attributes = predicted_attributes.squeeze(-1).cpu().numpy()
    targets = targets[feature_mask].cpu().numpy()

    print("Predicted attribute values:", predicted_attributes[:5])
    print("True attribute values:", targets[:5])

    accuracy_value = accuracy(torch.tensor(predicted_attributes), torch.tensor(targets))
    print(f"Accuracy: {accuracy_value:.4f}")
    if column_name in discrete_features:
        dis_accuracy += accuracy_value
    else:
        cont_accuracy += accuracy_value

avg_dis_accuracy = dis_accuracy / len(discrete_features) if len(discrete_features) > 0 else 0
avg_cont_accuracy = cont_accuracy / len(continuous_features) if len(continuous_features) > 0 else 0
print(f"\nAverage Discrete Accuracy: {avg_dis_accuracy:.4f}")
print(f"Average Continuous Accuracy: {avg_cont_accuracy:.4f}")
avg_accuracy = (dis_accuracy + cont_accuracy) / len(data.columns)
print(f"Overall Average Accuracy: {avg_accuracy:.4f}")


Feature Column age:
Predicted attribute values: [0.31313652 0.18980552 0.04015766 0.579761   0.15299653]
True attribute values: [0.4520548  0.08219178 0.05479452 0.60273975 0.15068494]
Accuracy: 0.2860

Feature Column workclass:
Predicted attribute values: [4 4 4 4 4]
True attribute values: [6. 4. 4. 4. 4.]
Accuracy: 0.6898

Feature Column fnlwgt:
Predicted attribute values: [0.18872398 0.17175908 0.16113535 0.17457661 0.14949259]
True attribute values: [0.31154814 0.08694384 0.11687589 0.01234275 0.01189962]
Accuracy: 0.3201

Feature Column education:
Predicted attribute values: [ 1 15 16 12  2]
True attribute values: [ 1. 15. 16. 12.  2.]
Accuracy: 0.8500

Feature Column education-num:
Predicted attribute values: [16 16 16 16 16]
True attribute values: [16. 16. 16. 16. 16.]
Accuracy: 0.8429

Feature Column marital-status:
Predicted attribute values: [3 3 1 5 5]
True attribute values: [3. 3. 6. 5. 5.]
Accuracy: 0.7911

Feature Column occupation:
Predicted attribute values: [ 1  0  8 