In [2]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db1.0.0/"
test_dataset_filename = "rowcircles_test.csv"

model_dir = os.environ["DATA"] + "models/db1.0.0/rowcircles/"
shap_scale_img_path = os.path.join(db_dir,"shap_scale.png")
yes_pred_img_path = os.path.join(db_dir,"button_yes.png")
no_pred_img_path = os.path.join(db_dir,"button_no.png")
pos_pred_legend_path = os.path.join(db_dir,"cf_info_pos.png")
neg_pred_legend_path = os.path.join(db_dir,"cf_info_neg.png")


In [3]:
# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#F86C62", "#7AB0CD", "#F4D67B", "#87C09C"]

In [4]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)

In [7]:
import numpy as np
from xaipatimg.ml.xai import generate_shap_resnet18

generate_shap_resnet18(db_dir, test_dataset_filename, model_dir, yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0", n_jobs=1, dataset_size=2, masker="ndarray", shap_scale_img_path=shap_scale_img_path)

Loading dataset content for rowcircles_test.csv


100%|██████████| 2/2 [00:00<00:00, 182.04it/s]
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values


KeyboardInterrupt: 

Counterfactuals

In [5]:
import numpy as np

def extract_rows_with_only_circles(img_content):
    circles_counter = np.zeros(6,)
    non_circles_counter = np.zeros(6,)

    for c in img_content:
        if c["shape"] == "circle":
            circles_counter[c["pos"][1]] += 1
        else:
            non_circles_counter[c["pos"][1]] += 1

    return np.logical_and(circles_counter >= 1, non_circles_counter == 0)

def exist_row_with_only_circles(img_content):
    return np.sum(extract_rows_with_only_circles(img_content)) >= 1

In [6]:
from xaipatimg.datagen.utils import gen_rand_sym, PatImgObj
import numpy as np
COLORS  = ["#F86C62", "#7AB0CD", "#F4D67B", "#87C09C"]

def rowcircles_counterfactuals(img_entry, is_pos, nb_cf):

    output_cf = []

    # Case we are searching for negative counterfactuals
    if is_pos:
        # Extracting indices of lines with only circles
        rows_with_only_circles = np.nonzero(extract_rows_with_only_circles(img_entry["content"]))[0]

        # Iterating over all counterfactuals to generate.
        for _ in range(nb_cf):

            patimgobj = PatImgObj(img_entry)

            # Randomly setting a non circle symbol in every row that only contains circles
            for j_row in rows_with_only_circles:
                patimgobj.set_symbol(posx=np.random.choice(np.arange(img_entry["division"][0])),
                                     posy=j_row,
                                     value=gen_rand_sym(shapes=["square", "triangle"], colors=COLORS))

            output_cf.append(patimgobj.get_img_dict())

    # Case we are looking for a positive counterfactual
    else:
        patimgobj = PatImgObj(img_entry)

        # Creating the list of lines which are not empty in a random order
        non_empty_lines_rnd = np.setdiff1d(np.arange(img_entry["division"][1]), patimgobj.get_empty_lines())
        np.random.shuffle(non_empty_lines_rnd)

        for idx in range(min(nb_cf, len(non_empty_lines_rnd))):
            patimgobj = PatImgObj(img_entry)
            patimgobj.change_shapes_of_line(non_empty_lines_rnd[idx], "circle")
            output_cf.append(patimgobj.get_img_dict())

    return output_cf


In [None]:
# from xaipatimg.ml.xai import generate_counterfactuals_resnet18
#
# generate_counterfactuals_resnet18(db_dir, test_dataset_filename, model_dir, rowcircles_counterfactuals, 5, None, 1, None)

Loading dataset content for rowcircles_test.csv


100%|██████████| 1000/1000 [00:16<00:00, 62.43it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


In [5]:
from xaipatimg.ml.xai import generate_counterfactuals_resnet18_random_approach

generate_counterfactuals_resnet18_random_approach(db_dir, test_dataset_filename, model_dir, yes_pred_img_path, no_pred_img_path, SHAPES, COLORS, 1-SHAPE_PROB,
                                                  max_depth=10, nb_tries_per_depth=2000, device="cuda:0", n_jobs=1, dataset_size=10,
                                                  pos_pred_legend_path=pos_pred_legend_path, neg_pred_legend_path=neg_pred_legend_path)

  from .autonotebook import tqdm as notebook_tqdm


Loading dataset content for rowcircles_test.csv


100%|██████████| 10/10 [00:00<00:00, 75.78it/s]
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Generating counterfactual images


  0%|          | 0/10 [00:00<?, ?it/s]

Loading dataset content for /tmp/tmpcxm51cx0/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 8/2000 [00:00<00:26, 76.32it/s][A
  1%|          | 16/2000 [00:00<00:28, 69.83it/s][A
  1%|          | 24/2000 [00:00<00:35, 55.74it/s][A
  2%|▏         | 31/2000 [00:00<00:33, 59.64it/s][A
  2%|▏         | 38/2000 [00:00<00:38, 51.19it/s][A
  2%|▏         | 44/2000 [00:00<00:37, 52.67it/s][A
  3%|▎         | 52/2000 [00:00<00:33, 58.84it/s][A
  3%|▎         | 60/2000 [00:01<00:30, 63.86it/s][A
  3%|▎         | 67/2000 [00:01<00:30, 63.37it/s][A
  4%|▎         | 74/2000 [00:01<00:31, 61.27it/s][A
  4%|▍         | 81/2000 [00:01<00:36, 52.34it/s][A
  4%|▍         | 87/2000 [00:01<00:41, 46.63it/s][A
  5%|▍         | 92/2000 [00:01<00:48, 39.71it/s][A
  5%|▍         | 97/2000 [00:01<00:46, 41.17it/s][A
  5%|▌         | 102/2000 [00:01<00:46, 41.13it/s][A
  5%|▌         | 108/2000 [00:02<00:41, 45.56it/s][A
  6%|▌         | 113/2000 [00:02<00:42, 44.87it/s][A
  6%|▌         | 118/2000 [00:02<00:45, 41.79it/s][

Loading dataset content for /tmp/tmp_p_1mpb6/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 6/2000 [00:00<00:34, 58.47it/s][A
  1%|          | 12/2000 [00:00<00:40, 49.39it/s][A
  1%|          | 21/2000 [00:00<00:30, 65.76it/s][A
  1%|▏         | 29/2000 [00:00<00:28, 69.92it/s][A
  2%|▏         | 37/2000 [00:00<00:30, 64.44it/s][A
  2%|▏         | 44/2000 [00:00<00:34, 57.42it/s][A
  2%|▎         | 50/2000 [00:00<00:35, 55.23it/s][A
  3%|▎         | 58/2000 [00:00<00:31, 61.21it/s][A
  3%|▎         | 67/2000 [00:01<00:28, 67.34it/s][A
  4%|▎         | 74/2000 [00:01<00:30, 62.72it/s][A
  4%|▍         | 82/2000 [00:01<00:28, 66.69it/s][A
  4%|▍         | 90/2000 [00:01<00:27, 69.43it/s][A
  5%|▍         | 98/2000 [00:01<00:26, 70.99it/s][A
  5%|▌         | 106/2000 [00:01<00:25, 72.93it/s][A
  6%|▌         | 114/2000 [00:01<00:31, 59.15it/s][A
  6%|▌         | 123/2000 [00:01<00:28, 65.76it/s][A
  7%|▋         | 131/2000 [00:02<00:30, 61.55it/s][A
  7%|▋         | 138/2000 [00:02<00:32, 57.30it/s]

Loading dataset content for /tmp/tmp_cxaj6uq/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 7/2000 [00:00<00:28, 68.88it/s][A
  1%|          | 16/2000 [00:00<00:25, 76.68it/s][A
  1%|          | 24/2000 [00:00<00:26, 75.25it/s][A
  2%|▏         | 32/2000 [00:00<00:28, 68.06it/s][A
  2%|▏         | 40/2000 [00:00<00:27, 70.83it/s][A
  2%|▏         | 48/2000 [00:00<00:31, 62.10it/s][A
  3%|▎         | 55/2000 [00:00<00:32, 59.06it/s][A
  3%|▎         | 62/2000 [00:00<00:34, 55.73it/s][A
  3%|▎         | 68/2000 [00:01<00:35, 54.51it/s][A
  4%|▎         | 74/2000 [00:01<00:37, 51.65it/s][A
  4%|▍         | 80/2000 [00:01<00:37, 51.21it/s][A
  4%|▍         | 86/2000 [00:01<00:37, 51.58it/s][A
  5%|▍         | 92/2000 [00:01<00:36, 52.38it/s][A
  5%|▍         | 99/2000 [00:01<00:33, 56.66it/s][A
  5%|▌         | 105/2000 [00:01<00:34, 55.65it/s][A
  6%|▌         | 112/2000 [00:01<00:31, 59.07it/s][A
  6%|▌         | 120/2000 [00:02<00:29, 64.28it/s][A
  6%|▋         | 127/2000 [00:02<00:32, 57.99it/s][

Loading dataset content for /tmp/tmpta1n68pr/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 8/2000 [00:00<00:28, 70.44it/s][A
  1%|          | 16/2000 [00:00<00:27, 71.90it/s][A
  1%|          | 24/2000 [00:00<00:27, 72.43it/s][A
  2%|▏         | 32/2000 [00:00<00:31, 63.28it/s][A
  2%|▏         | 40/2000 [00:00<00:29, 66.06it/s][A
  2%|▏         | 47/2000 [00:00<00:32, 59.91it/s][A
  3%|▎         | 54/2000 [00:00<00:33, 57.30it/s][A
  3%|▎         | 60/2000 [00:00<00:35, 55.23it/s][A
  3%|▎         | 68/2000 [00:01<00:31, 61.08it/s][A
  4%|▍         | 75/2000 [00:01<00:32, 60.13it/s][A
  4%|▍         | 82/2000 [00:01<00:33, 57.14it/s][A
  4%|▍         | 88/2000 [00:01<00:33, 56.84it/s][A
  5%|▍         | 96/2000 [00:01<00:30, 61.53it/s][A
  5%|▌         | 103/2000 [00:01<00:30, 62.12it/s][A
  6%|▌         | 111/2000 [00:01<00:28, 65.69it/s][A
  6%|▌         | 119/2000 [00:01<00:27, 68.67it/s][A
  6%|▋         | 127/2000 [00:01<00:26, 70.96it/s][A
  7%|▋         | 135/2000 [00:02<00:25, 72.57it/s]

Loading dataset content for /tmp/tmpn5mac126/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 10/2000 [00:00<00:21, 93.35it/s][A
  1%|          | 20/2000 [00:00<00:23, 85.54it/s][A
  1%|▏         | 29/2000 [00:00<00:24, 80.65it/s][A
  2%|▏         | 38/2000 [00:00<00:29, 67.24it/s][A
  2%|▏         | 46/2000 [00:00<00:27, 69.86it/s][A
  3%|▎         | 54/2000 [00:00<00:29, 65.27it/s][A
  3%|▎         | 62/2000 [00:00<00:29, 66.76it/s][A
  4%|▎         | 71/2000 [00:00<00:27, 70.89it/s][A
  4%|▍         | 79/2000 [00:01<00:28, 67.71it/s][A
  4%|▍         | 86/2000 [00:01<00:30, 63.53it/s][A
  5%|▍         | 94/2000 [00:01<00:28, 67.09it/s][A
  5%|▌         | 101/2000 [00:01<00:31, 59.35it/s][A
  5%|▌         | 108/2000 [00:01<00:34, 54.23it/s][A
  6%|▌         | 117/2000 [00:01<00:30, 62.57it/s][A
  6%|▌         | 124/2000 [00:01<00:29, 63.71it/s][A
  7%|▋         | 131/2000 [00:02<00:31, 59.94it/s][A
  7%|▋         | 138/2000 [00:02<00:33, 56.24it/s][A
  7%|▋         | 146/2000 [00:02<00:30, 60.73it/

Loading dataset content for /tmp/tmp47q5o992/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 7/2000 [00:00<00:31, 62.52it/s][A
  1%|          | 15/2000 [00:00<00:28, 69.56it/s][A
  1%|          | 22/2000 [00:00<00:35, 55.60it/s][A
  1%|▏         | 28/2000 [00:00<00:37, 52.13it/s][A
  2%|▏         | 34/2000 [00:00<00:39, 50.28it/s][A
  2%|▏         | 40/2000 [00:00<00:40, 47.97it/s][A
  2%|▏         | 48/2000 [00:00<00:35, 54.69it/s][A
  3%|▎         | 54/2000 [00:01<00:36, 52.80it/s][A
  3%|▎         | 61/2000 [00:01<00:34, 55.84it/s][A
  3%|▎         | 68/2000 [00:01<00:33, 57.74it/s][A
  4%|▎         | 74/2000 [00:01<00:37, 51.46it/s][A
  4%|▍         | 80/2000 [00:01<00:38, 49.81it/s][A
  4%|▍         | 86/2000 [00:01<00:41, 46.46it/s][A
  5%|▍         | 92/2000 [00:01<00:39, 48.06it/s][A
  5%|▍         | 98/2000 [00:01<00:37, 50.97it/s][A
  5%|▌         | 104/2000 [00:01<00:37, 50.25it/s][A
  6%|▌         | 110/2000 [00:02<00:38, 49.55it/s][A
  6%|▌         | 116/2000 [00:02<00:39, 48.21it/s][A

Loading dataset content for /tmp/tmpv9p2mmdk/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 6/2000 [00:00<00:37, 53.09it/s][A
  1%|          | 12/2000 [00:00<00:36, 54.44it/s][A
  1%|          | 20/2000 [00:00<00:31, 63.68it/s][A
  1%|▏         | 27/2000 [00:00<00:33, 58.75it/s][A
  2%|▏         | 33/2000 [00:00<00:35, 54.95it/s][A
  2%|▏         | 39/2000 [00:00<00:37, 52.96it/s][A
  2%|▏         | 48/2000 [00:00<00:31, 62.05it/s][A
  3%|▎         | 55/2000 [00:00<00:32, 59.23it/s][A
  3%|▎         | 62/2000 [00:01<00:31, 61.57it/s][A
  4%|▎         | 70/2000 [00:01<00:29, 65.39it/s][A
  4%|▍         | 77/2000 [00:01<00:31, 61.14it/s][A
  4%|▍         | 84/2000 [00:01<00:32, 59.35it/s][A
  5%|▍         | 93/2000 [00:01<00:29, 65.71it/s][A
  5%|▌         | 101/2000 [00:01<00:27, 68.14it/s][A
  5%|▌         | 109/2000 [00:01<00:26, 70.42it/s][A
  6%|▌         | 117/2000 [00:01<00:26, 72.06it/s][A
  6%|▋         | 125/2000 [00:01<00:26, 72.01it/s][A
  7%|▋         | 133/2000 [00:02<00:25, 72.91it/s]

Loading dataset content for /tmp/tmpf076sb5q/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 7/2000 [00:00<00:30, 66.41it/s][A
  1%|          | 14/2000 [00:00<00:39, 49.67it/s][A
  1%|          | 20/2000 [00:00<00:43, 45.65it/s][A
  1%|▏         | 27/2000 [00:00<00:37, 52.05it/s][A
  2%|▏         | 33/2000 [00:00<00:40, 48.76it/s][A
  2%|▏         | 40/2000 [00:00<00:37, 52.16it/s][A
  2%|▏         | 46/2000 [00:00<00:43, 45.01it/s][A
  3%|▎         | 51/2000 [00:01<00:43, 44.43it/s][A
  3%|▎         | 56/2000 [00:01<00:44, 43.80it/s][A
  3%|▎         | 61/2000 [00:01<00:44, 43.64it/s][A
  3%|▎         | 66/2000 [00:01<00:45, 42.06it/s][A
  4%|▎         | 71/2000 [00:01<00:43, 43.92it/s][A
  4%|▍         | 76/2000 [00:01<00:44, 43.14it/s][A
  4%|▍         | 81/2000 [00:01<00:44, 43.43it/s][A
  4%|▍         | 86/2000 [00:01<00:44, 42.85it/s][A
  5%|▍         | 91/2000 [00:01<00:43, 44.34it/s][A
  5%|▍         | 98/2000 [00:02<00:37, 50.35it/s][A
  5%|▌         | 104/2000 [00:02<00:36, 51.96it/s][A
 

Loading dataset content for /tmp/tmpfxbcnucp/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 3/2000 [00:00<01:15, 26.52it/s][A
  0%|          | 6/2000 [00:00<01:39, 20.14it/s][A
  0%|          | 9/2000 [00:00<01:34, 21.00it/s][A
  1%|          | 15/2000 [00:00<01:01, 32.38it/s][A
  1%|          | 21/2000 [00:00<00:49, 39.99it/s][A
  1%|▏         | 26/2000 [00:00<00:48, 40.83it/s][A
  2%|▏         | 33/2000 [00:00<00:41, 47.82it/s][A
  2%|▏         | 38/2000 [00:01<00:48, 40.83it/s][A
  2%|▏         | 43/2000 [00:01<00:46, 41.66it/s][A
  2%|▏         | 48/2000 [00:01<00:45, 42.66it/s][A
  3%|▎         | 53/2000 [00:01<00:44, 43.90it/s][A
  3%|▎         | 58/2000 [00:01<00:47, 41.19it/s][A
  3%|▎         | 63/2000 [00:01<00:49, 39.01it/s][A
  3%|▎         | 69/2000 [00:01<00:43, 44.13it/s][A
  4%|▍         | 76/2000 [00:01<00:39, 48.77it/s][A
  4%|▍         | 82/2000 [00:02<00:41, 46.44it/s][A
  4%|▍         | 87/2000 [00:02<00:47, 39.89it/s][A
  5%|▍         | 92/2000 [00:02<00:49, 38.79it/s][A
  5%

Loading dataset content for /tmp/tmp_64hl6ye/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 7/2000 [00:00<00:29, 67.32it/s][A
  1%|          | 14/2000 [00:00<00:42, 46.69it/s][A
  1%|          | 20/2000 [00:00<00:46, 42.85it/s][A
  1%|▏         | 27/2000 [00:00<00:39, 50.37it/s][A
  2%|▏         | 35/2000 [00:00<00:33, 57.97it/s][A
  2%|▏         | 42/2000 [00:00<00:40, 48.28it/s][A
  2%|▏         | 49/2000 [00:00<00:36, 53.07it/s][A
  3%|▎         | 56/2000 [00:01<00:33, 57.18it/s][A
  3%|▎         | 63/2000 [00:01<00:36, 52.50it/s][A
  3%|▎         | 69/2000 [00:01<00:37, 51.07it/s][A
  4%|▍         | 75/2000 [00:01<00:41, 46.14it/s][A
  4%|▍         | 80/2000 [00:01<00:43, 44.17it/s][A
  4%|▍         | 87/2000 [00:01<00:38, 49.77it/s][A
  5%|▍         | 93/2000 [00:01<00:38, 49.45it/s][A
  5%|▌         | 100/2000 [00:01<00:35, 53.18it/s][A
  5%|▌         | 107/2000 [00:02<00:33, 56.74it/s][A
  6%|▌         | 114/2000 [00:02<00:31, 59.11it/s][A
  6%|▌         | 121/2000 [00:02<00:35, 52.92it/s][

Loading dataset content for /tmp/tmpxjglmqtj/dataset.csv



  0%|          | 0/2000 [00:00<?, ?it/s][A
  0%|          | 5/2000 [00:00<00:42, 46.43it/s][A
  0%|          | 10/2000 [00:00<00:48, 40.91it/s][A
  1%|          | 16/2000 [00:00<00:40, 48.54it/s][A
  1%|          | 23/2000 [00:00<00:35, 55.23it/s][A
  1%|▏         | 29/2000 [00:00<00:38, 50.68it/s][A
  2%|▏         | 35/2000 [00:00<00:40, 48.50it/s][A
  2%|▏         | 42/2000 [00:00<00:37, 52.41it/s][A
  2%|▏         | 48/2000 [00:00<00:36, 53.50it/s][A
  3%|▎         | 55/2000 [00:01<00:34, 57.06it/s][A
  3%|▎         | 62/2000 [00:01<00:32, 60.37it/s][A
  3%|▎         | 69/2000 [00:01<00:36, 52.96it/s][A
  4%|▍         | 75/2000 [00:01<00:39, 48.36it/s][A
  4%|▍         | 83/2000 [00:01<00:34, 55.12it/s][A
  4%|▍         | 90/2000 [00:01<00:33, 56.80it/s][A
  5%|▍         | 96/2000 [00:01<00:36, 52.26it/s][A
  5%|▌         | 102/2000 [00:01<00:35, 53.04it/s][A
  5%|▌         | 108/2000 [00:02<00:37, 50.99it/s][A
  6%|▌         | 115/2000 [00:02<00:34, 54.45it/s][A