In [1]:
import pandas as pd
import numpy as np
import time
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from imodels import OneRClassifier, GreedyRuleListClassifier
from sklearn.preprocessing import OrdinalEncoder

import jpype
import os

In [2]:
models = {
    "Decision Tree": DecisionTreeClassifier(),
    "Random Forest": RandomForestClassifier(),
}

datasets = [
    ("mushroom", 1),
    ("bank-marketing", 1),
    ("adult", 1),
    ("page-blocks", 1)
]

In [3]:
results = []

for name, version in datasets:
    X, y = fetch_openml(name, version=version, as_frame=True, return_X_y=True)
    X = X.astype(str).fillna("NA")
    y = y.astype(str)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.1, stratify=y, random_state=42
    )

    for model_name, model in models.items():
        try:
            X_train_enc, X_test_enc = X_train, X_test
            if model_name != "LORD":
                enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
                X_train_enc = enc.fit_transform(X_train)
                X_test_enc = enc.transform(X_test)

            clf = model
            if model_name == "LORD":
                clf = LocalRuleClassifier()

            start = time.time()
            clf.fit(X_train_enc, y_train)
            dt = time.time() - start

            y_pred = clf.predict(X_test_enc)
            acc = accuracy_score(y_test, y_pred)
            results.append((name, model_name, f"{acc:.3f}", f"{dt:.2f}", len(y_test)))

        except Exception as e:
            results.append((name, model_name, f"ERROR: {e}", "-", len(y_test)))

In [4]:
df_results = pd.DataFrame(results, columns=["Dataset", "Model", "Accuracy", "Time (s)", "Test Samples"])
df_results

Unnamed: 0,Dataset,Model,Accuracy,Time (s),Test Samples
0,mushroom,Decision Tree,1.0,0.01,813
1,mushroom,Random Forest,1.0,0.14,813
2,bank-marketing,Decision Tree,0.831,0.22,4522
3,bank-marketing,Random Forest,0.897,3.55,4522
4,adult,Decision Tree,0.789,0.16,4885
5,adult,Random Forest,0.834,2.43,4885
6,page-blocks,Decision Tree,0.816,0.03,548
7,page-blocks,Random Forest,0.83,0.64,548
