## Imports

In [18]:
from time import time
import os

import matplotlib.pyplot as plt
import numpy as np
from gymnasium.envs.toy_text.blackjack import BlackjackEnv

from agent import QLearningAgent
from src.q_learning.blackjack_util import get_policy_from_qtable, get_mdp_policy_from_gym_policy, get_values_from_qtable, map_to_state_indexes
from src.q_learning.blackjack_util import play_using_policy
from src.q_learning.visualize_policy import create_plots, create_grids
from q_learning import train_qlearning_agent as train_QLearning_agent

## Functions

In [19]:
def train_agent_on_blackjack(FL_ENV, learning_rate=0.01, n_episodes=100000, start_epsilon=1.0, epsilon_decay=0.001, final_epsilon=0.1, gamma=0.95, decay='linear'):
    agent = QLearningAgent(
        FL_ENV.action_space.n,
        learning_rate=learning_rate,
        initial_epsilon=start_epsilon,
        epsilon_decay=epsilon_decay,
        final_epsilon=final_epsilon,
        discount_factor=gamma,
        decay=decay,
        init_q_values='zero'
    )

    env_wrapper = train_QLearning_agent(agent, FL_ENV, n_episodes=n_episodes)

    return agent, env_wrapper


def get_policy_from_q_learning(q_table, states):
    return [np.argmax(q_table[s]) for s in states]

## Parameters

In [27]:
LR = [0.01, 0.1, 0.15]
GAMMAS = [0.8, 0.95, 0.99]
N_EPISODES = [120_000, 150_000]
EPSILON = [(0.9, 0.1, 0.002), (0.8, 0.2, 0.01)]

## Final

In [None]:
blackjack_env = BlackjackEnv()
runs_dict = {}

os.makedirs('plots', exist_ok=True)
for learning_rate in LR:
    for n_episodes in N_EPISODES:
        for gamma in GAMMAS:
            for start_epsilon, final_epsilon, epsilon_step in EPSILON:
                start_time = time()
                agent, env_wrapper = train_agent_on_blackjack(blackjack_env, learning_rate=learning_rate, n_episodes=n_episodes, start_epsilon=start_epsilon, epsilon_decay=epsilon_step, final_epsilon=final_epsilon, gamma=0.95, decay='linear')
                exec_time = time() - start_time
                policy = get_policy_from_qtable(agent.q_values)
                values = get_values_from_qtable(agent.q_values)
                values_mdp = map_to_state_indexes(values)
                mdp_policy = get_mdp_policy_from_gym_policy(policy)
                print(f'\n******************* Game plays for lr_{learning_rate}-gamma_{gamma} *******************')
                policy_grid, values_grid = create_grids(agent, usable_ace=True)
                fig = create_plots(values_grid, policy_grid, rf'Q-Learning Policy for Blackjack with $\gamma$={gamma}, $\alpha$={learning_rate}, start $\epsilon$={start_epsilon}, final $\epsilon$={final_epsilon}, $\epsilon$ step={epsilon_step} and {n_episodes} episodes')
                fig.savefig(f"plots/learning_rate_{learning_rate}__gamma_{gamma}__n_episodes_{n_episodes}__epsilon_start_step_final_{start_epsilon}_{final_epsilon}_{epsilon_step}.png", bbox_inches='tight')
                # plt.show()
                play_using_policy(BlackjackEnv(), policy, games=1000)