In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from copy import deepcopy

import imodels

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

In [None]:
def decision_boundary(X, y, model, model_name, cols, filename, save=False):
    fig = plt.figure(model_name)
    plt.clf()
    disp = DecisionBoundaryDisplay.from_estimator(
        model, X, grid_resolution=200, response_method="predict", alpha=0.5, cmap="coolwarm"
    )
    ax = disp.ax_
    ax.scatter(X[y==0, 0], X[y==0, 1], s=3, c="dodgerblue", marker="s", alpha=0.3, label=0)
    ax.scatter(X[y==1, 0], X[y==1, 1], s=3, c="orangered", marker="^", alpha=0.3, label=1)
    ax.set_xlabel(cols[0])
    ax.set_ylabel(cols[1])
    ax.set_title(f"{model_name} (AUC {sklearn.metrics.roc_auc_score(y, model.predict(X)):.3f})")
    ax.legend()
    if save:
        plt.savefig(filename, bbox_inches=None, facecolor="white", edgecolor="auto")

In [None]:
def feature_importance(X, y, cols, N=10):
    mean_feature_importance = np.zeros(X.shape[1])
    for _ in range(N):
        tree = DecisionTreeClassifier(criterion='gini')
        tree.fit(X, y)
        mean_feature_importance += tree.feature_importances_
    mean_feature_importance /= N
    most_important = np.flip(np.argsort(mean_feature_importance))
    print(f"     Feature importance\n{34*'='}")
    for i, _ in zip(most_important, range(10)):
        print(f"{cols[i]:>25s} | {mean_feature_importance[i]:.3f}")
    return cols[most_important[:2]]

In [None]:
def test_simpler_boundary(dataset_name, database_name, cols_name=None, save=False):
    X, y, cols = imodels.util.data_util.get_clean_dataset(dataset_name, database_name)
    cols = np.array(cols)
    
    # Select provided columns
    if cols_name:
        _save_name = ""
        new_X = X[:, np.array([np.where(cols == cols_name[0])[0][0], np.where(cols == cols_name[1])[0][0]])]
    else:
        # Test which feature is the most important
        _save_name = "_reproduced"
        cols_name = feature_importance(X, y, cols)
    new_X = X[:, np.array([np.where(cols == cols_name[0])[0][0], np.where(cols == cols_name[1])[0][0]])]
    X_train, X_test, y_train, y_test = train_test_split(new_X, y, test_size=0.33)
    
    # Train the random forest
    RF = RandomForestClassifier(n_estimators=50)
    RF.fit(X_train, y_train)
    hsRF = imodels.HSTreeClassifierCV(deepcopy(RF))
    hsRF.fit(X_train, y_train)
    print(f"Optimal lambda: {hsRF.reg_param}")
    
    decision_boundary(X_test, y_test, RF, "RF", cols_name, "../graphs/claim_4/boundaries/"+dataset_name+"_RF"+_save_name, save)
    decision_boundary(X_test, y_test, hsRF, "hsRF", cols_name, "../graphs/claim_4/boundaries/"+dataset_name+"_hsRF"+_save_name, save)

In [None]:
np.random.seed(42)
test_simpler_boundary("heart", "imodels", ["att_8", "att_10"])
test_simpler_boundary("heart", "imodels")

In [None]:
np.random.seed(42)
test_simpler_boundary("breast_cancer", "imodels", ["age", "tumor-size"])
test_simpler_boundary("breast_cancer", "imodels")

In [None]:
np.random.seed(42)
test_simpler_boundary("haberman", "imodels", ["Age_of_patient_at_time_of_operation", "Number_of_positive_axillary_nodes_detected"])
test_simpler_boundary("haberman", "imodels")

In [None]:
np.random.seed(42)
test_simpler_boundary("ionosphere", "pmlb", ["X_4", "X_6"])
test_simpler_boundary("ionosphere", "pmlb")

In [None]:
np.random.seed(42)
test_simpler_boundary("diabetes", "pmlb", ["A2", "A6"])
test_simpler_boundary("diabetes", "pmlb")

In [None]:
np.random.seed(42)
test_simpler_boundary("german", "pmlb", ["Credit", "Age"])
test_simpler_boundary("german", "pmlb")

In [None]:
np.random.seed(42)
test_simpler_boundary("juvenile_clean", "imodels", ["friends_broken_in_steal:1", "fr_suggest_agnts_law:2"])
test_simpler_boundary("juvenile_clean", "imodels")

In [None]:
np.random.seed(42)
test_simpler_boundary("compas_two_year_clean", "imodels", ["age", "priors_count"])
test_simpler_boundary("compas_two_year_clean", "imodels")