<a target="_blank" href="https://colab.research.google.com/github/echosprint/TabularTransformer/blob/main/notebooks/supervised_training.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---


**for more details about the [TabularTransformer](https://github.com/echosprint/TabularTransformer) model**,
ckeck the online **[Documents](https://echosprint.github.io/TabularTransformer/)**

---

- This notebook provides a usage example of the
  [TabularTransformer](https://github.com/echosprint/TabularTransformer)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install tabular-transformer

In [1]:
import tabular_transformer as ttf
import torch

In [2]:
income_dataset_path = ttf.prepare_income_dataset()

more details see website: https://huggingface.co/datasets/scikit-learn/adult-census-income
/home/qiao/work/TabularTransformer/notebooks/data/income/income.csv already exists, skipping download.


In [3]:
categorical_cols = [
    'workclass', 'education',
    'marital.status', 'occupation',
    'relationship', 'race', 'sex',
    'native.country', 'income']

# all the rest columns are numerical, no need listed explicitly
numerical_cols = []


In [4]:
income_reader = ttf.DataReader(
    file_path=income_dataset_path,
    ensure_categorical_cols=categorical_cols,
    ensure_numerical_cols=numerical_cols,
    label='income',
)

In [5]:
split = income_reader.split_data({'test': 0.2, 'train': -1})

start reading file, it may take a while..
read file completed.
split: test, n_samples: 6512
/home/qiao/work/TabularTransformer/notebooks/data/income/income_test.csv *exists*, delete old split `test`
save split `test` at path: /home/qiao/work/TabularTransformer/notebooks/data/income/income_test.csv
split: train, n_samples: 26049
/home/qiao/work/TabularTransformer/notebooks/data/income/income_train.csv *exists*, delete old split `train`
save split `train` at path: /home/qiao/work/TabularTransformer/notebooks/data/income/income_train.csv


In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() \
    and torch.cuda.is_bf16_supported() else 'float16'

In [7]:
ts = ttf.TrainSettings(device=device, dtype=dtype)

hp = ttf.HyperParameters(dim=64, n_layers=6)

In [8]:
tp = ttf.TrainParameters(max_iters=3000, learning_rate=5e-4,
                         output_dim=1, loss_type='BINCE',
                         batch_size=128, eval_interval=100,
                         eval_iters=20, warmup_iters=100,
                         validate_split=0.2)

In [9]:
trainer = ttf.Trainer(hp=hp, ts=ts)

trainer.train(
    data_reader=income_reader(file_path=split['train']),
    tp=tp)

start reading file, it may take a while..
read file completed.
num parameter tensors: 62, with 356,464 parameters
Transformer num decayed parameter tensors: 43, with 323,648 parameters
Transformer num non-decayed parameter tensors: 13, with 832 parameters
Output num decayed parameter tensors: 5, with 31,872 parameters
Output num non-decayed parameter tensors: 1, with 112 parameters


  from .autonotebook import tqdm as notebook_tqdm


using fused AdamW: True
step 0: train loss 0.6931, val loss 0.6932
0 | epoch 0.0000 | loss 0.6931 |lr 0.000000e+00 | 549.59ms | mfu -100.00%
1 | epoch 0.0062 | loss 0.6932 |lr 5.000000e-06 | 8.87ms | mfu -100.00%
2 | epoch 0.0123 | loss 0.6931 |lr 1.000000e-05 | 7.72ms | mfu -100.00%
3 | epoch 0.0185 | loss 0.6930 |lr 1.500000e-05 | 7.49ms | mfu -100.00%
4 | epoch 0.0247 | loss 0.6930 |lr 2.000000e-05 | 7.22ms | mfu -100.00%
5 | epoch 0.0309 | loss 0.6929 |lr 2.500000e-05 | 8.07ms | mfu  0.16%
6 | epoch 0.0370 | loss 0.6926 |lr 3.000000e-05 | 7.56ms | mfu  0.16%
7 | epoch 0.0432 | loss 0.6923 |lr 3.500000e-05 | 7.33ms | mfu  0.16%
8 | epoch 0.0494 | loss 0.6920 |lr 4.000000e-05 | 7.35ms | mfu  0.16%
9 | epoch 0.0556 | loss 0.6917 |lr 4.500000e-05 | 7.01ms | mfu  0.16%
10 | epoch 0.0617 | loss 0.6917 |lr 5.000000e-05 | 6.97ms | mfu  0.16%
11 | epoch 0.0679 | loss 0.6918 |lr 5.500000e-05 | 7.29ms | mfu  0.17%
12 | epoch 0.0741 | loss 0.6900 |lr 6.000000e-05 | 7.38ms | mfu  0.17%
13 | epo

In [13]:
predictor = ttf.Predictor(checkpoint='out/ckpt.pt')

prediction = predictor.predict(
    data_reader=income_reader(file_path=split['test']),
    save_as="prediction_income.csv"
)


load checkpoint from out/ckpt.pt
start reading file, it may take a while..
read file completed.


100%|██████████| 7/7 [00:00<00:00, 441.91it/s]

binary cross entropy loss: 0.279975
auc score: 0.926880
f1 macro score: 0.817514
samples: 6512, accuracy: 0.8722
save prediction output to file: out/prediction_income.csv





In [14]:
prediction.head(3)

Unnamed: 0,prediction,probability
0,<=50K,0.02063
1,<=50K,0.474609
2,>50K,0.816406
