# Explainable AI 

## Imports


In [16]:
from stable_baselines3 import DQN
import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils import seeding
import numpy as np
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz
import matplotlib.pyplot as plt
import csv
import imageio.v2 as imageio
import pydotplus
from IPython.display import Image
import io

## Create Instance of Environment 

This code creates an instance of the AirlineseatInventoryEnv class, and sets the parameters to the ranges written below. 

In [18]:
env = AirlineSeatInventoryEnv(
    num_seats=150,
    price_range=[500, 1000],
    time_till_flight=20,
    season_range=[0.5, 1.5],
    demand_lambda=10
)

## Load the model

In [21]:
# train and load the model, if you wish to generate state action pairs
#model_path = "insert_name_here"
#model = DQN.load(model_path)

## Sample State-Action Pairs

We create state-action pairs to use to train the decision tree based on what actions are picked based on states that the model is provided. We 

In [22]:
# Sample the environment to collect data for the decision tree
# this function will return the state-action pairs that will be used to train the decision tree

def sample():
    state_action_pairs = []
    X = []  # states
    y = []  # actions
    state, _ = env.reset()

    # Collect data from the model, by running the model in the environment
    # and storing the state-action pairs. This data will be used to train a decision tree
    # Here we collect 1000 samples
    for _ in range(1000):
        action, _ = model.predict(state, deterministic=False)
        state, _, _, _, _ = env.step(action)
        state_action_pairs.append((state, action))
        X.append(state)
        y.append(action)
        state, _ = env.reset()
    X = np.array(X)
    y = np.array(y)

    # Combine X and y into one array for saving. Assuming y can be appended to X.
    # This depends on whether the dimensions of X allow for appending y directly.
    # If y is a simple vector and X is 2D, you can use np.column_stack
    combined = np.column_stack((X, y))
    
    # Save the data to a csv file
    np.savetxt("state_action_pairs.csv", combined, delimiter=",")


We only need to sample state-action pairs once, so that our trees remain consistent with the explanation in the notebook. If you wish to generate new state-action pairs, you can uncomment sample in the code box below. 

In [23]:
sample()

In [24]:
state_action_pairs = np.loadtxt("state_action_pairs.csv", delimiter=',')

X = state_action_pairs[:, :-1]  # All rows, all columns except the last one
y = state_action_pairs[:, -1]   # All rows, just the last column


In [25]:
# Train the decision tree using the state-action pairs collected from the model
clf = DecisionTreeClassifier().fit(X, y)

# create class names and feature names for the decision tree
# this is optional, but can be useful for visualizing the tree
class_names = []
feature_names = []

## Visualize the Decision Tree

In [1]:
# Visualize the tree
plt.figure(figsize=(25, 25))
# plot the tree with filled nodes and rounded corners
plot_tree(clf, filled=True, rounded=True, 
          class_names=class_names,
          feature_names=feature_names, 
          node_ids=False)
# save the tree to a pdf file
save_path = "decision_tree.pdf"
plt.savefig(save_path)

NameError: name 'plt' is not defined

## Traversing the Decision Tree

In [27]:
def decision_tree_parser(clf, sample, feature_names, class_names, save_path=None):
    # Get the sample
    # Get the decision path for the sample
    decision_path = clf.decision_path(sample)
    node_indicator = decision_path.indices
    images = []

    # Generate dot data
    dot_data = export_graphviz(clf, out_file=None, feature_names=feature_names, 
                               class_names=class_names, filled=True, rounded=True,
                               special_characters=True)

    # Parse dot file
    graph = pydotplus.graph_from_dot_data(dot_data)

    # Highlight each node in the path
    for node_id in node_indicator:
        # Get the node
        node = graph.get_node(str(node_id))[0]
        # Change node attributes to highlight
        node.set_fillcolor("red")

        # Create PNG image
        png_str = graph.create_png(prog='dot')
        sio = io.BytesIO(png_str)
        image = imageio.imread(sio)
        images.append(image)

        # Reset color for next iteration
        node.set_fillcolor("white")

    # Create and save GIF if save_path is provided
    if save_path:
        gif_name = f'{save_path}.gif'
        imageio.mimsave(gif_name, images, fps=0.5)
        return Image(filename=gif_name)
    else:
        # Create a temporary GIF to display (in-memory)
        sio = io.BytesIO()
        imageio.mimsave(sio, images, fps=0.5, format='gif')
        sio.seek(0)
        return Image(data=sio.getvalue())



In [2]:
# 710 is a random sample number picked to illustrate the path
decision_tree_parser(clf, X[[710]], feature_names, class_names, 'decision_tree_path')

NameError: name 'decision_tree_parser' is not defined