In [None]:
import logging
logging.basicConfig(
    filename="./data/mnist-oab-exp-third-exp.log",
    filemode="a",
    format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
    datefmt="%H:%M:%S",
    level=logging.INFO,
)
import pickle
import oab
import numpy as np
from rich import print
import matplotlib.pyplot as plt

(X_train, Y_train), (X_test, Y_test), (X_tree, Y_tree) = oab.get_data()
my_dom = oab.Domain("mnist", "RF")
index = {}
points = {}
for my_class in range(10):
    index[my_class] = np.where(Y_test == my_class)[0]

In [None]:
my_dom.load()

In [None]:
how_many_classes_todo = 1
class_to_work = 3

Save to disk explainers that pass the factual rule

In [None]:
for my_class in range(class_to_work, class_to_work + how_many_classes_todo):
    print(f"There are {len(index[my_class])} points with class {my_class} in the test set")
    for my_index in index[my_class]:
        testpoint = oab.TestPoint(X_test[my_index], my_dom)
        logging.info(f"start explanation of index {my_index}")
        exp = oab.Explainer(testpoint, howmany=10)
        if testpoint in exp.target.latentdt.rule:
            try:
                points[my_class].append(exp)
            except KeyError:
                points[my_class] = [exp]
            print("nice")
            if len(points[my_class]) >= 3:
                # how many points to do per class
                break

Now we plot

In [None]:
for my_class in range(class_to_work, class_to_work + how_many_classes_todo):
    
    print(f"showing class {my_class}")
    for explainer in points[my_class]:
        print("for test", explainer.testpoint.latent.a)
        plt.imshow(explainer.testpoint.a.astype("uint8"), cmap="gray")
        plt.title(f"TestPoint - black box predicted class: {explainer.testpoint.blackboxpd.predicted_class}")
        plt.show()

        print("for treepoint", explainer.target.latent.a)
        plt.imshow(explainer.target.a.astype("uint8"), cmap="gray")
        plt.title(f"TreePoint - black box predicted class: {explainer.target.blackboxpd.predicted_class}")
        plt.show()
        
        print("counterrules")
        print(explainer.target.latentdt.counterrules)

        print("for this target treepoint:")
        print(explainer.target)
        print("# factuals")
        for point in explainer.factuals:
            plt.imshow(point.a.astype("uint8"), cmap="gray")
            plt.title(f"factual - black box predicted class: {point.blackboxpd.predicted_class}")
            plt.show()
        print("## factuals BUT closest, instead of furthest")
        for point in explainer._factuals_default(closest=True):
            plt.imshow(point.a.astype("uint8"), cmap="gray")
            plt.title(f"factual - black box predicted class: {point.blackboxpd.predicted_class}")
            plt.show()
        print("# counterfactuals")
        for point in explainer.counterfactuals:
            plt.imshow(point.a.astype("uint8"), cmap="gray")
            plt.title(f"counterfactual - black box predicted class: {point.blackboxpd.predicted_class}")
            plt.show()
        print("## New method for **more** counterfactuals!")
        for point in explainer.more_counterfactuals():
            # this is generating right here the "more" counterfactuals
            plt.imshow(point.a.astype("uint8"), cmap="gray")
            plt.title(f"counterfactual - black box predicted class: {point.blackboxpd.predicted_class}")
            plt.show()

In [None]:
import copy
print("Counterrules applied on treepoint itself (like Abele)")
print(explainer.target.latentdt.counterrules[1])
print(explainer.target.latent.a)

explainer = points[my_class][-1]

eps = 0.04
i = 1
geq = {">=", ">"}

value_to_overwrite = (
    explainer.target.latentdt.counterrules[1].value + eps * i if explainer.target.latentdt.counterrules[1].operator in geq else explainer.target.latentdt.counterrules[1].value - eps * i
)

# THIS IS THE IMAGEEXPLANATION GENERATION
new_point = oab.ImageExplanation(
    latent=oab.Latent(a=copy.deepcopy(explainer.target.latent.a)),
)
new_point.latent.a[explainer.target.latentdt.counterrules[1].feature] = value_to_overwrite

# static set discriminator probability at 0.35
# passes discriminator? Return it immediately.
# No? start again with entire point generation
if my_dom.ae.discriminate(new_point) >= 0.35:
    print(new_point)
else:
    print(f"ahi {my_dom.ae.discriminate(new_point)}")


In [None]:
plt.imshow(new_point.a.astype("uint8"), cmap="gray")
plt.title(f"TreePoint - black box predicted class: {new_point.blackboxpd.predicted_class}")
plt.show()

In [None]:
print(new_point)

In [None]:
rabbia = new_point.latent.a
start = -2
for i in range(20):
    new_pointtt = copy.deepcopy(rabbia)
    new_pointtt[0] = start + i * 0.2
    new_pointtt = oab.ImageExplanation(oab.Latent(new_pointtt))
    plt.imshow(new_pointtt.a.astype("uint8"), cmap="gray")
    plt.title(f"TreePoint - black box predicted class: {new_pointtt.blackboxpd.predicted_class}")
    plt.show()