In [2]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db1.0.0/"
test_dataset_filename = "rowcircles_test.csv"

model_dir = os.environ["DATA"] + "models/db1.0.0/rowcircles/"



In [3]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)

In [4]:
import numpy as np
from xaipatimg.ml.xai import generate_shap_resnet18

generate_shap_resnet18(db_dir, test_dataset_filename, model_dir, "cuda:0", n_jobs=20, dataset_size=1000, masker="ndarray")

  from .autonotebook import tqdm as notebook_tqdm


Loading dataset content for rowcircles_test.csv


100%|██████████| 20/20 [00:00<00:00, 66.23it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values


PartitionExplainer explainer: 21it [01:00,  3.35s/it]                        


Generating shap images


100%|██████████| 20/20 [00:06<00:00,  2.92it/s]


Counterfactuals

In [5]:
import numpy as np

def extract_rows_with_only_circles(img_content):
    circles_counter = np.zeros(6,)
    non_circles_counter = np.zeros(6,)

    for c in img_content:
        if c["shape"] == "circle":
            circles_counter[c["pos"][1]] += 1
        else:
            non_circles_counter[c["pos"][1]] += 1

    return np.logical_and(circles_counter >= 1, non_circles_counter == 0)

def exist_row_with_only_circles(img_content):
    return np.sum(extract_rows_with_only_circles(img_content)) >= 1

In [6]:
from xaipatimg.datagen.utils import gen_rand_sym, PatImgObj
import numpy as np
COLORS  = ["#F86C62", "#7AB0CD", "#F4D67B", "#87C09C"]

def rowcircles_counterfactuals(img_entry, is_pos, nb_cf):

    output_cf = []

    # Case we are searching for negative counterfactuals
    if is_pos:
        # Extracting indices of lines with only circles
        rows_with_only_circles = np.nonzero(extract_rows_with_only_circles(img_entry["content"]))[0]

        # Iterating over all counterfactuals to generate.
        for _ in range(nb_cf):

            patimgobj = PatImgObj(img_entry)

            # Randomly setting a non circle symbol in every row that only contains circles
            for j_row in rows_with_only_circles:
                patimgobj.set_symbol(posx=np.random.choice(np.arange(img_entry["division"][0])),
                                     posy=j_row,
                                     value=gen_rand_sym(shapes=["square", "triangle"], colors=COLORS))

            output_cf.append(patimgobj.get_img_dict())

    # Case we are looking for a positive counterfactual
    else:
        patimgobj = PatImgObj(img_entry)

        # Creating the list of lines which are not empty in a random order
        non_empty_lines_rnd = np.setdiff1d(np.arange(img_entry["division"][1]), patimgobj.get_empty_lines())
        np.random.shuffle(non_empty_lines_rnd)

        for idx in range(min(nb_cf, len(non_empty_lines_rnd))):
            patimgobj = PatImgObj(img_entry)
            patimgobj.change_shapes_of_line(non_empty_lines_rnd[idx], "circle")
            output_cf.append(patimgobj.get_img_dict())

    return output_cf


In [None]:
from xaipatimg.ml.xai import generate_counterfactuals_resnet18

generate_counterfactuals_resnet18(db_dir, test_dataset_filename, model_dir, rowcircles_counterfactuals, 5, None, 1, None, n_jobs=20)

Loading dataset content for rowcircles_test.csv


100%|██████████| 1000/1000 [00:16<00:00, 62.43it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
