In [1]:
import torch
import numpy as np

from datasets.dataset import transform_dataset, kfold_dataset
from R2Ntab import train as train_r2ntab, R2Ntab
from DRNet import train as train_drnet, DRNet

import sys

In [2]:
# Read datasets
name = 'adult'
X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare', negations=False, labels='binary')
datasets = kfold_dataset(X, Y, shuffle=1)
X_train, X_test, Y_train, Y_test = datasets[0]
train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()), torch.Tensor(Y_train))
test_set = torch.utils.data.TensorDataset(torch.Tensor(X_test.to_numpy()), torch.Tensor(Y_test))

In [3]:
# Train DR-Net
# Default learning rate (1e-2), and_lam (1e-2), and and_lam (1e-5) usually work the best. A large epochs number is necessary for a sparse rule set i.e 10000 epochs.
net = DRNet(train_set[:][0].size(1), 50, 1)
train_drnet(net, train_set, test_set=test_set, device='cpu', lr=1e-2, epochs=1000, batch_size=400,
      and_lam=1e-2, or_lam=1e-5, num_alter=500)

Epoch: 100%|██████████| 1000/1000 [07:33<00:00,  2.21it/s, loss=0.544, epoch accu=0.831, test accu=0.833, num rules=14, sparsity=0.883]


In [4]:
# Get accuracy and the rule net
accu = (net.predict(np.array(X_test)) == Y_test).mean()
rules = net.get_rules(X_headers)
print(f'Accuracy: {accu}, num rules: {len(rules)}, num conditions: {sum(map(len, rules))}')

Accuracy: 0.8332504558262888, num rules: 14, num conditions: 209


In [5]:
print(rules)

[['NOT age<=30.0', 'NOT workclass==Self-emp-not-inc', 'NOT education_num<=7.0', 'education_num<=9.0', 'marital-status==Married-civ-spouse', 'NOT occupation==Adm-clerical', 'NOT occupation==Craft-repair', 'NOT occupation==Machine-op-inspct', 'NOT occupation==Other-service', 'NOT occupation==Protective-serv', 'NOT occupation==Sales', 'NOT occupation==Transport-moving', 'race==Amer-Indian-Eskimo', 'NOT capital-gain<=0.0', 'NOT hours-per-week<=25.0'], ['education==Doctorate', 'occupation==Prof-specialty', 'NOT capital-loss<=0.0'], ['NOT age<=33.0', 'age<=57.0', 'NOT workclass==Self-emp-not-inc', 'NOT education_num<=13.0', 'marital-status==Married-civ-spouse', 'NOT occupation==Adm-clerical', 'NOT occupation==Craft-repair', 'NOT occupation==Machine-op-inspct', 'NOT occupation==Other-service', 'NOT occupation==Sales', 'NOT occupation==Transport-moving', 'NOT relationship==Not-in-family', 'NOT relationship==Own-child', 'NOT relationship==Unmarried', 'NOT hours-per-week<=25.0', 'NOT native-coun

In [6]:
# Train R2N-tab
# Default learning rate (1e-2), and_lam (1e-2), and and_lam (1e-5) usually work the best. A large epochs number is necessary for a sparse rule set i.e 10000 epochs.
net = R2Ntab(train_set[:][0].size(1), 50, 1)
train_r2ntab(net, train_set, test_set=test_set, device='cpu', lr=1e-2, epochs=1000, batch_size=400,
      and_lam=1e-2, or_lam=1e-5, num_alter=500)

Epoch: 100%|██████████| 1000/1000 [08:45<00:00,  1.90it/s, loss=0.552, epoch accu=0.822, test accu=0.827, num rules=9, sparsity=0.971]


In [7]:
# Get accuracy and the rule net
accu = (net.predict(np.array(X_test)) == Y_test).mean()
rules = net.get_rules(X_headers)
print(f'Accuracy: {accu}, num rules: {len(rules)}, num conditions: {sum(map(len, rules))}')

Accuracy: 0.8267860102768109, num rules: 9, num conditions: 33


In [8]:
print(rules)

[['education==Prof-school', 'NOT hours-per-week<=40.0'], ['NOT education_num<=9.0', 'marital-status==Married-civ-spouse', 'occupation==Exec-managerial', 'NOT sex'], ['NOT education_num<=13.0', 'NOT capital-gain<=0.0'], ['NOT education_num<=9.0', 'relationship==Husband', 'NOT capital-gain<=0.0'], ['NOT education_num<=13.0', 'marital-status==Married-civ-spouse', 'NOT sex', 'NOT hours-per-week<=42.0'], ['marital-status==Married-civ-spouse', 'NOT capital-loss<=0.0'], ['NOT age<=26.0', 'NOT education_num<=11.0', 'NOT occupation==Craft-repair', 'NOT occupation==Handlers-cleaners', 'NOT occupation==Other-service', 'relationship==Husband', 'NOT hours-per-week<=25.0'], ['NOT age<=33.0', 'NOT education_num<=9.0', 'marital-status==Married-civ-spouse', 'NOT relationship==Husband', 'NOT capital-gain<=0.0'], ['marital-status==Married-civ-spouse', 'occupation==Prof-specialty', 'NOT sex', 'NOT capital-gain<=0.0']]
