In [None]:
import logging

logging.basicConfig(
    filename="./data/mnist-oab-exp-seven-exp.log",
    filemode="a",
    format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
    datefmt="%H:%M:%S",
    level=logging.INFO,
)
import pickle
import oab
import numpy as np
from rich import print
import matplotlib.pyplot as plt
from collections import Counter

(X_train, Y_train), (X_test, Y_test), (X_tree, Y_tree) = oab.get_data()
my_dom = oab.Domain("mnist", "RF")
index = {}
points = {}
for my_class in range(10):
    index[my_class] = np.where(Y_test == my_class)[0]
my_dom.load()

In [None]:
my_class = 9
# my_dom.explanation_base = my_dom.explanation_base[:500]
print(len(my_dom.explanation_base))

In [None]:
print(
    counting := Counter(
        [point.blackboxpd.predicted_class for point in my_dom.explanation_base]
    )
)
print(
    f"There are {len(index[my_class])} points with class {my_class} in the [bold red]test[/] set"
)
print(
    f"There are {counting[str(my_class)]} images with the same class in the [bold blue]explanation base[/]"
)

In [None]:
for my_index in index[my_class]:
    testpoint = oab.TestPoint(X_test[my_index], my_dom)
    logging.info(f"start explanation of index {my_index}")
    exp = oab.Explainer(testpoint, howmany=10)
    try:
        points[my_class].append(exp)
    except KeyError:
        points[my_class] = [exp]
    print("nice")
    if len(points[my_class]) >= 1:
        # how many points to do per class
        break

In [None]:
print(f"showing class {my_class}")
for explainer in points[my_class]:
    print("for test", explainer.testpoint.latent.a)
    plt.imshow(explainer.testpoint.a.astype("uint8"), cmap="gray")
    plt.title(
        f"TestPoint - black box predicted class: {explainer.testpoint.blackboxpd.predicted_class}"
    )
    plt.show()

    print("for treepoint", explainer.target.latent.a)
    plt.imshow(explainer.target.a.astype("uint8"), cmap="gray")
    plt.title(
        f"TreePoint - black box predicted class: {explainer.target.blackboxpd.predicted_class}"
    )
    plt.show()

    print("counterrules")
    print(explainer.target.latentdt.counterrules)

    print("for this target treepoint:")
    print(explainer.target)
    print("# factuals")
    for point in explainer.factuals:
        plt.imshow(point.a.astype("uint8"), cmap="gray")
        plt.title(
            f"factual - black box predicted class: {point.blackboxpd.predicted_class}"
        )
        plt.show()
    print("## factuals BUT closest, instead of furthest")
    for point in explainer._factuals_default(closest=True):
        plt.imshow(point.a.astype("uint8"), cmap="gray")
        plt.title(
            f"factual - black box predicted class: {point.blackboxpd.predicted_class}"
        )
        plt.show()
    print("# counterfactuals")
    for point in explainer.counterfactuals:
        plt.imshow(point.a.astype("uint8"), cmap="gray")
        plt.title(
            f"counterfactual - black box predicted class: {point.blackboxpd.predicted_class}"
        )
        plt.show()
    print("## New method for **more** counterfactuals!")
    for point in explainer.more_counterfactuals():
        # this is generating right here the "more" counterfactuals
        plt.imshow(point.a.astype("uint8"), cmap="gray")
        plt.title(
            f"counterfactual - black box predicted class: {point.blackboxpd.predicted_class}"
        )
        plt.show()