In [1]:
import torch
import numpy as np

from datasets.dataset import transform_dataset, kfold_dataset
from R2Ntab import train, R2Ntab

import sys

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Read datasets
name = 'adult'
X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare', negations=False, labels='binary')
datasets = kfold_dataset(X, Y, shuffle=1)
X_train, X_test, Y_train, Y_test = datasets[0]
train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()), torch.Tensor(Y_train))
test_set = torch.utils.data.TensorDataset(torch.Tensor(X_test.to_numpy()), torch.Tensor(Y_test))

In [3]:
# Train DR-Net
# Default learning rate (1e-2), and_lam (1e-2), and and_lam (1e-5) usually work the best. A large epochs number is necessary for a sparse rule set i.e 10000 epochs.
net = R2Ntab(train_set[:][0].size(1), 50, 1)
train(net, train_set, test_set=test_set, device='cpu', lr=1e-2, epochs=1000, batch_size=400,
      and_lam=1e-2, or_lam=1e-5, num_alter=500)

Epoch: 100%|██████████| 1000/1000 [10:38<00:00,  1.57it/s, loss=0.551, epoch accu=0.824, test accu=0.827, num rules=6, sparsity=0.966]


In [4]:
# Get accuracy and the rule net
accu = (net.predict(np.array(X_test)) == Y_test).mean()
rules = net.get_rules(X_headers)
print(f'Accuracy: {accu}, num rules: {len(rules)}, num conditions: {sum(map(len, rules))}')

Accuracy: 0.8271175203049892, num rules: 6, num conditions: 26


In [5]:
print(rules)

[['NOT education_num<=9.0', 'marital-status==Married-civ-spouse', 'NOT sex', 'NOT capital-gain<=0.0'], ['NOT education_num<=9.0', 'relationship==Husband', 'NOT capital-gain<=0.0'], ['NOT age<=30.0', 'education==Prof-school', 'NOT education_num<=13.0', 'NOT hours-per-week<=42.0'], ['NOT education_num<=13.0', 'NOT capital-gain<=0.0', 'NOT hours-per-week<=42.0'], ['NOT age<=30.0', 'NOT education_num<=11.0', 'marital-status==Married-civ-spouse', 'NOT occupation==Farming-fishing', 'NOT occupation==Handlers-cleaners', 'NOT occupation==Machine-op-inspct', 'NOT occupation==Other-service', 'NOT occupation==Transport-moving', 'NOT hours-per-week<=25.0'], ['NOT education_num<=11.0', 'marital-status==Married-civ-spouse', 'NOT capital-loss<=0.0']]
