In [1]:
import torch
import numpy as np

from datasets.dataset import transform_dataset, kfold_dataset
from R2Ntab import train as train_r2ntab, R2Ntab
from DRNet import train as train_drnet, DRNet

import sys

In [2]:
# Read datasets
name = 'magic'
X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare', negations=False, labels='binary')
datasets = kfold_dataset(X, Y, shuffle=1)
X_train, X_test, Y_train, Y_test = datasets[0]
train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()), torch.Tensor(Y_train))
test_set = torch.utils.data.TensorDataset(torch.Tensor(X_test.to_numpy()), torch.Tensor(Y_test))

In [3]:
# Train DR-Net
# Default learning rate (1e-2), and_lam (1e-2), and and_lam (1e-5) usually work the best. A large epochs number is necessary for a sparse rule set i.e 10000 epochs.
net = DRNet(train_set[:][0].size(1), 50, 1)
train_drnet(net, train_set, test_set=test_set, device='cpu', lr=1e-2, epochs=1000, batch_size=400,
      and_lam=1e-2, or_lam=1e-5, num_alter=500)

Epoch: 100%|██████████| 1000/1000 [04:34<00:00,  3.65it/s, loss=0.495, epoch accu=0.837, test accu=0.849, num rules=34, sparsity=0.878]


In [4]:
# Get accuracy and the rule net
accu = (net.predict(np.array(X_test)) == Y_test).mean()
rules = net.get_rules(X_headers)
print(f'Accuracy: {accu}, num rules: {len(rules)}, num conditions: {sum(map(len, rules))}')

Accuracy: 0.849106203995794, num rules: 34, num conditions: 372


In [5]:
print(rules)

[['fLength<=30.984860000000005', 'fWidth<=11.14812', 'fWidth<=12.70988', 'fWidth<=15.135580000000001', 'NOT fSize<=2.4216', 'NOT fSize<=2.5296', 'NOT fConc1<=0.309', 'NOT fAlpha<=1.9325999999999999', 'NOT fAlpha<=4.1578800000000005', 'NOT fAlpha<=7.29494', 'fAlpha<=11.4574'], ['NOT fLength<=37.1477', 'NOT fLength<=46.561859999999996', 'NOT fLength<=61.08167999999999', 'NOT fLength<=80.00280000000001', 'NOT fSize<=2.2867', 'fSize<=2.4216', 'fM3Long<=8.313820000000002', 'fM3Long<=28.885389999999997', 'NOT fAlpha<=7.29494', 'NOT fAlpha<=11.4574', 'NOT fAlpha<=17.6795', 'fAlpha<=26.6302', 'fAlpha<=38.9075'], ['NOT fLength<=30.984860000000005', 'fLength<=37.1477', 'NOT fSize<=2.2867', 'fSize<=2.4216', 'NOT fAlpha<=11.4574', 'NOT fAlpha<=17.6795', 'NOT fAlpha<=26.6302', 'NOT fAlpha<=38.9075', 'NOT fAlpha<=53.61464', 'NOT fDist<=130.3766', 'NOT fDist<=154.14880000000002'], ['NOT fLength<=30.984860000000005', 'NOT fLength<=37.1477', 'NOT fLength<=46.561859999999996', 'NOT fLength<=61.081679999

In [6]:
# Train R2N-tab
# Default learning rate (1e-2), and_lam (1e-2), and and_lam (1e-5) usually work the best. A large epochs number is necessary for a sparse rule set i.e 10000 epochs.
net = R2Ntab(train_set[:][0].size(1), 50, 1)
train_r2ntab(net, train_set, test_set=test_set, device='cpu', lr=1e-2, epochs=1000, batch_size=400,
      and_lam=1e-2, or_lam=1e-5, num_alter=500)

Epoch: 100%|██████████| 1000/1000 [05:09<00:00,  3.23it/s, loss=0.497, epoch accu=0.839, test accu=0.843, num rules=21, sparsity=0.962]


In [7]:
# Get accuracy and the rule net
accu = (net.predict(np.array(X_test)) == Y_test).mean()
rules = net.get_rules(X_headers)
print(f'Accuracy: {accu}, num rules: {len(rules)}, num conditions: {sum(map(len, rules))}')

Accuracy: 0.842534174553102, num rules: 21, num conditions: 71


In [8]:
print(rules)

[['NOT fLength<=37.1477', 'NOT fWidth<=27.49262000000001', 'NOT fConc<=0.4678'], ['fWidth<=11.14812', 'NOT fSize<=2.5296', 'NOT fAlpha<=17.6795', 'fAlpha<=26.6302'], ['NOT fLength<=105.45940000000002', 'NOT fAlpha<=7.29494', 'fAlpha<=26.6302'], ['NOT fLength<=37.1477', 'fWidth<=15.135580000000001', 'NOT fSize<=2.5296', 'NOT fAlpha<=26.6302'], ['NOT fLength<=80.00280000000001', 'NOT fAlpha<=11.4574', 'NOT fAlpha<=17.6795', 'fAlpha<=26.6302'], ['fWidth<=15.135580000000001', 'NOT fSize<=2.5296', 'NOT fAlpha<=38.9075', 'NOT fDist<=173.6184'], ['NOT fLength<=61.08167999999999', 'NOT fAlpha<=26.6302'], ['NOT fWidth<=38.67311000000001', 'NOT fAlpha<=7.29494', 'NOT fAlpha<=11.4574', 'fAlpha<=17.6795'], ['fWidth<=11.14812', 'NOT fSize<=2.4216', 'NOT fAlpha<=38.9075'], ['NOT fLength<=61.08167999999999', 'fM3Long<=-34.85791', 'NOT fAlpha<=11.4574', 'fAlpha<=26.6302'], ['fWidth<=9.60832', 'NOT fSize<=2.4216'], ['NOT fLength<=46.561859999999996', 'NOT fAlpha<=26.6302', 'NOT fDist<=209.6559199999999