# ðŸ§° Setup & Imports

In [None]:
!pip install shap eli5 scikit-learn tqdm alibi joblib

In [None]:
import os
import pandas as pd
import joblib
import numpy as np
import pickle
import shap
import re
import matplotlib.pyplot as plt
from google.colab import files
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.tree import export_text
from sklearn.inspection import PartialDependenceDisplay
from alibi.explainers import AnchorTabular

# ðŸ“‚ Load Data and Artifacts

In [None]:
uploaded = files.upload()

df = pd.read_csv('alien_dataset.csv')  # <--- Update if animals or aliens
X = df.drop(columns=['Class'])
y = df['Class']


# Load model, scaler, label encoder using joblib
model = joblib.load('model.pkl')
scaler = joblib.load('scaler.pkl')
label_encoder = joblib.load('label_encoder.pkl')

# Scale features
X_scaled = scaler.transform(X)

# ðŸ¤– Predict
print("Predicting classes...")
y_pred = model.predict(X_scaled)

feature_names = X.columns
feature_colors = {
    feature: color for feature, color in zip(
        feature_names,
        ['red', 'green', 'blue', 'orange', 'purple']
    )
}

# ðŸ“Š SHAP Analysis: Per Class & Cluster

In [None]:
# ------------------------------------------
# Setup: Define variables
# ------------------------------------------

n_clusters_list = [2, 3, 4, 5]

feature_importances_by_class_cluster_average = {}
feature_importances_by_class_average = {}
feature_importances_by_class = {}
feature_importances_by_class_cluster= {}

# ------------------------------------------
# Build SHAP Explainer Once
# ------------------------------------------

print("Building SHAP explainer...")
explainer = shap.TreeExplainer(model)

# Before loop
shap_values = explainer.shap_values(X_scaled)

# ------------------------------------------
# For each class
# ------------------------------------------

for class_idx, class_label in enumerate(label_encoder.classes_):
    print(f"\nAnalyzing class: {class_label}")

    idx_class = np.where(y_pred == class_idx)[0]
    X_class = X_scaled[idx_class]

    if len(X_class) == 0:
        continue

    # Extract SHAP values for this class
    shap_vals_class = shap_values[idx_class, :, class_idx]  # (samples_in_class, features)

    # Save mean(abs(SHAP)) for whole class
    mean_abs_shap_class = np.abs(shap_vals_class).mean(axis=0)
    feature_importances_by_class_average[class_label] = mean_abs_shap_class
    feature_importances_by_class[class_label] = shap_vals_class

    # ------------------------------------------
    # Now cluster inside class
    # ------------------------------------------

    for n_clusters in n_clusters_list:

        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        clusters = kmeans.fit_predict(X_class)

        for cluster_id in range(n_clusters):

            idx_cluster = np.where(clusters == cluster_id)[0]

            if len(idx_cluster) == 0:
                continue

            # Extract SHAP values for this cluster correctly
            shap_vals_cluster = shap_vals_class[idx_cluster, :]  # Shape: (samples_in_cluster, features)

            # Mean SHAP values
            mean_abs_shap_cluster_average = np.abs(shap_vals_cluster).mean(axis=0)

            # Save
            key = (class_label, n_clusters, cluster_id)
            feature_importances_by_class_cluster_average[key] = mean_abs_shap_cluster_average
            feature_importances_by_class_cluster[key] = shap_vals_cluster


In [None]:
print("\nPlotting SHAP Feature Importance and Distribution Grid...")

# 1. Get class labels
class_labels = list(label_encoder.classes_)
n_classes = len(class_labels)

# 2. Create 2 rows (importance / boxplot), n_classes columns
fig, axes = plt.subplots(2, n_classes, figsize=(6 * n_classes, 12))

if n_classes == 1:
    axes = axes.reshape(2, 1)  # Fix for 1-class edge case

