In [None]:
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_auc_score

import pandas as pd
import numpy as np
np.random.seed(0)


import os
import wget
from pathlib import Path

from matplotlib import pyplot as plt
%matplotlib inline

%load_ext autoreload
%autoreload 2

# Download census-income dataset

In [None]:
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
dataset_name = 'census-income'
out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')

In [None]:
out.parent.mkdir(parents=True, exist_ok=True)
if out.exists():
    print("File already exists.")
else:
    print("Downloading file...")
    wget.download(url, out.as_posix())

In [None]:
train = pd.read_csv(out)
train.columns = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
                 'j', 'k', 'l', 'm', 'n', 'o']

In [None]:
if "Set" not in train.columns:
    train["Set"] = np.random.choice(["train", "valid", "test"], p =[.8, .1, .1], size=(train.shape[0],))

train_indices = train[train.Set=="train"].index
valid_indices = train[train.Set=="valid"].index
test_indices = train[train.Set=="test"].index

train.drop(['Set'], axis=1, inplace=True)
train.fillna(-1, inplace=True)

In [None]:
train.head()

In [None]:
from lib.data import preprocess

In [None]:
split_indices = dict(
    train=train_indices,
    valid=valid_indices,
    test=test_indices
)

In [None]:
data = preprocess(train, target='o', split_indices=split_indices, quantile_transform=True)

In [None]:
from lib.model import NodeClassifier

In [None]:
data

In [None]:
model = NodeClassifier(layer_dim=128, input_dim=data['X_train'].shape[1], output_dim=2)
#cat_idxs=data['cat_idxs'], cat_dims=data['cat_dims'], cat_emb_dim=2)

In [None]:
model.fit(data['X_train'], data['y_train'],
          X_valid=data['X_valid'], y_valid=data['y_valid'],
          plot=True)

In [None]:
model.trainer.load_checkpoint(tag='best')

In [None]:
preds = model.predict(data['X_test'])

In [None]:
from sklearn.metrics import roc_auc_score

In [None]:
roc_auc_score(data['y_test'], preds[:, 1])