In [1]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
model_dir_root = os.environ["DATA"] + "models/db2.0.0/"
shap_scale_img_path = os.path.join(db_dir,"shap_scale.png")
yes_pred_img_path = os.path.join(db_dir,"button_yes.png")
no_pred_img_path = os.path.join(db_dir,"button_no.png")
pos_pred_legend_path = os.path.join(db_dir,"cf_info_pos.png")
neg_pred_legend_path = os.path.join(db_dir,"cf_info_neg.png")

XAI_DATASET_SIZE = 100
N_JOBS = 20
N_JOBS_GPU = 6

In [2]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

In [3]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape, generic_rule_N_times_color_exactly, \
    generic_rule_shape_color_plus_shape_equals_N, generic_rule_exist_row_with_only_color_and_col_with_only_shape, \
    generic_rule_shape_in_every_row

rules_data = [
    {"name": "disc_1_circle_all", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}, "question": "In the image, is there a triangle in every row (1, ..., 6)?", "target_acc" : 1.0},

    # {"name": "easy_1_6_blue", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#0C90C0", "N": 6, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 6 blue symbols?", "target_acc": 0.8},
    # {"name": "easy_2_row_circle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only circles?", "target_acc": 0.8},
    # {"name": "easy_3_7_purple", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#A33E9A", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 7 purple symbols?", "target_acc": 0.8},
    # {"name": "easy_4_row_triangle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only triangles?", "target_acc": 0.8},
    # {"name": "easy_5_7_yellow", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#E0B000", "N": 5, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 7 yellow symbols?", "target_acc": 0.8},
    # {"name": "easy_6_row_square", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "square", "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only squares?", "target_acc": 0.8},
    #
    # {"name": "hard_1_blue_square_plus_circle_8", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#0C90C0", "shape1": "square", "shape2": "circle", "N": 8, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
    #  "question": "In the image, does the number of blue squares plus (+) the number of circles equal to 8?", "target_acc": 0.8},
    # {"name": "hard_2_row_purple_col_triangle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#A33E9A", "shape": "triangle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only purple symbols, and one column (A, ..., F) containing only triangles?", "target_acc": 0.8},
    # {"name": "hard_3_yellow_circle_plus_triangle_9", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#E0B000", "shape1": "circle", "shape2": "triangle", "N": 9, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, does the number of yellow circles plus (+) the number of triangles equal to 9?", "target_acc": 0.8},
    # {"name": "hard_4_row_yellow_col_circle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#E0B000", "shape": "circle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only yellow symbols, and one column (A, ..., F) containing only circles?", "target_acc": 0.8},
    # {"name": "hard_5_purple_triangle_plus_square_7", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#A33E9A", "shape1": "triangle", "shape2": "square", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, does the number of purple triangles plus (+) the number of squares equal to 7?", "target_acc": 0.8},
    # {"name": "hard_6_row_blue_col_square", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#0C90C0", "shape": "square" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only blue symbols, and one column (A, ..., F) containing only squares?", "target_acc": 0.8},
]

In [4]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)

In [None]:
from xaipatimg.ml.xai import generate_shap_resnet18, generate_counterfactuals_resnet18_random_approach, create_xai_index
from tqdm import tqdm