# 3. First row: Absolute importance (mean abs SHAP)
for idx, class_label in enumerate(class_labels):
    ax = axes[0, idx]

    feature_names_to_plot = feature_names[:len(feature_importances_by_class_average[class_label])]
    shap_vals_to_plot = feature_importances_by_class_average[class_label]

    ax.bar(
        feature_names_to_plot,
        shap_vals_to_plot,
        color=[feature_colors[feat] for feat in feature_names_to_plot]
    )
    ax.set_title(f"{class_label} - Mean |SHAP|", fontsize=14)
    ax.set_ylabel("Mean |SHAP| Value", fontsize=12)
    ax.set_xticklabels(feature_names_to_plot, rotation=45, ha='right')
    ax.grid(True, axis='y')

# 4. Second row: Distribution of SHAP values (Boxplot)
for idx, class_label in enumerate(class_labels):
    ax = axes[1, idx]

    shap_vals_class = feature_importances_by_class[class_label]  # Full SHAP values (n_samples_in_class, n_features)

    bp = ax.boxplot(
        shap_vals_class,
        vert=True,
        labels=feature_names[:shap_vals_class.shape[1]],
        patch_artist=True,
        showfliers=True
    )

    # Optional: color boxes
    box_colors = [feature_colors[feat] for feat in feature_names[:shap_vals_class.shape[1]]]
    for patch, color in zip(bp['boxes'], box_colors):
        patch.set_facecolor(color)

    ax.axhline(0, color='black', linestyle='--')
    ax.set_title(f"{class_label} - SHAP Distribution (Boxplot)", fontsize=14)
    ax.set_ylabel("SHAP Value (Effect Direction)", fontsize=12)
    ax.set_xticklabels(feature_names[:shap_vals_class.shape[1]], rotation=45, ha='right')
    ax.grid(True, axis='y')

# Final layout
fig.suptitle("Full Class SHAP Feature Importance and Effect Distribution", fontsize=20)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

In [None]:
# Example settings
# class = ["Duck", "Dog", "Cat"] or ["Zarnak", "Quorvian", "Bliptor"]
class_label = "Bliptor"
# clasters : 2, 3, 4, 5
n_clusters = 5
feature_names_to_plot = feature_names

# Colors for clusters
cluster_colors = ["#fd7f6f", "#7eb0d5", "#b2e061", "#bd7ebe", "#ffb55a", "#ffee65", "#beb9db", "#fdcce5", "#8bd3c7"]


# Prepare
feature_importances_clusters = []
for cluster_id in range(n_clusters):
    key = (class_label, n_clusters, cluster_id)
    if key in feature_importances_by_class_cluster_average:
        feature_importances_clusters.append(feature_importances_by_class_cluster_average[key])

feature_importances_clusters = np.array(feature_importances_clusters)

# Plot
x = np.arange(len(feature_names_to_plot))  # Feature index
width = 0.15  # Bar width

fig, ax = plt.subplots(figsize=(12, 6))

for cluster_idx in range(n_clusters):
    ax.bar(
        x + cluster_idx * width,
        feature_importances_clusters[cluster_idx],
        width=width,
        label=f"Cluster {cluster_idx}",
        color=cluster_colors[cluster_idx % len(cluster_colors)]  # Color by cluster
    )

# X-axis setup
ax.set_xticks(x + width * (n_clusters-1)/2)
ax.set_xticklabels(feature_names_to_plot, rotation=45, ha='right')

# Labels and Title
ax.set_ylabel("Mean |SHAP| Value", fontsize=12)
ax.set_title(f"Feature Importance Across Clusters for Class '{class_label}'", fontsize=16)
ax.grid(True, axis='y')
ax.legend(title="Cluster")
plt.tight_layout()
plt.show()

# ðŸ§  Anchor Explanations

In [None]:
print("\nRunning Anchor explanations...")

# Create anchor explainer
predict_fn = lambda x: model.predict(x)
anchor_explainer = AnchorTabular(predict_fn, feature_names=X.columns.tolist())
anchor_explainer.fit(X_scaled)

