In [1]:
import torch
from sklearn.datasets import fetch_openml

from neural_blueprints.utils import Trainer, infer_types
from neural_blueprints.config.architectures import BERTConfig
from neural_blueprints.config.utils import TrainerConfig
from neural_blueprints.architectures import BERT
from neural_blueprints.config.components.composite.projections import TabularProjectionConfig
from neural_blueprints.datasets import MaskedTabularDataset, TabularLabelDataset
from neural_blueprints.preprocess import TabularPreprocessor

import logging
logging.basicConfig(
    level=logging.DEBUG,  # or DEBUG if you want even more detail
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

In [2]:
data = fetch_openml(name="adult", version=2, as_frame=True)
X = data.data
y = data.target

data = X.copy()
data['income'] = y

dtypes = infer_types(data)
data = data.astype(dtypes)
data.head()

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K
1,38,Private,89814,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,<=50K
2,28,Local-gov,336951,Assoc-acdm,12,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,>50K
3,44,Private,160323,Some-college,10,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,>50K
4,18,,103497,Some-college,10,Never-married,,Own-child,White,Female,0,0,30,United-States,<=50K


In [3]:
preprocessor = TabularPreprocessor()
data, discrete_features, continuous_features = preprocessor.run(data)

2026-01-22 16:40:13,048 - neural_blueprints.preprocess.tabular_preprocess - INFO - Identified 10 discrete features: ['workclass', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'income']
2026-01-22 16:40:13,049 - neural_blueprints.preprocess.tabular_preprocess - INFO - Identified 5 continuous features: ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
2026-01-22 16:40:13,049 - neural_blueprints.preprocess.tabular_preprocess - INFO - Discrete column 'workclass' has 2799/5.73% NaN values; these will be encoded as 0.
2026-01-22 16:40:13,112 - neural_blueprints.preprocess.tabular_preprocess - INFO - Discrete column 'occupation' has 2809/5.75% NaN values; these will be encoded as 0.
2026-01-22 16:40:13,167 - neural_blueprints.preprocess.tabular_preprocess - INFO - Discrete column 'native-country' has 857/1.75% NaN values; these will be encoded as 0.


### Income Inference Accuracy

In [4]:
dataset = TabularLabelDataset(
    data=data,
    label_columns=['income'],              # Specify the label column for single-label classification
)

train_dataset, val_dataset = dataset.random_split([0.9, 0.1])

In [None]:
latent_dim = 32
bert_config = BERTConfig(
    input_spec=(len(dataset.cardinalities), ),
    output_spec=(2, ),
    input_projection=TabularProjectionConfig(
        input_cardinalities=dataset.cardinalities,
        hidden_dims=[latent_dim*2, latent_dim*2],
        projection_dim=latent_dim,
        output_dim=[len(dataset.cardinalities), latent_dim]
    ),
    output_projection=TabularProjectionConfig(
        input_dim=[len(dataset.cardinalities)*latent_dim],
        hidden_dims=[latent_dim*8, latent_dim*4, latent_dim*2],
        output_cardinalities=[2],
    ),
    encoder_layers=4,
    dropout_p = 0.1,
    normalization = "batchnorm1d",
    activation = "gelu",
    final_activation = None
)

model = BERT(bert_config)
model.blueprint(batch_size=256, with_graph=False)

Layer (type:depth-idx)                                                      Output Shape              Param #
BERT                                                                        [256, 3]                  --
├─TabularInputProjection: 1-1                                               [256, 14, 32]             --
│    └─ModuleList: 2-1                                                      --                        --
│    │    └─NumericalProjection: 3-1                                        [256, 32]                 6,624
│    │    └─DiscreteProjection: 3-2                                         [256, 32]                 11,424
│    │    └─NumericalProjection: 3-3                                        [256, 32]                 6,624
│    │    └─DiscreteProjection: 3-4                                         [256, 32]                 11,872
│    │    └─DiscreteProjection: 3-5                                         [256, 32]                 11,872
│    │    └─DiscreteProjection: 

In [6]:
trainer = Trainer(
    model=model,
    config=TrainerConfig(
        criterion='cross_entropy',
        optimizer='adam',
        early_stopping_patience=5,
        learning_rate=3e-4,
        weight_decay=1e-4,
        batch_size=256
    )
)
trainer.train(train_dataset, val_dataset, epochs=5, visualize=True)
trainer.predict(val_dataset)

2026-01-22 16:40:38,145 - neural_blueprints.utils.trainer - INFO - Trainer initialized on device: cpu
Training Epochs:  20%|██        | 1/5 [00:14<00:57, 14.27s/epoch]

Epoch 1/5, Training Loss: 0.8796, Validation Loss: 0.7642


Training Epochs:  40%|████      | 2/5 [00:28<00:43, 14.43s/epoch]

Epoch 2/5, Training Loss: 0.7500, Validation Loss: 0.7132


Training Epochs:  60%|██████    | 3/5 [00:43<00:28, 14.42s/epoch]

Epoch 3/5, Training Loss: 0.7090, Validation Loss: 0.6935


Training Epochs:  80%|████████  | 4/5 [00:57<00:14, 14.48s/epoch]

Epoch 4/5, Training Loss: 0.6930, Validation Loss: 0.6838


Training Epochs: 100%|██████████| 5/5 [01:12<00:00, 14.47s/epoch]
2026-01-22 16:41:50,486 - neural_blueprints.utils.trainer - INFO - Training completed in 72.34 seconds.
2026-01-22 16:41:50,486 - neural_blueprints.utils.trainer - INFO - Best validation loss: 6.8074e-01


Epoch 5/5, Training Loss: 0.6863, Validation Loss: 0.6807


2026-01-22 16:41:51,107 - neural_blueprints.utils.trainer - INFO - Inference completed in 0.38 seconds.


Classification Accuracy: 0.8622


0.8622313203684749

### Masked Dataset Inference Accuracy

In [9]:
# Create dataset
dataset = MaskedTabularDataset(
    data, 
    mask_prob=0.15
)

train_dataset, val_dataset = dataset.random_split([0.9, 0.1])

In [18]:
latent_dim = 32

bert_config = BERTConfig(
    input_spec=(len(dataset.cardinalities), ),
    output_spec=(len(dataset.cardinalities), ),

    input_projection=TabularProjectionConfig(
        input_cardinalities=dataset.cardinalities,
        hidden_dims=[latent_dim*2, latent_dim*2],
        projection_dim=latent_dim,
        output_dim=[len(dataset.cardinalities), latent_dim]
    ),

    output_projection=TabularProjectionConfig(
        input_dim=[len(dataset.cardinalities), latent_dim],
        hidden_dims=[latent_dim*4, latent_dim*2],
        output_cardinalities=dataset.cardinalities,
    ),
    encoder_layers=6,
    dropout_p = 0.1,
    normalization = "batchnorm1d",
    activation = "gelu",
    final_activation = None
)

model = BERT(bert_config)
model.blueprint(batch_size=256, with_graph=False)

Layer (type:depth-idx)                                                      Output Shape              Param #
BERT                                                                        [256, 1]                  --
├─TabularInputProjection: 1-1                                               [256, 15, 32]             --
│    └─ModuleList: 2-1                                                      --                        --
│    │    └─NumericalProjection: 3-1                                        [256, 32]                 6,624
│    │    └─DiscreteProjection: 3-2                                         [256, 32]                 11,424
│    │    └─NumericalProjection: 3-3                                        [256, 32]                 6,624
│    │    └─DiscreteProjection: 3-4                                         [256, 32]                 11,872
│    │    └─DiscreteProjection: 3-5                                         [256, 32]                 11,872
│    │    └─DiscreteProjection: 

In [19]:
trainer = Trainer(
    model=model,
    config=TrainerConfig(
        criterion='masked_reconstruction',
        optimizer='adam',
        early_stopping_patience=5,
        learning_rate=3e-4,
        weight_decay=1e-4,
        batch_size=256
    )
)
trainer.train(train_dataset, val_dataset, epochs=5, visualize=True)
trainer.predict(val_dataset)

2026-01-22 16:49:38,501 - neural_blueprints.utils.trainer - INFO - Trainer initialized on device: cpu
Training Epochs:  20%|██        | 1/5 [00:23<01:34, 23.65s/epoch]

Epoch 1/5, Training Loss: 2.8820, Validation Loss: 2.7040


Training Epochs:  40%|████      | 2/5 [00:47<01:11, 23.72s/epoch]

Epoch 2/5, Training Loss: 2.6755, Validation Loss: 2.6326


Training Epochs:  60%|██████    | 3/5 [01:11<00:47, 23.66s/epoch]

Epoch 3/5, Training Loss: 2.6214, Validation Loss: 2.5775


Training Epochs:  80%|████████  | 4/5 [01:34<00:23, 23.71s/epoch]

Epoch 4/5, Training Loss: 2.5751, Validation Loss: 2.5358


Training Epochs: 100%|██████████| 5/5 [01:58<00:00, 23.71s/epoch]
2026-01-22 16:51:37,048 - neural_blueprints.utils.trainer - INFO - Training completed in 118.55 seconds.
2026-01-22 16:51:37,048 - neural_blueprints.utils.trainer - INFO - Best validation loss: 2.5115e+00


Epoch 5/5, Training Loss: 2.5437, Validation Loss: 2.5115


2026-01-22 16:51:37,723 - neural_blueprints.utils.trainer - INFO - Inference completed in 0.62 seconds.


Feature Column 0:
Predicted attribute values: [0.22807363 0.13467433 0.1812306  0.2658026  0.3681539 ]
True attribute values: [0.08219178 0.31506848 0.1369863  0.369863   0.369863  ]
Accuracy: 0.2837

Feature Column 1:
Predicted attribute values: [4 4 4 4 4]
True attribute values: [4. 4. 4. 7. 7.]
Accuracy: 0.7080

Feature Column 2:
Predicted attribute values: [0.11791205 0.11456297 0.11317367 0.12096792 0.12521034]
True attribute values: [0.1427778  0.10495124 0.05160559 0.03963359 0.1089144 ]
Accuracy: 0.6117

Feature Column 3:
Predicted attribute values: [12 12 12 12  2]
True attribute values: [16. 12. 16. 16.  3.]
Accuracy: 0.5405

Feature Column 4:
Predicted attribute values: [16 16  3 16  3]
True attribute values: [ 2.  2. 14.  2.  4.]
Accuracy: 0.5575

Feature Column 5:
Predicted attribute values: [5 3 3 3 3]
True attribute values: [5. 3. 3. 3. 3.]
Accuracy: 0.7861

Feature Column 6:
Predicted attribute values: [ 3 10  8  1  4]
True attribute values: [12. 10.  1.  1.  4.]
Accura

{'avg_discrete_accuracy': np.float64(0.6887952808714944),
 'avg_continuous_accuracy': np.float64(0.6697745602443502),
 'overall_avg_accuracy': np.float64(0.6824550406624464)}