In [64]:
import os

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.1.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.1.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.1.0/M/"
db_XS_dir = os.environ["DATA"] + "PatImgXAI_data/db3.1.0/XS/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.1.0/patterns/"

model_dir_root = os.environ["DATA"] + "models/db3.1.0/01_expv1/"
shap_scale_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.1.0", "shap_scale.png")
yes_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.1.0", "button_yes.png")
no_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.1.0", "button_no.png")
yes_small_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.1.0", "button_yes_small.png")
no_small_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.1.0", "button_no_small.png")
pos_pred_legend_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.1.0", "cf_info_pos.png")
neg_pred_legend_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.1.0", "cf_info_neg.png")
interface_dir = os.environ["DATA"] + "webinterfaces/int05_prototype/"

XAI_DATASET_SIZE = 200

N_JOBS = 20
N_JOBS_GPU = 4

RESNET_TYPE = "resnet18"

In [65]:
# Number of images generated
NBGEN_full_per_size = 5000000
NBGEN_patterns = 1000

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 8
Y_DIVISIONS_S = 8
X_DIVISIONS_M = 11
Y_DIVISIONS_M = 11
X_DIVISIONS_XS = 6
Y_DIVISIONS_XS = 6

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['c', 's', 't']
COLORS  = ["p", "y", "b"]

In [66]:
from xaipatimg.datagen.dbimg import load_db

db_patterns = load_db(db_patterns_dir)

In [67]:
import numpy as np

pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors. The two items of the same color cannot be
# on a diagonal.
for k, v in db_patterns.items():
    if len(v["cnt"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        color_matrix = np.full((2, 2), "", dtype="U100")
        for entry in v["cnt"]:
            img_col_d[entry["col"]] = True
            img_shape_d[entry["shp"]] = True
            color_matrix[entry["pos"][0]][entry["pos"][1]] = entry["col"]

        same_color_on_diagonal = color_matrix[0][0] == color_matrix[1][1] or color_matrix[0][1] == color_matrix[1][0]

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3 and not same_color_on_diagonal:
            pattern_3sym_2col_keys.append(k)

In [68]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_expv1")
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_expv1")
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_expv1")
datasets_path_XS = os.path.join(db_XS_dir, "datasets", "01_expv1")


In [69]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_L = [

    {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                                   "y_division_full": Y_DIVISIONS_L,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},

    {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                                   "y_division_full": Y_DIVISIONS_L,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},


    # {"name": "hard2_blue_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                                     "y_division": Y_DIVISIONS_L,
    #                                                                                                     "shape": "c",
    #                                                                                                     "color": "b",
    #                                                                                                     "N": 13,
    #                                                                                                     "restrict_plus_minus_1": True},
    #  "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "The image contains 13 blue circles. They are highlighted below. | A1;B1;C1;B2;C2;D2;E4;F4;G5;A6;B6;E6;D7", "neg_llm_scaffold": "The image contains 12 blue circles instead of 13. They are highlighted below. | A2;B3;D3;B3;C3;D3;E5;F6;A7;B7;E7;H9"},

    # {"name": "hard2bis_blue_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                                                   "y_division": Y_DIVISIONS_L,
    #                                                                                                                   "shape": "c",
    #                                                                                                                   "color": "b",
    #                                                                                                                   "N": 13,
    #                                                                                                                   "restrict_plus_minus_1": False},
    #  "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},


    # {"name": "hard4_purple_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                                         "y_division": Y_DIVISIONS_L,
    #                                                                                                         "shape": "t",
    #                                                                                                         "color": "p",
    #                                                                                                         "N": 13,
    #                                                                                                         "restrict_plus_minus_1": True},
    #  "question": "Does the number of purple triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4bis_purple_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #     "x_division": X_DIVISIONS_L,
    #     "y_division": Y_DIVISIONS_L,
    #     "shape": "t",
    #     "color": "p",
    #     "N": 13,
    #     "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

]

In [70]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [

        # {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
        #                                                                                                  "y_division_full": Y_DIVISIONS_S,
        #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
        #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
        #                                                                                                  "consider_rotations": True},
        #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[2]},

    #     {"name": "easy2_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
    #                                                                                                             "y_division": Y_DIVISIONS_S,
    #                                                                                                             "shape": "t",
    #                                                                                                             "color": "y",
    #                                                                                                             "N": 6,
    #                                                                                                             "restrict_plus_minus_1": True},
    #      "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #     {"name": "easy2bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #         "x_division": X_DIVISIONS_S,
    #         "y_division": Y_DIVISIONS_S,
    #         "shape": "t",
    #         "color": "y",
    #         "N": 6,
    #         "restrict_plus_minus_1": False},
    #      "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                                "y_division_full": Y_DIVISIONS_S,
    #                                                                                                                "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                                "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                                "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},
    #
    #     {"name": "easy4_purple_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
    #                                                                                                           "y_division": Y_DIVISIONS_S,
    #                                                                                                           "shape": "c",
    #                                                                                                           "color": "p",
    #                                                                                                           "N": 6,
    #                                                                                                           "restrict_plus_minus_1": True},
    #      "question": "Does the number of purple circles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #     {"name": "easy4bis_purple_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #         "x_division": X_DIVISIONS_S,
    #         "y_division": Y_DIVISIONS_S,
    #         "shape": "c",
    #         "color": "p",
    #         "N": 6,
    #         "restrict_plus_minus_1": False},
    #      "question": "Does the number of purple circles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "easy5_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                                "y_division_full": Y_DIVISIONS_S,
    #                                                                                                                "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                                "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                                "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[6]},

]

In [71]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly
rules_data_M = [

    {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},

    {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True,
                                                                                                                  },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},


    # {"name": "med2_yellow_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                                      "y_division": Y_DIVISIONS_M,
    #                                                                                                      "shape": "s",
    #                                                                                                      "color": "y",
    #                                                                                                      "N": 8,
    #                                                                                                      "restrict_plus_minus_1": True},
    #  "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    #
    # {"name": "med2bis_yellow_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                                                    "y_division": Y_DIVISIONS_M,
    #                                                                                                                    "shape": "s",
    #                                                                                                                    "color": "y",
    #                                                                                                                    "N": 8,
    #                                                                                                                    "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    #

    #
    # {"name": "med4_blue_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                                      "y_division": Y_DIVISIONS_M,
    #                                                                                                      "shape": "t",
    #                                                                                                      "color": "b",
    #                                                                                                      "N": 8,
    #                                                                                                      "restrict_plus_minus_1": True},
    #  "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "med4bis_blue_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #     "x_division": X_DIVISIONS_M,
    #     "y_division": Y_DIVISIONS_M,
    #     "shape": "t",
    #     "color": "b",
    #     "N": 8,
    #     "restrict_plus_minus_1": False},
    #  "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

]


In [72]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_XS = [
    {"name": "xeasy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_XS,
                                                                                                                    "y_division_full": Y_DIVISIONS_XS,
                                                                                                                    "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                    "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                    "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.5, "shown_acc" : 0.5, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[7]},

]

In [73]:
from xaipatimg.ml.xai import generate_shap_resnet, generate_counterfactuals_resnet_random_approach, \
    create_xai_index
from tqdm import tqdm


def generate_explanations(rules_data, db_dir, datasets_dir_path):
    for rule_idx in tqdm(range(len(rules_data))):

        model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
        dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"
        generic_rule_fun = rules_data[rule_idx]["gen_fun"]
        generic_rule_fun_kwargs = rules_data[rule_idx]["gen_kwargs"]
        xai_output_paths = {
            "shap": "shap",
            "cf": "cf",
        }

        if "pattern_id" in rules_data[rule_idx]:
            generic_rule_fun_kwargs["pattern_content"] = db_patterns[rules_data[rule_idx]["pattern_id"]]["cnt"]

        generate_shap_resnet(os.path.join(db_dir, "min/"), datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
                             model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["shap"]),
                             yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0",
                             n_jobs=N_JOBS,
                             dataset_size=XAI_DATASET_SIZE, masker="ndarray", shap_scale_img_path=shap_scale_img_path,
                             resnet_type=RESNET_TYPE)

        generate_counterfactuals_resnet_random_approach(os.path.join(db_dir, "min/"), datasets_dir_path=datasets_dir_path,
                                                        dataset_filename=dataset_filename,
                                                        model_dir=model_dir,
                                                        xai_output_path=os.path.join(model_dir, xai_output_paths["cf"]),
                                                        yes_pred_img_path=yes_pred_img_path,
                                                        no_pred_img_path=no_pred_img_path,
                                                        shapes=SHAPES, colors=COLORS, empty_probability=1 - SHAPE_PROB,
                                                        max_depth=10, nb_tries_per_depth=2000,
                                                        generic_rule_fun=generic_rule_fun,
                                                        devices=["cuda:0", "cuda:1"], n_jobs=N_JOBS_GPU,
                                                        dataset_size=XAI_DATASET_SIZE,
                                                        pos_pred_legend_path=pos_pred_legend_path,
                                                        neg_pred_legend_path=neg_pred_legend_path,
                                                        **generic_rule_fun_kwargs, resnet_type=RESNET_TYPE)

        create_xai_index(os.path.join(db_dir, "min/"), datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
                         model_dir=model_dir,
                         xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0",
                         resnet_type=RESNET_TYPE)


In [None]:
# generate_explanations(rules_data_S, db_S_dir, datasets_path_S)

In [74]:
generate_explanations(rules_data_L, db_L_dir, datasets_path_L)

  0%|          | 0/2 [00:00<?, ?it/s]Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


loading keys for /home/docker/data/PatImgXAI_data/db3.1.0/L/min/db.json




0it [00:00, ?it/s][A[A

272957it [00:00, 2725621.58it/s][A[A

580450it [00:00, 2930955.73it/s][A[A

904078it [00:00, 3044126.65it/s][A[A

1228190it [00:00, 3109912.96it/s][A[A

1552333it [00:00, 3142596.63it/s][A[A

1876459it [00:00, 3170007.44it/s][A[A

2201381it [00:00, 3195742.63it/s][A[A

2526799it [00:00, 3214262.78it/s][A[A

2855094it [00:00, 3235667.35it/s][A[A

3179957it [00:01, 3239650.56it/s][A[A

3504755it [00:01, 3242185.53it/s][A[A

3828980it [00:01, 3234803.70it/s][A[A

4155523it [00:01, 3244047.60it/s][A[A

4480723it [00:01, 3246437.45it/s][A[A

4809306it [00:01, 3258290.13it/s][A[A

5135141it [00:01, 3232815.65it/s][A[A

5462930it [00:01, 3246277.24it/s][A[A

5791236it [00:01, 3257270.40it/s][A[A

6119416it [00:01, 3264607.48it/s][A[A

6447953it [00:02, 3270819.86it/s][A[A

6775055it [00:02, 3247160.96it/s][A[A

7101255it [00:02, 3251578.91it/s][A[A

7429255it [00:02, 3260062.71it/s][A[A

7756324it [00:02, 3263233.55it/s

Generating counterfactual images




  0%|          | 0/200 [00:00<?, ?it/s][A[A

  0%|          | 1/200 [00:04<16:10,  4.88s/it][A[A

  1%|          | 2/200 [00:21<38:53, 11.78s/it][A[A

  2%|▏         | 3/200 [00:40<49:06, 14.96s/it][A[A

  2%|▏         | 4/200 [00:42<32:54, 10.07s/it][A[A

  2%|▎         | 5/200 [00:44<23:28,  7.22s/it][A[A

  3%|▎         | 6/200 [00:56<28:08,  8.70s/it][A[A

  4%|▎         | 7/200 [00:57<19:48,  6.16s/it][A[A

  4%|▍         | 8/200 [01:13<29:24,  9.19s/it][A[A

  4%|▍         | 9/200 [01:25<32:21, 10.16s/it][A[AUsing cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


  5%|▌      

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.57 GiB of which 4.00 MiB is free. Process 432106 has 2.42 GiB memory in use. Process 569398 has 3.70 GiB memory in use. Process 614116 has 1.66 GiB memory in use. Process 614120 has 144.00 MiB memory in use. Process 614123 has 1.66 GiB memory in use. Process 614115 has 1.66 GiB memory in use. Process 614126 has 1.66 GiB memory in use. Process 614117 has 1.66 GiB memory in use. Of the allocated memory 33.68 MiB is allocated by PyTorch, and 10.32 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
generate_explanations(rules_data_M, db_M_dir, datasets_path_M)

In [49]:
# generate_explanations(rules_data_XS, db_XS_dir, datasets_path_XS)


  0%|          | 0/1 [00:00<?, ?it/s]Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values





 59%|█████▉    | 5942/9998 [00:23<00:11, 340.05it/s][A[A[A

Generating shap images



  0%|          | 0/50 [00:00<?, ?it/s][A
 40%|████      | 20/50 [00:01<00:02, 13.20it/s][A
100%|██████████| 50/50 [00:08<00:00,  5.86it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


loading keys for /home/docker/data/PatImgXAI_data/db3.1.0/XS/min/db.json



0it [00:00, ?it/s][A
447006it [00:00, 2798113.96it/s][A


Generating counterfactual images



  0%|          | 0/50 [00:00<?, ?it/s][A
  2%|▏         | 1/50 [00:01<01:17,  1.58s/it][A

                                                    [A[A
  6%|▌         | 3/50 [00:02<00:27,  1.72it/s][A
  8%|▊         | 4/50 [00:02<00:22,  2.01it/s][AUsing cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
Using ca

In [None]:
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import csv
from xaipatimg.ml.xai import generate_LLM_explanations, create_xai_index
from tqdm import tqdm

model_id = "openai/gpt-oss-20b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
)


def gen_LLM_explanations(db_dir, rules_data, datasets_dir_path, X_divisions, Y_divisions):
    db = load_db(db_dir)

    for rule_idx in tqdm(range(len(rules_data))):
        model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
        dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"

        # Extracting the subset of indices of samples selected for the experimental interface, in order to ease the cost of calculation
        interface_content_path = os.path.join(interface_dir, "res", "tasks",
                                              f"{rules_data[rule_idx]["name"]}_content.csv")
        interface_selected_idx = [int(row["og_idx"]) for row in
                                  list(csv.DictReader(open(interface_content_path), delimiter=','))]

        xai_output_paths = {
            "shap": "shap",
            "cf": "cf",
            "llm": "llm",
        }
        generate_LLM_explanations(db_dir, db, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
                                  model_dir=model_dir, llm_model=llm_model, llm_tokenizer=tokenizer,
                                  xai_output_path=os.path.join(model_dir, xai_output_paths["llm"]),
                                  question=rules_data[rule_idx]["question"],
                                  yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
                                  yes_pred_img_path_small=yes_small_pred_img_path,
                                  no_pred_img_path_small=no_small_pred_img_path,
                                  X_division=X_divisions, Y_division=Y_divisions,
                                  device="cuda:0", dataset_size=XAI_DATASET_SIZE, only_for_index=interface_selected_idx,
                                  path_to_counterfactuals_dir_for_model_errors=os.path.join(model_dir,
                                                                                            xai_output_paths["cf"]),
                                  pos_llm_scaffold=rules_data[rule_idx]["pos_llm_scaffold"],
                                  neg_llm_scaffold=rules_data[rule_idx]["neg_llm_scaffold"],
                                  pattern_dict=db_patterns[rules_data[rule_idx]["pattern_id"]]["cnt"] if "pattern_id" in
                                                                                                             rules_data[rule_idx] else None,
                                  resnet_type=RESNET_TYPE)

        create_xai_index(db_dir, dataset_filename=dataset_filename, datasets_dir_path=datasets_dir_path,
                         model_dir=model_dir,
                         xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0",
                         resnet_type=RESNET_TYPE)


In [None]:
# gen_LLM_explanations(db_L_dir, rules_data_L, datasets_path_L, X_divisions=X_DIVISIONS_L, Y_divisions=Y_DIVISIONS_L)