n_anchors_per_class = 3

for class_idx, class_label in enumerate(label_encoder.classes_):
    print(f"\nAnchors for class: {class_label}")
    idx_class = np.where(y_pred == class_idx)[0]

    for j in range(min(n_anchors_per_class, len(idx_class))):
        i = np.random.choice(idx_class)
        sample_scaled = X_scaled[i].reshape(1, -1)

        # Inverse transform sample to original scale
        sample_original = scaler.inverse_transform(sample_scaled)

        # Explain
        explanation = anchor_explainer.explain(sample_scaled)

        print(f"\nAnchor {j+1} for class {class_label}:")

        # Rebuild human-readable conditions
        readable_conditions = []

        # inside your loop
        for cond in explanation.anchor:
            # Try to parse
            match = re.match(r"([a-zA-Z_ ]+) ([<>=]+) ([\d\.\-eE]+)", cond.strip())
            if match:
                feature_name, operator, threshold_scaled_str = match.groups()
                threshold_scaled = float(threshold_scaled_str)

                feature_idx = list(X.columns).index(feature_name)

                dummy = np.zeros((1, X_scaled.shape[1]))
                dummy[0, feature_idx] = threshold_scaled
                threshold_original = scaler.inverse_transform(dummy)[0, feature_idx]

                readable_conditions.append(f"{feature_name} {operator} {threshold_original:.2f}")
            else:
                print(f"Warning: Could not parse condition: {cond}")
                continue


        print('  Conditions:', readable_conditions)
        print('  Precision:', explanation.precision)
        print('  Coverage:', explanation.coverage)


# ðŸ“œ Rule Extraction

In [None]:
# ---- Start the Rule Extraction Process ----
print("\nExtracting Rules...")

# Define different extraction settings: Light, Medium, Strict
settings = [
    {'name': 'Light', 'n_trees': 5, 'threshold': 0.3},   # Light: Few trees, low frequency threshold
    {'name': 'Medium', 'n_trees': 10, 'threshold': 0.5}, # Medium: More trees, higher frequency threshold
    {'name': 'Strict', 'n_trees': 20, 'threshold': 0.7}, # Strict: Many trees, strict frequency threshold
]

# Loop over each setting
for setting in settings:
    print(f"\n### Setting: {setting['name']} ###")
    n_trees = setting['n_trees']
    threshold = setting['threshold']

    # Loop over each class label (e.g., class names like 'Apple', 'Banana', etc.)
    for class_idx, class_label in enumerate(label_encoder.classes_):
        print(f"\nClass: {class_label}")

        # ---- SIMPLE RULES: Show rules from individual trees ----
        print("\nSimple Rules (Single Trees):")
        for i in range(min(n_trees, len(model.estimators_))):
            # Export and print the decision rules from tree i
            tree_rules = export_text(model.estimators_[i], feature_names=list(X.columns))
            print(f"\nTree {i} Rules:\n", tree_rules)

        # ---- AGGREGATED RULES: Combine rules across multiple trees ----
        print("\nAggregated Rules (Across Forest):")

        feature_counts = {}  # Dictionary to count how often each feature condition appears
        path_counts = [0]    # List with one item: total number of paths predicting this class

        # Loop over the first n_trees estimators
        for estimator in model.estimators_[:n_trees]:
            tree = estimator.tree_
            feature = tree.feature
            threshold_values = tree.threshold

            # Recursive function to walk through a tree
            def traverse(node=0, conditions=[]):
                # If we reach a leaf node
                if tree.children_left[node] == tree.children_right[node]:
                    if np.argmax(tree.value[node]) == class_idx:
                        # If the leaf predicts the current class, count its conditions
                        for cond in conditions:
                            feature_counts[cond] = feature_counts.get(cond, 0) + 1
                        path_counts[0] += 1  # Count this path
                    return

                # If it's a decision node
                feat = feature[node]
                thresh = threshold_values[node]
                descaled = thresh * scaler.scale_[feat] + scaler.mean_[feat]



                # Recursively explore left and right branches, adding conditions

                if feat >= 0:
                    # Recursively explore left and right branches, adding conditions
                    traverse(
                        tree.children_left[node],
                        conditions + [f"{X.columns[feat]} <= {descaled:.2f}"]
                    )
                    traverse(
                        tree.children_right[node],
                        conditions + [f"{X.columns[feat]} > {descaled:.2f}"]
                    )

            # Start traversing from the root node
            traverse()

        # After traversing all trees, print conditions that appear frequently enough
        for cond, count in feature_counts.items():
            freq = count / path_counts[0] if path_counts[0] else 0
            if freq >= threshold:
                print(f"{cond} appeared in {freq*100:.1f}% of paths.")


