In [1]:
import os

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/L/"
db_M_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/M/"
db_XS_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/XS/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.2.0/patterns/"

model_dir_root = os.environ["DATA"] + "models/db3.2.0/01_expv1/"
shap_scale_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.2.0", "shap_scale.png")
yes_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.2.0", "button_yes.png")
no_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.2.0", "button_no.png")
yes_small_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.2.0", "button_yes_small.png")
no_small_pred_img_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.2.0", "button_no_small.png")
pos_pred_legend_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.2.0", "cf_info_pos.png")
neg_pred_legend_path = os.path.join(os.environ["DATA"] + "PatImgXAI_data/db3.2.0", "cf_info_neg.png")
interface_dir = os.environ["DATA"] + "webinterfaces/exp02/"

XAI_DATASET_SIZE = 200
# XAI_DATASET_SIZE = 20

N_JOBS = 20
N_JOBS_GPU = 4

RESNET_TYPE = "resnet18"

In [2]:
# Number of images generated
NBGEN_full_per_size = 5000000
NBGEN_patterns = 1000

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 9
Y_DIVISIONS_S = 9
X_DIVISIONS_M = 12
Y_DIVISIONS_M = 12
X_DIVISIONS_XS = 5
Y_DIVISIONS_XS = 5

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Size of the images in pixels
img_size = (700, 700)
img_size_patterns = (300, 300)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['c', 's', 't']
COLORS  = ["p", "y", "b"]

In [3]:
from xaipatimg.datagen.dbimg import load_db

db_patterns = load_db(db_patterns_dir)

In [4]:
import numpy as np

pattern_3sym_2col_keys = []

# Extracting list of patterns that contain 3 symbols of 3 different shapes and 2 different colors. The two items of the same color cannot be
# on a diagonal.
for k, v in db_patterns.items():
    if len(v["cnt"]) == 3:
        img_col_d = {}
        img_shape_d = {}
        color_matrix = np.full((2, 2), "", dtype="U100")
        for entry in v["cnt"]:
            img_col_d[entry["col"]] = True
            img_shape_d[entry["shp"]] = True
            color_matrix[entry["pos"][0]][entry["pos"][1]] = entry["col"]

        same_color_on_diagonal = color_matrix[0][0] == color_matrix[1][1] or color_matrix[0][1] == color_matrix[1][0]

        if len(img_col_d.keys()) == 2 and len(img_shape_d.keys()) == 3 and not same_color_on_diagonal:
            pattern_3sym_2col_keys.append(k)

In [5]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_expv1")
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_expv1")
datasets_path_M = os.path.join(db_M_dir, "datasets", "01_expv1")
datasets_path_XS = os.path.join(db_XS_dir, "datasets", "01_expv1")


