## Experiment 1: Comparison against the rule network

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import json

from datasets.dataset import transform_dataset, kfold_dataset
from R2Ntab import R2Ntab
from DRNet import train as train, DRNet

In [4]:
networks = ['drnet', 'r2ntab2', 'r2ntab4', 'r2ntab6']
def run_network(network):
    if network == 'drnet':
        net = DRNet(train_set[:][0].size(1), 50, 1)
        train(net, train_set, test_set=test_set, device='cpu', epochs=1000, batch_size=400)
        accuracy = (net.predict(np.array(X_test)) == Y_test).mean()
        sparsity = sum(map(len, net.get_rules(X_headers)))
    elif network == 'r2ntab2':
        net = R2Ntab(train_set[:][0].size(1), 50, 1)
        net.fit(train_set, test_set=test_set, device='cpu', epochs=1000, batch_size=400, cancel_lam=1e-2)
        accuracy = net.predict(X_test, Y_test)
        sparsity = sum(map(len, net.extract_rules(X_headers)))
    elif network == 'r2ntab4':
        net = R2Ntab(train_set[:][0].size(1), 50, 1)
        net.fit(train_set, test_set=test_set, device='cpu', epochs=1000, batch_size=400, cancel_lam=1e-4)
        accuracy = net.predict(X_test, Y_test)
        sparsity = sum(map(len, net.extract_rules(X_headers)))
    elif network == 'r2ntab6':
        net = R2Ntab(train_set[:][0].size(1), 50, 1)
        net.fit(train_set, test_set=test_set, device='cpu', epochs=1000, batch_size=400, cancel_lam=1e-6)
        accuracy = net.predict(X_test, Y_test)
        sparsity = sum(map(len, net.extract_rules(X_headers)))
        
    return accuracy, sparsity

In [5]:
runs = 10
for name in ['adult', 'heloc', 'house', 'magic']:
    accuracies = {}
    sparsities = {}
    for network in networks:
        accuracies[network] = []
        sparsities[network] = []
        
    print('dataset:', name)
    for run in range(runs):
        print('  run:', run+1)
        X, Y, X_headers, Y_headers = transform_dataset(name, method='onehot-compare', negations=False, labels='binary')
        datasets = kfold_dataset(X, Y, shuffle=1)
        X_train, X_test, Y_train, Y_test = datasets[0]
        train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()), torch.Tensor(Y_train))
        test_set = torch.utils.data.TensorDataset(torch.Tensor(X_test.to_numpy()), torch.Tensor(Y_test))
        
        for network in networks:
            accuracy, sparsity = run_network(network)
            accuracies[network].append(accuracy)
            sparsities[network].append(sparsity)

    with open(f'exp1-accuracies-{name}.json', 'w') as file:
        json.dump(accuracies, file)

    with open(f'exp1-sparsities-{name}.json', 'w') as file:
        json.dump(sparsities, file)

dataset: adult
  run: 1


