## AIPI 590 - XAI | Assignment #4
### Interpretable ML II
#### Author: Tal Erez
#### Colab Notebook:
[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/notthattal/InterpretableML_II/blob/main/imodels_demo.ipynb)

### Introduction

This notebook demonstrates the CART, FIGS and Rule-Fit algorithms which can be found [here](https://github.com/csinva/imodels?tab=readme-ov-file) and are part of Python's imodels interpretability library.

### Install required dependencies and import packages

In [None]:
import os

# Remove Colab default sample_data if it exists
!rm -r sample_data

# Clone GitHub files to colab workspace
repo_name = "InterpretableML_II"

# Check if the repo already exists
if not os.path.exists("/content/" + repo_name):
    git_path = 'https://github.com/notthattal/InterpretableML_II.git'
    !git clone "{git_path}"
else:
    print(f"{repo_name} already exists.")

# Change working directory to location of notebook
path_to_notebook = os.path.join("/content/" + repo_name)
%cd "{path_to_notebook}"
%ls

!pip install -r requirements.txt

from imodels import get_clean_dataset, GreedyTreeClassifier, FIGSClassifier, RuleFitRegressor
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris, load_breast_cancer
from sklearn.tree import plot_tree
from sklearn.metrics import classification_report, mean_squared_error

### Load Iris Dataset
- loads the dataset
- assign the features and target
- assign the feature names and target names
- split the data into training and test sets

In [None]:
# Load the iris dataset from sklearn
dataset = load_iris()

# assign the features and target
X = dataset.data
y = dataset.target

# assign the feature and target names
feature_names = dataset.feature_names
target_names = dataset.target_names

# split the data into a training and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

## CART (Classification And Regression Tree)
The Cart Approach:
1. CART works by using a greedy approach for splitting. Different splitting points are tested using a cost function (generally SSE for regression and Gini Index for classification)
2. The algorithm chooses to split the node based on the minimum cost
3. Repeat the process until the stopping criterion is reached. The most common stopping criterion used for cart is a threshold representing a minimum amount of training data for every leaf node.
4. It is recommended to Prune the tree before outputting the final model, but it is not required.
5. Output the final tree

![CART Image](https://github.com/notthattal/InterpretableML_II/blob/main/img/CART.png?raw=1)

### CART Algorithm

In [None]:
# initialize the CART classifier
model = GreedyTreeClassifier(max_depth=3)

#fit the model
model.fit(X_train, y_train, feature_names=feature_names)

# get the predictions
preds = model.predict(X_test)

# create and display the plot of the created tree
plt.figure(figsize=(15, 10))
plot_tree(model, filled=True, feature_names=feature_names, class_names=target_names)
plt.show()

# display metrics
print("Classification Report:\n", classification_report(y_test, preds, target_names=target_names))

### Load Heart Dataset
- loads the dataset
- assign the features and target
- assign the feature names and target names
- split the data into training and test sets

In [None]:
# get the features, target and feature names from imodels' heart dataset
X, y, feature_names = get_clean_dataset('heart')

# split the data into a training and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

### FIGS (Fast Interpretable Greedy-Tree Sums)
The FIGS Approach:
1. Given a dataset, at each iteration of FIGS there is the option to split a current tree or create a new tree. The choice to split or create
a new tree is either based on whichever minimizes the total unexplained variance or another splitting criterion.
2. After splitting (or creating the new tree), the model predicts the residuals for each tree after summing the predictions over all other trees.
3. This process is repeated until a stopping criterion is met. Some common stopping criterions are:
    - A threshold based on the model's predictive performance
    - Domain knowledge on how interpretable the model needs to be
    - Impurity decrease threshold

![FIGS Image](https://github.com/notthattal/InterpretableML_II/blob/main/img/FIGS.png?raw=1)

### FIGS Algorithm

In [None]:
# initialize the FIGS Classifier
model = FIGSClassifier(max_rules=4)

# fit the model
model.fit(X_train, y_train)

# visualize the model
model.plot(feature_names=feature_names, fig_size=5)

### Load Breast Cancer Dataset
- loads the dataset
- assign the features and target
- assign the feature names and target names
- split the data into training and test sets

In [None]:
# Load the breast cancer dataset from sklearn
dataset = load_breast_cancer()

# get the features, target, feature names and target names from the breast cancer dataset
X = dataset.data
y = dataset.target
feature_names = dataset.feature_names
target_names = dataset.target_names

# split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

### Rule-Fit Regression
The Rule-Fit Approach:
1. Create a decision-tree ensemble based on a dataset
2. The rulefit algorithm creates a rule (which is binary) for each node of every tree in the ensemble
3. Once rules are generated they are added as new features to the original dataset
4. We then fit the new data on a LASSO regularized linear regression model which accounts for the interaction features

![RuleFit Image](https://github.com/notthattal/InterpretableML_II/blob/main/img/RuleFit.png?raw=1)

### Rule-Fit Algorithm

In [None]:
# initialize the Rule-Fit Regressor
model = RuleFitRegressor(random_state=42)

# fit the model
model.fit(X_train, y_train, feature_names=feature_names)

# get the predictions on the test set
preds = model.predict(X_test)

# calculate and display the mean-squared error
mse = mean_squared_error(y_test, preds)
print(f'Mean Squared Error: {mse:.2f}\n')

# visualize the rules created by the model
model.visualize()