In [None]:

# ---- Start the Rule Extraction Process ----
print("\nExtracting Rules...")

# Define output file
output_file = "extracted_rules.txt"

# Clear the file before writing
with open(output_file, "w") as f:
    f.write("")

# Define different extraction settings: Light, Medium, Strict
settings = [
    {'name': 'Light', 'n_trees': 5, 'threshold': 0.3},
    {'name': 'Medium', 'n_trees': 10, 'threshold': 0.5},
    {'name': 'Strict', 'n_trees': 20, 'threshold': 0.7},
]

# Feature names list
feature_names = list(X.columns)

# Helper function to log text to file and console
def log(text):
    print(text)
    with open(output_file, "a") as f:
        f.write(text + "\n")

# Loop over each setting
for setting in settings:
    log(f"\n### Setting: {setting['name']} (n_trees={setting['n_trees']}, threshold={setting['threshold']}) ###")
    n_trees = setting['n_trees']
    threshold = setting['threshold']

    # Loop over each class label
    for class_idx, class_label in enumerate(label_encoder.classes_):
        log(f"\nClass: {class_label}")

        # ---- SIMPLE RULES: Show rules from individual trees ----
        log("\nSimple Rules (Single Trees):")
        for i in range(min(n_trees, len(model.estimators_))):
            estimator = model.estimators_[i]
            tree = estimator.tree_
            feature = tree.feature
            threshold_values = tree.threshold

            def print_rules(node=0, depth=0):
                indent = "  " * depth
                if tree.children_left[node] == tree.children_right[node]:
                    predicted_class_idx = np.argmax(tree.value[node])
                    predicted_class = label_encoder.classes_[predicted_class_idx]
                    log(indent + f"Predict: {predicted_class}")
                    return

                feat = feature[node]
                thresh = threshold_values[node]

                # Properly descale
                if isinstance(scaler, StandardScaler):
                    descaled_thresh = thresh * scaler.scale_[feat] + scaler.mean_[feat]
                else:
                    descaled_thresh = thresh

                if feat >= 0:
                    log(indent + f"if {feature_names[feat]} <= {descaled_thresh:.2f}:")
                    print_rules(tree.children_left[node], depth + 1)
                    log(indent + f"else ({feature_names[feat]} > {descaled_thresh:.2f}):")
                    print_rules(tree.children_right[node], depth + 1)

            log(f"\nTree {i} Rules:")
            print_rules()

        # ---- AGGREGATED RULES: Combine rules across multiple trees ----
        log("\nAggregated Rules (Across Forest):")

        feature_counts = {}  # Dictionary to count how often each feature condition appears
        path_counts = [0]    # List with one item: total number of paths predicting this class

        # Loop over the first n_trees estimators
        for estimator in model.estimators_[:n_trees]:
            tree = estimator.tree_
            feature = tree.feature
            threshold_values = tree.threshold

            # Recursive function to walk through a tree
            def traverse(node=0, conditions=[]):
                if tree.children_left[node] == tree.children_right[node]:
                    if np.argmax(tree.value[node]) == class_idx:
                        for cond in conditions:
                            feature_counts[cond] = feature_counts.get(cond, 0) + 1
                        path_counts[0] += 1
                    return

                feat = feature[node]
                thresh = threshold_values[node]

                if feat >= 0:
                    # Properly descale
                    if isinstance(scaler, StandardScaler):
                        descaled_thresh = thresh * scaler.scale_[feat] + scaler.mean_[feat]
                    else:
                        descaled_thresh = thresh

                    traverse(
                        tree.children_left[node],
                        conditions + [f"{feature_names[feat]} <= {descaled_thresh:.2f}"]
                    )
                    traverse(
                        tree.children_right[node],
                        conditions + [f"{feature_names[feat]} > {descaled_thresh:.2f}"]
                    )

            traverse()

        # After traversing all trees, print conditions that appear frequently enough
        for cond, count in feature_counts.items():
            freq = count / path_counts[0] if path_counts[0] else 0
            if freq >= threshold:
                log(f"{cond} appeared in {freq*100:.1f}% of paths.")

