# Timing of oab
## Run these once, at the start of the notebook

In [None]:
import logging
from time import perf_counter

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from rich import print

import oab

In [2]:
# make results list
results_table = []

## Personalize the following cell

In [3]:
# personalize this for running the notebook in different ways
algo_name = "RF"
dataset_name = "mnist"

how_many_images = 10  # how many images each combination of the above?

## Run these after each "personalization"
To get results on multiple datasets, algorithms, classes
### Non timed cells

In [None]:
(X_train, Y_train), (X_test, Y_test), (X_tree, Y_tree) = oab.get_data(dataset_name)
my_dom = oab.Domain(dataset_name, algo_name)
index = {}
points = {}
for my_class in range(10):
    index[my_class] = np.where(Y_test == my_class)[0]
myres = {}

### Timed cells

In [None]:
for my_class in [int(x) for x in my_dom.classes]:
    for i, my_index in enumerate(index[my_class]):
        start = perf_counter()
        testpoint = oab.TestPoint(X_test[my_index], my_dom)
        logging.info(f"start explanation of index {my_index}")
        exp = oab.Explainer(testpoint, howmany=5)

        end = perf_counter()
        execution_time = end - start

        # if we failed to find a target
        if exp.target:
            crules = len(exp.target.latentdt.counterrules)
            mymap = exp.get_map()
            experiment = oab.DeletionExperiment(exp, mymap)
            myres[my_index] = experiment.results
        else:
            crules = "error"

        # append into results_table
        results_table.append(
            {
                "Dataset": dataset_name,
                "Algo": algo_name,
                "Class": my_class,
                "id": my_index,
                "Time": round(execution_time, 2),
                "factuals / 5": len(exp.factuals),
                "cfact": len(exp.counterfactuals),
                "crules": crules,
            }
        )
        if i >= how_many_images - 1:
            # how many points to do per class
            break

## Run this to show the results table

In [None]:
results_dataframe = pd.DataFrame.from_records(results_table)
results_dataframe

## Graph

In [None]:
averages = pd.DataFrame(columns=["pixels remaining", "proba"])

for my_index in myres:
    averages = pd.concat(averages, myres[my_index]["proba"])

fig, ax = plt.subplots()
ax.plot(averages.index, averages.mean(axis=1))

ax.set(
    xlabel="deletion",
    ylabel="avg probability",
    title="Avg proba of predicting all class labels",
)
ax.grid()

plt.show()

### Individual instances

In [None]:
my_class = 7

for i, my_index in enumerate(index[my_class]):
    start = perf_counter()
    testpoint = oab.TestPoint(X_test[my_index], my_dom)
    logging.info(f"start explanation of index {my_index}")
    exp = oab.Explainer(testpoint, howmany=5)

    end = perf_counter()
    execution_time = end - start

    # if we failed to find a target
    if exp.target:
        crules = len(exp.target.latentdt.counterrules)
        mymap = exp.get_map()
        experiment = oab.DeletionExperiment(exp, mymap)
        myres[my_index] = experiment.results
    else:
        crules = "error"

    # append into results_table
    results_table.append(
        {
            "Dataset": dataset_name,
            "Algo": algo_name,
            "Class": my_class,
            "id": my_index,
            "Time": round(execution_time, 2),
            "factuals / 5": len(exp.factuals),
            "cfact": len(exp.counterfactuals),
            "crules": crules,
        }
    )
    if i >= how_many_images - 1:
        # how many points to do per class
        break

In [None]:
# plot

print("for test", exp.testpoint.latent.a)
plt.imshow(exp.testpoint.a.astype("uint8"), cmap="gray")
plt.title(
    f"TestPoint - black box predicted class: {exp.testpoint.blackboxpd.predicted_class}"
)
plt.show()

print("for treepoint", exp.target.latent.a)
plt.imshow(exp.target.a.astype("uint8"), cmap="gray")
plt.title(
    f"TreePoint - black box predicted class: {exp.target.blackboxpd.predicted_class}"
)
plt.show()

print("counterrules")
print(exp.target.latentdt.counterrules)

print("for this target treepoint:")
print(exp.target)
print("# factuals")
for point in exp.factuals:
    plt.imshow(point.a.astype("uint8"), cmap="gray")
    plt.title(
        f"factual - black box predicted class: {point.blackboxpd.predicted_class}"
    )
    plt.show()
print("## factuals BUT closest, instead of furthest")
for point in exp._factuals_default(closest=True):
    plt.imshow(point.a.astype("uint8"), cmap="gray")
    plt.title(
        f"factual - black box predicted class: {point.blackboxpd.predicted_class}"
    )
    plt.show()
print("# counterfactuals")
for point in exp.counterfactuals:
    plt.imshow(point.a.astype("uint8"), cmap="gray")
    plt.title(
        f"counterfactual - black box predicted class: {point.blackboxpd.predicted_class}"
    )
    plt.show()

print("## New method for **more** counterfactuals!")
for point in exp.more_counterfactuals():
    # this is generating right here the "more" counterfactuals
    plt.imshow(point.a.astype("uint8"), cmap="gray")
    plt.title(
        f"counterfactual - black box predicted class: {point.blackboxpd.predicted_class}"
    )
    plt.show()