for rule_idx in tqdm(range(len(rules_data))):

    model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
    dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"

    xai_output_paths = {
        "shap" : "shap",
        "cf" : "cf"
    }

    # generate_shap_resnet18(db_dir, dataset_filename=dataset_filename,
    #                        model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["shap"]),
    #                        yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0", n_jobs=N_JOBS,
    #                        dataset_size=XAI_DATASET_SIZE, masker="ndarray", shap_scale_img_path=shap_scale_img_path)

    generate_counterfactuals_resnet18_random_approach(db_dir, dataset_filename=dataset_filename,
                                                      model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["cf"]),
                                                      yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
                                                      shapes=SHAPES, colors=COLORS, empty_probability=1-SHAPE_PROB,
                                                      max_depth=10, nb_tries_per_depth=2000, devices=["cuda:0", "cuda:1"], n_jobs=N_JOBS_GPU,
                                                      dataset_size=XAI_DATASET_SIZE,pos_pred_legend_path=pos_pred_legend_path,
                                                      neg_pred_legend_path=neg_pred_legend_path)

    create_xai_index(db_dir, dataset_filename=dataset_filename, model_dir=model_dir, xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0")


  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset content for disc_1_circle_all_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  5%|▌         | 5/100 [00:00<00:01, 47.73it/s][A
 12%|█▏        | 12/100 [00:00<00:01, 51.60it/s][A
 18%|█▊        | 18/100 [00:00<00:01, 52.63it/s][A
 24%|██▍       | 24/100 [00:00<00:01, 53.46it/s][A
 30%|███       | 30/100 [00:00<00:01, 54.22it/s][A
 38%|███▊      | 38/100 [00:00<00:01, 60.04it/s][A
 44%|████▍     | 44/100 [00:00<00:00, 60.00it/s][A
 50%|█████     | 50/100 [00:00<00:00, 59.24it/s][A
 56%|█████▌    | 56/100 [00:00<00:00, 59.02it/s][A
 62%|██████▏   | 62/100 [00:01<00:00, 57.81it/s][A
 69%|██████▉   | 69/100 [00:01<00:00, 60.52it/s][A
 76%|███████▌  | 76/100 [00:01<00:00, 54.75it/s][A
 83%|████████▎ | 83/100 [00:01<00:00, 57.42it/s][A
 89%|████████▉ | 89/100 [00:01<00:00, 57.95it/s][A
100%|██████████| 100/100 [00:01<00:00, 57.62it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


Generating counterfactual images



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:00<00:29,  3.37it/s][A
  6%|▌         | 6/100 [00:03<00:48,  1.94it/s][A

Loading dataset content for /tmp/tmp9kfuqlm2/dataset.csv
Loading dataset content for /tmp/tmpjdt3vd31/dataset.csv


  0%|          | 0/2000 [00:00<?, ?it/s]22.28it/s]

Loading dataset content for /tmp/tmp70fcbb7y/dataset.csv
Loading dataset content for /tmp/tmpvzpyxl5h/dataset.csv


  1%|▏         | 26/2000 [00:00<00:16, 121.68it/s]

Loading dataset content for /tmp/tmpw4lkmbs_/dataset.csv


  4%|▍         | 78/2000 [00:00<00:15, 124.92it/s]]

Loading dataset content for /tmp/tmpmewoibr5/dataset.csv


100%|██████████| 2000/2000 [00:16<00:00, 123.79it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 121.71it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:15<00:00, 125.92it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 123.19it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:15<00:00, 127.16it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 122.65it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

 12%|█▏        | 12/100 [01:05<09:28,  6.47s/it][A

Loading dataset content for /tmp/tmphb6_nlmx/dataset.csv


  0%|          | 0/2000 [00:00<?, ?it/s]122.53it/s]

Loading dataset content for /tmp/tmpvz5dhdop/dataset.csv
Loading dataset content for /tmp/tmpxpedn3t4/dataset.csv


  0%|          | 0/2000 [00:00<?, ?it/s]121.15it/s]

Loading dataset content for /tmp/tmp2ja4ylyr/dataset.csv


  5%|▍         | 97/2000 [00:00<00:16, 116.23it/s]]

Loading dataset content for /tmp/tmp45xfacq6/dataset.csv


  1%|          | 13/2000 [00:00<00:15, 126.30it/s]]

Loading dataset content for /tmp/tmphdwsihud/dataset.csv


100%|██████████| 2000/2000 [00:15<00:00, 125.32it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 122.90it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 121.65it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 119.43it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:15<00:00, 126.15it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 123.41it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

 18%|█▊        | 18/100 [02:03<10:52,  7.96s/it][A