<a target="_blank" href="https://colab.research.google.com/github/echosprint/TabularTransformer/blob/main/notebooks/classification.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---


**for more details about the [TabularTransformer](https://github.com/echosprint/TabularTransformer) model**,
ckeck the online **[Documents](https://echosprint.github.io/TabularTransformer/)**

---

- This notebook provides a usage example of the
  [TabularTransformer](https://github.com/echosprint/TabularTransformer)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install tabular-transformer

In [18]:
import tabular_transformer as ttf
import pandas as pd
import torch

In [None]:
iris_dataset_path = ttf.prepare_iris_dataset()

In [20]:
categorical_cols = ['Id', 'Species']
numerical_cols = ['SepalLengthCm', 'SepalWidthCm',
        'PetalLengthCm', 'PetalWidthCm']

In [21]:
iris_reader = ttf.DataReader(
    file_path=iris_dataset_path,
    ensure_categorical_cols=categorical_cols,
    ensure_numerical_cols=numerical_cols,
    label='Species',
    header=True,
    id='Id',
)

In [None]:
df = iris_reader.read().to_pandas()
print(df.head(3))

In [None]:
split = iris_reader.split_data(
    {'train': 0.8, 'test': -1})
print(split)

In [24]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() \
    and torch.cuda.is_bf16_supported() else 'float16'

In [25]:
ts = ttf.TrainSettings(wandb_log=False,
                       device=device,
                       dtype=dtype,
                       )


hp = ttf.HyperParameters(dim=32,
                         n_layers=4,
                         n_heads=4,
                         output_forward_dim=4,
                         output_hidden_dim=64)

trainer = ttf.Trainer(hp=hp, ts=ts)

In [None]:
train_tp = ttf.TrainParameters(
    learning_rate=5e-4,
    max_iters=200,
    loss_type='MULCE',
    batch_size=16,
    output_dim=3,
    eval_interval=10,
    eval_iters=2,
    warmup_iters=5,
    validate_split=0.2,
    output_checkpoint='iris_ckpt.pt')

trainer.train(
    data_reader=iris_reader(file_path=split['train']),
    tp=train_tp,
    resume=False)

In [None]:
predictor = ttf.Predictor(checkpoint='out/iris_ckpt.pt')
prediction = predictor.predict(data_reader=iris_reader(file_path=split['test']),
                  save_as="iris_predictions.csv")

In [None]:
prediction.head(3)