Epoch: 100%|██████████| 1000/1000 [03:36<00:00,  4.63it/s, loss=0.549, epoch acc
100%|███████████████████| 1000/1000 [05:38<00:00,  2.96it/s]
100%|███████████████████| 1000/1000 [04:37<00:00,  3.60it/s]
100%|███████████████████| 1000/1000 [04:23<00:00,  3.80it/s]


  run: 2


Epoch: 100%|██████████| 1000/1000 [03:38<00:00,  4.57it/s, loss=0.546, epoch acc
100%|███████████████████| 1000/1000 [05:35<00:00,  2.98it/s]
100%|███████████████████| 1000/1000 [04:44<00:00,  3.51it/s]
100%|███████████████████| 1000/1000 [04:46<00:00,  3.49it/s]


  run: 3


Epoch: 100%|██████████| 1000/1000 [03:46<00:00,  4.42it/s, loss=0.544, epoch acc
100%|███████████████████| 1000/1000 [05:52<00:00,  2.84it/s]
100%|███████████████████| 1000/1000 [04:48<00:00,  3.47it/s]
100%|███████████████████| 1000/1000 [04:43<00:00,  3.52it/s]


  run: 4


Epoch: 100%|██████████| 1000/1000 [03:43<00:00,  4.48it/s, loss=0.549, epoch acc
100%|███████████████████| 1000/1000 [05:42<00:00,  2.92it/s]
100%|███████████████████| 1000/1000 [04:45<00:00,  3.50it/s]
100%|███████████████████| 1000/1000 [04:30<00:00,  3.69it/s]


  run: 5


Epoch: 100%|██████████| 1000/1000 [03:37<00:00,  4.60it/s, loss=0.544, epoch acc
100%|███████████████████| 1000/1000 [05:26<00:00,  3.06it/s]
100%|███████████████████| 1000/1000 [04:30<00:00,  3.70it/s]
100%|███████████████████| 1000/1000 [04:29<00:00,  3.72it/s]


  run: 6


Epoch: 100%|██████████| 1000/1000 [03:32<00:00,  4.71it/s, loss=0.545, epoch acc
100%|███████████████████| 1000/1000 [05:26<00:00,  3.06it/s]
100%|███████████████████| 1000/1000 [04:33<00:00,  3.66it/s]
100%|███████████████████| 1000/1000 [04:26<00:00,  3.75it/s]


  run: 7


Epoch: 100%|██████████| 1000/1000 [03:33<00:00,  4.67it/s, loss=0.546, epoch acc
100%|███████████████████| 1000/1000 [05:22<00:00,  3.10it/s]
100%|███████████████████| 1000/1000 [04:44<00:00,  3.51it/s]
100%|███████████████████| 1000/1000 [04:22<00:00,  3.81it/s]


  run: 8


Epoch: 100%|██████████| 1000/1000 [03:34<00:00,  4.66it/s, loss=0.55, epoch accu
100%|███████████████████| 1000/1000 [05:22<00:00,  3.10it/s]
100%|███████████████████| 1000/1000 [04:34<00:00,  3.64it/s]
100%|███████████████████| 1000/1000 [04:23<00:00,  3.80it/s]


  run: 9


Epoch: 100%|██████████| 1000/1000 [03:34<00:00,  4.66it/s, loss=0.546, epoch acc
100%|███████████████████| 1000/1000 [05:33<00:00,  3.00it/s]
100%|███████████████████| 1000/1000 [04:38<00:00,  3.59it/s]
100%|███████████████████| 1000/1000 [04:25<00:00,  3.77it/s]


  run: 10


Epoch: 100%|██████████| 1000/1000 [03:33<00:00,  4.68it/s, loss=0.55, epoch accu
100%|███████████████████| 1000/1000 [05:26<00:00,  3.07it/s]
100%|███████████████████| 1000/1000 [04:39<00:00,  3.58it/s]
100%|███████████████████| 1000/1000 [04:31<00:00,  3.68it/s]


dataset: heloc
  run: 1


Epoch: 100%|██████████| 1000/1000 [01:16<00:00, 13.16it/s, loss=0.625, epoch acc
100%|███████████████████| 1000/1000 [01:23<00:00, 11.95it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.87it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.86it/s]


  run: 2


Epoch: 100%|██████████| 1000/1000 [01:15<00:00, 13.29it/s, loss=0.634, epoch acc
100%|███████████████████| 1000/1000 [01:24<00:00, 11.85it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.78it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.84it/s]


  run: 3


Epoch: 100%|██████████| 1000/1000 [01:14<00:00, 13.35it/s, loss=0.629, epoch acc
100%|███████████████████| 1000/1000 [01:24<00:00, 11.77it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.90it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.80it/s]


  run: 4


Epoch: 100%|██████████| 1000/1000 [01:16<00:00, 13.15it/s, loss=0.636, epoch acc
100%|███████████████████| 1000/1000 [01:24<00:00, 11.90it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.84it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.86it/s]


  run: 5


Epoch: 100%|██████████| 1000/1000 [01:14<00:00, 13.34it/s, loss=0.627, epoch acc
100%|███████████████████| 1000/1000 [01:24<00:00, 11.87it/s]
100%|███████████████████| 1000/1000 [01:25<00:00, 11.70it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.78it/s]


  run: 6


Epoch: 100%|██████████| 1000/1000 [01:15<00:00, 13.26it/s, loss=0.637, epoch acc
100%|███████████████████| 1000/1000 [01:24<00:00, 11.87it/s]
100%|███████████████████| 1000/1000 [01:25<00:00, 11.74it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.83it/s]


  run: 7


Epoch: 100%|██████████| 1000/1000 [01:14<00:00, 13.34it/s, loss=0.619, epoch acc
100%|███████████████████| 1000/1000 [01:24<00:00, 11.89it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.78it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.83it/s]


  run: 8


Epoch: 100%|██████████| 1000/1000 [01:15<00:00, 13.24it/s, loss=0.637, epoch acc
100%|███████████████████| 1000/1000 [01:24<00:00, 11.87it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.79it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.79it/s]


  run: 9


Epoch: 100%|██████████| 1000/1000 [01:15<00:00, 13.28it/s, loss=0.642, epoch acc
100%|███████████████████| 1000/1000 [01:24<00:00, 11.78it/s]
100%|███████████████████| 1000/1000 [01:23<00:00, 11.91it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.83it/s]


  run: 10


Epoch: 100%|██████████| 1000/1000 [01:15<00:00, 13.25it/s, loss=0.634, epoch acc
100%|███████████████████| 1000/1000 [01:24<00:00, 11.82it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.86it/s]
100%|███████████████████| 1000/1000 [01:24<00:00, 11.77it/s]


dataset: house
  run: 1


Epoch: 100%|██████████| 1000/1000 [02:42<00:00,  6.14it/s, loss=0.353, epoch acc
100%|███████████████████| 1000/1000 [03:57<00:00,  4.20it/s]
100%|███████████████████| 1000/1000 [03:46<00:00,  4.41it/s]
100%|███████████████████| 1000/1000 [03:48<00:00,  4.38it/s]


  run: 2


Epoch: 100%|██████████| 1000/1000 [02:55<00:00,  5.71it/s, loss=0.353, epoch acc
100%|███████████████████| 1000/1000 [04:15<00:00,  3.91it/s]
100%|███████████████████| 1000/1000 [03:45<00:00,  4.43it/s]
100%|███████████████████| 1000/1000 [03:40<00:00,  4.53it/s]


  run: 3


Epoch: 100%|██████████| 1000/1000 [02:41<00:00,  6.18it/s, loss=0.349, epoch acc
100%|███████████████████| 1000/1000 [03:52<00:00,  4.30it/s]
100%|███████████████████| 1000/1000 [03:32<00:00,  4.71it/s]
100%|███████████████████| 1000/1000 [03:31<00:00,  4.73it/s]


  run: 4


Epoch: 100%|██████████| 1000/1000 [02:42<00:00,  6.14it/s, loss=0.348, epoch acc
100%|███████████████████| 1000/1000 [03:52<00:00,  4.30it/s]
100%|███████████████████| 1000/1000 [03:28<00:00,  4.79it/s]
100%|███████████████████| 1000/1000 [03:30<00:00,  4.75it/s]


  run: 5


Epoch: 100%|██████████| 1000/1000 [02:42<00:00,  6.16it/s, loss=0.355, epoch acc
100%|███████████████████| 1000/1000 [03:52<00:00,  4.31it/s]
100%|███████████████████| 1000/1000 [03:31<00:00,  4.73it/s]
100%|███████████████████| 1000/1000 [03:31<00:00,  4.72it/s]


  run: 6


Epoch: 100%|██████████| 1000/1000 [02:43<00:00,  6.13it/s, loss=0.348, epoch acc
100%|███████████████████| 1000/1000 [03:43<00:00,  4.47it/s]
100%|███████████████████| 1000/1000 [03:32<00:00,  4.71it/s]
100%|███████████████████| 1000/1000 [03:30<00:00,  4.75it/s]


  run: 7


Epoch: 100%|██████████| 1000/1000 [02:41<00:00,  6.21it/s, loss=0.354, epoch acc
100%|███████████████████| 1000/1000 [03:52<00:00,  4.30it/s]
100%|███████████████████| 1000/1000 [03:31<00:00,  4.73it/s]
100%|███████████████████| 1000/1000 [03:32<00:00,  4.70it/s]


  run: 8


Epoch: 100%|██████████| 1000/1000 [02:42<00:00,  6.15it/s, loss=0.347, epoch acc
100%|███████████████████| 1000/1000 [03:46<00:00,  4.41it/s]
100%|███████████████████| 1000/1000 [03:32<00:00,  4.72it/s]
100%|███████████████████| 1000/1000 [03:29<00:00,  4.76it/s]


  run: 9


Epoch: 100%|██████████| 1000/1000 [02:40<00:00,  6.21it/s, loss=0.357, epoch acc
100%|███████████████████| 1000/1000 [03:50<00:00,  4.34it/s]
100%|███████████████████| 1000/1000 [03:31<00:00,  4.73it/s]
100%|███████████████████| 1000/1000 [03:32<00:00,  4.71it/s]


  run: 10


Epoch: 100%|██████████| 1000/1000 [02:42<00:00,  6.14it/s, loss=0.347, epoch acc
100%|███████████████████| 1000/1000 [03:54<00:00,  4.26it/s]
100%|███████████████████| 1000/1000 [03:31<00:00,  4.72it/s]
100%|███████████████████| 1000/1000 [03:32<00:00,  4.71it/s]


dataset: magic
  run: 1


Epoch: 100%|██████████| 1000/1000 [02:10<00:00,  7.65it/s, loss=0.491, epoch acc
100%|███████████████████| 1000/1000 [02:58<00:00,  5.61it/s]
100%|███████████████████| 1000/1000 [02:33<00:00,  6.51it/s]
100%|███████████████████| 1000/1000 [02:37<00:00,  6.34it/s]


  run: 2


Epoch: 100%|██████████| 1000/1000 [02:10<00:00,  7.63it/s, loss=0.494, epoch acc
100%|███████████████████| 1000/1000 [02:59<00:00,  5.58it/s]
100%|███████████████████| 1000/1000 [02:36<00:00,  6.41it/s]
100%|███████████████████| 1000/1000 [02:36<00:00,  6.41it/s]


  run: 3


Epoch: 100%|██████████| 1000/1000 [02:11<00:00,  7.61it/s, loss=0.493, epoch acc
100%|███████████████████| 1000/1000 [02:54<00:00,  5.72it/s]
100%|███████████████████| 1000/1000 [02:34<00:00,  6.45it/s]
100%|███████████████████| 1000/1000 [02:34<00:00,  6.47it/s]


  run: 4


Epoch: 100%|██████████| 1000/1000 [02:11<00:00,  7.63it/s, loss=0.493, epoch acc
100%|███████████████████| 1000/1000 [02:58<00:00,  5.60it/s]
100%|███████████████████| 1000/1000 [02:36<00:00,  6.37it/s]
100%|███████████████████| 1000/1000 [02:37<00:00,  6.35it/s]


  run: 5


Epoch: 100%|██████████| 1000/1000 [02:11<00:00,  7.58it/s, loss=0.49, epoch accu
100%|███████████████████| 1000/1000 [02:56<00:00,  5.67it/s]
100%|███████████████████| 1000/1000 [02:36<00:00,  6.40it/s]
100%|███████████████████| 1000/1000 [02:37<00:00,  6.34it/s]


  run: 6


Epoch: 100%|██████████| 1000/1000 [02:11<00:00,  7.62it/s, loss=0.489, epoch acc
100%|███████████████████| 1000/1000 [03:03<00:00,  5.45it/s]
100%|███████████████████| 1000/1000 [02:37<00:00,  6.36it/s]
100%|███████████████████| 1000/1000 [02:37<00:00,  6.36it/s]


  run: 7


Epoch: 100%|██████████| 1000/1000 [02:14<00:00,  7.44it/s, loss=0.492, epoch acc
100%|███████████████████| 1000/1000 [03:04<00:00,  5.43it/s]
100%|███████████████████| 1000/1000 [02:38<00:00,  6.32it/s]
100%|███████████████████| 1000/1000 [02:40<00:00,  6.22it/s]


  run: 8


Epoch: 100%|██████████| 1000/1000 [02:13<00:00,  7.48it/s, loss=0.492, epoch acc
100%|███████████████████| 1000/1000 [03:01<00:00,  5.51it/s]
100%|███████████████████| 1000/1000 [02:37<00:00,  6.33it/s]
100%|███████████████████| 1000/1000 [02:37<00:00,  6.34it/s]


  run: 9


Epoch: 100%|██████████| 1000/1000 [02:15<00:00,  7.39it/s, loss=0.493, epoch acc
100%|███████████████████| 1000/1000 [02:59<00:00,  5.56it/s]
100%|███████████████████| 1000/1000 [02:39<00:00,  6.28it/s]
100%|███████████████████| 1000/1000 [02:37<00:00,  6.33it/s]


  run: 10


Epoch: 100%|██████████| 1000/1000 [02:15<00:00,  7.38it/s, loss=0.494, epoch acc
100%|███████████████████| 1000/1000 [03:01<00:00,  5.52it/s]
100%|███████████████████| 1000/1000 [02:38<00:00,  6.33it/s]
100%|███████████████████| 1000/1000 [02:40<00:00,  6.21it/s]
