## CheckerBoardEnv Simulation with GridSearch

In [1]:
import sys
import os
base_path = ".."
sys.path.insert(0, base_path)
os.chdir(base_path)

from agents.soft_q_learner import SoftQAgent
from agents.utils import *
from PIL.ImageOps import invert
from posixpath import join
from PIL import Image

import gymnasium as gym
import checkerboard_env
import pandas as pd
import numpy as np

In [2]:
def load_checkerboard(img, cross):
    
    board = Image.open(img)
    cross = Image.open(cross)

    if board.mode == 'RGBA':
        r,g,b,a = board.split()
        rgb_image = Image.merge('RGB', (r,g,b))
        rgb_inverse = invert(rgb_image)

        r2,g2,b2 = rgb_inverse.split()
        inverse = Image.merge('RGBA', (r2,g2,b2,a))

    else:
        inverse = invert(board)
    
    return board, inverse, cross


In [3]:
from itertools import product

# Add parameter search;
parameters = {
    'snr': [2, 5, 10, 20],
    'n_bins': [10, 20, 30],
    'q_mean': [0],
    'q_sigma': [0.1],
    'kernel_size': [5, 10, 20],
    'kernel_sigma': [1.0, 2.0, 3.0],
    'learning_rate': [0.80, 0.85, 0.90, 0.95, 0.99],
    'temperature': [1e-5, 0.001, 0.01, 0.05, 0.1, 0.5, 1.0],
    'min_temperature': [1e-5],
    'max_temperature': [1.0],
    'reduce_temperature': [True, False],
    'decay_rate': [0.001, 0.005, 0.01, 0.05]
}

# Generate all possible combinations
combinations = list(product(*parameters.values()))

# Convert the combinations into a list of dictionaries
parameter_combinations = [dict(zip(parameters.keys(), values)) for values in combinations]

In [4]:
from joblib import Parallel, delayed

epochs= 100
n_jobs = 6

assets_path = "checkerboard_env/assets/"
board, inverse, cross = load_checkerboard(join(base_path, assets_path + "checkerboard.png"),
                                        join(base_path, assets_path + "cross.png"))

data = np.array([])

def fit_parameters(params, seed):
    # import environment for each thread;
    import checkerboard_env
    
    # set seed for the Q-table initialization;
    rng = np.random.default_rng(seed=seed)

    # load the checkerboard environment;
    env = gym.make('checkerboard-v0',
                render_mode=None,
                checkerboard=board,
                inverse=inverse,
                cross=cross,
                snr=params["snr"])

    NUM_BINS = params["n_bins"]
    BINS = create_bins(NUM_BINS)

    q_table_shape = (NUM_BINS, NUM_BINS)  # contrast * frequency;
    q_table = rng.normal(params["q_mean"], params["q_sigma"], q_table_shape)
    kernel = generate_gaussian_kernel(params["kernel_size"], params["kernel_sigma"])

    model = SoftQAgent(
        env,
        q_table,
        kernel,
        learning_rate=params["learning_rate"],
        temperature=params["temperature"],
        min_temperature=params["min_temperature"],
        max_temperature=params["max_temperature"],
        reduce_temperature=params["reduce_temperature"],
        decay_rate=params["decay_rate"],
        num_bins_per_obs=NUM_BINS
    )

    # fit the model;
    model.fit(epochs, BINS)

    # collect the data;
    v_max = np.unravel_index(np.argmax(model.q_table), q_table_shape)
    min_rw = np.min(model.reward_log)
    max_rw = np.max(model.reward_log)
    mean_rw = np.mean(model.reward_log)

    current_row = []
    current_row.extend(list(params.values()))
    current_row.extend([v_max, min_rw, max_rw, mean_rw])
    current_row = np.array(current_row, dtype=object)

    return current_row

# run the loop in parallel;
if __name__ == "__main__":
    data = np.array([])  # Initialize data array
    processed_data = Parallel(n_jobs=n_jobs)(delayed(fit_parameters)(params, seed) 
                                            for seed, params in enumerate(parameter_combinations))

    for current_row in processed_data:
        if len(data) == 0:
            data = current_row
        else:
            data = np.vstack((current_row, data))


In [10]:
# generate data columns;
param_columns = list(parameter_combinations[0])
add_columns = ["v_max", "min_rw", "max_rw", "mean_rw"]
columns = param_columns + add_columns

# save and print data;
data = pd.DataFrame(data, columns=columns)
data.to_pickle("./results/data.pkl")
data

Unnamed: 0,snr,n_bins,q_mean,q_sigma,kernel_size,kernel_sigma,learning_rate,temperature,min_temperature,max_temperature,reduce_temperature,decay_rate,v_max,min_rw,max_rw,mean_rw
0,20,30,0,0.1,20,3.0,0.99,1.0,0.00001,1.0,False,0.05,"(1, 0)",-0.088129,0.942034,0.327061
1,20,30,0,0.1,20,3.0,0.99,1.0,0.00001,1.0,False,0.01,"(23, 17)",-0.047451,0.859617,0.321
2,20,30,0,0.1,20,3.0,0.99,1.0,0.00001,1.0,False,0.005,"(26, 24)",-0.060756,0.924498,0.340701
3,20,30,0,0.1,20,3.0,0.99,1.0,0.00001,1.0,False,0.001,"(26, 20)",-0.104769,0.939127,0.36012
4,20,30,0,0.1,20,3.0,0.99,1.0,0.00001,1.0,True,0.05,"(18, 26)",-0.022512,0.959006,0.407238
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30235,2,10,0,0.1,5,1.0,0.8,0.00001,0.00001,1.0,False,0.001,"(7, 3)",-1.32633,1.3453,0.298738
30236,2,10,0,0.1,5,1.0,0.8,0.00001,0.00001,1.0,True,0.05,"(7, 0)",-0.735798,1.954112,0.507503
30237,2,10,0,0.1,5,1.0,0.8,0.00001,0.00001,1.0,True,0.01,"(5, 6)",-1.042181,1.86545,0.351399
30238,2,10,0,0.1,5,1.0,0.8,0.00001,0.00001,1.0,True,0.005,"(7, 5)",-0.922577,1.422055,0.343417
