In [1]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
model_dir_root = os.environ["DATA"] + "models/db2.0.0/"
shap_scale_img_path = os.path.join(db_dir,"shap_scale.png")
yes_pred_img_path = os.path.join(db_dir,"button_yes.png")
no_pred_img_path = os.path.join(db_dir,"button_no.png")
pos_pred_legend_path = os.path.join(db_dir,"cf_info_pos.png")
neg_pred_legend_path = os.path.join(db_dir,"cf_info_neg.png")

XAI_DATASET_SIZE = 100
N_JOBS = 20
N_JOBS_GPU = 6

In [2]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

In [3]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape, generic_rule_N_times_color_exactly, \
    generic_rule_shape_color_plus_shape_equals_N, generic_rule_exist_row_with_only_color_and_col_with_only_shape

rules_data = [
    {"name": "easy_1_6_blue", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#0C90C0", "N": 6, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    {"name": "easy_2_row_circle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS}},
    # {"name": "easy_3_7_purple", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#A33E9A", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "easy_4_row_triangle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}},
    # {"name": "easy_5_7_yellow", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#E0B000", "N": 5, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "easy_6_row_square", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "square", "y_division": Y_DIVISIONS}},
    #
    # {"name": "hard_1_blue_square_plus_circle_8", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#0C90C0", "shape1": "square", "shape2": "circle", "N": 8, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "hard_2_row_purple_col_triangle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#A33E9A", "shape": "triangle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "hard_3_yellow_circle_plus_triangle_9", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#E0B000", "shape1": "circle", "shape2": "triangle", "N": 9, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "hard_4_row_yellow_col_circle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#E0B000", "shape": "circle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "hard_5_purple_triangle_plus_square_7", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#A33E9A", "shape1": "triangle", "shape2": "square", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
    # {"name": "hard_6_row_blue_col_square", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#0C90C0", "shape": "square" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}},
]

In [4]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)

In [5]:
from xaipatimg.ml.xai import generate_shap_resnet18, generate_counterfactuals_resnet18_random_approach, create_xai_index
from tqdm import tqdm

for rule_idx in tqdm(range(len(rules_data))):

    model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
    dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"

    xai_output_paths = {
        "shap" : "shap",
        "cf" : "cf"
    }

    generate_shap_resnet18(db_dir, dataset_filename=dataset_filename,
                           model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["shap"]),
                           yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0", n_jobs=N_JOBS,
                           dataset_size=XAI_DATASET_SIZE, masker="ndarray", shap_scale_img_path=shap_scale_img_path)

    generate_counterfactuals_resnet18_random_approach(db_dir, dataset_filename=dataset_filename,
                                                      model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["cf"]),
                                                      yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
                                                      shapes=SHAPES, colors=COLORS, empty_probability=1-SHAPE_PROB,
                                                      max_depth=10, nb_tries_per_depth=2000, devices=["cuda:0", "cuda:1"], n_jobs=N_JOBS_GPU,
                                                      dataset_size=XAI_DATASET_SIZE,pos_pred_legend_path=pos_pred_legend_path,
                                                      neg_pred_legend_path=neg_pred_legend_path)

    create_xai_index(db_dir, dataset_filename=dataset_filename, model_dir=model_dir, xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0")


  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/2 [00:00<?, ?it/s]

Loading dataset content for easy_1_6_blue_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 10%|█         | 10/100 [00:00<00:01, 70.42it/s][A
 18%|█▊        | 18/100 [00:00<00:01, 69.05it/s][A
 25%|██▌       | 25/100 [00:00<00:01, 65.19it/s][A
 32%|███▏      | 32/100 [00:00<00:01, 65.06it/s][A
 39%|███▉      | 39/100 [00:00<00:00, 64.91it/s][A
 46%|████▌     | 46/100 [00:00<00:00, 56.40it/s][A
 53%|█████▎    | 53/100 [00:00<00:00, 58.68it/s][A
 60%|██████    | 60/100 [00:00<00:00, 61.18it/s][A
 67%|██████▋   | 67/100 [00:01<00:00, 62.51it/s][A
 74%|███████▍  | 74/100 [00:01<00:00, 61.81it/s][A
 81%|████████  | 81/100 [00:01<00:00, 62.65it/s][A
 88%|████████▊ | 88/100 [00:01<00:00, 56.01it/s][A
100%|██████████| 100/100 [00:01<00:00, 60.80it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values



  0%|          | 0/9998 [00:00<?, ?it/s][A
 18%|█▊        | 1842/9998 [00:00<00:00, 13105.02it/s][A
 32%|███▏      | 3192/9998 [00:03<00:10, 674.82it/s]  [A
 38%|███▊      | 3792/9998 [00:05<00:11, 549.82it/s][A
 41%|████▏     | 4142/9998 [00:06<00:11, 501.78it/s][A
 44%|████▍     | 4392/9998 [00:07<00:11, 471.68it/s][A
 45%|████▌     | 4542/9998 [00:07<00:12, 453.92it/s][A
 47%|████▋     | 4692/9998 [00:08<00:12, 435.57it/s][A
 48%|████▊     | 4792/9998 [00:08<00:12, 423.95it/s][A
 49%|████▉     | 4892/9998 [00:08<00:12, 410.99it/s][A
 50%|████▉     | 4992/9998 [00:09<00:12, 395.73it/s][A
 51%|█████     | 5092/9998 [00:09<00:12, 386.27it/s][A
 51%|█████▏    | 5142/9998 [00:09<00:12, 380.84it/s][A
 52%|█████▏    | 5192/9998 [00:09<00:12, 375.88it/s][A
 52%|█████▏    | 5242/9998 [00:09<00:12, 370.16it/s][A
 53%|█████▎    | 5292/9998 [00:09<00:12, 365.86it/s][A
 53%|█████▎    | 5342/9998 [00:10<00:12, 361.76it/s][A
 54%|█████▍    | 5392/9998 [00:10<00:12, 358.67it/s][A

KeyboardInterrupt: 