## CheckerBoardEnv Simulation with GridSearch

In [2]:
import sys
import os
base_path = ".."
sys.path.insert(0, base_path)
os.chdir(base_path)

from agents.soft_q_learner import SoftQAgent
from agents.utils import *
from PIL.ImageOps import invert
from posixpath import join
from PIL import Image

import gymnasium as gym
import checkerboard_env
import pandas as pd
import numpy as np

In [3]:
def load_checkerboard(img, cross):
    
    board = Image.open(img)
    cross = Image.open(cross)

    if board.mode == 'RGBA':
        r,g,b,a = board.split()
        rgb_image = Image.merge('RGB', (r,g,b))
        rgb_inverse = invert(rgb_image)

        r2,g2,b2 = rgb_inverse.split()
        inverse = Image.merge('RGBA', (r2,g2,b2,a))

    else:
        inverse = invert(board)
    
    return board, inverse, cross


In [4]:
from itertools import product

# Add parameter search;
parameters = {
    'snr': [0.5, 1, 2, 3, 5],
    'n_bins': [10],
    'kernel_size': [20],
    'kernel_sigma': [0.5, 1.0, 2.0, 3.0, 4.0],
    'learning_rate': [0.05, 0.1, 0.3, 0.4, 0.8, 0.9],
    'temperature': [0.2],
    'min_temperature': [1e-5],
    'max_temperature': [1.0],
    'reduce_temperature': [False],
    'decay_rate': [0.001]
}

# Generate all possible combinations
combinations = list(product(*parameters.values()))

# Convert the combinations into a list of dictionaries
parameter_combinations = [dict(zip(parameters.keys(), values)) for values in combinations]

In [8]:
from joblib import Parallel, delayed

n_epochs = 100
n_iter = 100
n_jobs = -1

assets_path = "anonymised_RLsim/checkerboard_env/assets/"
board, inverse, cross = load_checkerboard(
    join(base_path, assets_path + "checkerboard.png"),
    join(base_path, assets_path + "cross.png")
)

def fit_parameters(params, n_epochs, n_iter):
    # import environment for each thread;
    import checkerboard_env
    
    # initialize environment variable;
    env = None
    
    results = []
    for current_iteration in range(n_iter):
        # set seed for the Q-table initialization;
        # rng = np.random.default_rng(seed=current_iteration)

        # load the checkerboard environment;
        if env is None:
            env = gym.make('checkerboard-v0',
                        render_mode=None,
                        checkerboard=board,
                        inverse=inverse,
                        cross=cross,
                        snr=params["snr"])

        NUM_BINS = params["n_bins"]
        BINS = create_bins(NUM_BINS)

        q_table_shape = (NUM_BINS, NUM_BINS)  # contrast * frequency;
        q_table = np.ones(q_table_shape)*0.5
        kernel = generate_gaussian_kernel(params["kernel_size"], params["kernel_sigma"])

        model = SoftQAgent(
            env,
            q_table,
            kernel,
            learning_rate=params["learning_rate"],
            temperature=params["temperature"],
            min_temperature=params["min_temperature"],
            max_temperature=params["max_temperature"],
            reduce_temperature=params["reduce_temperature"],
            decay_rate=params["decay_rate"],
            num_bins_per_obs=NUM_BINS
        )
        initial_state = model.env.reset()[0]
        discrete_state = discretize_observation(initial_state, BINS)

        # fit the model;
        for current_epoch in range(n_epochs):
            if current_epoch == 0: 
                action = initial_state

            (observation,
             reward,
             terminated,
             truncated,
             info) = model.env.step(action)

            # grab current q-value;
            old_q_value = model.q_table[discrete_state]

            # compute next q-value and update q-table;
            model.q_table = model.update_q_table(reward, discrete_state, old_q_value)
            optimum = model.q_table[9, 6]
            
            # get_next action;
            action = model.soft_q_action_selection()

            # update and discretize the next state;
            next_state_discrete = discretize_observation(action, BINS)
            discrete_state = next_state_discrete

            # by default keep temperature constant;
            model.reduce_temperature(current_epoch, reduce=model.reduce_temp)
            model.reward_log.append(reward)

            # collect the data;
            current_row = []
            current_row.extend(list(params.values()))
            current_row.extend([reward, optimum, current_epoch, current_iteration])
            #current_row = np.array(current_row, dtype=object)

            # save the data;
            results.append(current_row)
        
        if current_iteration == n_iter-1:
            env = None
    
    return results

# run the loop in parallel;
if __name__ == "__main__":
    all_rows = []
    processed_data = Parallel(n_jobs=n_jobs)(delayed(fit_parameters)(params, n_epochs, n_iter)
                                            for params in parameter_combinations)
    
    for processed_set in processed_data:
        for row in processed_set:
            all_rows.append(row)

    data = np.array(all_rows)

In [10]:
# generate data columns;
param_columns = list(parameter_combinations[0])
add_columns = ["reward", "optimum", "epochs", "iteration"]
columns = param_columns + add_columns

# save and print data;
data = pd.DataFrame(data, columns=columns)
data.to_pickle("./data/hyperparameter_search.pkl")
data

Unnamed: 0,snr,n_bins,kernel_size,kernel_sigma,learning_rate,temperature,min_temperature,max_temperature,reduce_temperature,decay_rate,reward,optimum,epochs,iteration
0,0.5,10.0,20.0,0.5,0.05,0.2,0.00001,1.0,0.0,0.001,-0.572149,0.500000,0.0,0.0
1,0.5,10.0,20.0,0.5,0.05,0.2,0.00001,1.0,0.0,0.001,9.547629,0.500000,1.0,0.0
2,0.5,10.0,20.0,0.5,0.05,0.2,0.00001,1.0,0.0,0.001,3.690634,0.500007,2.0,0.0
3,0.5,10.0,20.0,0.5,0.05,0.2,0.00001,1.0,0.0,0.001,0.826893,0.500007,3.0,0.0
4,0.5,10.0,20.0,0.5,0.05,0.2,0.00001,1.0,0.0,0.001,6.295135,0.500007,4.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1499995,5.0,10.0,20.0,4.0,0.90,0.2,0.00001,1.0,0.0,0.001,3.658092,4.798190,95.0,99.0
1499996,5.0,10.0,20.0,4.0,0.90,0.2,0.00001,1.0,0.0,0.001,3.376558,3.763282,96.0,99.0
1499997,5.0,10.0,20.0,4.0,0.90,0.2,0.00001,1.0,0.0,0.001,2.550228,1.991180,97.0,99.0
1499998,5.0,10.0,20.0,4.0,0.90,0.2,0.00001,1.0,0.0,0.001,2.924869,3.747826,98.0,99.0
