In [3]:
import wittgenstein as rule
import torch
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
from datasets.dataset import transform_dataset, kfold_dataset
from R2Ntab import train as train, R2Ntab

In [4]:
rule_learners = ['r2ntab', 'ripper', 'cart', 'c4.5']
def run_learner(rule_learner, dataset):
    if rule_learner == 'r2ntab':
        model = R2Ntab(train_set[:][0].size(1), 50, 1)
        if dataset == 'adult':
            train(model, train_set, test_set=test_set, device='cpu', lr_rules=1e-2, lr_cancel=1e-2, 
                    epochs=1000, batch_size=400, and_lam=1e-2, or_lam=1e-5, cancel_lam=1e-4, num_alter=500)
        elif dataset == 'heloc':
            train(model, train_set, test_set=test_set, device='cpu', lr_rules=1e-2, lr_cancel=1e-2, 
                    epochs=3000, batch_size=400, and_lam=1e-2, or_lam=1e-5, cancel_lam=1e-3, num_alter=500)
        else:
            train(model, train_set, test_set=test_set, device='cpu', lr_rules=1e-2, lr_cancel=1e-2, 
                    epochs=2000, batch_size=400, and_lam=1e-2, or_lam=1e-5, cancel_lam=1e-5, num_alter=500)
        acc = (model.predict(np.array(X_test)) == Y_test).mean()
        sparsity = sum(map(len, model.get_rules(X_headers)))
    elif rule_learner == 'ripper':
        model = rule.RIPPER()
        model.fit(RX_train, Y_train)
        acc = model.score(RX_test, Y_test)
        sparsity = sum(len(rule) for rule in model.ruleset_)
    elif rule_learner == 'cart':
        model = DecisionTreeClassifier()
        model.fit(X_train, Y_train)
        acc = model.score(X_test, Y_test)
        sparsity = export_text(model, feature_names=X_train.columns.tolist()).count('(')
    elif rule_learner == 'c4.5':
        model = DecisionTreeClassifier(criterion='entropy')
        model.fit(X_train, Y_train)
        acc = model.score(X_test, Y_test)
        sparsity = export_text(model, feature_names=X_train.columns.tolist()).count('(')
        
    return acc, sparsity

In [5]:
runs = 10
accuracies = {}
sparsities = {}
for learner in rule_learners:
    accuracies[learner] = []
    sparsities[learner] = []

dataset_names = ['adult', 'heloc', 'house', 'magic']
for dataset in dataset_names:
    X, Y, X_headers, Y_headers = transform_dataset(dataset, method='onehot-compare', negations=False, labels='binary')
    datasets = kfold_dataset(X, Y, shuffle=1)
    X_train, X_test, Y_train, Y_test = datasets[0]
    train_set = torch.utils.data.TensorDataset(torch.Tensor(X_train.to_numpy()), torch.Tensor(Y_train))
    test_set = torch.utils.data.TensorDataset(torch.Tensor(X_test.to_numpy()), torch.Tensor(Y_test))

    RX_train = pd.DataFrame(X_train)
    RX_train = RX_train.sort_index(axis=1)
    RX_test = pd.DataFrame(X_test)
    RX_test = RX_test.sort_index(axis=1)
    
    for run in range(runs):
        print(f'run {run+1}') 
        for learner in rule_learners:
            acc, sparsity = run_learner(learner, dataset)
            
            accuracies[learner].append(acc)
            sparsities[learner].append(sparsity)

run 1


