[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ogunlao/saint/blob/main/notebooks/Income_Dataset.ipynb)

N.B: This notebook shows usage for the previous version of saint. The version is kept in "saint-orig"

# Cloning the repo & installing requirements

In [None]:
!git clone -b saint_orig --single-branch https://github.com/ogunlao/saint.git

In [None]:
import pandas as pd
import numpy as np

In [None]:
!pip3 install -r 'saint/requirements.txt' 

#Reading the data

In [None]:
# download income dataset (from kaggle) and unzip
# save in 'data' directory

!unzip income.zip -d data

In [None]:
train = pd.read_csv('/content/data/train.csv')

test = pd.read_csv('/content/data/test.csv')

In [None]:
train.columns

# custom preprocessing

In [None]:
train_y = train[['income_>50K']]

train = train.drop(columns='income_>50K')

In [None]:
#concat train and test to preprocess
df = pd.concat([train, test])


In [None]:
from saint.src.dataset import preprocess   

In [None]:
processed_data, train_y, no_num, no_cat, cats = preprocess(df, train_y, cls_token_idx=0)

In [None]:
# This variables will need to be added to the config files in "configs/data/bank_*" before training

print('no of numerical columns: ', no_num)
print('no of categorical columns: ', no_cat)

print('list of categories in each categorical column: ', cats)

# splitting dataset

In [None]:
from saint.src.dataset import generate_splits

In [None]:
train = processed_data.iloc[:len(train)]
test = processed_data.iloc[len(train):]

train_indices, val_indices = generate_splits(dataset_size=len(train),
                                             num_supervised_train_data = 'all'
                                             validation_split=0.25, 
                                             test_split = 0,
                                             random_seed=1234)

x_train, y_train = train.iloc[train_indices], train_y.iloc[train_indices]
x_val, y_val = train.iloc[val_indices], train_y.iloc[val_indices]

In [None]:
y_train.values[32967]

# Saving csv files

In [None]:
x_train.to_csv('/content/saint/data/train.csv', index=False)
y_train.to_csv('/content/saint/data/train_y.csv', index=False)
x_val.to_csv('/content/saint/data/val.csv', index=False)
y_val.to_csv('/content/saint/data/val_y.csv', index=False)


In [None]:
%cd '/content/saint/'

In [None]:
# run this cell to train saint model using config

!python main.py

In [None]:
!ls /content/saint/checkpoints/lightning_logs/version_0/checkpoints/