In [None]:
import torch
from sklearn.datasets import fetch_openml

from neural_blueprints.utils import Trainer, infer_types
from neural_blueprints.config.architectures import TabularBERTConfig
from neural_blueprints.config.utils import TrainerConfig
from neural_blueprints.architectures import TabularBERT
from neural_blueprints.datasets import MaskedTabularDataset, TabularSingleLabelDataset
from neural_blueprints.preprocess import TabularPreprocessor

import logging
logging.basicConfig(
    level=logging.DEBUG,  # or DEBUG if you want even more detail
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

In [None]:
data = fetch_openml(name="adult", version=2, as_frame=True)
X = data.data
y = data.target

data = X.copy()
data['income'] = y

dtypes = infer_types(data)
data = data.astype(dtypes)
data.head()

In [None]:
preprocessor = TabularPreprocessor()
data, discrete_features, continuous_features = preprocessor.run(data)

### Income Inference Accuracy

In [None]:
dataset = TabularSingleLabelDataset(
    data=data,
    label_column='income',              # Specify the label column for single-label classification
    discrete_features=discrete_features,
    continuous_features=continuous_features
)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [None]:
bert_config = TabularBERTConfig(
    input_cardinalities = dataset.cardinalities,
    output_cardinalities = [2],
    latent_dim = 64,
    encoder_layers=8,
    dropout_p = 0.2,
    normalization = "batchnorm1d",
    activation = "gelu",
    final_activation = None
)

model = TabularBERT(bert_config)
model.blueprint(batch_size=256)

In [None]:
model.show_weights(group_by="component", show_stats=True)

In [None]:
trainer = Trainer(
    model=model,
    config=TrainerConfig(
        criterion='cross_entropy',
        optimizer='adam',
        early_stopping_patience=5,
        learning_rate=1e-5,
        weight_decay=1e-5,
        save_weights_path="./models/bert_adult.pth",
        batch_size=256
    )
)
trainer.train(train_dataset, val_dataset, epochs=10, visualize=True)

In [None]:
trainer.model.show_weights(group_by="component", show_stats=True)

In [None]:
trainer.predict(val_dataset)

### Masked Dataset Inference Accuracy

In [None]:
# Create dataset
dataset = MaskedTabularDataset(
    data, 
    discrete_features, 
    continuous_features,
    mask_prob=0.35
)

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [None]:
bert_config = TabularBERTConfig(
    input_cardinalities = dataset.cardinalities,   
    latent_dim = 64,
    encoder_layers=8,
    dropout_p = 0.2,
    normalization = "layernorm",
    activation = "gelu",
    final_activation = None
)

model = TabularBERT(bert_config)
model.blueprint()

In [None]:
model.show_weights(group_by="component", show_stats=True)

In [None]:
trainer = Trainer(
    model=model,
    config=TrainerConfig(
        criterion='masked_reconstruction',
        optimizer='adam',
        early_stopping_patience=2,
        learning_rate=1e-3,
        weight_decay=1e-5,
        save_weights_path="./models/bert_adult.pth",
        batch_size=256
    )
)
trainer.train(train_dataset, val_dataset, epochs=5, visualize=True)

In [None]:
trainer.predict(val_dataset)