In [1]:
import torch
from sklearn.datasets import fetch_openml

from neural_blueprints.architectures import VariationalAutoEncoder
from neural_blueprints.config.architectures import AutoEncoderConfig
from neural_blueprints.config.components.composite import EncoderConfig, DecoderConfig
from neural_blueprints.config.components.core import DenseLayerConfig
from neural_blueprints.config.utils import TrainerConfig
from neural_blueprints.config.components.composite.projections.input import TabularInputProjectionConfig
from neural_blueprints.config.components.composite.projections.output import TabularOutputProjectionConfig
from neural_blueprints.utils import Trainer, infer_types
from neural_blueprints.preprocess import TabularPreprocessor
from neural_blueprints.datasets import MaskedTabularDataset, TabularSingleLabelDataset

import logging
logging.basicConfig(
    level=logging.DEBUG,  # or DEBUG if you want even more detail
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

In [2]:
data = fetch_openml(name="adult", version=2, as_frame=True)
X = data.data
y = data.target

data = X.copy()
data['income'] = y

dtypes = infer_types(data)
data = data.astype(dtypes)
data.head()

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,25,Private,226802,11th,7,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K
1,38,Private,89814,HS-grad,9,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,<=50K
2,28,Local-gov,336951,Assoc-acdm,12,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,>50K
3,44,Private,160323,Some-college,10,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,>50K
4,18,,103497,Some-college,10,Never-married,,Own-child,White,Female,0,0,30,United-States,<=50K


In [3]:
preprocessor = TabularPreprocessor()
data, discrete_features, continuous_features = preprocessor.run(data)
data.head()

2025-12-24 21:19:51,841 - neural_blueprints.preprocess.tabular_preprocess - INFO - Identified 10 discrete features: ['workclass', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'income']
2025-12-24 21:19:51,843 - neural_blueprints.preprocess.tabular_preprocess - INFO - Identified 5 continuous features: ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']


Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,0.109589,4,0.145129,2,14,5,7,4,3,2,0.0,0.0,0.397959,39,1
1,0.287671,4,0.052451,12,16,3,5,1,5,2,0.0,0.0,0.5,39,1
2,0.150685,2,0.219649,8,4,3,11,1,5,2,0.0,0.0,0.397959,39,2
3,0.369863,4,0.100153,16,2,3,7,1,3,2,0.076881,0.0,0.397959,39,2
4,0.013699,0,0.061708,16,2,5,0,4,5,1,0.0,0.0,0.295918,39,1


### Income Inference Accuracy

In [4]:
dataset = TabularSingleLabelDataset(
    data=data,
    label_column='income',              # Specify the label column for single-label classification
    discrete_features=discrete_features,
    continuous_features=continuous_features
)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [5]:
latent_dim = 16

vae_config = AutoEncoderConfig(
    input_projection=TabularInputProjectionConfig(
        cardinalities=dataset.cardinalities,
        hidden_dims=[64, 32],
        output_dim=[len(dataset.cardinalities)*16],
        normalization="batchnorm1d",
        activation="gelu",
        dropout_p=0.1
    ),
    output_projection=TabularOutputProjectionConfig(
        output_cardinalities=[3],
        input_dim=[latent_dim*8],
        hidden_dims=[len(dataset.cardinalities)*8],
        normalization="batchnorm1d",
        activation="gelu",
        dropout_p=0.1
    ),
    encoder_config=EncoderConfig(
        layer_configs=[
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*16, output_dim=len(dataset.cardinalities)*8, normalization="batchnorm1d", activation="gelu", dropout_p=0.1),
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*8, output_dim=len(dataset.cardinalities)*4, normalization="batchnorm1d", activation="gelu", dropout_p=0.1),
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*4, output_dim=latent_dim*2, normalization="batchnorm1d", dropout_p=0.1)
        ]
    ),
    decoder_config=DecoderConfig(
        layer_configs=[
            DenseLayerConfig(input_dim=latent_dim, output_dim=latent_dim*2, normalization="batchnorm1d", activation="gelu", dropout_p=0.1),
            DenseLayerConfig(input_dim=latent_dim*2, output_dim=latent_dim*4, normalization="batchnorm1d", activation="gelu", dropout_p=0.1),
            DenseLayerConfig(input_dim=latent_dim*4, output_dim=latent_dim*8, normalization="batchnorm1d", activation="gelu", dropout_p=0.1)
        ]
    )
)

