In [None]:
import torch
from sklearn.datasets import fetch_openml

from neural_blueprints.architectures import VariationalAutoEncoder
from neural_blueprints.config.architectures import AutoEncoderConfig
from neural_blueprints.config.components.composite import EncoderConfig, DecoderConfig
from neural_blueprints.config.components.core import DenseLayerConfig
from neural_blueprints.config.utils import TrainerConfig
from neural_blueprints.config.components.composite.projections.input import TabularInputProjectionConfig
from neural_blueprints.config.components.composite.projections.output import TabularOutputProjectionConfig
from neural_blueprints.utils import Trainer, infer_types
from neural_blueprints.preprocess import TabularPreprocessor
from neural_blueprints.datasets import TabularDataset, MaskedTabularDataset, TabularSingleLabelDataset

import logging
logging.basicConfig(
    level=logging.DEBUG,  # or DEBUG if you want even more detail
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

In [None]:
data = fetch_openml(name="adult", version=2, as_frame=True)
X = data.data
y = data.target

data = X.copy()
data['income'] = y

dtypes = infer_types(data)
data = data.astype(dtypes)
data.head()

In [None]:
preprocessor = TabularPreprocessor()
data, discrete_features, continuous_features = preprocessor.run(data)
data.head()

### Income Inference Accuracy

In [None]:
dataset = TabularSingleLabelDataset(
    data=data,
    label_column='income',              # Specify the label column for single-label classification
    discrete_features=discrete_features,
    continuous_features=continuous_features
)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [None]:
latent_dim = 64

vae_config = AutoEncoderConfig(
    input_projection=TabularInputProjectionConfig(
        cardinalities=dataset.cardinalities,
        hidden_dims=[128],
        output_dim=[len(dataset.cardinalities)*64],
        normalization="layernorm",
        activation="gelu",
        dropout_p=0.2
    ),
    output_projection=TabularOutputProjectionConfig(
        output_cardinalities=[2],
        input_dim=[latent_dim*8],
        hidden_dims=[],
        normalization="layernorm",
        activation="gelu",
        dropout_p=0.2
    ),
    encoder_config=EncoderConfig(
        normalization="layernorm",
        activation="gelu",
        dropout_p=0.2,
        layer_configs=[
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*64, output_dim=len(dataset.cardinalities)*32),
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*32, output_dim=len(dataset.cardinalities)*16),
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*16, output_dim=latent_dim*2)
        ]
    ),
    decoder_config=DecoderConfig(
        normalization="layernorm",
        activation="gelu",
        dropout_p=0.2,
        layer_configs=[
            DenseLayerConfig(input_dim=latent_dim, output_dim=latent_dim*2),
            DenseLayerConfig(input_dim=latent_dim*2, output_dim=latent_dim*4),
            DenseLayerConfig(input_dim=latent_dim*4, output_dim=latent_dim*8)
        ]
    )
)

model = VariationalAutoEncoder(vae_config)
model.blueprint(batch_size=256)

In [None]:
trainer = Trainer(
    config=TrainerConfig(
        learning_rate=1e-3,
        weight_decay=1e-5,
        batch_size=256,
        early_stopping_patience=2,
        save_weights_path="./models/vae_adult.pth",
        criterion="vae_cross_entropy",
        optimizer='adam'
    ),
    model=model
)
trainer.train(train_dataset=train_dataset, val_dataset=val_dataset, epochs=5)

In [None]:
trainer.predict(val_dataset)

### Masked Dataset Inference Accuracy

In [None]:
dataset = TabularDataset(
    data = data,
    discrete_features = discrete_features,
    continuous_features = continuous_features
)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [None]:
latent_dim = 64

vae_config = AutoEncoderConfig(
    input_projection=TabularInputProjectionConfig(
        cardinalities=dataset.cardinalities,
        hidden_dims=[128],
        output_dim=[len(dataset.cardinalities)*64],
        normalization="layernorm",
        activation="gelu",
        dropout_p=0.2
    ),
    output_projection=TabularOutputProjectionConfig(
        output_cardinalities=dataset.cardinalities,
        input_dim=[latent_dim*8],
        hidden_dims=[],
        normalization="layernorm",
        activation="gelu",
        dropout_p=0.2
    ),
    encoder_config=EncoderConfig(
        normalization="layernorm",
        activation="gelu",
        dropout_p=0.2,
        layer_configs=[
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*64, output_dim=len(dataset.cardinalities)*32),
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*32, output_dim=len(dataset.cardinalities)*16),
            DenseLayerConfig(input_dim=len(dataset.cardinalities)*16, output_dim=latent_dim*2)
        ]
    ),
    decoder_config=DecoderConfig(
        normalization="layernorm",
        activation="gelu",
        dropout_p=0.2,
        layer_configs=[
            DenseLayerConfig(input_dim=latent_dim, output_dim=latent_dim*2),
            DenseLayerConfig(input_dim=latent_dim*2, output_dim=latent_dim*4),
            DenseLayerConfig(input_dim=latent_dim*4, output_dim=latent_dim*8)
        ]
    )
)

model = VariationalAutoEncoder(vae_config)
model.blueprint(batch_size=256)

In [None]:
trainer = Trainer(
    config=TrainerConfig(
        learning_rate=1e-3,
        weight_decay=1e-5,
        batch_size=256,
        early_stopping_patience=2,
        save_weights_path="./models/vae_adult.pth",
        criterion="vae_reconstruction",
        optimizer='adam'
    ),
    model=model
)
trainer.train(train_dataset=train_dataset, val_dataset=val_dataset, epochs=5)

In [None]:
test_dataset = MaskedTabularDataset(
    data=data,
    discrete_features=discrete_features,
    continuous_features=continuous_features,
    mask_prob=0.35
)

trainer = Trainer(
    config=TrainerConfig(
        learning_rate=1e-3,
        weight_decay=1e-5,
        batch_size=128,
        early_stopping_patience=2,
        save_weights_path="./models/vae_adult.pth",
        criterion="vae_masked_reconstruction",
        optimizer='adam'
    ),
    model=model
)
trainer.predict(test_dataset=test_dataset)