## CheckerBoardEnv Simulation with GridSearch

In [1]:
import sys
import os
base_path = "/home/ubuntu/Projects/RL_sim/"
sys.path.insert(0, base_path)
os.chdir(base_path)

from agents.soft_q_learner import SoftQAgent
from agents.utils import *
from PIL.ImageOps import invert
from posixpath import join
from PIL import Image

import gymnasium as gym
import checkerboard_env
import pandas as pd
import numpy as np

In [2]:
def load_checkerboard(img, cross):
    
    board = Image.open(img)
    cross = Image.open(cross)

    if board.mode == 'RGBA':
        r,g,b,a = board.split()
        rgb_image = Image.merge('RGB', (r,g,b))
        rgb_inverse = invert(rgb_image)

        r2,g2,b2 = rgb_inverse.split()
        inverse = Image.merge('RGBA', (r2,g2,b2,a))

    else:
        inverse = invert(board)
    
    return board, inverse, cross


In [3]:
from itertools import product

# Add parameter search;
parameters = {
    'snr': [0.5, 1, 2, 5, 10],
    'n_bins': [30],
    'q_mean': [0],
    'q_sigma': [0.1],
    'kernel_size': [15],
    'kernel_sigma': [1.0, 2.0, 3.0, 5.0],
    'learning_rate': [0.1, 0.3, 0.5, 0.8, 0.9, 0.95, 0.99],
    'temperature': [1e-5, 0.001, 0.01, 0.05, 0.1, 0.3, 0.5, 0.8, 1.0],
    'min_temperature': [1e-5],
    'max_temperature': [1.0],
    'reduce_temperature': [False],
    'decay_rate': [0.001]
}

# Generate all possible combinations
combinations = list(product(*parameters.values()))

# Convert the combinations into a list of dictionaries
parameter_combinations = [dict(zip(parameters.keys(), values)) for values in combinations]

In [4]:
from joblib import Parallel, delayed

n_epochs = 100
n_iter = 100
n_jobs = 50

assets_path = "checkerboard_env/assets/"
board, inverse, cross = load_checkerboard(
    join(base_path, assets_path + "checkerboard.png"),
    join(base_path, assets_path + "cross.png")
)

def fit_parameters(params, n_epochs, n_iter):
    # import environment for each thread;
    import checkerboard_env
    
    # initialize environment variable;
    env = None
    
    results = []
    for current_iteration in range(n_iter):
        # set seed for the Q-table initialization;
        rng = np.random.default_rng(seed=current_iteration)

        # load the checkerboard environment;
        if env is None:
            env = gym.make('checkerboard-v0',
                        render_mode=None,
                        checkerboard=board,
                        inverse=inverse,
                        cross=cross,
                        snr=params["snr"])

        NUM_BINS = params["n_bins"]
        BINS = create_bins(NUM_BINS)

        q_table_shape = (NUM_BINS, NUM_BINS)  # contrast * frequency;
        q_table = rng.normal(params["q_mean"], params["q_sigma"], q_table_shape)
        kernel = generate_gaussian_kernel(params["kernel_size"], params["kernel_sigma"])

        model = SoftQAgent(
            env,
            q_table,
            kernel,
            learning_rate=params["learning_rate"],
            temperature=params["temperature"],
            min_temperature=params["min_temperature"],
            max_temperature=params["max_temperature"],
            reduce_temperature=params["reduce_temperature"],
            decay_rate=params["decay_rate"],
            num_bins_per_obs=NUM_BINS
        )
        initial_state = model.env.reset()[0]
        discrete_state = discretize_observation(initial_state, BINS)

        # fit the model;
        for current_epoch in range(n_epochs):
            action = model.soft_q_action_selection()

            (next_state,
             reward,
             terminated,
             truncated,
             info) = model.env.step(action)

            # grab current q-value;
            old_q_value = model.q_table[discrete_state]

            # discretize the next state;
            next_state_discrete = discretize_observation(next_state, BINS)

            # compute next q-value and update q-table;
            model.q_table = model.update_q_table(reward, next_state_discrete, old_q_value)

            # update state;
            discrete_state = next_state_discrete

            # by default keep temperature constant;
            model.reduce_temperature(current_epoch, reduce=model.reduce_temp)
            model.reward_log.append(reward)

            # collect the data;
            current_row = []
            current_row.extend(list(params.values()))
            current_row.extend([reward, current_epoch, current_iteration])
            #current_row = np.array(current_row, dtype=object)

            # save the data;
            results.append(current_row)
        
        if current_iteration == n_iter-1:
            env = None
    
    return results

# run the loop in parallel;
if __name__ == "__main__":
    all_rows = []
    processed_data = Parallel(n_jobs=n_jobs)(delayed(fit_parameters)(params, n_epochs, n_iter)
                                            for params in parameter_combinations)
    
    for processed_set in processed_data:
        for row in processed_set:
            all_rows.append(row)

    data = np.array(all_rows)

In [5]:
# generate data columns;
param_columns = list(parameter_combinations[0])
add_columns = ["reward", "epochs", "iteration"]
columns = param_columns + add_columns

# save and print data;
data = pd.DataFrame(data, columns=columns)
data.to_pickle("./results/data.pkl")
data

Unnamed: 0,snr,n_bins,q_mean,q_sigma,kernel_size,kernel_sigma,learning_rate,temperature,min_temperature,max_temperature,reduce_temperature,decay_rate,reward,epochs,iteration
0,0.5,30.0,0.0,0.1,15.0,1.0,0.10,0.00001,0.00001,1.0,0.0,0.001,0.538772,0.0,0.0
1,0.5,30.0,0.0,0.1,15.0,1.0,0.10,0.00001,0.00001,1.0,0.0,0.001,1.396779,1.0,0.0
2,0.5,30.0,0.0,0.1,15.0,1.0,0.10,0.00001,0.00001,1.0,0.0,0.001,-2.525203,2.0,0.0
3,0.5,30.0,0.0,0.1,15.0,1.0,0.10,0.00001,0.00001,1.0,0.0,0.001,0.529220,3.0,0.0
4,0.5,30.0,0.0,0.1,15.0,1.0,0.10,0.00001,0.00001,1.0,0.0,0.001,0.465493,4.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12599995,10.0,30.0,0.0,0.1,15.0,5.0,0.99,1.00000,0.00001,1.0,0.0,0.001,0.151588,95.0,99.0
12599996,10.0,30.0,0.0,0.1,15.0,5.0,0.99,1.00000,0.00001,1.0,0.0,0.001,0.003628,96.0,99.0
12599997,10.0,30.0,0.0,0.1,15.0,5.0,0.99,1.00000,0.00001,1.0,0.0,0.001,-0.085264,97.0,99.0
12599998,10.0,30.0,0.0,0.1,15.0,5.0,0.99,1.00000,0.00001,1.0,0.0,0.001,-0.053025,98.0,99.0