model = VariationalAutoEncoder(vae_config)
model.blueprint()

2025-12-24 20:09:05,778 - neural_blueprints.architectures.autoencoder - INFO - Using input projection: TabularInputProjection
2025-12-24 20:09:05,782 - neural_blueprints.architectures.autoencoder - INFO - Using output projection: TabularOutputProjection


VariationalAutoEncoder(
  (input_projection): TabularInputProjection(
    (input_projections): ModuleList(
      (0): FeedForwardNetwork(
        (network): Sequential(
          (0): DenseLayer(
            (layer): Sequential(
              (0): Linear(in_features=1, out_features=64, bias=True)
              (1): NormalizationLayer(
                (network): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (2): GELU(approximate='none')
              (3): DropoutLayer(
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
          (1): DenseLayer(
            (layer): Sequential(
              (0): Linear(in_features=64, out_features=32, bias=True)
              (1): NormalizationLayer(
                (network): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (2): GELU(approximate='none')
              (3): DropoutLa

AutoEncoderConfig(input_projection=TabularInputProjectionConfig(cardinalities=[1, 9, 1, 16, 16, 7, 15, 6, 5, 2, 1, 1, 1, 42], hidden_dims=[64, 32], output_dim=[224], dropout_p=0.1, normalization='batchnorm1d', activation='gelu'), output_projection=TabularOutputProjectionConfig(input_cardinalities=None, output_cardinalities=[3], input_dim=[128], hidden_dims=[112], activation='gelu', normalization='batchnorm1d', dropout_p=0.1), encoder_config=EncoderConfig(layer_configs=[DenseLayerConfig(input_dim=224, output_dim=112, normalization=None, activation=None, dropout_p=None), DenseLayerConfig(input_dim=112, output_dim=56, normalization=None, activation=None, dropout_p=None), DenseLayerConfig(input_dim=56, output_dim=32, normalization=None, activation=None, dropout_p=None)], normalization=None, activation=None, dropout_p=None, final_activation=None), decoder_config=DecoderConfig(layer_configs=[DenseLayerConfig(input_dim=16, output_dim=32, normalization='batchnorm1d', activation='gelu', dropout

In [6]:
trainer = Trainer(
    config=TrainerConfig(
        learning_rate=1e-3,
        weight_decay=1e-5,
        batch_size=128,
        early_stopping_patience=2,
        save_weights_path="./models/vae_adult.pth",
        criterion="vae_cross_entropy",
        optimizer='adam'
    ),
    model=model
)
trainer.train(train_dataset=train_dataset, val_dataset=val_dataset, epochs=20)

2025-12-24 20:09:06,138 - neural_blueprints.utils.trainer - INFO - Trainer initialized on device: cpu


Directory ./models already exists. Existing weights are overwritten.


Training Epochs:   5%|▌         | 1/20 [00:02<00:47,  2.53s/epoch]

Epoch 1/20, Training Loss: 1.4547, Validation Loss: 0.9895


Training Epochs:  10%|█         | 2/20 [00:04<00:44,  2.46s/epoch]

Epoch 2/20, Training Loss: 0.9938, Validation Loss: 0.9805


Training Epochs:  15%|█▌        | 3/20 [00:07<00:41,  2.46s/epoch]

Epoch 3/20, Training Loss: 0.9866, Validation Loss: 0.9792


Training Epochs:  20%|██        | 4/20 [00:09<00:39,  2.45s/epoch]

Epoch 4/20, Training Loss: 0.9849, Validation Loss: 0.9787


Training Epochs:  25%|██▌       | 5/20 [00:12<00:36,  2.46s/epoch]

Epoch 5/20, Training Loss: 0.9843, Validation Loss: 0.9786


Training Epochs:  30%|███       | 6/20 [00:14<00:34,  2.45s/epoch]

Epoch 6/20, Training Loss: 0.9839, Validation Loss: 0.9787


2025-12-24 20:09:23,342 - neural_blueprints.utils.trainer - INFO - No improvement in validation loss for 2 consecutive epochs. Early stopping at epoch 7.
Training Epochs:  30%|███       | 6/20 [00:17<00:40,  2.87s/epoch]
2025-12-24 20:09:23,343 - neural_blueprints.utils.trainer - INFO - Training completed in 17.20 seconds.
2025-12-24 20:09:23,343 - neural_blueprints.utils.trainer - INFO - Best validation loss: 9.7862e-01


In [8]:
trainer.predict(val_dataset)

2025-12-24 20:09:48,085 - neural_blueprints.utils.trainer - INFO - Inference completed in 0.06 seconds.


Predictions: tensor([1, 1, 1, 1, 1]), 
 Ground Truth: tensor([1, 2, 1, 1, 1])
Prediction Accuracy: 0.7652


np.float64(0.7651995905834186)

### Masked Dataset Inference Accuracy

In [4]:
dataset = MaskedTabularDataset(
    data = data,
    discrete_features = discrete_features,
    continuous_features = continuous_features,
    mask_prob=0.35
)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [5]:
latent_dim = 16

vae_config = AutoEncoderConfig(
    input_projection=TabularInputProjectionConfig(
        cardinalities=dataset.cardinalities,
        hidden_dims=[64, 32],
        output_dim=[len(dataset.cardinalities)*16],
        normalization="batchnorm1d",
        activation="gelu",
        dropout_p=0.1
    ),
    output_projection=TabularOutputProjectionConfig(
        output_cardinalities=dataset.cardinalities,
        input_dim=[latent_dim*8],
        hidden_dims=[len(dataset.cardinalities)*8],
        normalization="batchnorm1d",
        activation="gelu",
        dropout_p=0.1
    ),
    encoder_config=EncoderConfig(
        layer_configs=[
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*16, output_dim=len(dataset.cardinalities)*8, normalization="batchnorm1d", activation="gelu", dropout_p=0.1),
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*8, output_dim=len(dataset.cardinalities)*4, normalization="batchnorm1d", activation="gelu", dropout_p=0.1),
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*4, output_dim=latent_dim*2, normalization="batchnorm1d", dropout_p=0.1)
        ]
    ),
    decoder_config=DecoderConfig(
        layer_configs=[
            DenseLayerConfig(input_dim=latent_dim, output_dim=latent_dim*2, normalization="batchnorm1d", activation="gelu", dropout_p=0.1),
            DenseLayerConfig(input_dim=latent_dim*2, output_dim=latent_dim*4, normalization="batchnorm1d", activation="gelu", dropout_p=0.1),
            DenseLayerConfig(input_dim=latent_dim*4, output_dim=latent_dim*8, normalization="batchnorm1d", activation="gelu", dropout_p=0.1)
        ]
    )
)

model = VariationalAutoEncoder(vae_config)
model.blueprint()

2025-12-24 21:19:55,335 - neural_blueprints.architectures.autoencoder - INFO - Using input projection: TabularInputProjection
2025-12-24 21:19:55,349 - neural_blueprints.architectures.autoencoder - INFO - Using output projection: TabularOutputProjection


VariationalAutoEncoder(
  (input_projection): TabularInputProjection(
    (input_projections): ModuleList(
      (0): FeedForwardNetwork(
        (network): Sequential(
          (0): DenseLayer(
            (layer): Sequential(
              (0): Linear(in_features=1, out_features=64, bias=True)
              (1): NormalizationLayer(
                (network): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (2): GELU(approximate='none')
              (3): DropoutLayer(
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
          (1): DenseLayer(
            (layer): Sequential(
              (0): Linear(in_features=64, out_features=32, bias=True)
              (1): NormalizationLayer(
                (network): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (2): GELU(approximate='none')
              (3): DropoutLa

AutoEncoderConfig(input_projection=TabularInputProjectionConfig(cardinalities=[1, 9, 1, 16, 16, 7, 15, 6, 5, 2, 1, 1, 1, 42, 2], hidden_dims=[64, 32], output_dim=[240], dropout_p=0.1, normalization='batchnorm1d', activation='gelu'), output_projection=TabularOutputProjectionConfig(input_cardinalities=None, output_cardinalities=[1, 9, 1, 16, 16, 7, 15, 6, 5, 2, 1, 1, 1, 42, 2], input_dim=[128], hidden_dims=[120], activation='gelu', normalization='batchnorm1d', dropout_p=0.1), encoder_config=EncoderConfig(layer_configs=[DenseLayerConfig(input_dim=240, output_dim=120, normalization=None, activation=None, dropout_p=None), DenseLayerConfig(input_dim=120, output_dim=60, normalization=None, activation=None, dropout_p=None), DenseLayerConfig(input_dim=60, output_dim=32, normalization=None, activation=None, dropout_p=None)], normalization=None, activation=None, dropout_p=None, final_activation=None), decoder_config=DecoderConfig(layer_configs=[DenseLayerConfig(input_dim=16, output_dim=32, normal

In [6]:
trainer = Trainer(
    config=TrainerConfig(
        learning_rate=1e-3,
        weight_decay=1e-5,
        batch_size=128,
        early_stopping_patience=2,
        save_weights_path="./models/vae_adult.pth",
        criterion="vae_reconstruction",
        optimizer='adam'
    ),
    model=model
)
trainer.train(train_dataset=train_dataset, val_dataset=val_dataset, epochs=20)

2025-12-24 21:19:55,662 - neural_blueprints.utils.trainer - INFO - Trainer initialized on device: cpu


Directory ./models already exists. Existing weights are overwritten.


Training Epochs:   5%|▌         | 1/20 [00:04<01:26,  4.56s/epoch]

Epoch 1/20, Training Loss: 191.4667, Validation Loss: 182.4451


Training Epochs:  10%|█         | 2/20 [00:09<01:22,  4.58s/epoch]

Epoch 2/20, Training Loss: 182.7516, Validation Loss: 182.1017


Training Epochs:  15%|█▌        | 3/20 [00:13<01:17,  4.57s/epoch]

Epoch 3/20, Training Loss: 182.6056, Validation Loss: 182.0340


Training Epochs:  20%|██        | 4/20 [00:18<01:12,  4.56s/epoch]

Epoch 4/20, Training Loss: 182.5681, Validation Loss: 182.0088


Training Epochs:  25%|██▌       | 5/20 [00:22<01:09,  4.62s/epoch]

Epoch 5/20, Training Loss: 182.5518, Validation Loss: 181.9973


Training Epochs:  30%|███       | 6/20 [00:27<01:05,  4.69s/epoch]

Epoch 6/20, Training Loss: 182.5450, Validation Loss: 181.9912


Training Epochs:  35%|███▌      | 7/20 [00:32<01:01,  4.69s/epoch]

Epoch 7/20, Training Loss: 182.5401, Validation Loss: 181.9874


Training Epochs:  40%|████      | 8/20 [00:37<00:56,  4.68s/epoch]

Epoch 8/20, Training Loss: 182.5371, Validation Loss: 181.9870


Training Epochs:  45%|████▌     | 9/20 [00:41<00:51,  4.66s/epoch]

Epoch 9/20, Training Loss: 182.5357, Validation Loss: 181.9832


Training Epochs:  50%|█████     | 10/20 [00:46<00:46,  4.66s/epoch]

Epoch 10/20, Training Loss: 182.5339, Validation Loss: 181.9821


Training Epochs:  55%|█████▌    | 11/20 [00:51<00:41,  4.67s/epoch]

Epoch 11/20, Training Loss: 182.5326, Validation Loss: 181.9817


Training Epochs:  60%|██████    | 12/20 [00:55<00:37,  4.65s/epoch]

Epoch 12/20, Training Loss: 182.5313, Validation Loss: 181.9808


Training Epochs:  65%|██████▌   | 13/20 [01:00<00:32,  4.65s/epoch]

Epoch 13/20, Training Loss: 182.5320, Validation Loss: 181.9803


Training Epochs:  70%|███████   | 14/20 [01:05<00:27,  4.64s/epoch]

Epoch 14/20, Training Loss: 182.5305, Validation Loss: 181.9800


Training Epochs:  75%|███████▌  | 15/20 [01:09<00:23,  4.67s/epoch]

Epoch 15/20, Training Loss: 182.5292, Validation Loss: 181.9797


Training Epochs:  80%|████████  | 16/20 [01:14<00:18,  4.68s/epoch]

Epoch 16/20, Training Loss: 182.5311, Validation Loss: 181.9796


Training Epochs:  85%|████████▌ | 17/20 [01:19<00:14,  4.67s/epoch]

Epoch 17/20, Training Loss: 182.5304, Validation Loss: 181.9794


Training Epochs:  90%|█████████ | 18/20 [01:23<00:09,  4.67s/epoch]

Epoch 18/20, Training Loss: 182.5316, Validation Loss: 181.9793


Training Epochs:  95%|█████████▌| 19/20 [01:28<00:04,  4.66s/epoch]

Epoch 19/20, Training Loss: 182.5315, Validation Loss: 181.9793


Training Epochs: 100%|██████████| 20/20 [01:33<00:00,  4.65s/epoch]
2025-12-24 21:21:28,716 - neural_blueprints.utils.trainer - INFO - Training completed in 93.05 seconds.
2025-12-24 21:21:28,716 - neural_blueprints.utils.trainer - INFO - Best validation loss: 1.8198e+02


Epoch 20/20, Training Loss: 182.5294, Validation Loss: 181.9791


In [7]:
trainer.predict(val_dataset)

2025-12-24 21:21:28,912 - neural_blueprints.utils.trainer - INFO - Inference completed in 0.07 seconds.



Feature Column 0:
Predicted attribute values: [7.1778645e-06 6.4152682e-06 5.8408928e-06 7.5200251e-06 9.0421672e-06]
True attribute values: [0.12328767 0.3561644  0.21917808 0.32876712 0.32876712]
Accuracy: 0.0641

Feature Column 1:
Predicted attribute values: [0 0 0 0 0]
True attribute values: [2. 4. 6. 4. 4.]
Accuracy: 0.0595

Feature Column 2:
Predicted attribute values: [8.2712959e-06 7.4442182e-06 6.0232223e-06 6.2188469e-06 6.0164020e-06]
True attribute values: [0.15964793 0.12129503 0.17448507 0.11925865 0.26580274]
Accuracy: 0.1351

Feature Column 3:
Predicted attribute values: [0 0 0 0 0]
True attribute values: [16. 12.  8. 12. 16.]
Accuracy: 0.0000

Feature Column 4:
Predicted attribute values: [0 0 0 0 0]
True attribute values: [16.  4. 12.  2.  5.]
Accuracy: 0.0000

Feature Column 5:
Predicted attribute values: [0 0 0 0 0]
True attribute values: [5. 5. 3. 3. 3.]
Accuracy: 0.0000

Feature Column 6:
Predicted attribute values: [0 0 0 0 0]
True attribute values: [12. 12.  5.

{'avg_discrete_accuracy': np.float64(0.013866323678904242),
 'avg_continuous_accuracy': np.float64(0.4218169563879652),
 'overall_avg_accuracy': np.float64(0.14984986791525792)}