<a target="_blank" href="https://colab.research.google.com/github/echosprint/TabularTransformer/blob/main/notebooks/supervised_training.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---


**for more details about the [TabularTransformer](https://github.com/echosprint/TabularTransformer) model**,
ckeck the online **[Documents](https://echosprint.github.io/TabularTransformer/)**

---

- This notebook provides a usage example of the
  [TabularTransformer](https://github.com/echosprint/TabularTransformer)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install git+https://github.com/echosprint/TabularTransformer.git

In [None]:
import tabular_transformer as ttf
import torch

In [None]:
income_dataset_path = ttf.prepare_income_dataset()

In [None]:
categorical_cols = [
    'workclass', 'education',
    'marital.status', 'occupation',
    'relationship', 'race', 'sex',
    'native.country', 'income']

numerical_cols = [
    'age', 'fnlwgt', 'education.num',
    'capital.gain', 'capital.loss',
    'hours.per.week']

income_reader = ttf.DataReader(
    file_path=income_dataset_path,
    ensure_categorical_cols=categorical_cols,
    ensure_numerical_cols=numerical_cols,
    label='income',
)

In [None]:
split = income_reader.split_data({'test': 0.2, 'train': -1})

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() \
    and torch.cuda.is_bf16_supported() else 'float16'

In [None]:
ts = ttf.TrainSettings(device=device, dtype=dtype)

tp = ttf.TrainParameters(max_iters=3000, learning_rate=5e-4,
                         batch_size=128, eval_interval=100,
                         eval_iters=20, warmup_iters=100,
                         validate_split=0.2, output_checkpoint='ckpt.pt')

hp = ttf.HyperParameters(dim=64, n_layers=6)

In [None]:
trainer = ttf.Trainer(hp=hp, ts=ts)

trainer.train(
    data_reader=income_reader(file_path=split['train']),
    tp=tp)

In [None]:
predictor = ttf.Predictor(checkpoint='out/ckpt.pt')

predictor.predict(
    data_reader=income_reader(file_path=split['test']),
    save_as="prediction_income.csv"
)