In [1]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
datasets_dir_path = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/datasets/02_protov3/"
model_dir_root = os.environ["DATA"] + "models/db2.0.0/02_protov3/"

shap_scale_img_path = os.path.join(db_dir,"shap_scale.png")
yes_pred_img_path = os.path.join(db_dir,"button_yes.png")
no_pred_img_path = os.path.join(db_dir,"button_no.png")
yes_small_pred_img_path = os.path.join(db_dir,"button_yes_small.png")
no_small_pred_img_path = os.path.join(db_dir,"button_no_small.png")
pos_pred_legend_path = os.path.join(db_dir,"cf_info_pos.png")
neg_pred_legend_path = os.path.join(db_dir,"cf_info_neg.png")
interface_dir = os.environ["DATA"] + "webinterfaces/int03_prototype/"

XAI_DATASET_SIZE = 100

N_JOBS = 20
N_JOBS_GPU = 6

In [2]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

explict_colors_dict = {
    "#A33E9A": "purple",
    "#E0B000": "yellow",
    "#0C90C0": "blue"
}

In [3]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape, generic_rule_N_times_color_exactly, \
    generic_rule_shape_color_plus_shape_equals_N, generic_rule_shape_in_every_row, generic_rule_shape_color_times_2_shape_equals_shape

