<a target="_blank" href="https://colab.research.google.com/github/huoyushequ/TabularTransformer/blob/main/income_analysis.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---

**See also** [RTDL](https://github.com/yandex-research/rtdl)
-- **other projects on tabular deep learning**.

---

- This notebook provides a usage example of the
  [rtdl_revisiting_models](https://github.com/yandex-research/rtdl-revisiting-models)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install git+https://github.com/huoyushequ/TabularTransformer.git

In [None]:
import pandas as pd
import os

import warnings
warnings.filterwarnings('ignore', category=UserWarning)
# https://huggingface.co/datasets/scikit-learn/adult-census-income
data_url = "hf://datasets/scikit-learn/adult-census-income/adult.csv"
fname = "income.csv"

data_cache_dir = os.path.join(os.getcwd(), 'data', fname.split('.')[0])
os.makedirs(data_cache_dir, exist_ok=True)
full_path = os.path.join(data_cache_dir, fname)

if not os.path.exists(full_path):
    print(f"Downloading {data_url} to {fname} ...")
    df = pd.read_csv(data_url)
    df.to_csv(full_path, index=False)
    print(f"save data at path: {full_path}")
else:
    df = pd.read_csv(full_path)
    print(f"{full_path} already exists, skipping download.")

In [None]:
df.head(3)

In [None]:
from tabular_transformer import DataReader, TrainSettings, TrainParameters, HyperParameters, Trainer

class IncomeDataReader(DataReader):
    ensure_categorical_cols = ['workclass', 'education', 'marital.status', 'occupation', 'relationship', 'race', 'sex', 'native.country', 'income']
    ensure_numerical_cols = ['age', 'fnlwgt', 'education.num', 'capital.gain', 'capital.loss', 'hours.per.week']

    def read_data_file(self, file_path):
        df = pd.read_csv(file_path)
        return df

In [None]:
income_reader = IncomeDataReader(full_path)
df = income_reader.read_data_file()
df.head(3)

In [None]:
ts = TrainSettings()
tp = TrainParameters(train_epochs=3, warmup_iters=10)
hp = HyperParameters(dim=1024, n_layers=12)

trainer = Trainer(hp=hp, ts=ts)
trainer.train(data_reader=income_reader, tp=tp)

In [None]:
type(trainer.model)

In [None]:
type(trainer.optimizer)

In [None]:
import torch

In [None]:
torch.cuda.empty_cache()

In [None]:
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())


In [None]:
import torch
import gc

# Assuming 'model' is your instance of nn.Module
# del trainer
gc.collect()  # Call Python garbage collector

torch.cuda.empty_cache()  # Clear unused memory cached by PyTorch


In [None]:
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

In [None]:
trainer