<a target="_blank" href="https://colab.research.google.com/github/echosprint/TabularTransformer/blob/main/notebooks/classification.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---


**for more details about the [TabularTransformer](https://github.com/echosprint/TabularTransformer) model**,
ckeck the online **[Documents](https://echosprint.github.io/TabularTransformer/)**

---

- This notebook provides a usage example of the
  [TabularTransformer](https://github.com/echosprint/TabularTransformer)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install git+https://github.com/echosprint/TabularTransformer.git

In [10]:
import tabular_transformer as ttf
import pandas as pd
import torch

In [11]:
iris_dataset_path = ttf.prepare_iris_dataset()

more details see website: https://huggingface.co/datasets/scikit-learn/iris
/home/qiao/work/TabularTransformer/notebooks/data/iris/iris.csv already exists, skipping download.


In [12]:
class IrisDataReader(ttf.DataReader):
    ensure_categorical_cols = ['Id', 'Species']
    ensure_numerical_cols = [
        'SepalLengthCm', 'SepalWidthCm',
        'PetalLengthCm', 'PetalWidthCm'
    ]

    def read_data_file(self, file_path):
        df = pd.read_csv(file_path)
        return df

In [13]:
iris_reader = IrisDataReader(iris_dataset_path)
df = iris_reader.read_data_file()
print(df.head(3))

  Id  SepalLengthCm  SepalWidthCm  PetalLengthCm  PetalWidthCm      Species
0  1            5.1           3.5            1.4           0.2  Iris-setosa
1  2            4.9           3.0            1.4           0.2  Iris-setosa
2  3            4.7           3.2            1.3           0.2  Iris-setosa


In [14]:
split = iris_reader.split_data(
    {'train': 0.8, 'test': -1})
print(split)

split: train, n_samples: 120
/home/qiao/work/TabularTransformer/notebooks/data/iris/iris_train.csv *exists*, delete old split `train`
save split `train` at path: /home/qiao/work/TabularTransformer/notebooks/data/iris/iris_train.csv
split: test, n_samples: 30
/home/qiao/work/TabularTransformer/notebooks/data/iris/iris_test.csv *exists*, delete old split `test`
save split `test` at path: /home/qiao/work/TabularTransformer/notebooks/data/iris/iris_test.csv
{'train': PosixPath('/home/qiao/work/TabularTransformer/notebooks/data/iris/iris_train.csv'), 'test': PosixPath('/home/qiao/work/TabularTransformer/notebooks/data/iris/iris_test.csv')}


In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() \
    and torch.cuda.is_bf16_supported() else 'float16'

In [16]:
ts = ttf.TrainSettings(wandb_log=False,
                       device=device,
                       dtype=dtype,
                       )


hp = ttf.HyperParameters(dim=32,
                         n_layers=4,
                         n_heads=4,
                         output_forward_dim=4,
                         output_hidden_dim=64)

trainer = ttf.Trainer(hp=hp, ts=ts)

In [17]:
train_tp = ttf.TrainParameters(
    learning_rate=5e-4,
    train_epochs=100,
    loss_type='MULCE',
    batch_size=16,
    output_dim=3,
    eval_interval=10,
    eval_iters=2,
    warmup_iters=5,
    validate_split=0.2,
    output_checkpoint='iris_ckpt.pt')

trainer.train(
    data_reader=IrisDataReader(split['train']),
    tp=train_tp,
    resume=False)

load dataset from file: /home/qiao/work/TabularTransformer/notebooks/data/iris/iris_train.csv
num parameter tensors: 44, with 59,668 parameters
Transformer num decayed parameter tensors: 29, with 53,536 parameters
Transformer num non-decayed parameter tensors: 9, with 288 parameters
Output num decayed parameter tensors: 5, with 5,824 parameters
Output num non-decayed parameter tensors: 1, with 20 parameters
using fused AdamW: True
step 0: train loss 1.1016, val loss 1.1016
0 | loss 1.1016 | lr 0.000000e+00 | 26.76ms | mfu -100.00%
1 | loss 1.1016 | lr 1.000000e-04 | 8.63ms | mfu -100.00%
2 | loss 1.1016 | lr 2.000000e-04 | 8.94ms | mfu -100.00%
3 | loss 1.1016 | lr 3.000000e-04 | 8.51ms | mfu -100.00%
4 | loss 1.1016 | lr 4.000000e-04 | 8.53ms | mfu -100.00%
5 | loss 1.1011 | lr 5.000000e-04 | 8.54ms | mfu  0.00%
6 | loss 1.0996 | lr 4.999965e-04 | 8.60ms | mfu  0.00%
7 | loss 1.1006 | lr 4.999861e-04 | 9.30ms | mfu  0.00%
8 | loss 1.0986 | lr 4.999686e-04 | 8.59ms | mfu  0.00%
9 | los

In [18]:
predictor = ttf.Predictor(checkpoint='out/iris_ckpt.pt')
predictor.predict(data_reader=IrisDataReader(split['test']),
                  save_as="iris_predictions.csv")

load checkpoint from out/iris_ckpt.pt
cross entropy loss: 0.1503
auc score: 1.0000
f1 macro score: 0.9167
samples: 30, accuracy: 0.9333
save prediction output to file: out/iris_predictions.csv
