# Study of the explanation base size
## Run these once, at the start of the notebook

In [None]:
import logging
from time import perf_counter

import numpy as np
import pandas as pd
from rich import print

import oab

In [None]:
# make results list
results_table = []

## Personalize the following cell

In [None]:
# personalize this for running the notebook in different ways
dataset_name = "mnist"
algo_name = "RF"

how_many_images = 50  # how many images each combination of the above?

## Run these after each "personalization"
To get results on multiple datasets, algorithms, classes

### Cells *not* timed
Data loading and more

In [None]:
(X_train, Y_train), (X_test, Y_test), (X_tree, Y_tree) = oab.get_data(dataset_name)
my_dom = oab.Domain(dataset_name, algo_name)
index = {}
points = {}
for my_class in range(10):
    index[my_class] = np.where(Y_test == my_class)[0]

### Timed cells

In [None]:
for my_class in [int(x) for x in my_dom.classes]:
    for i, my_index in enumerate(index[my_class]):
        start = perf_counter()
        testpoint = oab.TestPoint(X_test[my_index], my_dom)
        logging.info(f"start explanation of index {my_index}")
        exp = oab.Explainer(testpoint, howmany=5)

        end = perf_counter()
        execution_time = end - start

        # if we failed to find a target
        if exp.target:
            crules = len(exp.target.latentdt.counterrules)
        else:
            crules = "error"

        # append into results_table
        results_table.append(
            {
                "Dataset": dataset_name,
                "Algo": algo_name,
                "Class": my_class,
                "id": my_index,
                "Time": execution_time,
                "factuals / 5": len(exp.factuals),
                "cfact": len(exp.counterfactuals),
                "crules": crules,
            }
        )

        if i >= how_many_images - 1:
            # how many points to do per class
            break

## Run this to show the results table

In [None]:
results_dataframe = pd.DataFrame.from_records(results_table)
results_dataframe["Time"] = results_dataframe["Time"].round(2)
results_dataframe