In [1]:
import torch
import numpy as np

from datasets.dataset import transform_dataset, kfold_dataset
from R2Ntab import train as train_r2ntab, R2Ntab
from DRNet import train as train_drnet, DRNet

import sys

import matplotlib.pyplot as plt

In [2]:
# Read datasets
name = 'adult'
X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare', negations=False, labels='binary')
datasets = kfold_dataset(X, Y, shuffle=1)
X_train, X_test, Y_train, Y_test = datasets[0]
train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()), torch.Tensor(Y_train))
test_set = torch.utils.data.TensorDataset(torch.Tensor(X_test.to_numpy()), torch.Tensor(Y_test))

In [None]:
# Train R2N-tab
# Default learning rate (1e-2), and_lam (1e-2), and and_lam (1e-5) usually work the best. A large epochs number is necessary for a sparse rule set i.e 10000 epochs.
net = R2Ntab(train_set[:][0].size(1), 50, 1)
acc, rules = train_r2ntab(net, train_set, test_set=test_set, device='cpu', lr_rules=1e-2, lr_cancel=1e-2, epochs=1000, batch_size=400,
             and_lam=1e-2, or_lam=1e-5, cancel_lam=1, num_alter=500, track_performance=True)

Epoch:  97%|█████████▋| 968/1000 [13:53<00:27,  1.18it/s, rules cancelled=61, loss=1.08, epoch accu=0.817, test accu=0.824, num rules=6, sparsity=0.883] 

In [4]:
# Get accuracy and the rule net
accu = (net.predict(np.array(X_test)) == Y_test).mean()
rules = net.get_rules(X_headers)
print(f'Accuracy: {accu}, num rules: {len(rules)}, num conditions: {sum(map(len, rules))}')

Accuracy: 0.8234709099950274, num rules: 6, num conditions: 106


In [5]:
print(rules)

[['NOT age<=26.0', 'age<=30.0', 'age<=33.0', 'NOT workclass==State-gov', 'NOT education_num<=9.0', 'NOT education_num<=11.0', 'NOT marital-status==Divorced', 'NOT marital-status==Never-married', 'NOT marital-status==Separated', 'NOT marital-status==Widowed', 'NOT occupation==Adm-clerical', 'NOT occupation==Craft-repair', 'NOT occupation==Farming-fishing', 'NOT occupation==Other-service', 'NOT occupation==Protective-serv', 'NOT occupation==Sales', 'NOT occupation==Transport-moving', 'NOT relationship==Not-in-family', 'NOT relationship==Other-relative', 'hours-per-week<=25.0', 'NOT native-country==Puerto-Rico'], ['NOT age<=26.0', 'NOT age<=30.0', 'NOT education==Assoc-voc', 'NOT education_num<=10.0', 'NOT education_num<=11.0', 'NOT marital-status==Divorced', 'NOT marital-status==Married-spouse-absent', 'NOT marital-status==Never-married', 'NOT occupation==Adm-clerical', 'NOT occupation==Farming-fishing', 'NOT occupation==Handlers-cleaners', 'NOT occupation==Other-service', 'NOT occupatio

In [7]:
indices = torch.where(net.cancelout_layer.weight < 0)
print(indices)
print(net.cancelout_layer.weight)

(tensor([  0,   1,   2,   3,   8,  11,  12,  13,  15,  19,  21,  24,  25,  26,
         27,  28,  29,  30,  31,  34,  35,  36,  37,  38,  39,  41,  42,  43,
         44,  45,  46,  48,  49,  50,  51,  52,  54,  55,  56,  57,  58,  59,
         60,  61,  62,  66,  67,  68,  69,  70,  71,  72,  73,  77,  78,  81,
         85,  88,  90,  92,  94,  98,  99, 100, 109, 111, 112, 113, 114, 115,
        117, 118, 119, 121, 124, 125, 126, 127]),)
Parameter containing:
tensor([-1.5587e-02, -1.9227e-02, -1.8510e-02, -9.4457e-03,  7.2455e+01,
         5.8961e+01,  5.2783e+00,  3.7613e+01, -5.7379e-03,  1.0435e+01,
         5.0213e+01, -1.4973e-02, -1.1789e-02, -4.0100e-03,  1.4281e+01,
        -1.0465e-02,  3.0106e+01,  2.5027e+01,  1.4388e+01, -5.5581e-03,
         1.2089e+01, -1.9776e-02,  2.3992e+00,  4.7402e+00, -1.2652e-02,
        -1.6692e-02, -8.4564e-03, -1.2752e-02, -3.0186e-02, -1.5103e-02,
        -1.2372e-02, -5.8284e-03,  5.0415e+01,  1.3727e+01, -1.7763e-02,
        -1.0110e-02, -2.1