In [6]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_L = [

    {"name": "hard1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                                   "y_division_full": Y_DIVISIONS_L,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[0]},


    {"name": "hard3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                                   "y_division_full": Y_DIVISIONS_L,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[1]},

]

In [7]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly

rules_data_S = [

    # {"name": "easy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
    #                                                                                                                "y_division_full": Y_DIVISIONS_S,
    #                                                                                                                "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                                "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                                "consider_rotations": True},
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[4]},

    {"name": "easy3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                                   "y_division_full": Y_DIVISIONS_S,
                                                                                                                   "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                   "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                   "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[5]},

]

In [8]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly
rules_data_M = [

    {"name": "med1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[2]},

    {"name": "med3_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
                                                                                                                  "y_division_full": Y_DIVISIONS_M,
                                                                                                                  "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                  "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                  "consider_rotations": True,
                                                                                                                  },
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[3]},

    # {"name": "med5_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_M,
    #                                                                                                               "y_division_full": Y_DIVISIONS_M,
    #                                                                                                               "x_division_pattern": X_DIVISIONS_PATTERNS,
    #                                                                                                               "y_division_pattern": Y_DIVISIONS_PATTERNS,
    #                                                                                                               "consider_rotations": True,
    #                                                                                                               },
    #  "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 13, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[6]},


]


In [9]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_N_times_color_shape_exactly, generic_rule_shape_in_every_row

rules_data_XS = [
    {"name": "xeasy1_find_pattern_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_XS,
                                                                                                                    "y_division_full": Y_DIVISIONS_XS,
                                                                                                                    "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                                    "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                                    "consider_rotations": True},
     "question": "Is the pattern or any of its left or right rotations in the image?", "target_acc" : 0.5, "shown_acc" : 0.5, "samples_interface": 6, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "pattern_id": pattern_3sym_2col_keys[7]},

]

In [10]:
from xaipatimg.ml.xai import generate_shap_resnet, generate_counterfactuals_resnet_random_approach, \
    create_xai_index, generate_cam_resnet18
from tqdm import tqdm


def generate_explanations(rules_data, db_dir, datasets_dir_path):
    for rule_idx in tqdm(range(len(rules_data))):

        model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
        dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"
        generic_rule_fun = rules_data[rule_idx]["gen_fun"]
        generic_rule_fun_kwargs = rules_data[rule_idx]["gen_kwargs"]
        xai_output_paths = {
            "shap": "shap",
            "cf": "cf",
            "gradcam": "gradcam"
        }

        if "pattern_id" in rules_data[rule_idx]:
            generic_rule_fun_kwargs["pattern_content"] = db_patterns[rules_data[rule_idx]["pattern_id"]]["cnt"]

        # generate_shap_resnet(os.path.join(db_dir, "min/"), datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
        #                      model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["shap"]),
        #                      yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0",
        #                      n_jobs=N_JOBS,
        #                      dataset_size=XAI_DATASET_SIZE, masker="ndarray", shap_scale_img_path=shap_scale_img_path,
        #                      resnet_type=RESNET_TYPE)

        generate_counterfactuals_resnet_random_approach(os.path.join(db_dir, "min/"), datasets_dir_path=datasets_dir_path,
                                                        dataset_filename=dataset_filename,
                                                        model_dir=model_dir,
                                                        xai_output_path=os.path.join(model_dir, xai_output_paths["cf"]),
                                                        yes_pred_img_path=yes_pred_img_path,
                                                        no_pred_img_path=no_pred_img_path,
                                                        shapes=SHAPES, colors=COLORS, empty_probability=1 - SHAPE_PROB,
                                                        max_depth=10, nb_tries_per_depth=2000,
                                                        generic_rule_fun=generic_rule_fun,
                                                        devices=["cuda:0", "cuda:1"], n_jobs=N_JOBS_GPU,
                                                        dataset_size=XAI_DATASET_SIZE,
                                                        pos_pred_legend_path=pos_pred_legend_path,
                                                        neg_pred_legend_path=neg_pred_legend_path,
                                                        **generic_rule_fun_kwargs, resnet_type=RESNET_TYPE)

        # generate_cam_resnet18(cam_technique="gradcam",
        #                       db_dir=os.path.join(db_dir, "min/"),
        #                       xai_output_path=os.path.join(model_dir, xai_output_paths["gradcam"]),
        #                       datasets_dir_path=datasets_dir_path,
        #                       dataset_filename=dataset_filename,
        #                       model_dir=model_dir,
        #                       yes_pred_img_path=yes_pred_img_path,
        #                       no_pred_img_path=no_pred_img_path,
        #                       dataset_size=XAI_DATASET_SIZE,
        #                       device="cuda:0")
        #
        # create_xai_index(os.path.join(db_dir, "min/"), datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
        #                  model_dir=model_dir,
        #                  xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0",
        #                  resnet_type=RESNET_TYPE)


  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# generate_explanations(rules_data_S, db_S_dir, datasets_path_S)

In [12]:
# generate_explanations(rules_data_L, db_L_dir, datasets_path_L)

In [13]:
# generate_explanations(rules_data_M, db_M_dir, datasets_path_M)

In [14]:
# generate_explanations(rules_data_XS, db_XS_dir, datasets_path_XS)


In [15]:
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import csv
from xaipatimg.ml.xai import generate_LLM_explanations, create_xai_index
from tqdm import tqdm

model_id = "openai/gpt-oss-20b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
)


