In [4]:
import os

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/M/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"

model_dir_root = os.environ["DATA"] + "models/db3.0.0/01_protov5/"
shap_scale_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0", "shap_scale.png")
yes_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0", "button_yes.png")
no_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0", "button_no.png")
yes_small_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0", "button_yes_small.png")
no_small_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0", "button_no_small.png")
pos_pred_legend_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0", "cf_info_pos.png")
neg_pred_legend_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.0.0", "cf_info_neg.png")
interface_dir = os.environ["DATA"] + "webinterfaces/int05_prototype/"

XAI_DATASET_SIZE = 50

N_JOBS = 20
N_JOBS_GPU = 4

RESNET_TYPE = "resnet50"

In [5]:

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10
X_DIVISIONS_M = 12
Y_DIVISIONS_M = 12

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS = ["#A33E9A", "#E0B000", "#0C90C0"]  # Purple, Yellow, Blue

explict_colors_dict = {
    "#A33E9A": "purple",
    "#E0B000": "yellow",
    "#0C90C0": "blue"
}

In [6]:
from xaipatimg.datagen.dbimg import load_db

db_patterns = load_db(db_patterns_dir)

In [7]:
import numpy as np

pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors. The two items of the same color cannot be
# on a diagonal.
for k, v in db_patterns.items():
    if len(v["cnt"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        color_matrix = np.full((2, 2), "", dtype="U100")
        for entry in v["cnt"]:
            img_col_d[entry["col"]] = True
            img_shape_d[entry["shp"]] = True
            color_matrix[entry["pos"][0]][entry["pos"][1]] = entry["col"]

        same_color_on_diagonal = color_matrix[0][0] == color_matrix[1][1] or color_matrix[0][1] == color_matrix[1][0]

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3 and not same_color_on_diagonal:
            pattern_3sym_2col_keys.append(k)

In [8]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_protov5")
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_protov5")
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_protov5")

In [9]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_L = [
    #
    # {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},
    #
    # {"name": "hard2_blue_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "The image contains 13 blue circles. They are highlighted below. | A1;B1;C1;B2;C2;D2;E4;F4;G5;A6;B6;E6;D7", "neg_llm_scaffold": "The image contains 12 blue circles instead of 13. They are highlighted below. | A2;B3;D3;B3;C3;D3;E5;F6;A7;B7;E7;H9"},
    #
    # {"name": "hard2bis_blue_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

    # {"name": "hard4_purple_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of purple triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4bis_purple_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "hard4_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of yellow triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow triangles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4_yellow_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of yellow squares equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4bis_yellow_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow squares equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4_blue_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard4bis_blue_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_L,
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 13,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of blue circles equal to 13 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    {"name": "hard5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "The pattern is found in the image and is highlighted below. | B5;C5;B6", "neg_llm_scaffold": "The pattern was not found in the image. |", "pattern_id": pattern_3sym_2col_keys[2]},
    #
    # {"name": "hard6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
    #                                                                                                  "y_division_full": Y_DIVISIONS_L,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},

    # {"name": "hard7_square_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a square in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "hard8_square_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_L,
    #                                                                                             "shape": "square",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a square in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""}

]

In [24]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [
    #
    # {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},
    #
    {"name": "easy2_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly,
     "gen_kwargs": {"x_division": X_DIVISIONS_S,
                    "y_division": Y_DIVISIONS_S,
                    "shape": "triangle",
                    "color": "#E0B000",
                    "N": 6,
                    "restrict_plus_minus_1": True},
     "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc": 0.85, "shown_acc": 0.85,
     "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "easy2bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_S,
    #                                                                                             "y_division": Y_DIVISIONS_S,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 6,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow triangles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},
    #
    {"name": "easy4_purple_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly,
     "gen_kwargs": {"x_division": X_DIVISIONS_S,
                    "y_division": Y_DIVISIONS_S,
                    "shape": "circle",
                    "color": "#A33E9A",
                    "N": 6,
                    "restrict_plus_minus_1": True},
     "question": "Does the number of purple circles equal to 6 in the image?", "target_acc": 0.85, "shown_acc": 0.85,
     "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "easy4bis_purple_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                               "x_division": X_DIVISIONS_S,
    #                                                                                               "y_division": Y_DIVISIONS_S,
    #                                                                                               "shape": "circle",
    #                                                                                               "color": "#A33E9A",
    #                                                                                               "N": 6,
    #                                                                                               "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple circles equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #

    {"name": "easy5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more,
     "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                    "y_division_full": Y_DIVISIONS_S,
                    "x_division_pattern": X_DIVISIONS_PATTERNS,
                    "y_division_pattern": Y_DIVISIONS_PATTERNS,
                    "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc": 0.85, "shown_acc": 0.85, "samples_interface": 13,
     "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[6]},
    #
    {"name": "easy6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more,
     "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                    "y_division_full": Y_DIVISIONS_S,
                    "x_division_pattern": X_DIVISIONS_PATTERNS,
                    "y_division_pattern": Y_DIVISIONS_PATTERNS,
                    "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc": 0.85, "shown_acc": 0.85, "samples_interface": 13,
     "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[7]},

    # {"name": "easy7_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[10]},

    # {"name": "easy8_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                  "y_division_full": Y_DIVISIONS_S,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": False},
    #  "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[11]},
    #
    # {"name": "easy9_blue_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_S,
    #                                                                                               "y_division": Y_DIVISIONS_S,
    #                                                                                               "shape": "square",
    #                                                                                               "color": "#0C90C0",
    #                                                                                               "N": 6,
    #                                                                                               "restrict_plus_minus_1": True},
    #  "question": "Does the number of blue squares equal to 6 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "easy10_circle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_S,
    #                                                                                             "shape": "circle",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a circle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "easy11_circle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_S,
    #                                                                                             "shape": "circle",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a circle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "easy12_circle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_S,
    #                                                                                             "shape": "circle",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a circle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

]

In [25]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_M = [

    # {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
    #                                                                                                  "y_division_full": Y_DIVISIONS_M,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[8]},
    #
    {"name": "med2_yellow_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly,
     "gen_kwargs": {"x_division": X_DIVISIONS_M,
                    "y_division": Y_DIVISIONS_M,
                    "shape": "square",
                    "color": "#E0B000",
                    "N": 8,
                    "restrict_plus_minus_1": True},
     "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc": 0.85, "shown_acc": 0.85,
     "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #
    # {"name": "med2bis_yellow_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    #
    #
    # {"name": "med2_purple_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of purple triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "med2_purple_square_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of purple squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "med2bis_purple_square_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "square",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple squares equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "med2_yellow_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of yellow triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    #
    # {"name": "med2bis_yellow_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#E0B000",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of yellow triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "med2_purple_circle_N", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": True},
    #  "question": "Does the number of purple circles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    #
    # {"name": "med2bis_purple_circle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "circle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple circles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    #
    # {"name": "med2bis_purple_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {"x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#A33E9A",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of purple triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},

    # {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
    #                                                                                                  "y_division_full": Y_DIVISIONS_M,
    #                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                  "consider_rotations": True,
    #                                                                                                         },
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[9]},
    #
    {"name": "med4_blue_triangle_N", "gen_fun": generic_rule_N_times_color_shape_exactly,
     "gen_kwargs": {"x_division": X_DIVISIONS_M,
                    "y_division": Y_DIVISIONS_M,
                    "shape": "triangle",
                    "color": "#0C90C0",
                    "N": 8,
                    "restrict_plus_minus_1": True},
     "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc": 0.85, "shown_acc": 0.85,
     "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "med4bis_blue_triangle_N_norestrict", "gen_fun": generic_rule_N_times_color_shape_exactly, "gen_kwargs": {
    #                                                                                             "x_division": X_DIVISIONS_M,
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "color": "#0C90C0",
    #                                                                                             "N": 8,
    #                                                                                             "restrict_plus_minus_1": False},
    #  "question": "Does the number of blue triangles equal to 8 in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    {"name": "med5_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more,
     "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                    "y_division_full": Y_DIVISIONS_M,
                    "x_division_pattern": X_DIVISIONS_PATTERNS,
                    "y_division_pattern": Y_DIVISIONS_PATTERNS,
                    "consider_rotations": False,
                    },
     "question": "Is the pattern in the image?", "target_acc": 0.85, "shown_acc": 0.85, "samples_interface": 13,
     "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[10]},

    {"name": "med6_find_pattern", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more,
     "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                    "y_division_full": Y_DIVISIONS_M,
                    "x_division_pattern": X_DIVISIONS_PATTERNS,
                    "y_division_pattern": Y_DIVISIONS_PATTERNS,
                    "consider_rotations": False,
                    },
     "question": "Is the pattern in the image?", "target_acc": 0.85, "shown_acc": 0.85, "samples_interface": 13,
     "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[11]},

    #
    # {"name": "med7_triangle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a triangle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
    #
    # {"name": "med8_triangle_allrows", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {
    #                                                                                             "y_division": Y_DIVISIONS_M,
    #                                                                                             "shape": "triangle",
    #                                                                                             "exclude_two_rows_missing": True},
    #  "question": "Is there a triangle in every row (1, 2, ...) of the image ?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": ""},
]


In [26]:
from xaipatimg.ml.xai import generate_shap_resnet, generate_counterfactuals_resnet_random_approach, \
    create_xai_index
from tqdm import tqdm


def generate_explanations(rules_data, db_dir, datasets_dir_path):
    for rule_idx in tqdm(range(len(rules_data))):

        model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
        dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"
        generic_rule_fun = rules_data[rule_idx]["gen_fun"]
        generic_rule_fun_kwargs = rules_data[rule_idx]["gen_kwargs"]
        xai_output_paths = {
            "shap": "shap",
            "cf": "cf",
        }

        if "pattern_id" in rules_data[rule_idx]:
            generic_rule_fun_kwargs["pattern_content"] = db_patterns[rules_data[rule_idx]["pattern_id"]]["content"]

        generate_shap_resnet(db_dir, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
                             model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["shap"]),
                             yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0",
                             n_jobs=N_JOBS,
                             dataset_size=XAI_DATASET_SIZE, masker="ndarray", shap_scale_img_path=shap_scale_img_path,
                             resnet_type=RESNET_TYPE)

        generate_counterfactuals_resnet_random_approach(db_dir, datasets_dir_path=datasets_dir_path,
                                                        dataset_filename=dataset_filename,
                                                        model_dir=model_dir,
                                                        xai_output_path=os.path.join(model_dir, xai_output_paths["cf"]),
                                                        yes_pred_img_path=yes_pred_img_path,
                                                        no_pred_img_path=no_pred_img_path,
                                                        shapes=SHAPES, colors=COLORS, empty_probability=1 - SHAPE_PROB,
                                                        max_depth=10, nb_tries_per_depth=2000,
                                                        generic_rule_fun=generic_rule_fun,
                                                        devices=["cuda:0", "cuda:1"], n_jobs=N_JOBS_GPU,
                                                        dataset_size=XAI_DATASET_SIZE,
                                                        pos_pred_legend_path=pos_pred_legend_path,
                                                        neg_pred_legend_path=neg_pred_legend_path,
                                                        **generic_rule_fun_kwargs, resnet_type=RESNET_TYPE)

        create_xai_index(db_dir, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
                         model_dir=model_dir,
                         xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0",
                         resnet_type=RESNET_TYPE)


In [None]:
generate_explanations(rules_data_S, db_S_dir, datasets_path_S)

  0%|          | 0/4 [00:00<?, ?it/s]

Loading dataset content for easy2_yellow_triangle_N_test.csv



  0%|          | 0/50 [00:00<?, ?it/s][A
  8%|▊         | 4/50 [00:00<00:01, 37.15it/s][A
 20%|██        | 10/50 [00:00<00:00, 50.05it/s][A
 34%|███▍      | 17/50 [00:00<00:00, 57.29it/s][A
 46%|████▌     | 23/50 [00:00<00:00, 57.95it/s][A
 58%|█████▊    | 29/50 [00:00<00:00, 55.27it/s][A
 70%|███████   | 35/50 [00:00<00:00, 49.08it/s][A

In [None]:
# generate_explanations(rules_data_L, db_L_dir, datasets_path_L)

In [None]:
# generate_explanations(rules_data_M, db_M_dir, datasets_path_M)

In [10]:
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import csv
from xaipatimg.ml.xai import generate_LLM_explanations, create_xai_index
from tqdm import tqdm

model_id = "openai/gpt-oss-20b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
)


def gen_LLM_explanations(db_dir, rules_data, datasets_dir_path, X_divisions, Y_divisions):
    db = load_db(db_dir)

    for rule_idx in tqdm(range(len(rules_data))):
        model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
        dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"

        # Extracting the subset of indices of samples selected for the experimental interface, in order to ease the cost of calculation
        interface_content_path = os.path.join(interface_dir, "res", "tasks",
                                              f"{rules_data[rule_idx]["name"]}_content.csv")
        interface_selected_idx = [int(row["og_idx"]) for row in
                                  list(csv.DictReader(open(interface_content_path), delimiter=','))]

        xai_output_paths = {
            "shap": "shap",
            "cf": "cf",
            "llm": "llm",
        }
        generate_LLM_explanations(db_dir, db, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
                                  model_dir=model_dir, llm_model=llm_model, llm_tokenizer=tokenizer,
                                  xai_output_path=os.path.join(model_dir, xai_output_paths["llm"]),
                                  question=rules_data[rule_idx]["question"],
                                  yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
                                  yes_pred_img_path_small=yes_small_pred_img_path,
                                  no_pred_img_path_small=no_small_pred_img_path,
                                  X_division=X_divisions, Y_division=Y_divisions,
                                  device="cuda:0", dataset_size=XAI_DATASET_SIZE, only_for_index=interface_selected_idx,
                                  path_to_counterfactuals_dir_for_model_errors=os.path.join(model_dir,
                                                                                            xai_output_paths["cf"]),
                                  pos_llm_scaffold=rules_data[rule_idx]["pos_llm_scaffold"],
                                  neg_llm_scaffold=rules_data[rule_idx]["neg_llm_scaffold"],
                                  pattern_dict=db_patterns[rules_data[rule_idx]["pattern_id"]]["cnt"] if "pattern_id" in
                                                                                                             rules_data[rule_idx] else None,
                                  resnet_type=RESNET_TYPE)

        create_xai_index(db_dir, dataset_filename=dataset_filename, datasets_dir_path=datasets_dir_path,
                         model_dir=model_dir,
                         xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0",
                         resnet_type=RESNET_TYPE)


  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 41 files: 100%|██████████| 41/41 [00:00<00:00, 175297.11it/s]
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]
Fetching 41 files: 100%|██████████| 41/41 [00:00<00:00, 195638.75it/s]
Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.05s/it]


In [11]:
gen_LLM_explanations(db_L_dir, rules_data_L, datasets_path_L, X_divisions=X_DIVISIONS_L, Y_divisions=Y_DIVISIONS_L)

  0%|          | 0/1 [00:00<?, ?it/s]

/home/docker/data/models/db3.0.0/01_protov5/hard5_find_pattern/llm
Loading dataset content for hard5_find_pattern_test.csv



  0%|          | 0/50 [00:00<?, ?it/s][A
 14%|█▍        | 7/50 [00:00<00:00, 57.18it/s][A
 28%|██▊       | 14/50 [00:00<00:00, 60.09it/s][A
 42%|████▏     | 21/50 [00:00<00:00, 61.22it/s][A
 56%|█████▌    | 28/50 [00:00<00:00, 59.20it/s][A
 68%|██████▊   | 34/50 [00:00<00:00, 57.93it/s][A
 80%|████████  | 40/50 [00:00<00:00, 56.70it/s][A
100%|██████████| 50/50 [00:00<00:00, 57.05it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

  0%|          | 0/12 [00:00<?, ?it/s][A

You are the explainability system of an AI model. Your role is to justify the decisions of the model. The role of the model is to answer questions about the content of images of symbols of colors. The images are described in a JSON data structure. The coordinates system uses letters from A to F for the columns and numbers from 1 to 6 for the rows. The user will provide you the prediction of the AI model for a given image and the corresponding JSON data. You need to give an explanation of the prediction. The explanation is expected to be a very short sentence which introduces a list of all coordinates that are involved in the model's prediction. The justification sentence and the list of coordinates must be separated by the character '|'. The coordinates are separated with the symbol ';', and there is no need to sort them. Do not use escape characters or markdown syntax. The question the model must answer is 'Is the pattern in the image?'. The pattern to search for is [{'shape': 'circle


  8%|▊         | 1/12 [01:48<19:50, 108.23s/it][A

<|channel|>analysis<|message|>We need to explain that no pattern found. Justification: "The pattern was not found in the image. |" According to instructions: no coordinates. Ensure no markdown. Output exactly that.<|end|><|start|>assistant<|channel|>final<|message|>The pattern was not found in the image. |<|return|>
You are the explainability system of an AI model. Your role is to justify the decisions of the model. The role of the model is to answer questions about the content of images of symbols of colors. The images are described in a JSON data structure. The coordinates system uses letters from A to F for the columns and numbers from 1 to 6 for the rows. The user will provide you the prediction of the AI model for a given image and the corresponding JSON data. You need to give an explanation of the prediction. The explanation is expected to be a very short sentence which introduces a list of all coordinates that are involved in the model's prediction. The justification sentence an


 17%|█▋        | 2/12 [03:15<15:56, 95.61s/it] [A

<|channel|>analysis<|message|>We need to check if pattern exists: pattern: circle purple at X N+1; triangle blue at X+1 N; square purple at X+1 N+1. X any column letter coordinate, N any row number coordinate. So for each column letter X and row number N, the pattern is:

- circle purple at (X, N+1)
- triangle blue at (X+1, N)
- square purple at (X+1, N+1).

Essentially a 2x2 block offset: top-left: ?, unknown; top-right: circle purple at (X,N+1). bottom-left: triangle blue at (X+1, N). bottom-right: square purple at (X+1, N+1). So pattern forms a rectangle 2 columns wide and 2 rows high, starting at X+1 column and N+1 row are the bottom-right is square purple, top-right circle purple, bottom-left triangle blue. The top-left element at (X,N) unspecified.

We must search in the large list. The predictions says No, presumably correct. We must provide justification: "The pattern was not found in the image. |". So simply. According to examples. Provide justification sentence and coordinate


 25%|██▌       | 3/12 [04:31<13:03, 87.10s/it][A

<|channel|>analysis<|message|>We need to produce justification for model predicting "No" meaning pattern not found.

Pattern: three shapes with relative positions: circle at X N+1 (so circle two rows below? Actually N+1 means increase row number by 1, same column X. So coordinate of circle is at same column X, row N+1 where N is some number coordinate. Then triangle at X+1N (same row N), column one letter to the right. Then square at X+1N+1 (column right, row N+1). So arrangement like:

Column X: circle (row N+1)
Column X+1: triangle (row N), square (row N+1). So basically a 2x2 block with rows N (top row has triangle at column X+1), row N+1 has circle at column X and square at column X+1.

We need to check if any such arrangement exists in the given dataset.

We need to list all coordinates involved in model's prediction? Wait justification sentence must introduce list of all coordinates that are involved in model's prediction. For negative case, example: "The pattern was not found in


 33%|███▎      | 4/12 [07:36<15:12, 114.00s/it][A
  0%|          | 0/1 [07:38<?, ?it/s]

<|channel|>analysis<|message|>We need create justification: says pattern is found and list of coordinates involved. Need check pattern: three shapes: circle purple at position X N+1, triangle blue at X+1 N, square purple at X+1 N+1. That means a 2x2 block: coordinates: X column letter, N row number. Purple circle at (X, N+1). Then triangle blue at (X+1, N). Square purple at (X+1, N+1). So arrangement like:

circle purple at left column same column as triangle? Actually circle at X, N+1. Triangle at X+1, N. Square at X+1, N+1. So offset: circle above square? N+1 is below N. So circle is just below (above?) Actually N is smaller row number? Coordinates rows numbers increase downwards maybe. In any case pattern exists somewhere.

We need to find any coordinates in data that satisfy this.

Let's parse possible candidate sets. Since dataset large, but pattern likely exists. Let's look at early entries:

We need find a pair of adjacent columns, X and X+1, and adjacent rows N and N+1.

We nee




KeyError: 'cnt'