In [2]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
model_dir_root = os.environ["DATA"] + "models/db2.0.0/02_protov3/"
datasets_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/datasets/02_protov3/"
devices = ["cuda:0", "cuda:1"]
INTERVAL_BATCH = 1

In [3]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

In [6]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape, generic_rule_N_times_color_exactly, \
    generic_rule_shape_color_plus_shape_equals_N, generic_rule_shape_in_every_row, generic_rule_shape_color_times_2_shape_equals_shape

rules_data = [
    {"name": "disc_1_triangle_all", "gen_fun": generic_rule_shape_in_every_row, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}, "question": "In the image, is there a triangle in every row (1, ..., 6)?", "target_acc" : 1.0, "samples_interface": 5, "pos_llm_scaffold": "The AI predicts |YES| because every row contains at least one triangle : \n - Row 1 : XX, XX, XX\n- Row 2 : XX, XX, XX\n- Row 3 : XX, XX, XX\n- Row 4 : XX, XX, XX\n- Row 5 : XX, XX, XX\n- Row 5 : XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because the rows X and X do not contain any triangle."},

    {"name": "easy_1_6_blue", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#0C90C0", "N": 6, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 6 blue symbols?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is exactly 6 blue symbols, which are located at :\n- XX\n- XX\n- XX\n- XX\n- XX\n- XX", "neg_llm_scaffold": "The AI predicts |NO| because there is X blue symbols instead of 6. They are located at : \n- XX\n- XX\n- XX\n- XX\n- XX."},

    {"name": "easy_2_row_circle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only circles?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is at least one row which contains only circles : \nRow X contains only circles which are located at XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because there is not a single row containing only circles :\nRow 1 contains a non-circle symbol at XX\nRow 2 contains non-circle symbols at XX, XX, XX.\nRow 3 does not contain any symbol at all\n ..."},

    {"name": "easy_3_7_purple", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#A33E9A", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 7 purple symbols?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is exactly 7 purple symbols, which are located at :\n- XX\n- XX\n- XX\n- XX\n- XX\n- XX\n- XX", "neg_llm_scaffold": "The AI predicts |NO| because there is X purple symbols instead of 7. They are located at : \n- XX\n- XX\n- XX\n- XX\n- XX\n- XX."},

    {"name": "easy_4_row_triangle", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only triangles?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is at least one row which contains only triangles : \nRow X contains only triangles which are located at XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because there is not a single row containing only triangles :\nRow 1 contains a non-triangle symbol at XX\nRow 2 contains non-triangle symbols at XX, XX, XX.\nRow 3 does not contain any symbol at all\n ..."},

    {"name": "easy_5_5_yellow", "gen_fun": generic_rule_N_times_color_exactly, "gen_kwargs": {"color": "#E0B000", "N": 5, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS}, "question": "In the image, is there exactly 5 yellow symbols?", "target_acc": 0.9, "samples_interface": 10,  "pos_llm_scaffold": "The AI predicts |YES| because there is exactly 5 yellow symbols, which are located at :\n- XX\n- XX\n- XX\n- XX\n- XX", "neg_llm_scaffold": "The AI predicts |NO| because there is X yellow symbols instead of 5, which are located at : \n- XX\n- XX\n- XX\n- XX\n- XX\n- XX."},

    {"name": "easy_6_row_square", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "square", "y_division": Y_DIVISIONS},
     "question": "In the image, is there at least one row (1, ..., 6) containing only squares?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because there is at least one row which contains only squares : \nRow X contains only squares which are located at XX, XX, XX", "neg_llm_scaffold": "The AI predicts |NO| because there is not a single row containing only squares :\nRow 1 contains a non-square symbol at XX\nRow 2 contains non-square symbols at XX, XX, XX.\nRow 3 does not contain any symbol at all\n ..."},


    {"name": "hard_1_blue_square_plus_circle_8", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#0C90C0", "shape1": "square", "shape2": "circle", "N": 8, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of blue squares plus (+) the number of circles equal to 8?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X blue squares at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X circles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X + X = 8", "neg_llm_scaffold": "The AI predicts |NO| because \n\n There is a total of X blue squares at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X circles at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = X ≠ 8"},

    {"name": "hard_2_yellow_triangles_times2_squares", "gen_fun": generic_rule_shape_color_times_2_shape_equals_shape, "gen_kwargs": {"color1": "#E0B000", "shape1": "triangle", "shape2": "square", "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of yellow triangles multiplied by 2 (×2) equal to the number of squares?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X yellow triangles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X squares at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X", "neg_llm_scaffold": "There is a total of X yellow triangles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X squares at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X ≠ X"},

    {"name": "hard_3_yellow_circle_plus_triangle_9", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#E0B000", "shape1": "circle", "shape2": "triangle", "N": 9, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, does the number of yellow circles plus (+) the number of triangles equal to 9?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X yellow circles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X triangles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X + X = 9", "neg_llm_scaffold": "The AI predicts |NO| because \n\n There is a total of X yellow circles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X triangles at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = X ≠ 9"},

    {"name": "hard_4_purple_squares_times2_circles", "gen_fun": generic_rule_shape_color_times_2_shape_equals_shape, "gen_kwargs": {"color1": "#A33E9A", "shape1": "square", "shape2": "circle", "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of purple squares multiplied by 2 (×2) equal to the number of circles?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X purple squares at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X circles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X", "neg_llm_scaffold": "There is a total of X purple squares at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X circles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X ≠ X"},

    {"name": "hard_5_purple_triangle_plus_square_7", "gen_fun": generic_rule_shape_color_plus_shape_equals_N, "gen_kwargs": {"color1": "#A33E9A", "shape1": "triangle", "shape2": "square", "N": 7, "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS},
     "question": "In the image, does the number of purple triangles plus (+) the number of squares equal to 7?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X purple triangles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X squares at positions : \n- XX\n- XX\n- XX\n- XX\n\n X + X = 7", "neg_llm_scaffold": "The AI predicts |NO| because \n\n There is a total of X purple triangles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X squares at positions : \n- XX\n- XX\n- XX\n- XX\n X + X = X ≠ 7"},

    {"name": "hard_6_blue_circles_times2_triangles", "gen_fun": generic_rule_shape_color_times_2_shape_equals_shape, "gen_kwargs": {"color1": "#A33E9A", "shape1": "circle", "shape2": "triangle", "x_division": X_DIVISIONS, "y_division": Y_DIVISIONS,},
     "question": "In the image, does the number of purple squares multiplied by 2 (×2) equal to the number of circles?", "target_acc": 0.9, "samples_interface": 10, "pos_llm_scaffold": "The AI predicts |YES| because \n\n There is a total of X blue circles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X triangles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X", "neg_llm_scaffold": "There is a total of X blue circles at positions : \n- XX\n- XX\n- XX\n- XX\n \nThere is a total of X triangles at positions : \n- XX\n- XX\n- XX\n- XX\n\n X × 2 = X ≠ X"},

]


In [7]:
from xaipatimg.ml.learning import train_resnet18_model, compute_resnet18_model_scores


def _train_model(db_dir, datasets_dir_path, train_dataset_filename, valid_dataset_filename, test_dataset_filename, model_dir, target_accuracy, interval_batch, device):
    train_resnet18_model(db_dir, datasets_dir_path, train_dataset_filename, valid_dataset_filename, model_dir, target_accuracy=target_accuracy,
                         interval_batch=interval_batch, device=device)
    compute_resnet18_model_scores(db_dir, datasets_dir_path, train_dataset_filename, test_dataset_filename, valid_dataset_filename, model_dir, device=device)




In [None]:
from tqdm import tqdm
from joblib import Parallel, delayed

for rule_idx in tqdm(range(0, len(rules_data), 2)):

    if rule_idx + 1 == len(rules_data):
        offsets = [0]
    else:
        offsets = [0, 1]

    Parallel(n_jobs=len(devices))(delayed(_train_model)(
        db_dir, # db_dir
        datasets_dir,
        rules_data[rule_idx + offset]["name"] + "_train.csv", # train_dataset_filename
        rules_data[rule_idx + offset]["name"] + "_valid.csv", # valid_dataset_filename
        rules_data[rule_idx + offset]["name"] + "_test.csv", # test_dataset_filename
        os.path.join(model_dir_root, rules_data[rule_idx + offset]["name"]), # model_dir
        rules_data[rule_idx + offset]["target_acc"], # target_accuracy,
        INTERVAL_BATCH if rules_data[rule_idx + offset]["target_acc"] < 1.0 else 50, # interval_batch (higher if the target accuracy is the best possible performance)
        devices[offset]) for offset in offsets)