def gen_LLM_explanations(db_dir, rules_data, datasets_dir_path, X_divisions, Y_divisions):
    db = load_db(os.path.join(db_dir, "min/"))

    for rule_idx in tqdm(range(len(rules_data))):
        model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
        dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"

        # Extracting the subset of indices of samples selected for the experimental interface, in order to ease the cost of calculation
        interface_content_path = os.path.join(interface_dir, "res", "tasks",
                                              f"{rules_data[rule_idx]["name"]}_content.csv")
        interface_selected_idx = [int(row["og_idx"]) for row in
                                  list(csv.DictReader(open(interface_content_path), delimiter=','))]

        xai_output_paths = {
            "shap": "shap",
            "cf": "cf",
            "gradcam": "gradcam",
            "llm": "llm",
        }
        generate_LLM_explanations(os.path.join(db_dir, "min/"), db, datasets_dir_path=datasets_dir_path,
                                  dataset_filename=dataset_filename,
                                  model_dir=model_dir, llm_model=llm_model, llm_tokenizer=tokenizer,
                                  xai_output_path=os.path.join(model_dir, xai_output_paths["llm"]),
                                  question=rules_data[rule_idx]["question"],
                                  yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
                                  yes_pred_img_path_small=yes_small_pred_img_path,
                                  no_pred_img_path_small=no_small_pred_img_path,
                                  X_division=X_divisions, Y_division=Y_divisions,
                                  device="cuda:0", dataset_size=XAI_DATASET_SIZE,
                                  # only_for_index=interface_selected_idx,
                                  only_for_index=[30],
                                  path_to_counterfactuals_dir_for_model_errors=os.path.join(model_dir,
                                                                                            xai_output_paths["cf"]),
                                  pos_llm_scaffold=rules_data[rule_idx]["pos_llm_scaffold"],
                                  neg_llm_scaffold=rules_data[rule_idx]["neg_llm_scaffold"],
                                  pattern_dict=db_patterns[rules_data[rule_idx]["pattern_id"]]["cnt"] if "pattern_id" in
                                                                                                             rules_data[rule_idx] else None,
                                  resnet_type=RESNET_TYPE)

        create_xai_index(os.path.join(db_dir, "min/"), dataset_filename=dataset_filename,
                         datasets_dir_path=datasets_dir_path,
                         model_dir=model_dir,
                         xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0",
                         resnet_type=RESNET_TYPE)


`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 41 files: 100%|██████████| 41/41 [00:00<00:00, 182361.04it/s]
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]
Fetching 41 files: 100%|██████████| 41/41 [00:00<00:00, 213357.90it/s]
Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.02s/it]


In [16]:
#gen_LLM_explanations(db_L_dir, rules_data_L, datasets_path_L, X_divisions=X_DIVISIONS_L, Y_divisions=Y_DIVISIONS_L)

In [17]:
# gen_LLM_explanations(db_M_dir, rules_data_M, datasets_path_M, X_divisions=X_DIVISIONS_M, Y_divisions=Y_DIVISIONS_M)

In [18]:
gen_LLM_explanations(db_S_dir, rules_data_S, datasets_path_S, X_divisions=X_DIVISIONS_S, Y_divisions=Y_DIVISIONS_S)

  0%|          | 0/1 [00:00<?, ?it/s]Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


/home/docker/data/models/db3.2.0/01_expv1/easy3_find_pattern_rot/llm



  0%|          | 0/1 [00:00<?, ?it/s][A

You are the explainability system of an AI model. Your role is to justify the decisions of the model. The role of the model is to answer questions about the content of images of symbols of colors. The predictions of the model are always correct. The user will provide you the prediction of the AI model for a given image and the corresponding data. You need to give an explanation of the prediction. The explanation is expected to be a very short sentence which introduces a list of all coordinates that are involved in the model's prediction and which will be highlighted from your output. The justification sentence and the list of coordinates must be separated by the character '|'. The coordinates are separated with the symbol ';', and there is no need to sort them. Do not use escape characters or markdown syntax. The question the model must answer is 'Is the pattern or any of its left or right rotations in the image?'. The pattern to search for is 'A yellow square. Just on its bottom, a pu


100%|██████████| 1/1 [04:36<00:00, 276.97s/it][A

<|channel|>analysis<|message|>We need to provide explanation: coordinates involved in model prediction. The prediction is Yes: meaning pattern found. Pattern: 'A yellow square. Just on its bottom, a purple triangle. Just on its right, a yellow circle.'

We need to find coordinates that match this pattern in the image data. Let's parse the dataset: Each line has coordinate (X,Y): shape.

Coordinate system: X is left-right, Y is top-bottom. So leftmost X = 0. Coordinates like (x,y).

We need to identify a yellow square with its bottom adjacent to a purple triangle, and right adjacent a yellow circle. Adjacent means directly below? It says "Just on its bottom" meaning same X? "Just on its right" meaning same Y? The pattern examples earlier show coordinates like (2,5);(3,5);(2,6) for pattern found at location highlighted. Let's interpret that example: They listed 3 coordinates: (2,5) ; (3,5) ; (2,6). That suggests yellow square at (2,5), purple triangle at (3,5) (right?), yellow circle at 


Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 1/1 [04:51<00:00, 291.95s/it]


In [19]:
# gen_LLM_explanations(db_XS_dir, rules_data_XS, datasets_path_XS, X_divisions=X_DIVISIONS_XS, Y_divisions=Y_DIVISIONS_XS)