## CheckerBoardEnv Simulation with GridSearch

In [None]:
import sys
import os
base_path = "/Users/giuseppe/Documents/Projects/RL_sim/"
sys.path.insert(0, base_path)
os.chdir(base_path)

from agents.soft_q_learner import SoftQAgent
from agents.utils import *
from PIL.ImageOps import invert
from posixpath import join
from PIL import Image

import gymnasium as gym
import checkerboard_env
import pandas as pd
import numpy as np

In [None]:
def load_checkerboard(img, cross):
    
    board = Image.open(img)
    cross = Image.open(cross)

    if board.mode == 'RGBA':
        r,g,b,a = board.split()
        rgb_image = Image.merge('RGB', (r,g,b))
        rgb_inverse = invert(rgb_image)

        r2,g2,b2 = rgb_inverse.split()
        inverse = Image.merge('RGBA', (r2,g2,b2,a))

    else:
        inverse = invert(board)
    
    return board, inverse, cross


In [None]:
from itertools import product

# Add parameter search;
parameters = {
    'snr': [0.5, 1, 2, 5, 10],
    'n_bins': [30],
    'q_mean': [0],
    'q_sigma': [0.1],
    'kernel_size': [3, 5, 8, 12, 15],
    'kernel_sigma': [1.0, 2.0, 3.0, 4.0, 5.0],
    'learning_rate': [0.1, 0.3, 0.5, 0.8, 0.9, 0.95, 0.99],
    'temperature': [1e-5, 0.001, 0.01, 0.05, 0.1, 0.3, 0.5, 0.8, 1.0],
    'min_temperature': [1e-5],
    'max_temperature': [1.0],
    'reduce_temperature': [False],
    'decay_rate': [0.001]
}

# Generate all possible combinations
combinations = list(product(*parameters.values()))

# Convert the combinations into a list of dictionaries
parameter_combinations = [dict(zip(parameters.keys(), values)) for values in combinations]

In [None]:
from joblib import Parallel, delayed

epochs = 100
n_jobs = 6

assets_path = "checkerboard_env/assets/"
board, inverse, cross = load_checkerboard(
    join(base_path, assets_path + "checkerboard.png"),
    join(base_path, assets_path + "cross.png")
)

def fit_parameters(params, n_iter):
    # import environment for each thread;
    import checkerboard_env
    
    # initialize environment variable;
    env = None
    
    results = []
    for i in range(n_iter):
        # set seed for the Q-table initialization;
        rng = np.random.default_rng(seed=n_iter)

        # load the checkerboard environment;
        if env is None:
            env = gym.make('checkerboard-v0',
                        render_mode=None,
                        checkerboard=board,
                        inverse=inverse,
                        cross=cross,
                        snr=params["snr"])

        NUM_BINS = params["n_bins"]
        BINS = create_bins(NUM_BINS)

        q_table_shape = (NUM_BINS, NUM_BINS)  # contrast * frequency;
        q_table = rng.normal(params["q_mean"], params["q_sigma"], q_table_shape)
        kernel = generate_gaussian_kernel(params["kernel_size"], params["kernel_sigma"])

        model = SoftQAgent(
            env,
            q_table,
            kernel,
            learning_rate=params["learning_rate"],
            temperature=params["temperature"],
            min_temperature=params["min_temperature"],
            max_temperature=params["max_temperature"],
            reduce_temperature=params["reduce_temperature"],
            decay_rate=params["decay_rate"],
            num_bins_per_obs=NUM_BINS
        )

        # fit the model;
        model.fit(epochs, BINS)

        # collect the data;
        v_max = np.unravel_index(np.argmax(model.q_table), q_table_shape)
        min_rw = np.min(model.reward_log)
        max_rw = np.max(model.reward_log)
        mean_rw = np.mean(model.reward_log)

        current_row = []
        current_row.extend(list(params.values()))
        current_row.extend([v_max, min_rw, max_rw, mean_rw])
        current_row = np.array(current_row, dtype=object)
        
        # save the data;
        results.append(current_row)

        if i == n_iter-1:
            env = None
    
    return results

# run the loop in parallel;
if __name__ == "__main__":
    data = np.array([])  # Initialize data array
    processed_data = Parallel(n_jobs=n_jobs)(delayed(fit_parameters)(params, seed, 100)
                                            for seed, params in enumerate(parameter_combinations))

    for current_rows in processed_data:
        for current_row in current_rows:
            if len(data) == 0:
                data = current_row
            else:
                data = np.vstack((current_row, data))


In [None]:
# generate data columns;
param_columns = list(parameter_combinations[0])
add_columns = ["v_max", "min_rw", "max_rw", "mean_rw"]
columns = param_columns + add_columns

# save and print data;
data = pd.DataFrame(data, columns=columns)
data.to_pickle("./results/data.pkl")
data