In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_auc_score, accuracy_score

import pandas as pd
import numpy as np
np.random.seed(0)

import scipy

import os
from pathlib import Path

import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')

In [3]:
dataset_name = 'census-income'
out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')

In [4]:
train = pd.read_csv(out)
target = ' <=50K'
if "Set" not in train.columns:
    train["Set"] = np.random.choice(["train", "valid", "test"], p =[.8, .1, .1], size=(train.shape[0],))

train_indices = train[train.Set=="train"].index
valid_indices = train[train.Set=="valid"].index
test_indices = train[train.Set=="test"].index

In [5]:
nunique = train.nunique()
types = train.dtypes

categorical_columns = []
categorical_dims =  {}
for col in train.columns:
    if types[col] == 'object' or nunique[col] < 200:
        print(col, train[col].nunique())
        l_enc = LabelEncoder()
        train[col] = train[col].fillna("VV_likely")
        train[col] = l_enc.fit_transform(train[col].values)
        categorical_columns.append(col)
        categorical_dims[col] = len(l_enc.classes_)
    else:
        train.fillna(train.loc[train_indices, col].mean(), inplace=True)

39 73
 State-gov 9
 Bachelors 16
 13 16
 Never-married 7
 Adm-clerical 15
 Not-in-family 6
 White 5
 Male 2
 2174 119
 0 92
 40 94
 United-States 42
 <=50K 2
Set 3


In [6]:
unused_feat = ['Set']

features = [ col for col in train.columns if col not in unused_feat+[target]] 

cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]

cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]

num_features = [i for i in range(len(features)) if i not in cat_idxs]

In [7]:
X_train = train[features].values[train_indices].astype(float)
y_train = train[target].values[train_indices]

X_valid = train[features].values[valid_indices].astype(float)
y_valid = train[target].values[valid_indices]

X_test = train[features].values[test_indices].astype(float)
y_test = train[target].values[test_indices]

In [8]:
mean = X_train[:, num_features].mean(axis=0)
std = X_train[:, num_features].std(axis=0)

X_train[:, num_features] = (X_train[:, num_features].astype(float) - mean) / std
X_valid[:, num_features] = (X_valid[:, num_features].astype(float) - mean) / std
X_test[:, num_features] = (X_test[:, num_features].astype(float) - mean) / std

In [9]:
from tabr.model import TabRClassifier
clf = TabRClassifier(
    cat_indices=cat_idxs,
    cat_cardinalities=cat_dims,
    device_name="cuda",
    d_main=96,
    context_size=96,
)

In [10]:
clf.fit(X_train, y_train, eval_set=[(X_test, y_test), (X_valid, y_valid)], max_epochs=100, batch_size=1024)

 epochs:   0%|          | 0/100 [00:00<?, ?it/s]

 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.8645661605760337}
{'val_1_auc': 0.8629436783887046}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.8832279758589148}
{'val_1_auc': 0.8838786286430265}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.8915904805220688}
{'val_1_auc': 0.893358435243252}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.8989350122391827}
{'val_1_auc': 0.9042659396062538}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9050111660185096}
{'val_1_auc': 0.9105398189167823}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9093549260794244}
{'val_1_auc': 0.9149891217954045}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9128258458833931}
{'val_1_auc': 0.9182366014826747}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.915405528328981}
{'val_1_auc': 0.9210738844246699}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9171593853323998}
{'val_1_auc': 0.9236490676281253}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9195183025607743}
{'val_1_auc': 0.9255056074427803}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9209256808205109}
{'val_1_auc': 0.9266248585620314}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9220586356505163}
{'val_1_auc': 0.9280435000330288}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.923345410689738}
{'val_1_auc': 0.9290301031138728}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9243807586760219}
{'val_1_auc': 0.9303938741111517}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9259031188197238}
{'val_1_auc': 0.930466324445382}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9263850206711876}
{'val_1_auc': 0.9314129732977899}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9275414829085817}
{'val_1_auc': 0.9309287280491468}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.92783174828677}
{'val_1_auc': 0.9322243104965618}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9284536725213737}
{'val_1_auc': 0.9324576431906275}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9290750857253824}
{'val_1_auc': 0.9325316916939954}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9288476771104286}
{'val_1_auc': 0.9326393017492494}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9290342032777502}
{'val_1_auc': 0.9326185255504629}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9299699002979309}
{'val_1_auc': 0.9330526948327994}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9301921986069306}
{'val_1_auc': 0.9330649474628532}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9303950777533052}
{'val_1_auc': 0.9332647186050328}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9301727794443054}
{'val_1_auc': 0.9329099250565219}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9306434386226704}
{'val_1_auc': 0.9331416595814502}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9306332180107624}
{'val_1_auc': 0.9334794059924948}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9312806937751363}
{'val_1_auc': 0.933394170305165}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9312444106028628}
{'val_1_auc': 0.9333542160767293}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9313823888636213}
{'val_1_auc': 0.9327612953267403}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9313415064159891}
{'val_1_auc': 0.9329594682997824}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9318939304896184}
{'val_1_auc': 0.9329930298516685}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9322036150304319}
{'val_1_auc': 0.9336109885848105}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9317319337908762}
{'val_1_auc': 0.9334879295612282}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9316660108440693}
{'val_1_auc': 0.9342134983496241}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9320487727600252}
{'val_1_auc': 0.9332402133449254}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.931908750376885}
{'val_1_auc': 0.9332194371461386}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9321448465119606}
{'val_1_auc': 0.9332609895437122}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9319741622930965}
{'val_1_auc': 0.9333046728334686}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9325833107628154}
{'val_1_auc': 0.9333840485672946}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9314968597169913}
{'val_1_auc': 0.933264185881987}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9323610124538155}
{'val_1_auc': 0.9334458444406089}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.931589356254759}
{'val_1_auc': 0.9329392248240416}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.9318898422448553}
{'val_1_auc': 0.9334772751003118}


 batches:   0%|          | 0/25 [00:00<?, ?it/s]

