In [None]:
import torch
from sklearn.datasets import fetch_openml

from neural_blueprints.utils import Trainer, accuracy
from neural_blueprints.config import TabularBERTConfig, TransformerEncoderConfig, TrainerConfig
from neural_blueprints.architectures import TabularBERT
from neural_blueprints.datasets import MaskedTabularDataset
from neural_blueprints.preprocess import TabularPreprocessor

import logging
logging.basicConfig(
    level=logging.INFO,  # or DEBUG if you want even more detail
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

In [2]:
# Fetch Adult Income dataset from OpenML
adult = fetch_openml("adult", version=2, as_frame=True)

# X = features, y = target
X = adult.data
y = adult.target

data = X.copy()
data[y.name] = y

In [None]:
preprocessor = TabularPreprocessor(with_masking=True)
data, discrete_features, continuous_features = preprocessor.run(data)

# Create dataset
dataset = MaskedTabularDataset(
    data, 
    discrete_features, 
    continuous_features,
    mask_prob=0.25
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [None]:
bert_config = TabularBERTConfig(
    cardinalities=dataset.cardinalities,
    encoder_config=TransformerEncoderConfig(
        input_dim=data.shape[1],
        hidden_dim=32,
        num_layers=4,
        num_heads=4,
        dropout=0.1,
        projection=None,
        final_activation=None
    ),
    with_input_projection=True,
    with_output_projection=True,
    dropout=0.1,
    final_activation=None
)

model = TabularBERT(bert_config)
model.blueprint()

BERT(
  (input_projections): ModuleList(
    (0): FeedForwardNetwork(
      (network): Sequential(
        (0): DenseLayer(
          (linear): Linear(in_features=1, out_features=32, bias=True)
          (normalization): Identity()
          (activation): ReLU()
        )
        (1): DenseLayer(
          (linear): Linear(in_features=32, out_features=32, bias=True)
          (normalization): Identity()
          (activation): ReLU()
        )
        (2): DenseLayer(
          (linear): Linear(in_features=32, out_features=32, bias=True)
          (normalization): Identity()
          (activation): Identity()
        )
      )
    )
    (1): EmbeddingLayer(
      (embedding): Embedding(10, 32)
    )
    (2): FeedForwardNetwork(
      (network): Sequential(
        (0): DenseLayer(
          (linear): Linear(in_features=1, out_features=32, bias=True)
          (normalization): Identity()
          (activation): ReLU()
        )
        (1): DenseLayer(
          (linear): Linear(in_feat

BERTConfig(cardinalities=[1, 10, 1, 18, 18, 9, 16, 8, 7, 4, 1, 1, 1, 43, 4], encoder_config=TransformerEncoderConfig(input_dim=15, hidden_dim=32, num_layers=4, num_heads=4, dropout=0.1, projection=None, final_normalization=NormalizationConfig(norm_type='layernorm', num_features=32), final_activation=None), dropout=0.1, with_input_projection=True, with_output_projection=True, final_normalization=None, final_activation=None)

In [5]:
trainer = Trainer(
    model=model,
    config=TrainerConfig(
        training_type='masked_label',
        criterion='tabular_masked_loss',
        optimizer='adam',
        learning_rate=1e-3,
        weight_decay=1e-5,
        save_weights_path="../models/bert_adult.pth",
        batch_size=64
    )
)
trainer.train(train_dataset, val_dataset, epochs=5, visualize=True)

Directory ../models already exists. Weights file may be overwritten.


Training Epochs:  20%|██        | 1/5 [00:28<01:55, 28.78s/epoch]

Epoch 1/5, Training Loss: 1.0449, Validation Loss: 0.8619


Training Epochs:  40%|████      | 2/5 [00:58<01:28, 29.43s/epoch]

Epoch 2/5, Training Loss: 0.8600, Validation Loss: 0.8100


Training Epochs:  60%|██████    | 3/5 [01:29<01:00, 30.21s/epoch]

Epoch 3/5, Training Loss: 0.8147, Validation Loss: 0.7747


Training Epochs:  80%|████████  | 4/5 [02:01<00:30, 30.60s/epoch]

Epoch 4/5, Training Loss: 0.7919, Validation Loss: 0.7782


Training Epochs: 100%|██████████| 5/5 [02:28<00:00, 29.69s/epoch]


Epoch 5/5, Training Loss: 0.7802, Validation Loss: 0.7747


In [6]:
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False)
X, y, mask = next(iter(val_loader))  # Get first 100 samples from validation set
predictions, _ = trainer.predict(X, y, mask)

dis_accuracy = 0
cont_accuracy = 0
for column_idx, column_name in enumerate(data.columns):
    print(f"\nFeature Column {column_name}:")
    predicted_attributes = predictions[column_idx]      # shape: (batch_size, num_classes)
    targets = y[:, column_idx]                     # shape: (batch_size,)

    feature_mask = mask[:, column_idx]                  # shape: (batch_size,)
    predicted_attributes = predicted_attributes[feature_mask]
    if predicted_attributes.size(1) > 1:
        predicted_attributes = predicted_attributes.softmax(dim=-1).argmax(dim=-1).cpu().numpy()
    else:
        predicted_attributes = predicted_attributes.squeeze(-1).cpu().numpy()
    targets = targets[feature_mask].cpu().numpy()

    print("Predicted attribute values:", predicted_attributes[:5])
    print("True attribute values:", targets[:5])

    accuracy_value = accuracy(torch.tensor(predicted_attributes), torch.tensor(targets))
    print(f"Accuracy: {accuracy_value:.4f}")
    if column_name in discrete_features:
        dis_accuracy += accuracy_value
    else:
        cont_accuracy += accuracy_value

avg_dis_accuracy = dis_accuracy / len(discrete_features) if len(discrete_features) > 0 else 0
avg_cont_accuracy = cont_accuracy / len(continuous_features) if len(continuous_features) > 0 else 0
print(f"\nAverage Discrete Accuracy: {avg_dis_accuracy:.4f}")
print(f"Average Continuous Accuracy: {avg_cont_accuracy:.4f}")
avg_accuracy = (dis_accuracy + cont_accuracy) / len(data.columns)
print(f"Overall Average Accuracy: {avg_accuracy:.4f}")


Feature Column age:
Predicted attribute values: [-0.86300075  0.3950443   0.49658772 -0.23111027  0.30063537]
True attribute values: [-0.77631646 -0.04694151  1.7764958  -0.411629   -0.33869147]
Accuracy: 0.0741

Feature Column workclass:
Predicted attribute values: [5 5 5 5 5]
True attribute values: [5. 5. 5. 5. 5.]
Accuracy: 0.7200

Feature Column fnlwgt:
Predicted attribute values: [-0.08535059  0.05779101 -0.06095454  0.01411958 -0.00652844]
True attribute values: [ 0.59962213  0.3915787   0.55691504  0.12289303 -0.443862  ]
Accuracy: 0.0968

Feature Column education:
Predicted attribute values: [13 13 17 14 17]
True attribute values: [13. 13. 17. 14. 17.]
Accuracy: 0.8000

Feature Column education-num:
Predicted attribute values: [17  6  6 12  7]
True attribute values: [12.  6.  6. 12.  7.]
Accuracy: 0.8182

Feature Column marital-status:
Predicted attribute values: [4 6 6 6 4]
True attribute values: [4. 7. 8. 7. 4.]
Accuracy: 0.7273

Feature Column occupation:
Predicted attribut