In [1]:
import torch
import numpy as np
import sys

sys.path.insert(0, './src')
sys.path.insert(0, './src/include')

from datasets.dataset import transform_dataset, kfold_dataset
from sklearn.metrics import roc_auc_score
from R2Ntab import R2Ntab
from DRNet import train as train_drnet, DRNet

## Prepare data

In [2]:
name = 'adult'
X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare', negations=False, labels='binary')
datasets = kfold_dataset(X, Y, shuffle=1)
X_train, X_test, Y_train, Y_test = datasets[0]
train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()), torch.Tensor(Y_train))
test_set = torch.utils.data.TensorDataset(torch.Tensor(X_test.to_numpy()), torch.Tensor(Y_test))

## Run DR-Net

In [3]:
drnet = DRNet(train_set[:][0].size(1), 20, 1)
train_drnet(drnet, train_set, test_set, device='cpu', epochs=1000, batch_size=400)

Epoch: 100%|██████████| 1000/1000 [03:11<00:00,  5.22it/s, loss=0.552, epoch acc


In [7]:
auc_drnet = roc_auc_score(drnet.predict(np.array(X_test)), Y_test)
rules_drnet = drnet.get_rules(X_headers)
print(rules_drnet)
print(f'AUC: {auc_drnet}, num rules: {len(rules_drnet)}, num conditions: {sum(map(len, rules_drnet))}')

[['NOT age<=26.0', 'NOT education==7th-8th', 'NOT education==9th', 'NOT marital-status==Divorced', 'marital-status==Married-civ-spouse', 'NOT occupation==Farming-fishing', 'NOT occupation==Handlers-cleaners', 'NOT occupation==Machine-op-inspct', 'NOT occupation==Other-service', 'NOT occupation==Sales', 'NOT relationship==Not-in-family', 'NOT relationship==Own-child', 'NOT capital-gain<=0.0', 'NOT hours-per-week<=25.0', 'NOT hours-per-week<=40.0', 'NOT hours-per-week<=42.0', 'NOT native-country==Poland'], ['NOT age<=26.0', 'NOT workclass==Local-gov', 'NOT education==Assoc-voc', 'NOT education==Bachelors', 'NOT education==Some-college', 'NOT education_num<=7.0', 'marital-status==Married-civ-spouse', 'NOT occupation==Adm-clerical', 'NOT occupation==Craft-repair', 'NOT occupation==Farming-fishing', 'NOT occupation==Handlers-cleaners', 'NOT occupation==Machine-op-inspct', 'NOT occupation==Other-service', 'NOT occupation==Priv-house-serv', 'NOT occupation==Protective-serv', 'NOT occupation==

## Run R2N-Tab

In [5]:
# Train R2N-tab
r2ntab = R2Ntab(len(X_headers), 20, 1)
r2ntab.fit(train_set, epochs=1000, batch_size=400, cancel_lam=1e-2)

100%|█████████| 1000/1000 [03:42<00:00,  4.50it/s]


In [6]:
# Get accuracy and the rule net
Y_pred = r2ntab.predict(X_test)
rules = r2ntab.extract_rules(X_headers, print_rules=True)
print(f'AUC: {r2ntab.score(Y_pred, Y_test, metric="auc")}, num rules: {len(rules)}, num conditions: {sum(map(len, rules))}')

Rulelist:
if [ not age<=26.0 && not education_num<=7.0 && not marital-status==Divorced && marital-status==Married-civ-spouse && not marital-status==Never-married && not occupation==Farming-fishing && not occupation==Machine-op-inspct && not occupation==Other-service && not relationship==Not-in-family && not capital-gain<=0.0 && not hours-per-week<=25.0 && native-country==United-States ]:
  prediction = true
else if [ not age<=26.0 && not age<=30.0 && marital-status==Married-civ-spouse && occupation==Exec-managerial && not capital-loss<=0.0 && not hours-per-week<=25.0 && hours-per-week<=42.0 ]:
  prediction = true
else if [ not age<=26.0 && not education_num<=10.0 && not education_num<=11.0 && not marital-status==Married-civ-spouse && not occupation==Other-service && not capital-gain<=0.0 && not hours-per-week<=25.0 ]:
  prediction = true
else if [ not age<=30.0 && age<=57.0 && not education_num<=13.0 && not marital-status==Divorced && marital-status==Married-civ-spouse && not marital-s