In [None]:
import os

db_dir = os.environ["DATA"] + "PatImgXAI_data/db2.0.0/"
model_dir_root = os.environ["DATA"] + "models/db2.0.0/"
shap_scale_img_path = os.path.join(db_dir,"shap_scale.png")
yes_pred_img_path = os.path.join(db_dir,"button_yes.png")
no_pred_img_path = os.path.join(db_dir,"button_no.png")
pos_pred_legend_path = os.path.join(db_dir,"cf_info_pos.png")
neg_pred_legend_path = os.path.join(db_dir,"cf_info_neg.png")

XAI_DATASET_SIZE = 5
N_JOBS = 20
N_JOBS_GPU = 6

In [None]:
# Number of images generated
NBGEN = 1000000

# Grid division of each image
X_DIVISIONS = 6
Y_DIVISIONS = 6

# Size of the images in pixels
img_size = (700, 700)

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"]

In [None]:
from xaipatimg.datagen.gendataset import generic_rule_exist_row_with_only_shape

rules_data = [
    {"name": "easy_1_row_circles", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "circle", "y_division": Y_DIVISIONS}},
    {"name": "easy_1_row_triangles", "gen_fun": generic_rule_exist_row_with_only_shape, "gen_kwargs": {"shape": "triangle", "y_division": Y_DIVISIONS}}
]

In [None]:
from xaipatimg.datagen.dbimg import load_db

db = load_db(db_dir)

In [None]:
from xaipatimg.ml.xai import generate_shap_resnet18, generate_counterfactuals_resnet18_random_approach, create_xai_index
from tqdm import tqdm

for rule_idx in tqdm(range(len(rules_data))):

    model_dir = os.path.join(model_dir_root, rules_data[rule_idx]["name"])
    dataset_filename = rules_data[rule_idx]["name"] + "_test.csv"

    xai_output_paths = {
        "shap" : "shap",
        "cf" : "cf"
    }

    generate_shap_resnet18(db_dir, dataset_filename=dataset_filename,
                           model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["shap"]),
                           yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path, device="cuda:0", n_jobs=N_JOBS,
                           dataset_size=XAI_DATASET_SIZE, masker="ndarray", shap_scale_img_path=shap_scale_img_path)

    generate_counterfactuals_resnet18_random_approach(db_dir, dataset_filename=dataset_filename,
                                                      model_dir=model_dir, xai_output_path=os.path.join(model_dir, xai_output_paths["cf"]),
                                                      yes_pred_img_path=yes_pred_img_path, no_pred_img_path=no_pred_img_path,
                                                      shapes=SHAPES, colors=COLORS, empty_probability=1-SHAPE_PROB,
                                                      max_depth=10, nb_tries_per_depth=2000, devices=["cuda:0", "cuda:1"], n_jobs=N_JOBS_GPU,
                                                      dataset_size=XAI_DATASET_SIZE,pos_pred_legend_path=pos_pred_legend_path,
                                                      neg_pred_legend_path=neg_pred_legend_path)

    create_xai_index(db_dir, dataset_filename=dataset_filename, model_dir=model_dir, xai_dirs=xai_output_paths, dataset_size=XAI_DATASET_SIZE, device="cuda:0")