log("\nFinished rule extraction. Output saved to 'extracted_rules.txt'.")


## PDP (Partial Dependency Plot)

In [None]:
def plot_pdp_grid_for_class(model, X_scaled, feature_names, label_encoder, class_label, scaler, grid_resolution=50):
    """
    Plots PDP+ICE (individual) and PDP (average) for all features, side-by-side per feature.
    """

    print(f"Plotting PDP grid for class '{class_label}'...")

    class_idx = list(label_encoder.classes_).index(class_label)
    n_features = len(feature_names)

    fig, axes = plt.subplots(nrows=n_features, ncols=2, figsize=(14, 3 * n_features))

    if n_features == 1:
        axes = axes.reshape(1, 2)

    for feature_idx, feature_name in enumerate(feature_names):

        # --- LEFT: PDP + ICE ---
        display = PartialDependenceDisplay.from_estimator(
            model,
            X_scaled,
            features=[feature_idx],
            target=class_idx,
            feature_names=feature_names,
            kind='individual',
            grid_resolution=grid_resolution,
            ax=axes[feature_idx, 0],
        )

        ax_left = display.axes_[0, 0]
        x_ticks_scaled = ax_left.get_xticks()

        x_dummy = np.zeros((len(x_ticks_scaled), X_scaled.shape[1]))
        x_dummy[:, feature_idx] = x_ticks_scaled
        x_ticks_original = scaler.inverse_transform(x_dummy)[:, feature_idx]

        ax_left.set_xticks(x_ticks_scaled)
        ax_left.set_xticklabels([f"{tick:.2f}" for tick in x_ticks_original])

        ax_left.set_ylabel("Partial Dependence (Individual)", fontsize=10)
        ax_left.set_xlabel(f"{feature_name} (Original Scale)", fontsize=10)
        ax_left.set_title(f"{feature_name} - PDP + ICE", fontsize=12)
        ax_left.grid(True)

        # --- RIGHT: PDP Average only ---
        display = PartialDependenceDisplay.from_estimator(
            model,
            X_scaled,
            features=[feature_idx],
            target=class_idx,
            feature_names=feature_names,
            kind='average',
            grid_resolution=grid_resolution,
            ax=axes[feature_idx, 1],
        )

        ax_right = display.axes_[0, 0]
        x_ticks_scaled = ax_right.get_xticks()

        x_dummy = np.zeros((len(x_ticks_scaled), X_scaled.shape[1]))
        x_dummy[:, feature_idx] = x_ticks_scaled
        x_ticks_original = scaler.inverse_transform(x_dummy)[:, feature_idx]

        ax_right.set_xticks(x_ticks_scaled)
        ax_right.set_xticklabels([f"{tick:.2f}" for tick in x_ticks_original])

        ax_right.set_ylabel("Partial Dependence (Average)", fontsize=10)
        ax_right.set_xlabel(f"{feature_name} (Original Scale)", fontsize=10)
        ax_right.set_title(f"{feature_name} - PDP Average", fontsize=12)
        ax_right.grid(True)

    fig.suptitle(f"Partial Dependence for Class '{class_label}'", fontsize=18)
    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    plt.show()


