<a target="_blank" href="https://colab.research.google.com/github/huoyushequ/TabularTransformer/blob/main/income_analysis.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---

**See also** [RTDL](https://github.com/yandex-research/rtdl)
-- **other projects on tabular deep learning**.

---

- This notebook provides a usage example of the
  [rtdl_revisiting_models](https://github.com/yandex-research/rtdl-revisiting-models)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [None]:
%pip install git+https://github.com/huoyushequ/TabularTransformer.git

In [1]:
import pandas as pd
import os

import warnings
warnings.filterwarnings('ignore', category=UserWarning)
# https://huggingface.co/datasets/scikit-learn/adult-census-income
data_url = "hf://datasets/scikit-learn/adult-census-income/adult.csv"
fname = "income.csv"

data_cache_dir = os.path.join(os.getcwd(), 'data', fname.split('.')[0])
os.makedirs(data_cache_dir, exist_ok=True)
full_path = os.path.join(data_cache_dir, fname)

if not os.path.exists(full_path):
    print(f"Downloading {data_url} to {fname} ...")
    df = pd.read_csv(data_url)
    df.to_csv(full_path, index=False)
    print(f"save data at path: {full_path}")
else:
    df = pd.read_csv(full_path)
    print(f"{full_path} already exists, skipping download.")

/home/qiao/work/TabularTransformer/data/income/income.csv already exists, skipping download.


In [2]:
df.head(3)

Unnamed: 0,age,workclass,fnlwgt,education,education.num,marital.status,occupation,relationship,race,sex,capital.gain,capital.loss,hours.per.week,native.country,income
0,90,?,77053,HS-grad,9,Widowed,?,Not-in-family,White,Female,0,4356,40,United-States,<=50K
1,82,Private,132870,HS-grad,9,Widowed,Exec-managerial,Not-in-family,White,Female,0,4356,18,United-States,<=50K
2,66,?,186061,Some-college,10,Widowed,?,Unmarried,Black,Female,0,4356,40,United-States,<=50K


In [3]:
from tabular_transformer import DataReader, TrainSettings, TrainParameters, HyperParameters, Trainer

class IncomeDataReader(DataReader):
    ensure_categorical_cols = ['workclass', 'education', 'marital.status', 'occupation', 'relationship', 'race', 'sex', 'native.country', 'income']
    ensure_numerical_cols = ['age', 'fnlwgt', 'education.num', 'capital.gain', 'capital.loss', 'hours.per.week']

    def read_data_file(self, file_path):
        df = pd.read_csv(file_path)
        return df

In [4]:
income_reader = IncomeDataReader(full_path)
df = income_reader.read_data_file()
df.head(3)

Unnamed: 0,age,workclass,fnlwgt,education,education.num,marital.status,occupation,relationship,race,sex,capital.gain,capital.loss,hours.per.week,native.country,income
0,90,?,77053,HS-grad,9,Widowed,?,Not-in-family,White,Female,0,4356,40,United-States,<=50K
1,82,Private,132870,HS-grad,9,Widowed,Exec-managerial,Not-in-family,White,Female,0,4356,18,United-States,<=50K
2,66,?,186061,Some-college,10,Widowed,?,Unmarried,Black,Female,0,4356,40,United-States,<=50K


In [6]:
ts = TrainSettings()
tp = TrainParameters(train_epochs=10, warmup_iters=10)
hp = HyperParameters()

trainer = Trainer(hp=hp, ts=ts)
trainer.train(data_reader=income_reader, tp=tp)

load dataset from file: /home/qiao/work/TabularTransformer/data/income/income.csv
num parameter tensors: 62, with 356,400 parameters
Transformer num decayed parameter tensors: 43, with 323,584 parameters
Output num decayed parameter tensors: 5, with 31,872 parameters
Transformer num non-decayed parameter tensors: 13, with 832 parameters
Output num non-decayed parameter tensors: 1, with 112 parameters
using fused AdamW: True
step 0: train loss 0.6933, val loss 0.6933
0 | loss 0.6934 | lr 0.000000e+00 | 2121.54ms | mfu -100.00%
1 | loss 0.6933 | lr 5.000000e-05 | 14.80ms | mfu -100.00%
2 | loss 0.6932 | lr 1.000000e-04 | 13.67ms | mfu -100.00%
3 | loss 0.6925 | lr 1.500000e-04 | 13.33ms | mfu -100.00%
4 | loss 0.6917 | lr 2.000000e-04 | 13.48ms | mfu -100.00%
5 | loss 0.6902 | lr 2.500000e-04 | 13.14ms | mfu 0.10%
6 | loss 0.6883 | lr 3.000000e-04 | 13.56ms | mfu 0.10%
7 | loss 0.6880 | lr 3.500000e-04 | 13.22ms | mfu 0.10%
8 | loss 0.6833 | lr 4.000000e-04 | 13.47ms | mfu 0.10%
9 | loss

In [7]:
type(trainer.model)

NoneType

In [8]:
type(trainer.optimizer)

NoneType