<a target="_blank" href="https://colab.research.google.com/github/echosprint/TabularTransformer/blob/main/income_analysis.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---


**for more details about the [TabularTransformer](https://github.com/echosprint/TabularTransformer) model**,
ckeck the online **[Documents](https://echosprint.github.io/TabularTransformer/)**

---

- This notebook provides a usage example of the
  [TabularTransformer](https://github.com/echosprint/TabularTransformer)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install git+https://github.com/echosprint/TabularTransformer.git

In [None]:
import tabular_transformer as ttf
import pandas as pd
import torch

In [None]:
fish_dataset_path = ttf.prepare_fish_dataset()

In [None]:
class FishDataReader(ttf.DataReader):
    ensure_categorical_cols = ['Species']

    ensure_numerical_cols = ['Weight', 'Length1',
                             'Length2', 'Length3',
                             'Height', 'Width']

    def read_data_file(self, file_path):
        df = pd.read_csv(file_path)
        return df

In [None]:
fish_reader = FishDataReader(fish_dataset_path)
df = fish_reader.read_data_file()
print(df.head(3))

In [None]:
split = fish_reader.split_data(
    {'train': 0.8, 'test': -1})
print(split)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() \
    and torch.cuda.is_bf16_supported() else 'float16'

In [None]:
ts = ttf.TrainSettings(wandb_log=False,
                       device=device,
                       dtype=dtype,
                       )


hp = ttf.HyperParameters(dim=32,
                         n_layers=4,
                         n_heads=4,
                         output_forward_dim=4,
                         output_hidden_dim=64)

trainer = ttf.Trainer(hp=hp, ts=ts)

In [None]:
train_tp = ttf.TrainParameters(
    learning_rate=5e-4,
    lr_scheduler='constant',
    train_epochs=50,
    loss_type='MSE',
    batch_size=16,
    output_dim=1,
    eval_interval=10,
    eval_iters=2,
    warmup_iters=5,
    validate_split=0.23,
    output_checkpoint='fish_ckpt.pt')

trainer.train(
    data_reader=FishDataReader(split['train']),
    tp=train_tp,
    resume=False)

In [None]:
predictor = ttf.Predictor(checkpoint='out/fish_ckpt.pt')
predictor.predict(data_reader=FishDataReader(split['test']),
                  save_as="fish_predictions.csv")