# -----

def plot_2d_pdp_on_ax(model, X_scaled, feature_names, label_encoder, class_label, feature1, feature2, scaler, ax, grid_resolution=30):
    """
    Plots a 2D PDP for two features on a given matplotlib axis.
    """

    class_idx = list(label_encoder.classes_).index(class_label)

    feature1_idx = list(feature_names).index(feature1)
    feature2_idx = list(feature_names).index(feature2)

    display = PartialDependenceDisplay.from_estimator(
        model,
        X_scaled,
        features=[(feature1_idx, feature2_idx)],
        target=class_idx,
        feature_names=feature_names,
        kind='average',
        grid_resolution=grid_resolution,
        ax=ax,
    )

    ax_disp = display.axes_[0, 0]

    # Fix X ticks
    x_ticks_scaled = ax_disp.get_xticks()
    x_dummy = np.zeros((len(x_ticks_scaled), X_scaled.shape[1]))
    x_dummy[:, feature1_idx] = x_ticks_scaled
    x_ticks_original = scaler.inverse_transform(x_dummy)[:, feature1_idx]

    ax_disp.set_xticks(x_ticks_scaled)
    ax_disp.set_xticklabels([f"{tick:.2f}" for tick in x_ticks_original])

    # Fix Y ticks
    y_ticks_scaled = ax_disp.get_yticks()
    y_dummy = np.zeros((len(y_ticks_scaled), X_scaled.shape[1]))
    y_dummy[:, feature2_idx] = y_ticks_scaled
    y_ticks_original = scaler.inverse_transform(y_dummy)[:, feature2_idx]

    ax_disp.set_yticks(y_ticks_scaled)
    ax_disp.set_yticklabels([f"{tick:.2f}" for tick in y_ticks_original])

    ax_disp.set_xlabel(f"{feature1} (original scale)", fontsize=8)
    ax_disp.set_ylabel(f"{feature2} (original scale)", fontsize=8)
    ax_disp.set_title(f"{feature1} vs {feature2}", fontsize=9)
    ax_disp.tick_params(axis='both', labelsize=6)
    ax_disp.grid(True)
#------

def plot_2d_pdp_grid(model, X_scaled, feature_names, label_encoder, class_label, scaler):
    """
    Plots a full 5x5 grid of 2D PDPs for all feature pairs.
    """

    n_features = len(feature_names)
    fig, axes = plt.subplots(nrows=n_features, ncols=n_features, figsize=(3*n_features, 3*n_features))

    if n_features == 1:
        axes = axes.reshape(1, 1)

    for i, feature1 in enumerate(feature_names):
        for j, feature2 in enumerate(feature_names):
            ax = axes[i, j]

            if i == j:
                ax.axis('off')
                ax.text(0.5, 0.5, feature1, horizontalalignment='center', verticalalignment='center', fontsize=10)
            else:
                plot_2d_pdp_on_ax(
                    model=model,
                    X_scaled=X_scaled,
                    feature_names=feature_names,
                    label_encoder=label_encoder,
                    class_label=class_label,
                    feature1=feature1,
                    feature2=feature2,
                    scaler=scaler,
                    ax=ax,
                    grid_resolution=30,
                )

    fig.suptitle(f"2D Partial Dependence Grid for Class '{class_label}'", fontsize=20)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


In [None]:
# class = ["Duck", "Dog", "Cat"] or ["Zarnak", "Quorvian", "Bliptor"]
class_label = "Cat"
plot_pdp_grid_for_class(
    model=model,
    X_scaled=X_scaled,
    feature_names=feature_names,
    label_encoder=label_encoder,
    class_label=class_label,
    scaler=scaler,
    grid_resolution=50
)

plot_2d_pdp_grid(
    model=model,
    X_scaled=X_scaled,
    feature_names=feature_names,
    label_encoder=label_encoder,
    class_label=class_label,
    scaler=scaler
)