In [1]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
model_dir_root = os.environ["DATA"] + "models/db2.0.0/"
shap_scale_img_path = os.path.join(db_dir,"shap_scale.png")
yes_pred_img_path = os.path.join(db_dir,"button_yes.png")
no_pred_img_path = os.path.join(db_dir,"button_no.png")
pos_pred_legend_path = os.path.join(db_dir,"cf_info_pos.png")
neg_pred_legend_path = os.path.join(db_dir,"cf_info_neg.png")

XAI_DATASET_SIZE = 100
N_JOBS = 20
N_JOBS_GPU = 6

In [2]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

In [3]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape, generic_rule_N_times_color_exactly, \
    generic_rule_shape_color_plus_shape_equals_N, generic_rule_exist_row_with_only_color_and_col_with_only_shape, \
    generic_rule_shape_in_every_row

rules_data = [
    {"name": "disc_1_circle_all", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}, "question": "In the image, is there a triangle in every row (1, ..., 6)?", "target_acc" : 1.0},

    {"name": "easy_1_6_blue", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#0C90C0", "N": 6, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 6 blue symbols?", "target_acc": 0.8},
    {"name": "easy_2_row_circle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only circles?", "target_acc": 0.8},
    {"name": "easy_3_7_purple", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#A33E9A", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 7 purple symbols?", "target_acc": 0.8},
    {"name": "easy_4_row_triangle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only triangles?", "target_acc": 0.8},
    {"name": "easy_5_7_yellow", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#E0B000", "N": 5, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 7 yellow symbols?", "target_acc": 0.8},
    {"name": "easy_6_row_square", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "square", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only squares?", "target_acc": 0.8},

    {"name": "hard_1_blue_square_plus_circle_8", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#0C90C0", "shape1": "square", "shape2": "circle", "N": 8, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of blue squares plus (+) the number of circles equal to 8?", "target_acc": 0.8},
    {"name": "hard_2_row_purple_col_triangle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#A33E9A", "shape": "triangle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only purple symbols, and one column (A, ..., F) containing only triangles?", "target_acc": 0.8},
    {"name": "hard_3_yellow_circle_plus_triangle_9", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#E0B000", "shape1": "circle", "shape2": "triangle", "N": 9, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, does the number of yellow circles plus (+) the number of triangles equal to 9?", "target_acc": 0.8},
    {"name": "hard_4_row_yellow_col_circle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#E0B000", "shape": "circle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only yellow symbols, and one column (A, ..., F) containing only circles?", "target_acc": 0.8},
    {"name": "hard_5_purple_triangle_plus_square_7", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#A33E9A", "shape1": "triangle", "shape2": "square", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, does the number of purple triangles plus (+) the number of squares equal to 7?", "target_acc": 0.8},
    {"name": "hard_6_row_blue_col_square", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#0C90C0", "shape": "square" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only blue symbols, and one column (A, ..., F) containing only squares?", "target_acc": 0.8},
]

In [4]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)

In [5]:
from xaipatimg.ml.xai import generate_shap_resnet18, generate_counterfactuals_resnet18_random_approach, create_xai_index
from tqdm import tqdm

for rule_idx in tqdm(range(len(rules_data))):

    model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
    dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"

    xai_output_paths = {
        "shap" : "shap",
        "cf" : "cf"
    }

    # generate_shap_resnet18(db_dir, dataset_filename=dataset_filename,
    #                        model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["shap"]),
    #                        yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0", n_jobs=N_JOBS,
    #                        dataset_size=XAI_DATASET_SIZE, masker="ndarray", shap_scale_img_path=shap_scale_img_path)

    # generate_counterfactuals_resnet18_random_approach(db_dir, dataset_filename=dataset_filename,
    #                                                   model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["cf"]),
    #                                                   yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
    #                                                   shapes=SHAPES, colors=COLORS, empty_probability=1-SHAPE_PROB,
    #                                                   max_depth=10, nb_tries_per_depth=2000, devices=["cuda:0", "cuda:1"], n_jobs=N_JOBS_GPU,
    #                                                   dataset_size=XAI_DATASET_SIZE,pos_pred_legend_path=pos_pred_legend_path,
    #                                                   neg_pred_legend_path=neg_pred_legend_path)

    create_xai_index(db_dir, dataset_filename=dataset_filename, model_dir=model_dir, xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0")


  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/13 [00:00<?, ?it/s]

Loading dataset content for disc_1_circle_all_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 17%|█▋        | 17/100 [00:00<00:00, 163.24it/s][A
 34%|███▍      | 34/100 [00:00<00:00, 132.18it/s][A
 54%|█████▍    | 54/100 [00:00<00:00, 157.43it/s][A
 71%|███████   | 71/100 [00:00<00:00, 132.48it/s][A
100%|██████████| 100/100 [00:00<00:00, 143.17it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
  8%|▊         | 1/13 [00:03<00:44,  3.73s/it]

Loading dataset content for easy_1_6_blue_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 16%|█▌        | 16/100 [00:00<00:00, 158.43it/s][A
 32%|███▏      | 32/100 [00:00<00:00, 145.35it/s][A
 50%|█████     | 50/100 [00:00<00:00, 157.11it/s][A
 66%|██████▌   | 66/100 [00:00<00:00, 141.36it/s][A
 81%|████████  | 81/100 [00:00<00:00, 134.31it/s][A
100%|██████████| 100/100 [00:00<00:00, 142.21it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 15%|█▌        | 2/13 [00:04<00:24,  2.18s/it]

Loading dataset content for easy_2_row_circle_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 16%|█▌        | 16/100 [00:00<00:00, 156.09it/s][A
 32%|███▏      | 32/100 [00:00<00:00, 109.46it/s][A
 44%|████▍     | 44/100 [00:00<00:00, 103.70it/s][A
 55%|█████▌    | 55/100 [00:00<00:00, 90.51it/s] [A
 70%|███████   | 70/100 [00:00<00:00, 105.62it/s][A
 82%|████████▏ | 82/100 [00:00<00:00, 96.83it/s] [A
100%|██████████| 100/100 [00:00<00:00, 103.08it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 23%|██▎       | 3/13 [00:06<00:18,  1.81s/it]

Loading dataset content for easy_3_7_purple_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 15%|█▌        | 15/100 [00:00<00:00, 146.45it/s][A
 30%|███       | 30/100 [00:00<00:00, 144.27it/s][A
 45%|████▌     | 45/100 [00:00<00:00, 120.14it/s][A
 62%|██████▏   | 62/100 [00:00<00:00, 135.43it/s][A
 79%|███████▉  | 79/100 [00:00<00:00, 144.53it/s][A
100%|██████████| 100/100 [00:00<00:00, 132.75it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 31%|███       | 4/13 [00:07<00:13,  1.55s/it]

Loading dataset content for easy_4_row_triangle_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 17%|█▋        | 17/100 [00:00<00:00, 162.88it/s][A
 34%|███▍      | 34/100 [00:00<00:00, 152.22it/s][A
 51%|█████     | 51/100 [00:00<00:00, 157.43it/s][A
 67%|██████▋   | 67/100 [00:00<00:00, 156.26it/s][A
100%|██████████| 100/100 [00:00<00:00, 151.40it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 38%|███▊      | 5/13 [00:08<00:11,  1.38s/it]

Loading dataset content for easy_5_7_yellow_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 16%|█▌        | 16/100 [00:00<00:00, 159.04it/s][A
 32%|███▏      | 32/100 [00:00<00:00, 127.10it/s][A
 49%|████▉     | 49/100 [00:00<00:00, 143.10it/s][A
 64%|██████▍   | 64/100 [00:00<00:00, 136.40it/s][A
 79%|███████▉  | 79/100 [00:00<00:00, 139.53it/s][A
100%|██████████| 100/100 [00:00<00:00, 135.76it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 46%|████▌     | 6/13 [00:09<00:09,  1.30s/it]

Loading dataset content for easy_6_row_square_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 10%|█         | 10/100 [00:00<00:00, 98.68it/s][A
 26%|██▌       | 26/100 [00:00<00:00, 134.47it/s][A
 42%|████▏     | 42/100 [00:00<00:00, 145.87it/s][A
 57%|█████▋    | 57/100 [00:00<00:00, 119.91it/s][A
 73%|███████▎  | 73/100 [00:00<00:00, 131.84it/s][A
100%|██████████| 100/100 [00:00<00:00, 136.91it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 54%|█████▍    | 7/13 [00:10<00:07,  1.25s/it]

Loading dataset content for hard_1_blue_square_plus_circle_8_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  5%|▌         | 5/100 [00:00<00:01, 49.43it/s][A
 10%|█         | 10/100 [00:00<00:02, 30.76it/s][A
 14%|█▍        | 14/100 [00:00<00:02, 31.36it/s][A
 20%|██        | 20/100 [00:00<00:02, 38.00it/s][A
 27%|██▋       | 27/100 [00:00<00:01, 46.97it/s][A
 43%|████▎     | 43/100 [00:00<00:00, 80.02it/s][A
 52%|█████▏    | 52/100 [00:00<00:00, 75.89it/s][A
 69%|██████▉   | 69/100 [00:00<00:00, 100.45it/s][A
100%|██████████| 100/100 [00:01<00:00, 84.39it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 62%|██████▏   | 8/13 [00:12<00:06,  1.36s/it]

Loading dataset content for hard_2_row_purple_col_triangle_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 12%|█▏        | 12/100 [00:00<00:00, 119.31it/s][A
 24%|██▍       | 24/100 [00:00<00:00, 107.69it/s][A
 41%|████      | 41/100 [00:00<00:00, 132.16it/s][A
 58%|█████▊    | 58/100 [00:00<00:00, 143.70it/s][A
 73%|███████▎  | 73/100 [00:00<00:00, 141.86it/s][A
100%|██████████| 100/100 [00:00<00:00, 126.68it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 69%|██████▉   | 9/13 [00:13<00:05,  1.31s/it]

Loading dataset content for hard_3_yellow_circle_plus_triangle_9_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 13%|█▎        | 13/100 [00:00<00:00, 125.29it/s][A
 26%|██▌       | 26/100 [00:00<00:00, 118.37it/s][A
 41%|████      | 41/100 [00:00<00:00, 130.78it/s][A
 58%|█████▊    | 58/100 [00:00<00:00, 143.01it/s][A
 73%|███████▎  | 73/100 [00:00<00:00, 135.58it/s][A
100%|██████████| 100/100 [00:00<00:00, 122.81it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 77%|███████▋  | 10/13 [00:14<00:03,  1.28s/it]

Loading dataset content for hard_4_row_yellow_col_circle_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 16%|█▌        | 16/100 [00:00<00:00, 159.28it/s][A
 32%|███▏      | 32/100 [00:00<00:00, 154.16it/s][A
 48%|████▊     | 48/100 [00:00<00:00, 151.42it/s][A
 64%|██████▍   | 64/100 [00:00<00:00, 151.01it/s][A
 81%|████████  | 81/100 [00:00<00:00, 155.24it/s][A
100%|██████████| 100/100 [00:00<00:00, 148.37it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 85%|████████▍ | 11/13 [00:15<00:02,  1.22s/it]

Loading dataset content for hard_5_purple_triangle_plus_square_7_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 11%|█         | 11/100 [00:00<00:00, 108.72it/s][A
 25%|██▌       | 25/100 [00:00<00:00, 121.11it/s][A
 39%|███▉      | 39/100 [00:00<00:00, 129.02it/s][A
 56%|█████▌    | 56/100 [00:00<00:00, 142.24it/s][A
 71%|███████   | 71/100 [00:00<00:00, 126.72it/s][A
 84%|████████▍ | 84/100 [00:00<00:00, 108.35it/s][A
100%|██████████| 100/100 [00:00<00:00, 103.08it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 92%|█████████▏| 12/13 [00:17<00:01,  1.27s/it]

Loading dataset content for hard_6_row_blue_col_square_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 14%|█▍        | 14/100 [00:00<00:00, 137.03it/s][A
 28%|██▊       | 28/100 [00:00<00:00, 97.14it/s] [A
 43%|████▎     | 43/100 [00:00<00:00, 114.43it/s][A
 60%|██████    | 60/100 [00:00<00:00, 130.98it/s][A
 74%|███████▍  | 74/100 [00:00<00:00, 125.53it/s][A
100%|██████████| 100/100 [00:00<00:00, 122.96it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 13/13 [00:18<00:00,  1.42s/it]