2023-06-30 01:53:56.409675: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-30 01:53:56.440974: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-06-30 01:53:56.441424: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Epoch: 100%|██████████| 1000/1000 [04:15<00:00,  3.92it/s, rules cancelled=77, l


run 2


Epoch: 100%|██████████| 1000/1000 [04:13<00:00,  3.95it/s, rules cancelled=76, l


run 3


Epoch: 100%|██████████| 1000/1000 [04:10<00:00,  3.99it/s, rules cancelled=67, l


run 4


Epoch: 100%|██████████| 1000/1000 [04:11<00:00,  3.98it/s, rules cancelled=76, l


run 5


Epoch: 100%|██████████| 1000/1000 [04:11<00:00,  3.97it/s, rules cancelled=78, l


run 6


Epoch: 100%|██████████| 1000/1000 [04:11<00:00,  3.98it/s, rules cancelled=78, l


run 7


Epoch: 100%|██████████| 1000/1000 [04:11<00:00,  3.98it/s, rules cancelled=77, l


run 8


Epoch: 100%|██████████| 1000/1000 [04:12<00:00,  3.95it/s, rules cancelled=73, l


run 9


Epoch: 100%|██████████| 1000/1000 [04:11<00:00,  3.98it/s, rules cancelled=77, l


run 10


Epoch: 100%|██████████| 1000/1000 [04:11<00:00,  3.97it/s, rules cancelled=72, l


run 1


Epoch: 100%|██████████| 3000/3000 [04:27<00:00, 11.22it/s, rules cancelled=56, l


run 2


Epoch: 100%|██████████| 3000/3000 [04:24<00:00, 11.36it/s, rules cancelled=56, l


run 3


Epoch: 100%|██████████| 3000/3000 [04:26<00:00, 11.26it/s, rules cancelled=57, l


run 4


Epoch: 100%|██████████| 3000/3000 [04:26<00:00, 11.28it/s, rules cancelled=56, l


run 5


Epoch: 100%|██████████| 3000/3000 [04:26<00:00, 11.25it/s, rules cancelled=58, l


run 6


Epoch: 100%|██████████| 3000/3000 [04:26<00:00, 11.27it/s, rules cancelled=59, l


run 7


Epoch: 100%|██████████| 3000/3000 [04:27<00:00, 11.23it/s, rules cancelled=56, l


run 8


Epoch: 100%|██████████| 3000/3000 [04:27<00:00, 11.23it/s, rules cancelled=56, l


run 9


Epoch: 100%|██████████| 3000/3000 [04:27<00:00, 11.20it/s, rules cancelled=57, l


run 10


Epoch: 100%|██████████| 3000/3000 [04:26<00:00, 11.24it/s, rules cancelled=58, l


run 1


Epoch: 100%|██████████| 2000/2000 [06:19<00:00,  5.26it/s, rules cancelled=63, l


run 2


Epoch: 100%|██████████| 2000/2000 [06:20<00:00,  5.26it/s, rules cancelled=61, l


run 3


Epoch: 100%|██████████| 2000/2000 [06:19<00:00,  5.27it/s, rules cancelled=60, l


run 4


Epoch: 100%|██████████| 2000/2000 [06:19<00:00,  5.27it/s, rules cancelled=63, l


run 5


Epoch: 100%|██████████| 2000/2000 [06:19<00:00,  5.27it/s, rules cancelled=63, l


run 6


Epoch: 100%|██████████| 2000/2000 [06:18<00:00,  5.28it/s, rules cancelled=58, l


run 7


Epoch: 100%|██████████| 2000/2000 [06:21<00:00,  5.25it/s, rules cancelled=61, l


run 8


Epoch: 100%|██████████| 2000/2000 [06:21<00:00,  5.24it/s, rules cancelled=64, l


run 9


Epoch: 100%|██████████| 2000/2000 [06:20<00:00,  5.25it/s, rules cancelled=61, l


run 10


Epoch: 100%|██████████| 2000/2000 [06:23<00:00,  5.22it/s, rules cancelled=66, l


run 1


Epoch: 100%|██████████| 2000/2000 [05:05<00:00,  6.56it/s, rules cancelled=40, l


run 2


Epoch: 100%|██████████| 2000/2000 [05:04<00:00,  6.57it/s, rules cancelled=37, l


run 3


Epoch: 100%|██████████| 2000/2000 [05:04<00:00,  6.57it/s, rules cancelled=38, l


run 4


Epoch: 100%|██████████| 2000/2000 [05:03<00:00,  6.58it/s, rules cancelled=37, l


run 5


Epoch: 100%|██████████| 2000/2000 [05:06<00:00,  6.52it/s, rules cancelled=39, l


run 6


Epoch: 100%|██████████| 2000/2000 [05:04<00:00,  6.57it/s, rules cancelled=40, l


run 7


Epoch: 100%|██████████| 2000/2000 [05:05<00:00,  6.55it/s, rules cancelled=43, l


run 8


Epoch: 100%|██████████| 2000/2000 [05:06<00:00,  6.53it/s, rules cancelled=45, l


run 9


Epoch: 100%|██████████| 2000/2000 [05:05<00:00,  6.54it/s, rules cancelled=36, l


run 10


Epoch: 100%|██████████| 2000/2000 [05:06<00:00,  6.52it/s, rules cancelled=40, l


In [38]:
for learner in rule_learners:
    print(learner, np.mean(accuracies[learner]), np.std(accuracies[learner]), np.mean(sparsities[learner]), np.std(sparsities[learner]))

ripper 0.8053543998244459 0.0 234.0 0.0
cart 0.8240070221637041 0.0 1792.0 0.0
c4.5 0.82642089093702 0.0 1690.0 0.0


In [8]:
import json

with open('exp4accs.json', 'w') as file:
    json.dump(accuracies, file)

with open('exp4spars.json', 'w') as file:
    json.dump(sparsities, file)

In [9]:
print(accuracies)
print(sparsities)

{'r2ntab': [0.8044090833747721, 0.8173379744737278, 0.8314271506713078, 0.8322559257417537, 0.8196585446709763, 0.8191612796287088, 0.8208188297696005, 0.831592905685397, 0.8140228741919443, 0.8352395159953588, 0.7174952198852772, 0.6964627151051626, 0.6945506692160612, 0.7146271510516252, 0.6902485659655831, 0.7122370936902486, 0.6998087954110899, 0.6988527724665392, 0.7179732313575525, 0.7050669216061185, 0.8433179723502304, 0.826201448321264, 0.8246653500109721, 0.8389291200351108, 0.836515251261795, 0.8389291200351108, 0.8283958744788238, 0.8347597103357473, 0.8305903006363836, 0.8235681369321922, 0.8228180862250263, 0.8454258675078864, 0.8309674027339643, 0.833596214511041, 0.823080967402734, 0.8338590956887487, 0.8286014721345951, 0.8046792849631966, 0.8251840168243953, 0.8299158780231335], 'ripper': [0.827780540361346, 0.8310956406431295, 0.8286093154317918, 0.8264545002486325, 0.8307641306149511, 0.8276147853472567, 0.8267860102768109, 0.8267860102768109, 0.8302668655726836, 0.