In [1]:
import os

db_S_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0_S/"
db_L_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0_L/"
db_patterns_dir = os.environ["DATA"] + "PatImgXAI_data/db3.0.0/patterns/"

model_dir_root = os.environ["DATA"] + "models/db3.0.0/01_protov5/"
devices = ["cuda:0", "cuda:1"]
INTERVAL_BATCH = 1

In [2]:

# Grid division for full image
X_DIVISIONS_L = 15
Y_DIVISIONS_L = 15
X_DIVISIONS_S = 10
Y_DIVISIONS_S = 10

# Grid division of patterns
X_DIVISIONS_PATTERNS = 2
Y_DIVISIONS_PATTERNS = 2

# Probability to generate a geometrical shape at each position in the grid
SHAPE_PROB = 0.5

# Define available shapes
SHAPES = ['circle', 'square', 'triangle']
COLORS  = ["#A33E9A", "#E0B000", "#0C90C0"] # Purple, Yellow, Blue

In [3]:
from xaipatimg.datagen.dbimg import load_db

db_patterns = load_db(db_patterns_dir)

In [4]:
pattern_2sym_keys = []
pattern_3sym_keys = []
for k, v in db_patterns.items():
    if len(v["content"]) == 2:
        pattern_2sym_keys.append(k)
    if len(v["content"]) == 3:
        pattern_3sym_keys.append(k)


In [5]:
datasets_path_L = os.path.join(db_L_dir, "datasets", "01_protov5")

In [6]:
datasets_path_S = os.path.join(db_S_dir, "datasets", "01_protov5")

In [7]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_pattern_exactly_N_times

rules_data_L = [
    {"name": "hard1_2sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_2sym_keys[3]},

    {"name": "hard2_3sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_3sym_keys[1]},

    {"name": "hard3_2sym_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": True},
     "question": "Is the pattern or its left/right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "w": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_2sym_keys[4]},

    {"name": "hard4_2sym_3times", "gen_fun": generic_rule_pattern_exactly_N_times, "gen_kwargs": {"x_division_full": X_DIVISIONS_L,
                                                                                                     "y_division_full": Y_DIVISIONS_L,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "N": 3},
     "question": "Is the pattern exactly 3 times in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_L, Y_DIVISIONS_L], "pattern_id": pattern_2sym_keys[5]},

]

In [None]:
from xaipatimg.datagen.gendataset import generic_rule_pattern_exactly_1_time_exclude_more, \
    generic_rule_pattern_exactly_N_times

rules_data_S = [
    {"name": "easy1_2sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_2sym_keys[0]},

    {"name": "easy2_3sym", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "consider_rotations": False},
     "question": "Is the pattern in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_3sym_keys[0]},

    {"name": "easy3_2sym_rot", "gen_fun": generic_rule_pattern_exactly_1_time_exclude_more, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                         "consider_rotations": True},
     "question": "Is the pattern or its left/right rotations in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_2sym_keys[1]},

    {"name": "easy4_2sym_2times", "gen_fun": generic_rule_pattern_exactly_N_times, "gen_kwargs": {"x_division_full": X_DIVISIONS_S,
                                                                                                     "y_division_full": Y_DIVISIONS_S,
                                                                                                     "x_division_pattern": X_DIVISIONS_PATTERNS,
                                                                                                     "y_division_pattern": Y_DIVISIONS_PATTERNS,
                                                                                                     "N": 2},
     "question": "Is the pattern exactly 2 times in the image?", "target_acc" : 0.85, "shown_acc" : 0.85, "samples_interface": 10, "pos_llm_scaffold": "", "neg_llm_scaffold": "", "filter_on_dim": [X_DIVISIONS_S, Y_DIVISIONS_S], "pattern_id": pattern_2sym_keys[2]},

]

In [None]:
from xaipatimg.ml.learning import train_resnet18_model, compute_resnet18_model_scores


def _train_model(db_dir, datasets_dir_path, train_dataset_filename, valid_dataset_filename, test_dataset_filename, model_dir, target_accuracy, interval_batch, device):
    train_resnet18_model(db_dir, datasets_dir_path, train_dataset_filename, valid_dataset_filename, model_dir, target_accuracy=target_accuracy,
                         interval_batch=interval_batch, device=device)
    compute_resnet18_model_scores(db_dir, datasets_dir_path, train_dataset_filename, test_dataset_filename, valid_dataset_filename, model_dir, device=device)




In [None]:
from tqdm import tqdm
from joblib import Parallel, delayed

def train(rules_data, datasets_dir, db_dir):
    for rule_idx in tqdm(range(0, len(rules_data), 1)):

        if rule_idx + 1 == len(rules_data):
            offsets = [0]
        else:
            offsets = [0, 1]

        Parallel(n_jobs=len(devices))(delayed(_train_model)(
            db_dir, # db_dir
            datasets_dir,
            rules_data[rule_idx + offset]["name"] + "_train.csv", # train_dataset_filename
            rules_data[rule_idx + offset]["name"] + "_valid.csv", # valid_dataset_filename
            rules_data[rule_idx + offset]["name"] + "_test.csv", # test_dataset_filename
            os.path.join(model_dir_root, rules_data[rule_idx + offset]["name"]), # model_dir
            rules_data[rule_idx + offset]["target_acc"], # target_accuracy,
            INTERVAL_BATCH if rules_data[rule_idx + offset]["target_acc"] < 1.0 else 50, # interval_batch (higher if the target accuracy is the best possible performance)
            devices[offset]) for offset in offsets)

In [None]:
train(rules_data_S, datasets_path_S, db_S_dir)

In [None]:
train(rules_data_L, datasets_path_L, db_L_dir)