<a target="_blank" href="https://colab.research.google.com/github/echosprint/TabularTransformer/blob/main/notebooks/higgs_classification.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---


**for more details about the [TabularTransformer](https://github.com/echosprint/TabularTransformer) model**,
ckeck the online **[Documents](https://echosprint.github.io/TabularTransformer/)**

---

- This notebook provides a usage example of the
  [TabularTransformer](https://github.com/echosprint/TabularTransformer)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install git+https://github.com/echosprint/TabularTransformer.git
%pip install pyarrow

In [None]:
from pathlib import Path
import torch
import pandas as pd
import tabular_transformer as ttf
from pathlib import Path
import pyarrow as pa
from pyarrow import csv, parquet, compute

In [None]:
higgs_path = ttf.prepare_higgs_dataset()

In [None]:
higgs_cols = ["label", "lepton  pT", "lepton  eta", "lepton  phi",
              "missing energy magnitude", "missing energy phi",
              "jet 1 pt", "jet 1 eta", "jet 1 phi", "jet 1 b-tag",
              "jet 2 pt", "jet 2 eta", "jet 2 phi", "jet 2 b-tag",
              "jet 3 pt", "jet 3 eta", "jet 3 phi", "jet 3 b-tag",
              "jet 4 pt", "jet 4 eta", "jet 4 phi", "jet 4 b-tag",
              "m_jj", "m_jjj", "m_lv", "m_jlv", "m_bb", "m_wbb", "m_wwbb"]

In [None]:
class HiggsDataReader(ttf.DataReader):

    ensure_categorical_cols = ['label']
    ensure_numerical_cols = [col for col in higgs_cols if col != 'label']

    def read_data_file(self, file_path):
        table = parquet.read_table(file_path)
        print(f"complete dataset loading: {file_path}")
        df = table.to_pandas()
        return df

In [None]:
def load_and_preprocess(file_path) -> pa.Table:
    print(f"load dataset {file_path}, it may take a few minutes.")
    table = csv.read_csv(
        file_path, read_options=csv.ReadOptions(column_names=higgs_cols))
    print("load dataset complete.")

    label_column = compute.cast(table['label'], pa.int32())
    table = table.set_column(table.column_names.index(
        'label'), 'label', label_column)

    cols = [col for col in table.column_names if col != 'label']
    table = table.select(cols + ['label'])
    return table

In [None]:
def split_data(file_path) -> torch.Dict[str, Path]:
    split = {'train': 10_500_000,
             'test': 500_000}

    file_path = Path(file_path)

    split_path = {'train': 'higgs_train.parquet',
                  'test': 'higgs_test.parquet', }

    split_path = {sp: file_path.with_name(fn) for sp, fn in split_path.items()}

    if all(path.exists() for path in split_path.values()):
        print("split already exists, skip split.")
        return split_path

    table = load_and_preprocess(file_path)

    train_table = table.slice(0, split['train'])
    test_table = table.slice(split['train'], split['test'])
    print(f"save split train on disk..")
    parquet.write_table(train_table, split_path['train'])
    print(f"save split test on disk..")
    parquet.write_table(test_table, split_path['test'])
    print("save split complete")
    return split_path

In [None]:
split = split_data(higgs_path)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() \
    and torch.cuda.is_bf16_supported() else 'float16'

In [None]:
ts = ttf.TrainSettings(device=device,
                       dtype=dtype,
                       unk_ratio_default=0,
                       wandb_log=True)

tp = ttf.TrainParameters(train_epochs=5, learning_rate=5e-4,
                         output_dim=1, loss_type='BINCE',
                         batch_size=1024, eval_interval=1000,
                         eval_iters=100, warmup_iters=1000,
                         validate_split=0.2, output_checkpoint='higgs_r2_ckpt.pt')

hp = ttf.HyperParameters(dim=768, n_layers=12,
                         n_heads=16,
                         output_forward_dim=32,
                         output_hidden_dim=256)

In [None]:
trainer = ttf.Trainer(hp=hp, ts=ts)

trainer.train(data_reader=HiggsDataReader(split['train']), tp=tp)

In [None]:
predictor = ttf.Predictor(checkpoint='out/higgs_r2_ckpt.pt')

predictor.predict(
    data_reader=HiggsDataReader(split['test']),
    save_as="prediction_higgs.csv",
    has_truth=True
)