{'val_0_auc': 0.931427870586612}
{'val_1_auc': 0.9333526179075917}

Early stopping occurred at epoch 45 with best_epoch = 35 and best_val_1_auc = 0.93421




In [11]:
# preds_train = clf.predict(X_train)
# train_auc = roc_auc_score(y_score=preds_train[:,1], y_true=y_train)

preds = clf.predict(X_test)
test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)

preds_valid = clf.predict(X_valid)
valid_auc = roc_auc_score(y_score=preds_valid[:,1], y_true=y_valid)

# print(f"FINAL TRAIN SCORE FOR {dataset_name} : {train_auc}")
print(f"FINAL VALID SCORE FOR {dataset_name} : {valid_auc}")
print(f"FINAL TEST SCORE FOR {dataset_name} : {test_auc}")

FINAL VALID SCORE FOR census-income : 0.9342134983496241
FINAL TEST SCORE FOR census-income : 0.9316660108440693


# Catboost

In [12]:
from catboost import CatBoostClassifier

In [13]:
df_X_train = pd.DataFrame(X_train)
df_X_valid = pd.DataFrame(X_valid)
df_X_test = pd.DataFrame(X_test)

In [14]:
cat_idxs

[0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]

In [15]:
for cat_id in cat_idxs:
    df_X_train[cat_id] = df_X_train[cat_id].astype(int).astype(str)
    df_X_valid[cat_id] = df_X_valid[cat_id].astype(int).astype(str)
    df_X_test[cat_id] = df_X_test[cat_id].astype(int).astype(str)

In [16]:
clf = CatBoostClassifier()
clf.fit(
    df_X_train, y_train,
    eval_set=[(df_X_valid, y_valid)],
    cat_features=cat_idxs,
    early_stopping_rounds=50,
)

Learning rate set to 0.070909
0:	learn: 0.6242844	test: 0.6223686	best: 0.6223686 (0)	total: 109ms	remaining: 1m 49s
1:	learn: 0.5742973	test: 0.5707437	best: 0.5707437 (1)	total: 131ms	remaining: 1m 5s
2:	learn: 0.5248284	test: 0.5205044	best: 0.5205044 (2)	total: 158ms	remaining: 52.4s
3:	learn: 0.4870848	test: 0.4824703	best: 0.4824703 (3)	total: 177ms	remaining: 44.2s
4:	learn: 0.4565065	test: 0.4517592	best: 0.4517592 (4)	total: 210ms	remaining: 41.9s
5:	learn: 0.4296383	test: 0.4247726	best: 0.4247726 (5)	total: 227ms	remaining: 37.7s
6:	learn: 0.4095931	test: 0.4043557	best: 0.4043557 (6)	total: 246ms	remaining: 35s
7:	learn: 0.3924569	test: 0.3869281	best: 0.3869281 (7)	total: 264ms	remaining: 32.7s
8:	learn: 0.3792239	test: 0.3736618	best: 0.3736618 (8)	total: 288ms	remaining: 31.7s
9:	learn: 0.3660921	test: 0.3608975	best: 0.3608975 (9)	total: 310ms	remaining: 30.7s
10:	learn: 0.3568991	test: 0.3515171	best: 0.3515171 (10)	total: 328ms	remaining: 29.5s
11:	learn: 0.3492019	te

<catboost.core.CatBoostClassifier at 0x7fc92f6e6a10>

In [17]:
roc_auc_score(y_valid, clf.predict_proba(df_X_valid)[:, 1])

0.9349715632438146

In [18]:
roc_auc_score(y_test, clf.predict_proba(df_X_test)[:, 1])

0.9325201984842832