In [14]:
import os

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/M/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"


model_dir_root = os.environ["DATA"] + "models/db3.0.0/01_protov5/"
shap_scale_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0","shap_scale.png")
yes_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0","button_yes.png")
no_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0","button_no.png")
yes_small_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0","button_yes_small.png")
no_small_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0","button_no_small.png")
pos_pred_legend_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0","cf_info_pos.png")
neg_pred_legend_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0","cf_info_neg.png")
interface_dir = os.environ["DATA"] + "webinterfaces/int05_prototype/"

XAI_DATASET_SIZE = 40

N_JOBS = 20
N_JOBS_GPU = 6

RESNET_TYPE = "resnet50"

In [15]:

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10
X_DIVISIONS_M = 12
Y_DIVISIONS_M = 12

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

explict_colors_dict = {
    "#A33E9A": "purple",
    "#E0B000": "yellow",
    "#0C90C0": "blue"
}

In [16]:
from xaipatimg.datagen.dbimg import load_db

db_patterns = load_db(db_patterns_dir)

In [17]:
import numpy as np

pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors. The two items of the same color cannot be
# on a diagonal.
for k, v in db_patterns.items():
    if len(v["content"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        color_matrix = np.full((2,2), "", dtype="U100")
        for entry in v["content"]:
            img_col_d[entry["color"]] = True
            img_shape_d[entry["shape"]] = True
            color_matrix[entry["pos"][0]][entry["pos"][1]] = entry["color"]

        same_color_on_diagonal = color_matrix[0][0] == color_matrix[1][1] or color_matrix[0][1] == color_matrix[1][0]

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3 and not same_color_on_diagonal:
            pattern_3sym_2col_keys.append(k)

In [18]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_protov5")
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_protov5")
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_protov5")

In [19]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_L = [

    {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},

    {"name": "hard2_blue_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "circle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "hard2bis_blue_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

    {"name": "hard4_purple_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
                                                                                                "y_division": Y_DIVISIONS_L,
                                                                                                "shape": "square",
                                                                                                "color": "#A33E9A",
                                                                                                "N": 13,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of purple squares equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "hard4bis_purple_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple squares equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[2]},

    {"name": "hard6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},
]

In [20]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [

    {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},

    {"name": "easy2_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                "y_division": Y_DIVISIONS_S,
                                                                                                "shape": "triangle",
                                                                                                "color": "#E0B000",
                                                                                                "N": 6,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "easy2bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_S,
    #                                                                                             "y_division": Y_DIVISIONS_S,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 6,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},

    {"name": "easy4_purple_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
                                                                                                  "y_division": Y_DIVISIONS_S,
                                                                                                  "shape": "circle",
                                                                                                  "color": "#A33E9A",
                                                                                                  "N": 6,
                                                                                                  "restrict_plus_minus_1": True},
     "question": "Does the number of purple circles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "easy4bis_purple_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                               "x_division": X_DIVISIONS_S,
    #                                                                                               "y_division": Y_DIVISIONS_S,
    #                                                                                               "shape": "circle",
    #                                                                                               "color": "#A33E9A",
    #                                                                                               "N": 6,
    #                                                                                               "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple circles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    {"name": "easy5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[6]},

    {"name": "easy6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[7]},
]

In [21]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly
rules_data_M = [

    {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[8]},

    {"name": "med2_yellow_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "square",
                                                                                                "color": "#E0B000",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    # {"name": "med2bis_yellow_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True,
                                                                                                            },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[9]},

    {"name": "med4_blue_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
                                                                                                "y_division": Y_DIVISIONS_M,
                                                                                                "shape": "triangle",
                                                                                                "color": "#0C90C0",
                                                                                                "N": 8,
                                                                                                "restrict_plus_minus_1": True},
     "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "med4bis_blue_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "med5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False,
                                                                                                     },
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[10]},

    {"name": "med6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                     "y_division_full": Y_DIVISIONS_M,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False,
                                                                                                     },
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[11]},
]


In [22]:
from xaipatimg.ml.xai import generate_shap_resnet, generate_counterfactuals_resnet_random_approach, \
    create_xai_index
from tqdm import tqdm

def generate_explanations(rules_data, db_dir, datasets_dir_path):

    for rule_idx in tqdm(range(len(rules_data))):

        model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
        dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"
        generic_rule_fun = rules_data[rule_idx]["gen_fun"]
        generic_rule_fun_kwargs = rules_data[rule_idx]["gen_kwargs"]
        xai_output_paths = {
            "shap" : "shap",
            # "cf" : "cf",
        }

        if "pattern_id" in rules_data[rule_idx]:
            generic_rule_fun_kwargs["pattern_content"] = db_patterns[rules_data[rule_idx]["pattern_id"]]["content"]

        generate_shap_resnet(db_dir, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
                               model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["shap"]),
                               yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0", n_jobs=N_JOBS,
                               dataset_size=XAI_DATASET_SIZE, masker="ndarray", shap_scale_img_path=shap_scale_img_path, resnet_type=RESNET_TYPE,
                               max_evals=1)
        #
        # generate_counterfactuals_resnet_random_approach(db_dir, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
        #                                                   model_dir=model_dir,
        #                                                   xai_output_path=os.path.join(model_dir, xai_output_paths["cf"]),
        #                                                   yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
        #                                                   shapes=SHAPES, colors=COLORS, empty_probability=1-SHAPE_PROB,
        #                                                   max_depth=10, nb_tries_per_depth=2000, generic_rule_fun=generic_rule_fun,
        #                                                   devices=["cuda:0", "cuda:1"], n_jobs=N_JOBS_GPU,
        #                                                   dataset_size=XAI_DATASET_SIZE,pos_pred_legend_path=pos_pred_legend_path,
        #                                                   neg_pred_legend_path=neg_pred_legend_path,
        #                                                   **generic_rule_fun_kwargs, resnet_type=RESNET_TYPE)

        create_xai_index(db_dir, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename, model_dir=model_dir,
                         xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0", resnet_type=RESNET_TYPE)


In [23]:
generate_explanations(rules_data_S, db_S_dir, datasets_path_S)

  0%|          | 0/6 [00:00<?, ?it/s]

Loading dataset content for easy1_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 35%|███▌      | 14/40 [00:00<00:00, 138.75it/s][A
100%|██████████| 40/40 [00:00<00:00, 140.86it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
 50%|█████     | 20/40 [00:00<00:00, 54.16it/s][A
100%|██████████| 40/40 [00:08<00:00,  4.63it/s][A


Loading dataset content for easy1_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 48%|████▊     | 19/40 [00:00<00:00, 182.94it/s][A
100%|██████████| 40/40 [00:00<00:00, 139.70it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 17%|█▋        | 1/6 [00:16<01:21, 16.39s/it]

Loading dataset content for easy2_yellow_triangle_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:00<00:00, 164.16it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 24.84it/s][A


Loading dataset content for easy2_yellow_triangle_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 42%|████▎     | 17/40 [00:00<00:00, 164.84it/s][A
100%|██████████| 40/40 [00:00<00:00, 188.61it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 33%|███▎      | 2/6 [00:25<00:47, 11.83s/it]

Loading dataset content for easy3_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:00<00:00, 210.39it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 23.92it/s][A


Loading dataset content for easy3_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 40%|████      | 16/40 [00:00<00:00, 155.78it/s][A
100%|██████████| 40/40 [00:00<00:00, 142.22it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 50%|█████     | 3/6 [00:33<00:30, 10.25s/it]

Loading dataset content for easy4_purple_circle_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 35%|███▌      | 14/40 [00:00<00:00, 134.20it/s][A
100%|██████████| 40/40 [00:00<00:00, 143.65it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 24.57it/s][A


Loading dataset content for easy4_purple_circle_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 48%|████▊     | 19/40 [00:00<00:00, 185.06it/s][A
100%|██████████| 40/40 [00:00<00:00, 146.74it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 67%|██████▋   | 4/6 [00:41<00:19,  9.56s/it]

Loading dataset content for easy5_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 40%|████      | 16/40 [00:00<00:00, 158.27it/s][A
100%|██████████| 40/40 [00:00<00:00, 133.30it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 24.59it/s][A


Loading dataset content for easy5_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 35%|███▌      | 14/40 [00:00<00:00, 130.78it/s][A
100%|██████████| 40/40 [00:00<00:00, 148.69it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 83%|████████▎ | 5/6 [00:50<00:09,  9.17s/it]

Loading dataset content for easy6_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 32%|███▎      | 13/40 [00:00<00:00, 125.48it/s][A
 65%|██████▌   | 26/40 [00:00<00:00, 120.60it/s][A
100%|██████████| 40/40 [00:00<00:00, 120.93it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 23.81it/s][A


Loading dataset content for easy6_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 48%|████▊     | 19/40 [00:00<00:00, 187.86it/s][A
100%|██████████| 40/40 [00:00<00:00, 185.64it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 6/6 [00:58<00:00,  9.82s/it]


In [24]:
generate_explanations(rules_data_L, db_L_dir, datasets_path_L)

  0%|          | 0/6 [00:00<?, ?it/s]

Loading dataset content for hard1_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 32%|███▎      | 13/40 [00:00<00:00, 126.59it/s][A
100%|██████████| 40/40 [00:00<00:00, 140.45it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 24.65it/s][A


Loading dataset content for hard1_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 45%|████▌     | 18/40 [00:00<00:00, 174.82it/s][A
100%|██████████| 40/40 [00:00<00:00, 171.17it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 17%|█▋        | 1/6 [00:08<00:42,  8.48s/it]

Loading dataset content for hard2_blue_circle_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 48%|████▊     | 19/40 [00:00<00:00, 186.10it/s][A
100%|██████████| 40/40 [00:00<00:00, 186.91it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 23.52it/s][A


Loading dataset content for hard2_blue_circle_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 35%|███▌      | 14/40 [00:00<00:00, 133.97it/s][A
100%|██████████| 40/40 [00:00<00:00, 151.22it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 33%|███▎      | 2/6 [00:17<00:34,  8.52s/it]

Loading dataset content for hard3_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 32%|███▎      | 13/40 [00:00<00:00, 125.02it/s][A
100%|██████████| 40/40 [00:00<00:00, 137.19it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 23.98it/s][A


Loading dataset content for hard3_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 40%|████      | 16/40 [00:00<00:00, 155.53it/s][A
100%|██████████| 40/40 [00:00<00:00, 171.47it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 50%|█████     | 3/6 [00:25<00:25,  8.62s/it]

Loading dataset content for hard4_purple_square_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 50%|█████     | 20/40 [00:00<00:00, 193.10it/s][A
100%|██████████| 40/40 [00:00<00:00, 186.45it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 24.30it/s][A


Loading dataset content for hard4_purple_square_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 50%|█████     | 20/40 [00:00<00:00, 197.94it/s][A
100%|██████████| 40/40 [00:00<00:00, 154.29it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 67%|██████▋   | 4/6 [00:34<00:17,  8.57s/it]

Loading dataset content for hard5_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 48%|████▊     | 19/40 [00:00<00:00, 185.20it/s][A
100%|██████████| 40/40 [00:00<00:00, 153.83it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 23.99it/s][A


Loading dataset content for hard5_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 45%|████▌     | 18/40 [00:00<00:00, 177.14it/s][A
100%|██████████| 40/40 [00:00<00:00, 176.56it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 83%|████████▎ | 5/6 [00:43<00:08,  8.65s/it]

Loading dataset content for hard6_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 32%|███▎      | 13/40 [00:00<00:00, 129.38it/s][A
100%|██████████| 40/40 [00:00<00:00, 135.61it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 24.08it/s][A


Loading dataset content for hard6_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 48%|████▊     | 19/40 [00:00<00:00, 188.44it/s][A
100%|██████████| 40/40 [00:00<00:00, 191.75it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 6/6 [00:51<00:00,  8.63s/it]


In [25]:
generate_explanations(rules_data_M, db_M_dir, datasets_path_M)

  0%|          | 0/6 [00:00<?, ?it/s]

Loading dataset content for med1_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 32%|███▎      | 13/40 [00:00<00:00, 127.29it/s][A
100%|██████████| 40/40 [00:00<00:00, 131.52it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 24.26it/s][A


Loading dataset content for med1_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 48%|████▊     | 19/40 [00:00<00:00, 186.99it/s][A
100%|██████████| 40/40 [00:00<00:00, 183.28it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 17%|█▋        | 1/6 [00:09<00:46,  9.36s/it]

Loading dataset content for med2_yellow_square_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 35%|███▌      | 14/40 [00:00<00:00, 139.96it/s][A
100%|██████████| 40/40 [00:00<00:00, 123.68it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 24.06it/s][A


Loading dataset content for med2_yellow_square_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:00<00:00, 199.85it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 33%|███▎      | 2/6 [00:17<00:35,  8.87s/it]

Loading dataset content for med3_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 45%|████▌     | 18/40 [00:00<00:00, 167.12it/s][A
100%|██████████| 40/40 [00:00<00:00, 139.65it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 23.10it/s][A


Loading dataset content for med3_find_pattern_rot_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 48%|████▊     | 19/40 [00:00<00:00, 181.84it/s][A
100%|██████████| 40/40 [00:00<00:00, 142.05it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 50%|█████     | 3/6 [00:26<00:26,  8.78s/it]

Loading dataset content for med4_blue_triangle_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 32%|███▎      | 13/40 [00:00<00:00, 126.48it/s][A
100%|██████████| 40/40 [00:00<00:00, 147.74it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 25.42it/s][A


Loading dataset content for med4_blue_triangle_N_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 42%|████▎     | 17/40 [00:00<00:00, 143.15it/s][A
100%|██████████| 40/40 [00:00<00:00, 104.10it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 67%|██████▋   | 4/6 [00:35<00:17,  8.84s/it]

Loading dataset content for med5_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 32%|███▎      | 13/40 [00:00<00:00, 128.99it/s][A
100%|██████████| 40/40 [00:00<00:00, 136.89it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 24.71it/s][A


Loading dataset content for med5_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 32%|███▎      | 13/40 [00:00<00:00, 127.92it/s][A
 65%|██████▌   | 26/40 [00:00<00:00, 121.18it/s][A
100%|██████████| 40/40 [00:00<00:00, 115.72it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
 83%|████████▎ | 5/6 [00:44<00:08,  8.88s/it]

Loading dataset content for med6_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
 38%|███▊      | 15/40 [00:00<00:00, 144.54it/s][A
100%|██████████| 40/40 [00:00<00:00, 137.19it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values
Generating shap images



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:01<00:00, 24.26it/s][A


Loading dataset content for med6_find_pattern_test.csv



  0%|          | 0/40 [00:00<?, ?it/s][A
100%|██████████| 40/40 [00:00<00:00, 206.09it/s][A
Using cache found in /home/jleguy/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 6/6 [00:52<00:00,  8.82s/it]


In [26]:
# from transformers import AutoTokenizer
# from transformers import AutoModelForCausalLM
# import csv
# from xaipatimg.ml.xai import generate_LLM_explanations, create_xai_index
# from tqdm import tqdm
#
# model_id = "openai/gpt-oss-20b"
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# llm_model = AutoModelForCausalLM.from_pretrained(
#     model_id,
#     device_map="auto",
#     torch_dtype="auto",
# )
#
# for rule_idx in tqdm(range(len(rules_data))):
#
#     model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
#     dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"
#
#     # Extracting the subset of indices of samples selected for the experimental interface, in order to ease the cost of calculation
#     interface_content_path = os.path.join(interface_dir, "res", "tasks", f"{rules_data[rule_idx]["name"]}_content.csv")
#     interface_selected_idx = [int(row["og_idx"]) for row in list(csv.DictReader(open(interface_content_path), delimiter=','))]
#
#     xai_output_paths = {
#         "shap" : "shap",
#         "cf" : "cf",
#         "llm" : "llm",
#     }
#     generate_LLM_explanations(db_dir, db, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
#                               model_dir=model_dir, llm_model=llm_model, llm_tokenizer=tokenizer,
#                               xai_output_path=os.path.join(model_dir, xai_output_paths["llm"]),
#                               explicit_colors_dict=explict_colors_dict, question=rules_data[rule_idx]["question"],
#                               yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
#                               yes_pred_img_path_small=yes_small_pred_img_path, no_pred_img_path_small=no_small_pred_img_path,
#                               device="cuda:0", dataset_size=XAI_DATASET_SIZE, only_for_index=interface_selected_idx,
#                               path_to_counterfactuals_dir_for_model_errors=os.path.join(model_dir, xai_output_paths["cf"]),
#                               pos_llm_scaffold=rules_data[rule_idx]["pos_llm_scaffold"], neg_llm_scaffold=rules_data[rule_idx]["neg_llm_scaffold"])
#
#     create_xai_index(db_dir, dataset_filename=dataset_filename, model_dir=model_dir, xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0")
