In [1]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
model_dir_root = os.environ["DATA"] + "models/db2.0.0/"
shap_scale_img_path = os.path.join(db_dir,"shap_scale.png")
yes_pred_img_path = os.path.join(db_dir,"button_yes.png")
no_pred_img_path = os.path.join(db_dir,"button_no.png")
yes_small_pred_img_path = os.path.join(db_dir,"button_yes_small.png")
no_small_pred_img_path = os.path.join(db_dir,"button_no_small.png")
pos_pred_legend_path = os.path.join(db_dir,"cf_info_pos.png")
neg_pred_legend_path = os.path.join(db_dir,"cf_info_neg.png")
interface_dir = os.environ["DATA"] + "webinterfaces/int02_prototype/"

XAI_DATASET_SIZE = 100

N_JOBS = 20
N_JOBS_GPU = 6

In [2]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

explict_colors_dict = {
    "#A33E9A": "purple",
    "#E0B000": "yellow",
    "#0C90C0": "blue"
}

In [3]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape, generic_rule_N_times_color_exactly, \
    generic_rule_shape_color_plus_shape_equals_N, generic_rule_exist_row_with_only_color_and_col_with_only_shape, \
    generic_rule_shape_in_every_row

rules_data = [
    # {"name": "disc_1_triangle_all", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}, "question": "In the image, is there a triangle in every row (1, ..., 6)?", "target_acc" : 1.0, "samples_interface": 5, "pos_llm_scaffold": "The AI predicts |YES| because every row contains at least one triangle : \n - Row 1 : XX, XX, XX\n- Row 2 : XX, XX, XX\n- Row 3 : XX, XX, XX\n- Row 4 : XX, XX, XX\n- Row 5 : XX, XX, XX\n- Row 5 : XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because the rows X and X do not contain any triangle."},

    {"name": "easy_1_6_blue", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#0C90C0", "N": 6, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 6 blue symbols?", "target_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is exactly 6 blue symbols, which are located at :\n- XX\n- XX\n- XX\n- XX\n- XX\n- XX", "neg_llm_scaffold": "The AI predicts |NO| because there is X blue symbols instead of 6. They are located at : \n- XX\n- XX\n- XX\n- XX\n- XX."},

    {"name": "easy_2_row_circle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only circles?", "target_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is at least one row which contains only circles : \nRow X contains only circles which are located at XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because there is not a single row containing only circles :\nRow 1 contains a non-circle symbol at XX\nRow 2 contains non-circle symbols at XX, XX, XX.\nRow 3 does not contain any symbol at all\n ..."},

    {"name": "easy_3_7_purple", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#A33E9A", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 7 purple symbols?", "target_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is exactly 7 purple symbols, which are located at :\n- XX\n- XX\n- XX\n- XX\n- XX\n- XX\n- XX", "neg_llm_scaffold": "The AI predicts |NO| because there is X purple symbols instead of 7. They are located at : \n- XX\n- XX\n- XX\n- XX\n- XX\n- XX."},

    {"name": "easy_4_row_triangle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only triangles?", "target_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is at least one row which contains only triangles : \nRow X contains only triangles which are located at XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because there is not a single row containing only triangles :\nRow 1 contains a non-triangle symbol at XX\nRow 2 contains non-triangle symbols at XX, XX, XX.\nRow 3 does not contain any symbol at all\n ..."},

    # {"name": "easy_5_5_yellow", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#E0B000", "N": 5, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 5 yellow symbols?", "target_acc": 0.8, "samples_interface": 10,  "pos_llm_scaffold": "The AI predicts |YES| because there is exactly 5 yellow symbols, which are located at :\n- XX\n- XX\n- XX\n- XX\n- XX", "neg_llm_scaffold": "The AI predicts |NO| because there is X yellow symbols instead of 5, which are located at : \n- XX\n- XX\n- XX\n- XX\n- XX\n- XX."},

    {"name": "easy_6_row_square", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "square", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only squares?", "target_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is at least one row which contains only squares : \nRow X contains only squares which are located at XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because there is not a single row containing only squares :\nRow 1 contains a non-square symbol at XX\nRow 2 contains non-square symbols at XX, XX, XX.\nRow 3 does not contain any symbol at all\n ..."},



    {"name": "hard_1_blue_square_plus_circle_8", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#0C90C0", "shape1": "square", "shape2": "circle", "N": 8, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of blue squares plus (+) the number of circles equal to 8?", "target_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X blue squares at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X circles at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = 8", "neg_llm_scaffold": "The AI predicts |NO| because \n\n There is a total of X blue squares at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X circles at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = X ≠ 8"},

    # {"name": "hard_2_row_purple_col_triangle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#A33E9A", "shape": "triangle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only purple symbols, and one column (A, ..., F) containing only triangles?", "target_acc": 0.8, "samples_interface": 10},

    {"name": "hard_3_yellow_circle_plus_triangle_9", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#E0B000", "shape1": "circle", "shape2": "triangle", "N": 9, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, does the number of yellow circles plus (+) the number of triangles equal to 9?", "target_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X yellow circles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X triangles at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = 9", "neg_llm_scaffold": "The AI predicts |NO| because \n\n There is a total of X yellow circles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X triangles at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = X ≠ 9"},

    # {"name": "hard_4_row_yellow_col_circle", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#E0B000", "shape": "circle" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only yellow symbols, and one column (A, ..., F) containing only circles?", "target_acc": 0.8, "samples_interface": 10},

    {"name": "hard_5_purple_triangle_plus_square_7", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#A33E9A", "shape1": "triangle", "shape2": "square", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, does the number of purple triangles plus (+) the number of squares equal to 7?", "target_acc": 0.8, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X purple triangles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X squares at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = 7", "neg_llm_scaffold": "The AI predicts |NO| because \n\n There is a total of X purple triangles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X squares at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = X ≠ 7"},

    # {"name": "hard_6_row_blue_col_square", "gen_fun": generic_rule_exist_row_with_only_color_and_col_with_only_shape, "gen_kwargs": {"color": "#0C90C0", "shape": "square" ,"x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
    #  "question": "In the image, is there at least one row (1, ..., 6) containing only blue symbols, and one column (A, ..., F) containing only squares?", "target_acc": 0.8, "samples_interface": 10},
]

In [4]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)

In [None]:
from xaipatimg.ml.xai import generate_shap_resnet18, generate_counterfactuals_resnet18_random_approach, \
    create_xai_index
from tqdm import tqdm

for rule_idx in tqdm(range(len(rules_data))):

    model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
    dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"
    generic_rule_fun = rules_data[rule_idx]["gen_fun"]
    generic_rule_fun_kwargs = rules_data[rule_idx]["gen_kwargs"]
    xai_output_paths = {
        "shap" : "shap",
        "cf" : "cf",
    }

    # generate_shap_resnet18(db_dir, dataset_filename=dataset_filename,
    #                        model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["shap"]),
    #                        yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0", n_jobs=N_JOBS,
    #                        dataset_size=XAI_DATASET_SIZE, masker="ndarray", shap_scale_img_path=shap_scale_img_path)
    #
    # generate_counterfactuals_resnet18_random_approach(db_dir, dataset_filename=dataset_filename,
    #                                                   model_dir=model_dir,
    #                                                   xai_output_path=os.path.join(model_dir, xai_output_paths["cf"]),
    #                                                   yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
    #                                                   shapes=SHAPES, colors=COLORS, empty_probability=1-SHAPE_PROB,
    #                                                   max_depth=10, nb_tries_per_depth=2000, generic_rule_fun=generic_rule_fun,
    #                                                   devices=["cuda:0", "cuda:1"], n_jobs=N_JOBS_GPU,
    #                                                   dataset_size=XAI_DATASET_SIZE,pos_pred_legend_path=pos_pred_legend_path,
    #                                                   neg_pred_legend_path=neg_pred_legend_path,
    #                                                   **generic_rule_fun_kwargs)
    #
    # create_xai_index(db_dir, dataset_filename=dataset_filename, model_dir=model_dir, xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0")


In [5]:
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
import csv
from xaipatimg.ml.xai import generate_LLM_explanations, create_xai_index
from tqdm import tqdm

model_id = "openai/gpt-oss-20b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
)

for rule_idx in tqdm(range(len(rules_data))):

    model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
    dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"

    # Extracting the subset of indices of samples selected for the experimental interface, in order to ease the cost of calculation
    interface_content_path = os.path.join(interface_dir, "res", "tasks", f"{rules_data[rule_idx]["name"]}_content.csv")
    interface_selected_idx = [int(row["og_idx"]) for row in list(csv.DictReader(open(interface_content_path), delimiter=','))]

    xai_output_paths = {
        "shap" : "shap",
        "cf" : "cf",
        "llm" : "llm",
    }
    generate_LLM_explanations(db_dir, db, dataset_filename=dataset_filename, model_dir=model_dir, llm_model=llm_model, llm_tokenizer=tokenizer,
                              xai_output_path=os.path.join(model_dir, xai_output_paths["llm"]),
                              explicit_colors_dict=explict_colors_dict, question=rules_data[rule_idx]["question"],
                              yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
                              yes_pred_img_path_small=yes_small_pred_img_path, no_pred_img_path_small=no_small_pred_img_path,
                              device="cuda:0", dataset_size=XAI_DATASET_SIZE, only_for_index=interface_selected_idx,
                              path_to_counterfactuals_dir_for_model_errors=os.path.join(model_dir, xai_output_paths["cf"]),
                              pos_llm_scaffold=rules_data[rule_idx]["pos_llm_scaffold"], neg_llm_scaffold=rules_data[rule_idx]["neg_llm_scaffold"])

    create_xai_index(db_dir, dataset_filename=dataset_filename, model_dir=model_dir, xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0")


  from .autonotebook import tqdm as notebook_tqdm
`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 41 files: 100%|██████████| 41/41 [00:00<00:00, 178203.59it/s]
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]
Fetching 41 files: 100%|██████████| 41/41 [00:00<00:00, 214958.08it/s]
Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.13s/it]
  0%|          | 0/8 [00:00<?, ?it/s]

/home/docker/data/models/db2.0.0/easy_1_6_blue/llm
Loading dataset content for easy_1_6_blue_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  6%|▌         | 6/100 [00:00<00:01, 59.07it/s][A
 13%|█▎        | 13/100 [00:00<00:01, 64.69it/s][A
 20%|██        | 20/100 [00:00<00:01, 64.34it/s][A
 27%|██▋       | 27/100 [00:00<00:01, 62.87it/s][A
 34%|███▍      | 34/100 [00:00<00:01, 64.95it/s][A
 41%|████      | 41/100 [00:00<00:00, 63.64it/s][A
 48%|████▊     | 48/100 [00:00<00:00, 55.75it/s][A
 55%|█████▌    | 55/100 [00:00<00:00, 58.98it/s][A
 62%|██████▏   | 62/100 [00:01<00:00, 60.90it/s][A
 69%|██████▉   | 69/100 [00:01<00:00, 61.97it/s][A
 76%|███████▌  | 76/100 [00:01<00:00, 61.81it/s][A
 83%|████████▎ | 83/100 [00:01<00:00, 62.99it/s][A
 90%|█████████ | 90/100 [00:01<00:00, 56.17it/s][A
100%|██████████| 100/100 [00:01<00:00, 61.10it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [02:03<18:33, 123.68s/it][A
 20%|██        | 2/10 [03:14<12:21, 92.71s/it] [A
 30%|███  

Loading dataset content for easy_1_6_blue_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  7%|▋         | 7/100 [00:00<00:01, 67.01it/s][A
 14%|█▍        | 14/100 [00:00<00:01, 66.90it/s][A
 21%|██        | 21/100 [00:00<00:01, 65.75it/s][A
 28%|██▊       | 28/100 [00:00<00:01, 65.53it/s][A
 35%|███▌      | 35/100 [00:00<00:01, 63.86it/s][A
 42%|████▏     | 42/100 [00:00<00:00, 63.73it/s][A
 49%|████▉     | 49/100 [00:00<00:00, 56.82it/s][A
 56%|█████▌    | 56/100 [00:00<00:00, 59.33it/s][A
 63%|██████▎   | 63/100 [00:01<00:00, 60.80it/s][A
 70%|███████   | 70/100 [00:01<00:00, 61.56it/s][A
 77%|███████▋  | 77/100 [00:01<00:00, 62.72it/s][A
 84%|████████▍ | 84/100 [00:01<00:00, 63.04it/s][A
 91%|█████████ | 91/100 [00:01<00:00, 56.25it/s][A
100%|██████████| 100/100 [00:01<00:00, 61.38it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
 12%|█▎        | 1/8 [16:38<1:56:30, 998.58s/it]

/home/docker/data/models/db2.0.0/easy_2_row_circle/llm
Loading dataset content for easy_2_row_circle_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  9%|▉         | 9/100 [00:00<00:01, 72.81it/s][A
 17%|█▋        | 17/100 [00:00<00:01, 70.30it/s][A
 25%|██▌       | 25/100 [00:00<00:01, 70.56it/s][A
 33%|███▎      | 33/100 [00:00<00:01, 60.02it/s][A
 40%|████      | 40/100 [00:00<00:00, 61.42it/s][A
 47%|████▋     | 47/100 [00:00<00:00, 62.28it/s][A
 54%|█████▍    | 54/100 [00:00<00:00, 62.87it/s][A
 61%|██████    | 61/100 [00:00<00:00, 62.36it/s][A
 68%|██████▊   | 68/100 [00:01<00:00, 63.99it/s][A
 75%|███████▌  | 75/100 [00:01<00:00, 62.08it/s][A
 82%|████████▏ | 82/100 [00:01<00:00, 57.44it/s][A
 89%|████████▉ | 89/100 [00:01<00:00, 60.11it/s][A
100%|██████████| 100/100 [00:01<00:00, 61.51it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [02:23<21:33, 143.71s/it][A
 20%|██        | 2/10 [04:45<18:59, 142.38s/it][A
 30%|███       | 3/10 [06:12<13:39, 117.11s/it][A
 40%|████  

Loading dataset content for easy_2_row_circle_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  8%|▊         | 8/100 [00:00<00:01, 70.58it/s][A
 16%|█▌        | 16/100 [00:00<00:01, 69.83it/s][A
 23%|██▎       | 23/100 [00:00<00:01, 67.64it/s][A
 30%|███       | 30/100 [00:00<00:01, 65.20it/s][A
 37%|███▋      | 37/100 [00:00<00:00, 65.37it/s][A
 44%|████▍     | 44/100 [00:00<00:00, 65.37it/s][A
 51%|█████     | 51/100 [00:00<00:00, 57.37it/s][A
 57%|█████▋    | 57/100 [00:00<00:00, 57.26it/s][A
 64%|██████▍   | 64/100 [00:01<00:00, 59.80it/s][A
 71%|███████   | 71/100 [00:01<00:00, 61.27it/s][A
 78%|███████▊  | 78/100 [00:01<00:00, 62.16it/s][A
 85%|████████▌ | 85/100 [00:01<00:00, 62.50it/s][A
 92%|█████████▏| 92/100 [00:01<00:00, 55.33it/s][A
100%|██████████| 100/100 [00:01<00:00, 61.30it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
 25%|██▌       | 2/8 [39:52<2:03:07, 1231.29s/it]

/home/docker/data/models/db2.0.0/easy_3_7_purple/llm
Loading dataset content for easy_3_7_purple_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  9%|▉         | 9/100 [00:00<00:01, 72.13it/s][A
 17%|█▋        | 17/100 [00:00<00:01, 70.35it/s][A
 25%|██▌       | 25/100 [00:00<00:01, 67.71it/s][A
 32%|███▏      | 32/100 [00:00<00:01, 66.79it/s][A
 39%|███▉      | 39/100 [00:00<00:01, 57.46it/s][A
 45%|████▌     | 45/100 [00:00<00:00, 57.77it/s][A
 52%|█████▏    | 52/100 [00:00<00:00, 60.96it/s][A
 59%|█████▉    | 59/100 [00:00<00:00, 62.04it/s][A
 66%|██████▌   | 66/100 [00:01<00:00, 62.62it/s][A
 73%|███████▎  | 73/100 [00:01<00:00, 63.15it/s][A
 80%|████████  | 80/100 [00:01<00:00, 62.45it/s][A
 88%|████████▊ | 88/100 [00:01<00:00, 58.15it/s][A
100%|██████████| 100/100 [00:01<00:00, 61.44it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [02:08<19:19, 128.80s/it][A
 20%|██        | 2/10 [04:08<16:25, 123.23s/it][A
 30%|███       | 3/10 [05:30<12:10, 104.37s/it][A
 40%|████  

Loading dataset content for easy_3_7_purple_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  8%|▊         | 8/100 [00:00<00:01, 68.26it/s][A
 15%|█▌        | 15/100 [00:00<00:01, 67.93it/s][A
 22%|██▏       | 22/100 [00:00<00:01, 66.21it/s][A
 29%|██▉       | 29/100 [00:00<00:01, 66.64it/s][A
 36%|███▌      | 36/100 [00:00<00:00, 65.31it/s][A
 43%|████▎     | 43/100 [00:00<00:00, 65.37it/s][A
 50%|█████     | 50/100 [00:00<00:00, 65.92it/s][A
 57%|█████▋    | 57/100 [00:00<00:00, 57.17it/s][A
 64%|██████▍   | 64/100 [00:01<00:00, 60.36it/s][A
 71%|███████   | 71/100 [00:01<00:00, 60.79it/s][A
 78%|███████▊  | 78/100 [00:01<00:00, 61.76it/s][A
 85%|████████▌ | 85/100 [00:01<00:00, 62.00it/s][A
 92%|█████████▏| 92/100 [00:01<00:00, 62.39it/s][A
100%|██████████| 100/100 [00:01<00:00, 61.44it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
 38%|███▊      | 3/8 [58:08<1:37:27, 1169.59s/it]

/home/docker/data/models/db2.0.0/easy_4_row_triangle/llm
Loading dataset content for easy_4_row_triangle_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 13%|█▎        | 13/100 [00:00<00:00, 121.39it/s][A
 26%|██▌       | 26/100 [00:00<00:00, 80.18it/s] [A
 35%|███▌      | 35/100 [00:00<00:00, 66.63it/s][A
 43%|████▎     | 43/100 [00:00<00:00, 67.04it/s][A
 51%|█████     | 51/100 [00:00<00:00, 66.65it/s][A
 58%|█████▊    | 58/100 [00:00<00:00, 59.71it/s][A
 65%|██████▌   | 65/100 [00:00<00:00, 61.17it/s][A
 72%|███████▏  | 72/100 [00:01<00:00, 62.44it/s][A
 79%|███████▉  | 79/100 [00:01<00:00, 62.69it/s][A
 86%|████████▌ | 86/100 [00:01<00:00, 63.73it/s][A
 93%|█████████▎| 93/100 [00:01<00:00, 64.10it/s][A
100%|██████████| 100/100 [00:01<00:00, 63.79it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [01:36<14:29, 96.56s/it][A
 20%|██        | 2/10 [02:40<10:20, 77.57s/it][A
 30%|███       | 3/10 [05:00<12:22, 106.04s/it][A
 40%|████      | 4/10 [06:50<10:45, 107.61s/it][A
 50%|█████

Loading dataset content for easy_4_row_triangle_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 10%|█         | 10/100 [00:00<00:00, 92.93it/s][A
 20%|██        | 20/100 [00:00<00:01, 67.19it/s][A
 28%|██▊       | 28/100 [00:00<00:01, 67.89it/s][A
 36%|███▌      | 36/100 [00:00<00:00, 68.37it/s][A
 43%|████▎     | 43/100 [00:00<00:00, 67.23it/s][A
 50%|█████     | 50/100 [00:00<00:00, 64.28it/s][A
 57%|█████▋    | 57/100 [00:00<00:00, 57.81it/s][A
 64%|██████▍   | 64/100 [00:00<00:00, 61.04it/s][A
 71%|███████   | 71/100 [00:01<00:00, 61.80it/s][A
 78%|███████▊  | 78/100 [00:01<00:00, 62.31it/s][A
 85%|████████▌ | 85/100 [00:01<00:00, 63.28it/s][A
 92%|█████████▏| 92/100 [00:01<00:00, 64.68it/s][A
100%|██████████| 100/100 [00:01<00:00, 62.89it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
 50%|█████     | 4/8 [1:18:40<1:19:35, 1193.97s/it]

/home/docker/data/models/db2.0.0/easy_6_row_square/llm
Loading dataset content for easy_6_row_square_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  8%|▊         | 8/100 [00:00<00:01, 70.44it/s][A
 16%|█▌        | 16/100 [00:00<00:01, 69.46it/s][A
 23%|██▎       | 23/100 [00:00<00:01, 67.24it/s][A
 30%|███       | 30/100 [00:00<00:01, 65.58it/s][A
 37%|███▋      | 37/100 [00:00<00:00, 63.87it/s][A
 44%|████▍     | 44/100 [00:00<00:00, 57.08it/s][A
 51%|█████     | 51/100 [00:00<00:00, 59.60it/s][A
 58%|█████▊    | 58/100 [00:00<00:00, 61.24it/s][A
 65%|██████▌   | 65/100 [00:01<00:00, 62.50it/s][A
 72%|███████▏  | 72/100 [00:01<00:00, 64.22it/s][A
 79%|███████▉  | 79/100 [00:01<00:00, 64.28it/s][A
 86%|████████▌ | 86/100 [00:01<00:00, 64.54it/s][A
 93%|█████████▎| 93/100 [00:01<00:00, 56.50it/s][A
100%|██████████| 100/100 [00:01<00:00, 61.56it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [02:06<18:56, 126.33s/it][A
 20%|██        | 2/10 [03:59<15:49, 118.67s/it][A
 30%|███  

Loading dataset content for easy_6_row_square_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 10%|█         | 10/100 [00:00<00:01, 72.39it/s][A
 18%|█▊        | 18/100 [00:00<00:01, 73.73it/s][A
 26%|██▌       | 26/100 [00:00<00:01, 71.94it/s][A
 34%|███▍      | 34/100 [00:00<00:00, 68.98it/s][A
 42%|████▏     | 42/100 [00:00<00:00, 61.58it/s][A
 49%|████▉     | 49/100 [00:00<00:00, 63.19it/s][A
 56%|█████▌    | 56/100 [00:00<00:00, 63.81it/s][A
 63%|██████▎   | 63/100 [00:00<00:00, 63.18it/s][A
 70%|███████   | 70/100 [00:01<00:00, 63.91it/s][A
 77%|███████▋  | 77/100 [00:01<00:00, 65.17it/s][A
 84%|████████▍ | 84/100 [00:01<00:00, 64.10it/s][A
 91%|█████████ | 91/100 [00:01<00:00, 56.72it/s][A
100%|██████████| 100/100 [00:01<00:00, 63.47it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
 62%|██████▎   | 5/8 [1:38:08<59:14, 1184.73s/it]  

/home/docker/data/models/db2.0.0/hard_1_blue_square_plus_circle_8/llm
Loading dataset content for hard_1_blue_square_plus_circle_8_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 11%|█         | 11/100 [00:00<00:00, 106.37it/s][A
 22%|██▏       | 22/100 [00:00<00:01, 72.33it/s] [A
 30%|███       | 30/100 [00:00<00:00, 71.46it/s][A
 38%|███▊      | 38/100 [00:00<00:00, 70.64it/s][A
 46%|████▌     | 46/100 [00:00<00:00, 61.56it/s][A
 53%|█████▎    | 53/100 [00:00<00:00, 63.35it/s][A
 60%|██████    | 60/100 [00:00<00:00, 64.75it/s][A
 67%|██████▋   | 67/100 [00:00<00:00, 65.07it/s][A
 74%|███████▍  | 74/100 [00:01<00:00, 64.41it/s][A
 81%|████████  | 81/100 [00:01<00:00, 64.47it/s][A
 88%|████████▊ | 88/100 [00:01<00:00, 57.18it/s][A
100%|██████████| 100/100 [00:01<00:00, 63.75it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [04:25<39:48, 265.35s/it][A
 20%|██        | 2/10 [07:11<27:37, 207.23s/it][A
 30%|███       | 3/10 [09:36<20:49, 178.45s/it][A
 40%|████      | 4/10 [12:16<17:08, 171.39s/it][A
 50%|████

Loading dataset content for hard_1_blue_square_plus_circle_8_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 12%|█▏        | 12/100 [00:00<00:00, 110.87it/s][A
 24%|██▍       | 24/100 [00:00<00:00, 76.37it/s] [A
 33%|███▎      | 33/100 [00:00<00:00, 67.35it/s][A
 41%|████      | 41/100 [00:00<00:00, 66.74it/s][A
 48%|████▊     | 48/100 [00:00<00:00, 66.87it/s][A
 55%|█████▌    | 55/100 [00:00<00:00, 66.08it/s][A
 62%|██████▏   | 62/100 [00:00<00:00, 65.47it/s][A
 69%|██████▉   | 69/100 [00:01<00:00, 57.80it/s][A
 75%|███████▌  | 75/100 [00:01<00:00, 58.02it/s][A
 82%|████████▏ | 82/100 [00:01<00:00, 60.29it/s][A
 89%|████████▉ | 89/100 [00:01<00:00, 61.37it/s][A
100%|██████████| 100/100 [00:01<00:00, 64.08it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
 75%|███████▌  | 6/8 [2:06:04<45:03, 1351.57s/it]

/home/docker/data/models/db2.0.0/hard_3_yellow_circle_plus_triangle_9/llm
Loading dataset content for hard_3_yellow_circle_plus_triangle_9_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 12%|█▏        | 12/100 [00:00<00:01, 83.28it/s][A
 21%|██        | 21/100 [00:00<00:01, 78.09it/s][A
 29%|██▉       | 29/100 [00:00<00:00, 73.15it/s][A
 37%|███▋      | 37/100 [00:00<00:01, 62.83it/s][A
 44%|████▍     | 44/100 [00:00<00:00, 63.28it/s][A
 51%|█████     | 51/100 [00:00<00:00, 63.65it/s][A
 58%|█████▊    | 58/100 [00:00<00:00, 62.92it/s][A
 65%|██████▌   | 65/100 [00:00<00:00, 64.43it/s][A
 72%|███████▏  | 72/100 [00:01<00:00, 64.58it/s][A
 79%|███████▉  | 79/100 [00:01<00:00, 57.14it/s][A
 87%|████████▋ | 87/100 [00:01<00:00, 62.64it/s][A
100%|██████████| 100/100 [00:01<00:00, 64.70it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [02:01<18:16, 121.80s/it][A
 20%|██        | 2/10 [04:32<18:28, 138.51s/it][A
 30%|███       | 3/10 [07:25<18:00, 154.35s/it][A
 40%|████      | 4/10 [10:44<17:12, 172.05s/it][A
 50%|█████ 

Loading dataset content for hard_3_yellow_circle_plus_triangle_9_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
  8%|▊         | 8/100 [00:00<00:01, 76.54it/s][A
 16%|█▌        | 16/100 [00:00<00:01, 72.22it/s][A
 24%|██▍       | 24/100 [00:00<00:01, 70.24it/s][A
 32%|███▏      | 32/100 [00:00<00:01, 59.79it/s][A
 39%|███▉      | 39/100 [00:00<00:00, 61.11it/s][A
 46%|████▌     | 46/100 [00:00<00:00, 62.13it/s][A
 53%|█████▎    | 53/100 [00:00<00:00, 62.90it/s][A
 60%|██████    | 60/100 [00:00<00:00, 63.11it/s][A
 67%|██████▋   | 67/100 [00:01<00:00, 62.24it/s][A
 74%|███████▍  | 74/100 [00:01<00:00, 62.59it/s][A
 81%|████████  | 81/100 [00:01<00:00, 57.34it/s][A
 87%|████████▋ | 87/100 [00:01<00:00, 57.94it/s][A
 93%|█████████▎| 93/100 [00:01<00:00, 58.13it/s][A
100%|██████████| 100/100 [00:01<00:00, 61.91it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
 88%|████████▊ | 7/8 [2:34:25<24:25, 1465.90s/it]

/home/docker/data/models/db2.0.0/hard_5_purple_triangle_plus_square_7/llm
Loading dataset content for hard_5_purple_triangle_plus_square_7_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 12%|█▏        | 12/100 [00:00<00:00, 110.46it/s][A
 24%|██▍       | 24/100 [00:00<00:00, 76.04it/s] [A
 33%|███▎      | 33/100 [00:00<00:01, 66.93it/s][A
 41%|████      | 41/100 [00:00<00:00, 67.71it/s][A
 48%|████▊     | 48/100 [00:00<00:00, 65.61it/s][A
 56%|█████▌    | 56/100 [00:00<00:00, 68.62it/s][A
 63%|██████▎   | 63/100 [00:00<00:00, 66.81it/s][A
 70%|███████   | 70/100 [00:01<00:00, 59.29it/s][A
 77%|███████▋  | 77/100 [00:01<00:00, 60.25it/s][A
 84%|████████▍ | 84/100 [00:01<00:00, 61.81it/s][A
 91%|█████████ | 91/100 [00:01<00:00, 61.46it/s][A
100%|██████████| 100/100 [00:01<00:00, 64.17it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0

  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [01:28<13:12, 88.11s/it][A
 20%|██        | 2/10 [03:19<13:35, 101.88s/it][A
 30%|███       | 3/10 [05:43<14:06, 120.95s/it][A
 40%|████      | 4/10 [08:02<12:49, 128.31s/it][A
 50%|█████

Loading dataset content for hard_5_purple_triangle_plus_square_7_test.csv



  0%|          | 0/100 [00:00<?, ?it/s][A
 12%|█▏        | 12/100 [00:00<00:01, 79.14it/s][A
 20%|██        | 20/100 [00:00<00:01, 75.09it/s][A
 28%|██▊       | 28/100 [00:00<00:00, 73.24it/s][A
 36%|███▌      | 36/100 [00:00<00:00, 70.50it/s][A
 44%|████▍     | 44/100 [00:00<00:00, 62.40it/s][A
 51%|█████     | 51/100 [00:00<00:00, 63.69it/s][A
 58%|█████▊    | 58/100 [00:00<00:00, 64.54it/s][A
 65%|██████▌   | 65/100 [00:00<00:00, 64.81it/s][A
 72%|███████▏  | 72/100 [00:01<00:00, 66.27it/s][A
 79%|███████▉  | 79/100 [00:01<00:00, 65.28it/s][A
 86%|████████▌ | 86/100 [00:01<00:00, 56.07it/s][A
 93%|█████████▎| 93/100 [00:01<00:00, 58.88it/s][A
100%|██████████| 100/100 [00:01<00:00, 63.89it/s][A
Using cache found in /home/docker/.cache/torch/hub/pytorch_vision_v0.10.0
100%|██████████| 8/8 [2:54:20<00:00, 1307.55s/it]