rules_data = [
    {"name": "disc_1_triangle_all", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}, "question": "In the image, is there a triangle in every row (1, ..., 6)?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 5, "pos_llm_scaffold": "The AI predicts |YES| because every row contains at least one triangle : \n - Row 1 : XX, XX, XX\n- Row 2 : XX, XX, XX\n- Row 3 : XX, XX, XX\n- Row 4 : XX, XX, XX\n- Row 5 : XX, XX, XX\n- Row 5 : XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because the rows X and X do not contain any triangle."},

    {"name": "disc_1_triangle_all_2", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}, "question": "In the image, is there a triangle in every row (1, ..., 6)?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 5, "pos_llm_scaffold": "The AI predicts |YES| because every row contains at least one triangle : \n - Row 1 : XX, XX, XX\n- Row 2 : XX, XX, XX\n- Row 3 : XX, XX, XX\n- Row 4 : XX, XX, XX\n- Row 5 : XX, XX, XX\n- Row 5 : XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because the rows X and X do not contain any triangle."},

    {"name": "disc_1_triangle_all_3", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}, "question": "In the image, is there a triangle in every row (1, ..., 6)?", "target_acc" : 1.0, "shown_acc" : 1.0, "samples_interface": 5, "pos_llm_scaffold": "The AI predicts |YES| because every row contains at least one triangle : \n - Row 1 : XX, XX, XX\n- Row 2 : XX, XX, XX\n- Row 3 : XX, XX, XX\n- Row 4 : XX, XX, XX\n- Row 5 : XX, XX, XX\n- Row 5 : XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because the rows X and X do not contain any triangle."},


    {"name": "easy_1_6_blue", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#0C90C0", "N": 6, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, are there exactly 6 blue symbols?", "target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is exactly 6 blue symbols, which are located at :\n- XX\n- XX\n- XX\n- XX\n- XX\n- XX", "neg_llm_scaffold": "The AI predicts |NO| because there is X blue symbols instead of 6. They are located at : \n- XX\n- XX\n- XX\n- XX\n- XX."},

    {"name": "easy_2_row_circle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only circles?", "target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is at least one row which contains only circles : \nRow X contains only circles which are located at XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because there is not a single row containing only circles :\nRow 1 contains a non-circle symbol at XX\nRow 2 contains non-circle symbols at XX, XX, XX.\nRow 3 does not contain any symbol at all\n ..."},

    {"name": "easy_3_7_purple", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#A33E9A", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, are there exactly 7 purple symbols?", "target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is exactly 7 purple symbols, which are located at :\n- XX\n- XX\n- XX\n- XX\n- XX\n- XX\n- XX", "neg_llm_scaffold": "The AI predicts |NO| because there is X purple symbols instead of 7. They are located at : \n- XX\n- XX\n- XX\n- XX\n- XX\n- XX."},

    {"name": "easy_4_row_triangle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only triangles?", "target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is at least one row which contains only triangles : \nRow X contains only triangles which are located at XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because there is not a single row containing only triangles :\nRow 1 contains a non-triangle symbol at XX\nRow 2 contains non-triangle symbols at XX, XX, XX.\nRow 3 does not contain any symbol at all\n ..."},

    {"name": "easy_5_5_yellow", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#E0B000", "N": 5, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, are there exactly 5 yellow symbols?", "target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10,  "pos_llm_scaffold": "The AI predicts |YES| because there is exactly 5 yellow symbols, which are located at :\n- XX\n- XX\n- XX\n- XX\n- XX", "neg_llm_scaffold": "The AI predicts |NO| because there is X yellow symbols instead of 5, which are located at : \n- XX\n- XX\n- XX\n- XX\n- XX\n- XX."},

    {"name": "easy_6_row_square", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "square", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only squares?", "target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is at least one row which contains only squares : \nRow X contains only squares which are located at XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because there is not a single row containing only squares :\nRow 1 contains a non-square symbol at XX\nRow 2 contains non-square symbols at XX, XX, XX.\nRow 3 does not contain any symbol at all\n ..."},


    {"name": "hard_1_blue_square_plus_circle_8", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#0C90C0", "shape1": "square", "shape2": "circle", "N": 8, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of blue squares plus (+) the number of circles equal to 8?", "question_llm": "In the image, does the number of blue squares plus (+) the number of circles of any color equal to 8", "target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X blue squares at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X circles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X + X = 8", "neg_llm_scaffold": "The AI predicts |NO| because \n\n There is a total of X blue squares at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X circles at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = X ≠ 8"},

    {"name": "hard_2_yellow_triangles_times2_squares", "gen_fun": generic_rule_shape_color_times_2_shape_equals_shape, "gen_kwargs": {"color1": "#E0B000", "shape1": "triangle", "shape2": "square", "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of yellow triangles multiplied by 2 (×2) equal to the number of squares?", "question_llm": "In the image, does the number of yellow triangles multiplied by 2 (×2) equal to the number of squares of any color ?", "target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X yellow triangles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X squares at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X", "neg_llm_scaffold": "There is a total of X yellow triangles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X squares at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X ≠ X"},

    {"name": "hard_3_yellow_circle_plus_triangle_9", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#E0B000", "shape1": "circle", "shape2": "triangle", "N": 9, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, does the number of yellow circles plus (+) the number of triangles equal to 9?",
     "question_llm": "In the image, does the number of yellow circles plus (+) the number of triangles of any color equal to 9?","target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X yellow circles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X triangles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X + X = 9", "neg_llm_scaffold": "The AI predicts |NO| because \n\n There is a total of X yellow circles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X triangles at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = X ≠ 9"},
    #
    {"name": "hard_4_purple_squares_times2_circles", "gen_fun": generic_rule_shape_color_times_2_shape_equals_shape, "gen_kwargs": {"color1": "#A33E9A", "shape1": "square", "shape2": "circle", "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of purple squares multiplied by 2 (×2) equal to the number of circles?",
     "question_llm": "In the image, does the number of purple squares multiplied by 2 (×2) equal to the number of circles of any color?", "target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X purple squares at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X circles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X", "neg_llm_scaffold": "There is a total of X purple squares at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X circles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X ≠ X"},

    {"name": "hard_5_purple_triangle_plus_square_7", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#A33E9A", "shape1": "triangle", "shape2": "square", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, does the number of purple triangles plus (+) the number of squares equal to 7?",
     "question_llm": "In the image, does the number of purple triangles plus (+) the number of squares of any color equal to 7?", "target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X purple triangles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X squares at positions : \n- XX\n- XX\n- XX\n- XX\n\n X + X = 7", "neg_llm_scaffold": "The AI predicts |NO| because \n\n There is a total of X purple triangles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X squares at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = X ≠ 7"},

    {"name": "hard_6_blue_circles_times2_triangles", "gen_fun": generic_rule_shape_color_times_2_shape_equals_shape, "gen_kwargs": {"color1": "#0C90C0", "shape1": "circle", "shape2": "triangle", "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of blue circles multiplied by 2 (×2) equal to the number of triangles?", "question_llm": "In the image, does the number of blue circles multiplied by 2 (×2) equal to the number of triangles of any color?","target_acc": 0.9, "shown_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X blue circles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X triangles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X", "neg_llm_scaffold": "There is a total of X blue circles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X triangles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X ≠ X"},
]


In [4]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)

In [None]:
from xaipatimg.ml.xai import generate_shap_resnet18, generate_counterfactuals_resnet18_random_approach, \
    create_xai_index, generate_cam_resnet18
from tqdm import tqdm

for rule_idx in tqdm(range(len(rules_data))):

    model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
    dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"
    generic_rule_fun = rules_data[rule_idx]["gen_fun"]
    generic_rule_fun_kwargs = rules_data[rule_idx]["gen_kwargs"]
    xai_output_paths = {
        "shap" : "shap",
        "cf" : "cf"
    }

    generate_shap_resnet18(db_dir, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
                           model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["shap"]),
                           yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0", n_jobs=N_JOBS,
                           dataset_size=XAI_DATASET_SIZE, masker="ndarray", shap_scale_img_path=shap_scale_img_path)

    generate_counterfactuals_resnet18_random_approach(db_dir, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
                                                      model_dir=model_dir,
                                                      xai_output_path=os.path.join(model_dir, xai_output_paths["cf"]),
                                                      yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
                                                      shapes=SHAPES, colors=COLORS, empty_probability=1-SHAPE_PROB,
                                                      max_depth=10, nb_tries_per_depth=2000, generic_rule_fun=generic_rule_fun,
                                                      devices=["cuda:0", "cuda:1"], n_jobs=N_JOBS_GPU,
                                                      dataset_size=XAI_DATASET_SIZE,pos_pred_legend_path=pos_pred_legend_path,
                                                      neg_pred_legend_path=neg_pred_legend_path,
                                                      **generic_rule_fun_kwargs)

    create_xai_index(db_dir, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename, model_dir=model_dir,
                     xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0")


  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset content for hard_6_blue_circles_times2_triangles_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  8%|▊         | 8/100 [00:00<00:01, 73.92it/s][A
 16%|█▌        | 16/100 [00:00<00:01, 68.17it/s][A
 23%|██▎       | 23/100 [00:00<00:01, 54.66it/s][A
 30%|███       | 30/100 [00:00<00:01, 59.00it/s][A
 37%|███▋      | 37/100 [00:00<00:01, 57.89it/s][A
 45%|████▌     | 45/100 [00:00<00:00, 62.91it/s][A
 52%|█████▏    | 52/100 [00:00<00:00, 54.89it/s][A
 59%|█████▉    | 59/100 [00:00<00:00, 58.03it/s][A
 66%|██████▌   | 66/100 [00:01<00:00, 59.18it/s][A
 73%|███████▎  | 73/100 [00:01<00:00, 59.81it/s][A
 80%|████████  | 80/100 [00:01<00:00, 53.32it/s][A
 86%|████████▌ | 86/100 [00:01<00:00, 54.80it/s][A
 93%|█████████▎| 93/100 [00:01<00:00, 57.78it/s][A
100%|██████████| 100/100 [00:01<00:00, 58.50it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


Computing shap values



  0%|          | 0/9998 [00:00<?, ?it/s][A
 18%|█▊        | 1792/9998 [00:00<00:00, 12712.25it/s][A
 31%|███       | 3092/9998 [00:03<00:10, 673.88it/s]  [A
 36%|███▋      | 3642/9998 [00:05<00:11, 555.81it/s][A
 40%|███▉      | 3992/9998 [00:06<00:11, 503.59it/s][A
 42%|████▏     | 4242/9998 [00:07<00:12, 469.62it/s][A
 44%|████▍     | 4392/9998 [00:07<00:12, 450.86it/s][A
 45%|████▌     | 4542/9998 [00:07<00:12, 431.00it/s][A
 46%|████▋     | 4642/9998 [00:08<00:12, 418.28it/s][A
 47%|████▋     | 4742/9998 [00:08<00:12, 406.60it/s][A
 48%|████▊     | 4842/9998 [00:08<00:13, 395.52it/s][A
 49%|████▉     | 4942/9998 [00:09<00:13, 384.91it/s][A
 50%|████▉     | 4992/9998 [00:09<00:13, 379.79it/s][A
 50%|█████     | 5042/9998 [00:09<00:13, 374.56it/s][A
 51%|█████     | 5092/9998 [00:09<00:13, 369.60it/s][A
 51%|█████▏    | 5142/9998 [00:09<00:13, 365.30it/s][A
 52%|█████▏    | 5192/9998 [00:09<00:13, 362.18it/s][A
 52%|█████▏    | 5242/9998 [00:10<00:13, 358.69it/s][A

Generating shap images



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:00<00:22,  4.40it/s][A
 20%|██        | 20/100 [00:04<00:19,  4.16it/s][A
 40%|████      | 40/100 [00:14<00:22,  2.61it/s][A
 60%|██████    | 60/100 [00:15<00:09,  4.19it/s][A
 80%|████████  | 80/100 [00:16<00:03,  6.00it/s][A
100%|██████████| 100/100 [00:18<00:00,  5.53it/s][A


Loading dataset content for hard_6_blue_circles_times2_triangles_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  6%|▌         | 6/100 [00:00<00:01, 54.36it/s][A
 13%|█▎        | 13/100 [00:00<00:01, 61.20it/s][A
 20%|██        | 20/100 [00:00<00:01, 51.34it/s][A
 28%|██▊       | 28/100 [00:00<00:01, 58.81it/s][A
 35%|███▌      | 35/100 [00:00<00:01, 56.24it/s][A
 41%|████      | 41/100 [00:00<00:01, 55.70it/s][A
 48%|████▊     | 48/100 [00:00<00:00, 54.87it/s][A
 54%|█████▍    | 54/100 [00:00<00:00, 55.61it/s][A
 60%|██████    | 60/100 [00:01<00:00, 55.05it/s][A
 66%|██████▌   | 66/100 [00:01<00:00, 54.07it/s][A
 73%|███████▎  | 73/100 [00:01<00:00, 53.56it/s][A
 81%|████████  | 81/100 [00:01<00:00, 59.09it/s][A
 89%|████████▉ | 89/100 [00:01<00:00, 60.00it/s][A
100%|██████████| 100/100 [00:01<00:00, 56.07it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


Generating counterfactual images



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:04<07:47,  4.72s/it][A
  6%|▌         | 6/100 [00:07<01:41,  1.08s/it][A

Loading dataset content for /tmp/tmpilthtuln/dataset.csv
Loading dataset content for /tmp/tmp37k9f3hj/dataset.csv


  0%|          | 0/2000 [00:00<?, ?it/s]24.14it/s]

Loading dataset content for /tmp/tmp1dscynmm/dataset.csv
Loading dataset content for /tmp/tmpfndrqjae/dataset.csv
Loading dataset content for /tmp/tmpbq6wgsh9/dataset.csv


  5%|▌         | 104/2000 [00:00<00:14, 126.65it/s]

Loading dataset content for /tmp/tmpttq65bav/dataset.csv


100%|██████████| 2000/2000 [00:16<00:00, 119.48it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 119.39it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:17<00:00, 116.86it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 118.37it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 117.78it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 119.03it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

 12%|█▏        | 12/100 [01:21<11:33,  7.88s/it][A

Loading dataset content for /tmp/tmpwdaa3p_n/dataset.csv


  0%|          | 0/2000 [00:00<?, ?it/s]21.76it/s]

Loading dataset content for /tmp/tmpto15p0we/dataset.csv


  3%|▎         | 52/2000 [00:00<00:15, 126.62it/s]]

Loading dataset content for /tmp/tmpuuslhcnq/dataset.csv


  5%|▌         | 104/2000 [00:00<00:15, 126.06it/s]

Loading dataset content for /tmp/tmp_u_u8nab/dataset.csv
Loading dataset content for /tmp/tmpy7pgykop/dataset.csv


  8%|▊         | 156/2000 [00:01<00:14, 125.09it/s]

Loading dataset content for /tmp/tmp1e4i3_nt/dataset.csv


100%|██████████| 2000/2000 [00:16<00:00, 120.79it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:15<00:00, 125.19it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 120.50it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 121.07it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 118.19it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 2000/2000 [00:16<00:00, 120.11it/s]
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0


In [5]:
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import csv
from xaipatimg.ml.xai import generate_LLM_explanations, create_xai_index
from tqdm import tqdm

model_id = "openai/gpt-oss-20b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
)

for rule_idx in tqdm(range(len(rules_data))):

    model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
    dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"

    # Extracting the subset of indices of samples selected for the experimental interface, in order to ease the cost of calculation
    interface_content_path = os.path.join(interface_dir, "res", "tasks", f"{rules_data[rule_idx]["name"]}_content.csv")
    interface_selected_idx = [int(row["og_idx"]) for row in list(csv.DictReader(open(interface_content_path), delimiter=','))]

    xai_output_paths = {
        "shap" : "shap",
        "cf" : "cf",
        "llm" : "llm",
    }
    generate_LLM_explanations(db_dir, db, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename,
                              model_dir=model_dir, llm_model=llm_model, llm_tokenizer=tokenizer,
                              xai_output_path=os.path.join(model_dir, xai_output_paths["llm"]),
                              explicit_colors_dict=explict_colors_dict, question=rules_data[rule_idx]["question"],
                              yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
                              yes_pred_img_path_small=yes_small_pred_img_path, no_pred_img_path_small=no_small_pred_img_path,
                              device="cuda:0", dataset_size=XAI_DATASET_SIZE, only_for_index=interface_selected_idx,
                              path_to_counterfactuals_dir_for_model_errors=os.path.join(model_dir, xai_output_paths["cf"]),
                              pos_llm_scaffold=rules_data[rule_idx]["pos_llm_scaffold"],
                              neg_llm_scaffold=rules_data[rule_idx]["neg_llm_scaffold"])

    create_xai_index(db_dir, datasets_dir_path=datasets_dir_path, dataset_filename=dataset_filename, model_dir=model_dir,
                     xai_dirs=xai_output_paths,
                     dataset_size=XAI_DATASET_SIZE,device="cuda:0")


  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 41 files: 100%|██████████| 41/41 [00:00<00:00, 176014.80it/s]
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]
Fetching 41 files: 100%|██████████| 41/41 [00:00<00:00, 211001.80it/s]
Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.17s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

/home/docker/data/models/db2.0.0/02_protov3/hard_6_blue_circles_times2_triangles/llm
Loading dataset content for hard_6_blue_circles_times2_triangles_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 10%|█         | 10/100 [00:00<00:00, 97.14it/s][A
 20%|██        | 20/100 [00:00<00:01, 67.35it/s][A
 28%|██▊       | 28/100 [00:00<00:01, 68.66it/s][A
 36%|███▌      | 36/100 [00:00<00:00, 67.32it/s][A
 43%|████▎     | 43/100 [00:00<00:00, 58.95it/s][A
 50%|█████     | 50/100 [00:00<00:00, 60.51it/s][A
 57%|█████▋    | 57/100 [00:00<00:00, 61.40it/s][A
 64%|██████▍   | 64/100 [00:01<00:00, 61.39it/s][A
 71%|███████   | 71/100 [00:01<00:00, 63.48it/s][A
 78%|███████▊  | 78/100 [00:01<00:00, 62.95it/s][A
 85%|████████▌ | 85/100 [00:01<00:00, 55.98it/s][A
 92%|█████████▏| 92/100 [00:01<00:00, 58.66it/s][A
100%|██████████| 100/100 [00:01<00:00, 62.88it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [03:32<31:53, 212.63s/it][A
 20%|██        | 2/10 [05:58<23:07, 173.42s/it][A
 30%|███       | 3/10 [08:36<19:22, 166.13s/it][A
 40%|████

Loading dataset content for hard_6_blue_circles_times2_triangles_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  9%|▉         | 9/100 [00:00<00:01, 83.35it/s][A
 18%|█▊        | 18/100 [00:00<00:01, 74.34it/s][A
 26%|██▌       | 26/100 [00:00<00:01, 59.63it/s][A
 33%|███▎      | 33/100 [00:00<00:01, 60.67it/s][A
 40%|████      | 40/100 [00:00<00:00, 61.66it/s][A
 47%|████▋     | 47/100 [00:00<00:00, 63.73it/s][A
 54%|█████▍    | 54/100 [00:00<00:00, 55.45it/s][A
 61%|██████    | 61/100 [00:00<00:00, 58.31it/s][A
 68%|██████▊   | 68/100 [00:01<00:00, 60.13it/s][A
 75%|███████▌  | 75/100 [00:01<00:00, 61.52it/s][A
 82%|████████▏ | 82/100 [00:01<00:00, 62.45it/s][A
 89%|████████▉ | 89/100 [00:01<00:00, 63.17it/s][A
100%|██████████| 100/100 [00:01<00:00, 61.53it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 1/1 [30:04<00:00, 1804.54s/it]
