In [None]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from gym import spaces
import matplotlib.colors as mcolors
import seaborn as sns
import time
from itertools import product

###MDP solution approaches

In [None]:
class Sol_Env(gym.Env):
    def __init__(self):
        # Define ranges for each state variable
        self.player_sum_1_range = range(4, 31)  # Current sum 1
        self.dealer_sum_range = range(4, 29)  # Example range for dealer sum
        self.usable_ace_1 = [False, True]  # Usable ace for hand 1
        self.stick_happened = [False, True]  # end or not

        # Define actions
        self.actions = ['hit', 'stick']

    def get_all_states(self):
        """Generate all possible states."""
        states = list(product(
            self.player_sum_1_range,  # Current sum 1
            self.dealer_sum_range,  # Dealer's showing card
            self.usable_ace_1,  # Usable ace for player
            #self.usable_ace_2,  # Usable ace for dealer
            self.stick_happened  # end or not
        ))
        filtered_states = [
            state for state in states

            # Filter out states where player_sum < 12 and usable_ace = True
            if not (state[0] < 12 and state[2])
            and not (state[0] >= 21 and not state[3])
            and not (state[1] >= 21 and not state[3])
                           ]
        return filtered_states

    def get_possible_actions(self, state):
        """Return possible actions for a given state."""
        player_sum, dealer_sum, _, stick_happened = state
        if player_sum >= 21 or dealer_sum >= 21 or (dealer_sum >= 17 and stick_happened):
            return []
        elif stick_happened:
            return ['stick']
        elif dealer_sum >= 17 and not stick_happened:
            return ['hit']
        actions = ['hit', 'stick']
        return actions

    def get_reward(self, player_sum_1, dealer_sum, stick_happened):
        """
        Calculate the reward for a given state.
        """

        if player_sum_1 >= 21 or dealer_sum >= 21 or (dealer_sum >= 17 and stick_happened):
          if player_sum_1 > 21 and dealer_sum == 21 or dealer_sum == 21 and player_sum_1 < 21 or dealer_sum < 21 and player_sum_1 > 21:
                return -1
          elif player_sum_1 == 21 and dealer_sum > 21 or dealer_sum < 21 and player_sum_1 == 21 or dealer_sum > 21 and player_sum_1 < 21:
                return 1

          elif player_sum_1 == 21 and dealer_sum == 21 or dealer_sum > 21 and player_sum_1 > 21:
                return 0

          else:
            diff_21_player = 21 - player_sum_1
            diff_21_dealer = 21 - dealer_sum
            if diff_21_player > diff_21_dealer:
                return -1
            elif diff_21_player < diff_21_dealer:
                return 1
            else:
                return 0

        return 0


    def get_transition_probabilities(self, state, action):
        """
        Calculate transition probabilities for a given state-action pair.
        """
        distr = [1 / 13] * 8 + [4 / 13] + [1 / 13]  # Probabilities for cards 2–11
        card_values = list(range(2, 12))  # Cards are valued from 2 to 11
        distr_dict = {card_values[i]: distr[i] for i in range(len(card_values))}

        transitions = []
        player_sum_1, dealer_sum, ace_1, stick_happened = state

        if action == 'hit':
            if player_sum_1 >= 21:
                reward = self.get_reward(player_sum_1, dealer_sum, True)
                transitions.append((state, 1.0, reward))
            else:
                for card, prob in distr_dict.items():
                    new_sum = player_sum_1 + card
                    new_ace_1 = ace_1
                    if new_sum > 21 and new_ace_1:
                        if card == 11:
                          new_sum -= 10
                        else:
                          new_sum -= 10
                          new_ace_1 = False
                    if card == 11 and player_sum_1 < 11 and not new_ace_1:
                        new_ace_1 = True
                    if card == 11 and new_sum > 21 and not new_ace_1:
                        new_sum -= 10
                    if new_sum >= 21:
                        next_state = (new_sum, dealer_sum, new_ace_1, True)
                        reward = self.get_reward(new_sum, dealer_sum, True)
                        transitions.append((next_state, prob, reward))
                    else:
                        next_state = (new_sum, dealer_sum, new_ace_1, stick_happened)
                        reward = self.get_reward(new_sum, dealer_sum, stick_happened)
                        transitions.append((next_state, prob, reward))

        elif action == 'stick': # usable ace for the dealer?
            if dealer_sum >= 21:
                reward = self.get_reward(player_sum_1, dealer_sum, True)
                transitions.append((state, 1.0, reward))
            else:
                if not stick_happened:
                  if dealer_sum < 17:
                    for card, prob in distr_dict.items():
                        new_dealer_sum = dealer_sum + card
                        next_state = (player_sum_1, new_dealer_sum, ace_1, True)
                        reward = self.get_reward(player_sum_1, new_dealer_sum, True)
                        transitions.append((next_state, prob, reward))
                  else:
                    reward = self.get_reward(player_sum_1, dealer_sum, True)
                    transitions.append((state, 1.0, reward))
                else:
                  if dealer_sum < 17:
                    for card, prob in distr_dict.items():
                        new_dealer_sum = dealer_sum + card
                        next_state = (player_sum_1, new_dealer_sum, ace_1, stick_happened)
                        reward = self.get_reward(player_sum_1, new_dealer_sum, stick_happened)
                        transitions.append((next_state, prob, reward))
                  else:
                    reward = self.get_reward(player_sum_1, dealer_sum, stick_happened)
                    transitions.append((state, 1.0, reward))

        return transitions

In [None]:
bj = Sol_Env()
#len(bj.get_all_states())
#bj.get_all_states()[:500]

In [None]:
all_states = bj.get_all_states()

for state in all_states[500:1000]:
    possible_actions = bj.get_possible_actions(state)
    for action in possible_actions:
        transitions = bj.get_transition_probabilities(state, action)
        for next_state, prob, reward in transitions:  # Assuming it's a list of (next_state, probability)
            print(f"State: {state}, Action: {action}, Next State: {next_state}, Probability: {prob}, Reward: {reward}")

In [None]:
print(f"Total states: {len(bj.get_all_states())}")
terminal_states = [s for s in bj.get_all_states() if not bj.get_possible_actions(s)]
print(f"Number of terminal states: {len(terminal_states)}")

In [None]:
for i in terminal_states:
  print(i)

Value iteration

In [None]:
def value_iteration_to_get_opt_solution(env, n=1000, theta=1e-15):
    """Compute the optimal value function V* using value iteration."""
    all_states = env.get_all_states()
    non_terminal_states = [s for s in all_states if not (s[0] >= 21 or s[1] >= 21 or (s[1] >= 17 and s[3]))]

    V_opt = {state: 0 for state in all_states}  # Initialize V*
    policy = {state: None for state in all_states}  # Store optimal policy

    for _ in range(n):
        delta = 0
        W = V_opt.copy()

        for state in non_terminal_states:
            old_value = V_opt[state]
            max_value = float('-inf')
            best_action = None

            for action in env.get_possible_actions(state):
                transitions = env.get_transition_probabilities(state, action)
                action_value = sum(prob * (reward + W[next_state]) for next_state, prob, reward in transitions)

                if action_value > max_value:
                    max_value = action_value
                    best_action = action

            V_opt[state] = max_value
            policy[state] = best_action
            delta = max(delta, abs(old_value - max_value))

        if delta < theta:
            break

    return V_opt, policy

In [None]:
V_opt, pol_opt = value_iteration_to_get_opt_solution(bj)

In [None]:
pol_opt

In [None]:
zero_value_states = [state for state, value in V_opt.items() if value == 0]
for i in zero_value_states:
  print(i)

In [None]:
print(terminal_states == zero_value_states)

In [None]:
def value_iteration(env, theta=1e-15, gamma=1, n=1000):
    """Perform value iteration and track sup-norm error."""
    all_states = env.get_all_states()
    non_terminal_states = [s for s in all_states if not (s[0] >= 21 or s[1] >= 21 or (s[1] >= 17 and s[3]))]

    V = {state: 0 for state in all_states}  # Initialize V
    policy = {state: None for state in all_states}  # Initialize policy
    sup_norm_errors = []  # Track sup norm differences
    iterations = 0

    for _ in range(n):
        delta = 0
        W = V.copy()

        for state in non_terminal_states:
            old_value = V[state]
            max_value = float('-inf')
            best_action = None

            for action in env.get_possible_actions(state):
                transitions = env.get_transition_probabilities(state, action)
                action_value = sum(prob * (reward + W[next_state]) for next_state, prob, reward in transitions)

                if action_value > max_value:
                    max_value = action_value
                    best_action = action

            V[state] = max_value
            policy[state] = best_action
            delta = max(delta, abs(old_value - max_value))

        # Compute sup norm error
        sup_norm = max(abs(V[state] - V_opt[state]) for state in non_terminal_states)
        sup_norm_errors.append(sup_norm)
        iterations += 1
        if delta < theta:
            break

    return V, policy, iterations, sup_norm_errors

In [None]:
bj = Sol_Env()
start_time = time.time()
#V_vi, policy_vi, iterations_vi, deltas_vi = value_iteration(bj, V_opt)
V_vi, policy_vi, iterations_vi, deltas_vi = value_iteration(bj)
end_time = time.time()

print(f"Value Iteration completed in {end_time - start_time:.2f} seconds and {iterations_vi} iterations.")
# Plot convergence
def plot_convergence(deltas, label, name, step=1):
    plt.figure(figsize=(10, 6))
    plt.plot(range(len(deltas)), deltas)
    #plt.plot(range(1, len(deltas)+1), deltas)
    plt.xlabel('Iteration')
    plt.ylabel('Delta')
    plt.title(label)
    plt.xticks(range(0, len(deltas), step))
    plt.legend()
    plt.grid()
    plt.savefig(name, format="jpg", dpi=300)
    plt.show()
plot_convergence(deltas_vi, 'Convergence of Value Iteration', "value_iteration.jpg", 2)

In [None]:
print(V_opt == V_vi)
print(pol_opt == policy_vi)

Gauss-Seidel

In [None]:
def gauss_seidel(env, V_opt, n = 100, theta=1e-15):
    """Perform value iteration to find the optimal policy."""
    # Get all states and filter out terminal ones
    all_states = env.get_all_states()
    non_terminal_states = [state for state in all_states if not(state[0] >= 21 or state[1] >= 21 or (state[1] >= 17 and state[3]))]
    #non_terminal_states = [state for state in all_states if not(state[0] >= 21 or state[1] >= 21 or state[1] > 17 and state[4])]

    V = {state: 0 for state in all_states}  # Initialize value function
    policy = {state: None for state in all_states}  # Initialize policy

    iterations = 0
    delta_list = []  # To track convergence

    for i in range(n):
    #while iterations != n:
        for state in non_terminal_states:
            old_value = V[state]
            max_value = float('-inf')
            best_action = None
            #print(state)
            for action in env.get_possible_actions(state):
                #print(action)
                transitions = env.get_transition_probabilities(state, action)
                #for i in transitions:
                  #print(i)
                action_value = sum(
                    prob * (reward + V[next_state])
                    for next_state, prob, reward in transitions
                )
                if action_value > max_value:
                    max_value = action_value
                    best_action = action

            V[state] = max_value
            policy[state] = best_action

        diff = max(abs(V_opt[key] - V[key]) for key in V)
        delta_list.append(diff)
        iterations += 1
        #if diff == 0:
        if diff < theta:
            break

    return V, policy, iterations, delta_list

In [None]:
# Main Execution
bj = Sol_Env()

start_time = time.time()
V_gs, policy_gs, iterations_gs, deltas_gs = gauss_seidel(bj, V_opt)
end_time = time.time()

print(f"Gauss-Seidel completed in {end_time - start_time:.2f} seconds and {iterations_gs} iterations.")
#print(deltas)
plot_convergence(deltas_gs, 'Convergence of Gauss-Seidel', "gauss-seidel.jpg", 2)

In [None]:
print(V_opt == V_gs)
print(pol_opt == policy_gs)

Optimistic Policy iteration

In [None]:
def opt_policy_iteration(env, n=100, theta=1e-15):
    """Perform policy iteration and track sup-norm error."""
    all_states = env.get_all_states()
    non_terminal_states = [s for s in all_states if not (s[0] >= 21 or s[1] >= 21 or (s[1] > 17 and s[3]))]

    V = {state: 0 for state in all_states}  # Initialize V
    # Initialize policy and value function
    policy = {}
    for state in all_states:
        possible_actions = env.get_possible_actions(state)
        if possible_actions:  # Ensure there are valid actions
            policy[state] = possible_actions[0]  # Default to the first action
        else:
            policy[state] = None  # No action for terminal states
    sup_norm_errors = []  # Track sup norm differences
    iterations = 0
    iterations2 = 0
    for _ in range(n):
        # Policy Evaluation
        while True:
            delta = 0
            for state in non_terminal_states:
                old_value = V[state]
                action = policy[state]
                if action is None:
                    continue
                transitions = env.get_transition_probabilities(state, action)
                if not transitions:
                    continue
                V[state] = sum(prob * (reward + V[next_state]) for next_state, prob, reward in transitions)
                delta = max(delta, abs(old_value - V[state]))
            iterations2 += 1
            if delta < theta:
                break

        # Policy Improvement
        policy_stable = True
        for state in non_terminal_states:
            old_action = policy[state]
            best_action = None
            max_value = float('-inf')

            for action in env.get_possible_actions(state):
                transitions = env.get_transition_probabilities(state, action)
                action_value = sum(prob * (reward + V[next_state]) for next_state, prob, reward in transitions)

                if action_value > max_value:
                    max_value = action_value
                    best_action = action

            if best_action is not None:
                policy[state] = best_action
                if old_action != best_action:
                    policy_stable = False

        # Compute sup norm error
        sup_norm = max(abs(V[state] - V_opt[state]) for state in non_terminal_states)
        sup_norm_errors.append(sup_norm)
        iterations += 1
        if policy_stable:
            break

    return V, policy, iterations, iterations2, sup_norm_errors

In [None]:
bj = Sol_Env()

start_time = time.time()
#V_pi, policy_pi, iterations_pi, deltas_pi, deltas_pi2 = opt_policy_iteration(bj, V_opt)
V_pi, policy_pi, iterations_pi, iterations_pi2, deltas_pi = opt_policy_iteration(bj)
end_time = time.time()


# Trim the list until the first 0 element
trimmed_deltas = deltas_pi[:next((i for i, x in enumerate(deltas_pi) if x == 0), len(deltas_pi))]
print(f"Optimistic Policy Iteration completed in {end_time - start_time:.2f} seconds and {len(trimmed_deltas)} iterations.")
plot_convergence(deltas_pi, 'Convergence of Policy Iteration', 's.jpg')
#plot_convergence(trimmed_deltas, 'Convergence of Policy Iteration', 's.jpg')
print(iterations_pi2)

In [None]:
print(deltas_pi)
#print(policy_pi)

In [None]:
print(V_opt == V_pi)
print(pol_opt == policy_pi)

In [None]:
print(deltas_pi)

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(deltas_vi, label="Value Iteration", linewidth=2)
plt.plot(deltas_gs, label="Gauss-Seidel Value Iteration", linestyle="--", linewidth=2)
plt.plot(deltas_pi, label="Policy Iteration", linewidth=2)

plt.xlabel("Number of Iterations")
plt.ylabel("Sup Norm Difference")
plt.title("Convergence of MDP Methods")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('MDP_conv.jpg', format="jpg", dpi=300)
plt.show()

###RL

In [None]:
class BlackJackEnv(gym.Env):

    metadata = {'render.modes':['human']}

    def __init__(self):
        self.observation_space = spaces.Discrete(2688)
        self.action_space = spaces.Discrete(2)
        self.step_count = 0                        ### Number of actions taken in the game till now
        self.actions = ['hit', 'stick']


    def check_usable_ace(self,hand):
        ### Creating a temporary hand taking the Ace's value as 11 to check of usability
        temp_hand = hand.copy()

        ### Checking if the hand has any ace, if not then returns False
        if np.any(temp_hand == 11) and temp_hand.sum() > 21:
            return True
        return False

    def use_ace(self,hand):
        temp_hand = hand.copy()
        temp_hand[np.where(temp_hand == 11)[0][0]] = 1
        return temp_hand


    def reset(self):
        distr = [1/13] * 9 + [4/13]
        ### New Player Hand
        self.current_hand = np.random.choice(range(2, 12), 2, p=distr)

        ### Initialising Usable Ace as False

        self.usable_ace = False

        ### Variable is used to inform whether the dealer has sticked,
        ### Used to know when to terminate the game

        self.stick_happened = False


        ### Checking if player hand has Usable Ace, if yes, then replacing it with 11.
        if self.check_usable_ace(self.current_hand):
            self.usable_ace = True
            self.current_hand = self.use_ace(self.current_hand)

        ### State variable Current Sum
        self.current_sum = self.current_hand.sum()

        ### Dealer's New Hand
        self.dealer_hand = np.random.choice(range(2, 12), 2, p=distr)

        ### Dealer's Sum
        self.dealer_sum = self.dealer_hand.sum()

        ### State Variable: Dealer Showing Card
        self.dealer_showing_card = self.dealer_hand[0]

        ### Checking if Dealer's hand has Usable Ace, if yes, then replacing it with 11.
        if self.check_usable_ace(self.dealer_hand):
            temp_dealer_hand = self.use_ace(self.dealer_hand)
            self.dealer_sum = temp_dealer_hand.sum()


    def take_turn(self,player):

        distr = [1/13] * 9 + [4/13]

        if player == 'dealer':

            ### takes a new random card
            new_card = np.random.choice(range(1, 11), p=distr)

            ### adding new card to the players hand and making a temporary new hand
            new_dealer_hand = np.array(self.dealer_hand.tolist() +  [new_card])

            ### Check if there is usable ace
            if self.check_usable_ace(new_dealer_hand):

                ### replace ace(1) with 11
                new_dealer_hand = self.use_ace(new_dealer_hand)

            ### Assigning the temporary hand to the players actual hand
            self.dealer_hand = new_dealer_hand

            ### Updating the players hand sum variable
            self.dealer_sum = self.dealer_hand.sum()

        if player == 'player':

            ### takes a new random card
            new_card = np.random.choice(range(1, 11), p=distr)

            ### adding new card to the players hand and making a temporary new hand
            new_player_hand = np.array(self.current_hand.tolist()+ [new_card])

            ### Check if there is usable ace
            if self.check_usable_ace(new_player_hand):

                ### replace ace(1) with 11
                self.usable_ace = True
                new_player_hand = self.use_ace(new_player_hand)

            ### Assigning the temporary hand to the players actual hand
            self.current_hand = new_player_hand
            ### Updating the players hand sum variable
            self.current_sum = self.current_hand.sum()



    def check_game_status(self, mode = 'normal'):

        '''
         checks the status of the game, there are two modes
         'normal' mode - the default mode, this is used to check after
                         each turn whether a terminal state has been reached
         'compare' mode - used when we need to compare the totals of both the players
                          to judge the winner


         returns a result dictionary with the winner, whether the game is finished
         and the reward of the game
        '''
        result = {'winner':'',
                 'is_done': False,
                 'reward':0}


        if self.current_sum > 21:
            self.stick_happened = True
            result['winner'] = 'dealer'
            result['is_done'] = True
            result['reward'] = -1
        elif self.dealer_sum > 21:
            self.stick_happened = True
            result['winner'] = 'player'
            result['is_done'] = True
            result['reward'] = 1

        elif self.current_sum == 21:
            self.stick_happened = True
            result['winner'] = 'player'
            result['is_done'] = True
            result['reward'] = 1

        elif self.dealer_sum == 21:
            self.stick_happened = True
            result['winner'] = 'dealer'
            result['is_done'] = True
            result['reward'] = -1

        else:
            if self.stick_happened and self.dealer_sum > 17:
                result['is_done'] = True
                diff_21_player = 21 - self.current_sum
                diff_21_dealer = 21 - self.dealer_sum

                if diff_21_player > diff_21_dealer:
                    result['reward'] = -1
                    result['winner'] = 'dealer'
                elif diff_21_player < diff_21_dealer:
                    result['reward'] = 1
                    result['winner'] = 'player'
                else:
                    result['reward'] = 0
                    result['winner'] = 'draw'

                return result

        return result

    def step(self,action):

        '''
        Performs one action, either Hit or Stick

        returns - a result dictionary with the winner, whether the game is finished
        and the reward of the game

        '''

        self.step_count += 1  ### Number of actions taken in the game till now


        result = {'winner':'',
                 'is_done': False,
                 'reward':0}

        ### Before taking the first step of the game we need to check for "natural"
        ### winning condition if the initial two cards of the players are 21
        ### If anyone has 21, then that player wins, if both have 21, then the game is
        ### drawn. Otherwise the game will continue

        if self.step_count == 1:
            if self.check_usable_ace(self.current_hand):
                self.current_hand = self.use_ace(self.current_hand)
            if self.check_usable_ace(self.dealer_hand):
                self.current_hand = self.use_ace(self.dealer_hand)

            if self.current_sum == 21 and self.dealer_sum == 21:
                self.stick_happened = True
                result['is_done'] = True
                result['reward'] = 0
                result['winner'] = 'draw'
                return result

            elif self.current_sum == 21 and self.dealer_sum < 21:
                self.stick_happened = True
                result['is_done'] = True
                result['reward'] = 1
                result['winner'] = 'player'
                return result

            elif self.dealer_sum == 21 and self.current_sum < 21:
                self.stick_happened = True
                result['is_done'] = True
                result['reward'] = -1
                result['winner'] = 'dealer'
                return result

            if self.dealer_sum >= 17:
                self.stick_happened = True

        ### action = 0, meaning "hit"

        if action == 0 and not self.stick_happened:

            ### Player Takes Turn
            self.take_turn('player')

            ### Checking game status
            result = self.check_game_status()
            if result['is_done'] == True:
                return result



        if action == 1:  ### stick

            self.stick_happened = True

            ### Dealers Turn
            if self.dealer_sum <= 17:

                self.take_turn('dealer')
                result = self.check_game_status()
            else:
                result = self.check_game_status()
                result['is_done'] == True
                return result

        """if action == 1:  ### stick

            self.stick_happened = True

            if self.stick_happened == True and self.dealer_sum > 17:
                  result = self.check_game_status()
                  result['is_done'] == True
                  return result

            ### Dealers Turn
            while self.dealer_sum <= 17:

                self.take_turn('dealer')
                result = self.check_game_status()
                if result['is_done'] == True:
                    return result"""

        return result


    def get_current_state(self):
        '''
        returns the current state variables, current_sum, dealer_showing_card, usable_ace
        '''
        current_state = {}

        current_state['current_sum'] = self.current_sum
        current_state['dealer_sum'] = self.dealer_sum
        current_state['usable_ace'] = self.usable_ace
        current_state['stick_happened'] = self.stick_happened

        return current_state


    def render(self):

        print('OBSERVABLE STATES')
        print('Current Sum - {}'.format(self.current_sum))
        print('Dealer Sum - {}'.format(self.dealer_sum))
        print('Usable Ace - {}'.format(self.usable_ace))
        print('Stick happened - {}'.format(self.stick_happened))

        print('AUXILLARY INFORMATION ------------------------------')
        print('Current Hand - {}'.format(self.current_hand))
        print('Dealer Hand - {}'.format(self.dealer_hand))
        print('Dealer Showing Card - {}'.format(self.dealer_showing_card))

In [None]:
bj = BlackJackEnv() #a double down miatt mindig újra kell ezt indítani

In [None]:
bj.reset()
bj.render()

In [None]:
print(bj.step(1))
bj.render()

#### Q-Learning

In [None]:
#### following are 4 dictionaries which help in converting the
#### state values like current_sum and action to indexes in the Q value table

current_sum_to_index = dict(zip(np.arange(4,33),np.arange(28)))
dealer_sum_to_index = dict(zip(np.arange(4,29),np.arange(24)))
usable_ace_index = dict(zip([False,True],[0,1]))
stick_happened_index = dict(zip([False,True],[0,1]))
action_index = dict(zip(['hit','stick'],[0,1]))

def get_state_q_indices(current_state):

    '''
    used to get indices of the Q table for any given state

    '''
    current_sum_idx = current_sum_to_index[current_state['current_sum']]
    dealer_sum_idx = dealer_sum_to_index[current_state['dealer_sum']]
    usable_ace_idx = usable_ace_index[current_state['usable_ace']]
    stick_happened_idx = stick_happened_index[current_state['stick_happened']]

    return [current_sum_idx, dealer_sum_idx, usable_ace_idx, stick_happened_idx]

def get_max_action(Q_sa, current_state):

    '''
    used to get the action with the max q-value given the current state and the Q table

    '''

    state_q_idxs = get_state_q_indices(current_state)
    action = Q_sa[state_q_idxs[0],state_q_idxs[1],state_q_idxs[2],state_q_idxs[3],:].argmax()

    return action

def get_q_value(Q_sa, state, action):
    '''
    used to get Q value for any given state and action, given the Q table

    '''
    state_q_idxs = get_state_q_indices(state)
    q_value = Q_sa[state_q_idxs[0],state_q_idxs[1],state_q_idxs[2],state_q_idxs[3],action]

    return q_value
#print(current_sum_to_index2)

In [None]:
Q = np.zeros((28, 25, 2, 2, 2))
#Q = np.full((28, 25, 2, 2, 2), -1.1)

V_Q_dict = {}
# Update V_Q_dict using correct ranges
for player_sum in range(4, 31):   # Player's sum
  for dealer_sum in range(4, 29):  # Dealer's sum
    for usable_ace in [False, True]:  # Boolean usable ace
      for stick_happened in [False, True]:  # Boolean stick_happened
          state = (player_sum, dealer_sum, usable_ace, stick_happened)
          # Apply the filtering conditions
          if not (state[0] < 12 and state[2]) and not (state[0] >= 21 and not state[3]) and not (state[1] >= 21 and not state[3]):
              V_Q_dict[state] = 0
#print(V_Q_dict)

episode_count = 0
total_episodes = 5000
gamma = 1             #### the discount factor
alpha = 0.05             #### learning rate
theta=1e-15
bj = BlackJackEnv()

# Initialize variables for tracking runtime and errors
start_time = time.time()
diffs_ql = []


while episode_count < total_episodes:


    bj.reset()  ### Initialize S (the environment's starting state)


    current_state = bj.get_current_state()
    current_action = get_max_action(Q, current_state)


    ### Take Action
    step_result = bj.step(current_action)

    next_state = bj.get_current_state()
    next_max_action = get_max_action(Q, next_state)
    immediate_reward = step_result['reward']

    next_state_q_idxs = get_state_q_indices(next_state)

    #### Get Q value for the next state and max action in the next state
    q_max_s_a = get_q_value(Q, next_state, next_max_action)
    #print(immediate_reward)
    td_target = immediate_reward + gamma * q_max_s_a

    #### Getting Q value for the current state and action
    q_current_s_a = get_q_value(Q, current_state, current_action)

    td_error = td_target - q_current_s_a

    state_q_idxs = get_state_q_indices(current_state)

    #### Updating current Q(S,A)
    Q[state_q_idxs[0],state_q_idxs[1],state_q_idxs[2],state_q_idxs[3],current_action] = q_current_s_a + alpha*td_error
    #V_Q_dict[(state_q_idxs[0]+4,state_q_idxs[1]+4,state_q_idxs[2],state_q_idxs[3])] = Q[state_q_idxs[0], state_q_idxs[1], state_q_idxs[2], state_q_idxs[3], :].min()

    current_state = next_state  ### S=S'

    alpha = 1.0 / (1 + Q[state_q_idxs[0], state_q_idxs[1], state_q_idxs[2], state_q_idxs[3], current_action])

    if step_result['is_done']:
        episode_count+=1
        #print(current_state)

        for (i, j, k, l), _ in V_Q_dict.items():
            #print((i, j, k, l))
            k = int(k)
            l = int(l)
            V_Q_dict[(i, j, k, l)] = Q[i-4, j-4, k, l].min()

        # Compute sup norm
        sup_norm = max(abs(V_opt[key] - V_Q_dict[key]) for key in V_opt)

        diffs_ql.append(sup_norm)

        if sup_norm < theta:
            break

        if episode_count%10000 == 0:
            print('---------Episode - {} -----------'.format(episode_count))

# Calculate total runtime
end_time = time.time()
runtime = end_time - start_time

print(f"Total runtime: {runtime} seconds")

if diffs_ql and diffs_ql[-1] < 0:
    diffs_ql[-1] = 0

print(len(diffs_ql))

plt.figure(figsize=(10, 6))
plt.plot(diffs_ql)
#plt.plot(range(1, len(deltas)+1), deltas)
plt.xlabel('Iteration')
plt.ylabel('Delta')
plt.title('Convergence of Q Learning'),
plt.legend()
plt.grid()
plt.savefig("q_learning.jpg", format="jpg", dpi=300)
plt.show()

In [None]:
def get_min_action(Q_sa, current_state):

    '''
    used to get the action with the max q-value given the current state and the Q table

    '''

    state_q_idxs = get_state_q_indices(current_state)
    action = Q_sa[state_q_idxs[0],state_q_idxs[1],state_q_idxs[2],state_q_idxs[3],:].argmin()

    return action

Q = np.zeros((28, 25, 2, 2, 2))
#Q = np.full((28, 25, 2, 2, 2), -1.1)

V_Q_dict = {}
# Update V_Q_dict using correct ranges
for player_sum in range(4, 31):   # Player's sum
  for dealer_sum in range(4, 29):  # Dealer's sum
    for usable_ace in [False, True]:  # Boolean usable ace
      for stick_happened in [False, True]:  # Boolean stick_happened
          state = (player_sum, dealer_sum, usable_ace, stick_happened)
          # Apply the filtering conditions
          if not (state[0] < 12 and state[2]) and not (state[0] >= 21 and not state[3]) and not (state[1] >= 21 and not state[3]):
              V_Q_dict[state] = 0
#print(V_Q_dict)

episode_count = 0
total_episodes = 5000
gamma = 1             #### the discount factor
alpha = 0.1             #### learning rate
theta=1e-15
bj = BlackJackEnv()

# Initialize variables for tracking runtime and errors
start_time = time.time()
diffs_ql = []


while episode_count < total_episodes:


    bj.reset()  ### Initialize S (the environment's starting state)


    current_state = bj.get_current_state()
    current_action = get_min_action(Q, current_state)


    ### Take Action
    step_result = bj.step(current_action)

    next_state = bj.get_current_state()
    next_max_action = get_min_action(Q, next_state)
    immediate_reward = step_result['reward']

    next_state_q_idxs = get_state_q_indices(next_state)

    #### Getting Q value for the current state and action
    q_current_s_a = get_q_value(Q, current_state, current_action)

    #### Get Q value for the next state and max action in the next state
    q_max_s_a = get_q_value(Q, next_state, next_max_action)

    state_q_idxs = get_state_q_indices(current_state)

    #### Updating current Q(S,A)
    Q[state_q_idxs[0],state_q_idxs[1],state_q_idxs[2],state_q_idxs[3],current_action] = (1-alpha)*q_current_s_a + alpha*(immediate_reward + q_max_s_a)

    current_state = next_state  ### S=S'

    alpha = 1.0 / (1 + Q[state_q_idxs[0], state_q_idxs[1], state_q_idxs[2], state_q_idxs[3], current_action])

    if step_result['is_done']:
        episode_count+=1
        #print(current_state)

        for (i, j, k, l), _ in V_Q_dict.items():
            #print((i, j, k, l))
            k = int(k)
            l = int(l)
            V_Q_dict[(i, j, k, l)] = Q[i-4, j-4, k, l].min()

        # Compute sup norm
        sup_norm = max(abs(V_opt[key] - V_Q_dict[key]) for key in V_opt)

        diffs_ql.append(sup_norm)

        if sup_norm < theta:
            break

        if episode_count%10000 == 0:
            print('---------Episode - {} -----------'.format(episode_count))

# Calculate total runtime
end_time = time.time()
runtime = end_time - start_time

if diffs_ql and diffs_ql[-1] < 0:
    diffs_ql[-1] = 0
print(len(diffs_ql))
print(f"Total runtime: {runtime} seconds")

plt.figure(figsize=(10, 6))
plt.plot(diffs_ql)
#plt.plot(range(1, len(deltas)+1), deltas)
plt.xlabel('Iteration')
plt.ylabel('Delta')
plt.title('Convergence of Q Learning'),
plt.legend()
plt.grid()
plt.savefig("q_learning.jpg", format="jpg", dpi=300)
plt.show()

In [None]:
zero_value_states = [state for state, value in V_Q_dict.items() if value == 0]
for i in zero_value_states:
  print(i)

####SARSA

In [None]:
def get_action_epsilon_greedy(Q_sa, current_state, epsilon):
    '''
    Get action using epsilon-greedy policy.
    '''
    random_number = np.random.rand()
    #print(random_number)
    if random_number < epsilon:
        return np.random.choice([0, 1])  # random action
    else:
        return get_max_action(Q_sa, current_state)

In [None]:
Q = np.zeros((28, 25, 2, 2, 2))
#Q = np.full((28, 25, 2, 2, 2), -1.1)

V_Q_dict = {}
# Update V_Q_dict using correct ranges
for player_sum in range(4, 31):   # Player's sum
  for dealer_sum in range(4, 29):  # Dealer's sum
    for usable_ace in [False, True]:  # Boolean usable ace
      for stick_happened in [False, True]:  # Boolean stick_happened
          state = (player_sum, dealer_sum, usable_ace, stick_happened)
          # Apply the filtering conditions
          if not (state[0] < 12 and state[2]) and not (state[0] >= 21 and not state[3]) and not (state[1] >= 21 and not state[3]):
              V_Q_dict[state] = 0
#print(V_Q_dict)

episode_count = 0
total_episodes = 5000
gamma = 1             #### the discount factor
alpha = 0.1             #### learning rate
epsilon = 0.1           #### epsilon for epsilon-greedy policy
theta=1e-15
bj = BlackJackEnv()

#epsilon_min = 0.01       # Lower bound for epsilon (to maintain some exploration)
#epsilon_decay = 0.9999

# Initialize variables for tracking runtime and errors
start_time = time.time()
diffs_sarsa = []

while episode_count < total_episodes:
    bj.reset()  ### Initialize S (the environment's starting state)

    current_state = bj.get_current_state()
    current_action = get_action_epsilon_greedy(Q, current_state, epsilon)

    step_result = bj.step(current_action)

    next_state = bj.get_current_state()
    next_action = get_action_epsilon_greedy(Q, next_state, epsilon)
    immediate_reward = step_result['reward']


    q_current_s_a = get_q_value(Q, current_state, current_action)
    q_next_s_a = get_q_value(Q, next_state, next_action)

    td_target = immediate_reward + gamma * q_next_s_a
    td_error = td_target - q_current_s_a

    Q_state_idxs = get_state_q_indices(current_state)

    Q[Q_state_idxs[0], Q_state_idxs[1], Q_state_idxs[2], Q_state_idxs[3], current_action] = q_current_s_a + alpha * td_error
    #V_Q_dict[(Q_state_idxs[0]+4, Q_state_idxs[1]+4, Q_state_idxs[2], Q_state_idxs[3])] = Q[Q_state_idxs[0], Q_state_idxs[1], Q_state_idxs[2], Q_state_idxs[3], :].min()

    current_state = next_state  ### S=S'
    current_action = next_action  ### A=A'

    alpha = 1.0 / (1 + Q[Q_state_idxs[0], Q_state_idxs[1], Q_state_idxs[2], Q_state_idxs[3], current_action])
    #alpha = max(0.05, 1.0 / (1 + 0.01 * episode_count))
    #alpha = 1.0 / (1 + episode_count)


    if step_result['is_done']:
        episode_count+=1

        for (i, j, k, l), _ in V_Q_dict.items():
            #print((i, j, k, l))
            k = int(k)
            l = int(l)
            V_Q_dict[(i, j, k, l)] = Q[i-4, j-4, k, l].min()

        # Compute sup norm
        sup_norm = max(abs(V_opt[key] - V_Q_dict[key]) for key in V_opt)

        #epsilon = max(epsilon_min, epsilon * epsilon_decay)

        diffs_sarsa.append(sup_norm)

        if sup_norm < theta:
            break

        if episode_count%10000 == 0:
            print('---------Episode - {} -----------'.format(episode_count))

# Calculate total runtime
end_time = time.time()
runtime = end_time - start_time

print(f"Total runtime: {runtime} seconds")

if diffs_sarsa and diffs_sarsa[-1] < 0:
    diffs_sarsa[-1] = 0

print(len(diffs_sarsa))

plt.figure(figsize=(10, 6))
plt.plot(diffs_sarsa)
#plt.plot(range(1, len(deltas)+1), deltas)
plt.xlabel('Iteration')
plt.ylabel('Delta')
plt.title('Convergence of SARSA')
plt.legend()
plt.grid()
plt.savefig("sarsa.jpg", format="jpg", dpi=300)
plt.show()

In [None]:
Q = np.zeros((28, 25, 2, 2, 2))
#Q = np.full((28, 25, 2, 2, 2), -1.1)

V_Q_dict = {}
# Update V_Q_dict using correct ranges
for player_sum in range(4, 31):   # Player's sum
  for dealer_sum in range(4, 29):  # Dealer's sum
    for usable_ace in [False, True]:  # Boolean usable ace
      for stick_happened in [False, True]:  # Boolean stick_happened
          state = (player_sum, dealer_sum, usable_ace, stick_happened)
          # Apply the filtering conditions
          if not (state[0] < 12 and state[2]) and not (state[0] >= 21 and not state[3]) and not (state[1] >= 21 and not state[3]):
              V_Q_dict[state] = 0
#print(V_Q_dict)

episode_count = 0
total_episodes = 5000
gamma = 1             #### the discount factor
alpha = 0.05             #### learning rate
epsilon = 0.1           #### epsilon for epsilon-greedy policy
theta=1e-15
bj = BlackJackEnv()

#epsilon_min = 0.01       # Lower bound for epsilon (to maintain some exploration)
#epsilon_decay = 0.9999

# Initialize variables for tracking runtime and errors
start_time = time.time()
diffs_sarsa = []

while episode_count < total_episodes:
    bj.reset()  ### Initialize S (the environment's starting state)


    current_state = bj.get_current_state()
    current_action = get_min_action(Q, current_state)


    ### Take Action
    step_result = bj.step(current_action)

    next_state = bj.get_current_state()
    next_max_action = get_min_action(Q, next_state)
    immediate_reward = step_result['reward']

    next_state_q_idxs = get_state_q_indices(next_state)

    #### Getting Q value for the current state and action
    q_current_s_a = get_q_value(Q, current_state, current_action)

    #### Get Q value for the next state and max action in the next state
    q_max_s_a = get_q_value(Q, next_state, next_max_action)

    state_q_idxs = get_state_q_indices(current_state)

    #### Updating current Q(S,A)
    Q[state_q_idxs[0],state_q_idxs[1],state_q_idxs[2],state_q_idxs[3],current_action] = (1-alpha)*q_current_s_a + alpha*(immediate_reward + q_max_s_a)

    current_state = next_state  ### S=S'

    alpha = 1.0 / (1 + Q[state_q_idxs[0], state_q_idxs[1], state_q_idxs[2], state_q_idxs[3], current_action])


    if step_result['is_done']:
        episode_count+=1

        for (i, j, k, l), _ in V_Q_dict.items():
            #print((i, j, k, l))
            k = int(k)
            l = int(l)
            V_Q_dict[(i, j, k, l)] = Q[i-4, j-4, k, l].min()

        # Compute sup norm
        sup_norm = max(abs(V_opt[key] - V_Q_dict[key]) for key in V_opt)

        #epsilon = max(epsilon_min, epsilon * epsilon_decay)

        diffs_sarsa.append(sup_norm)

        if sup_norm < theta:
            break

        if episode_count%10000 == 0:
            print('---------Episode - {} -----------'.format(episode_count))

# Calculate total runtime
end_time = time.time()
runtime = end_time - start_time

print(f"Total runtime: {runtime} seconds")

if diffs_sarsa and diffs_sarsa[-1] < 0:
    diffs_sarsa[-1] = 0

print(len(diffs_sarsa))

plt.figure(figsize=(10, 6))
plt.plot(diffs_sarsa)
#plt.plot(range(1, len(deltas)+1), deltas)
plt.xlabel('Iteration')
plt.ylabel('Delta')
plt.title('Convergence of SARSA')
plt.legend()
plt.grid()
plt.savefig("sarsa.jpg", format="jpg", dpi=300)
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(diffs_ql, label="Q-Learning", linewidth=2)
plt.plot(diffs_sarsa, label="SARSA", linewidth=2)

plt.xlabel("Number of Iterations")
plt.ylabel("Sup Norm Difference")
plt.title("Convergence of RL Methods")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('RL_conv.jpg', format="jpg", dpi=300)
plt.show()

###Simulation

In [None]:
def get_min_action(Q_sa, current_state):

    '''
    used to get the action with the max q-value given the current state and the Q table

    '''

    state_q_idxs = get_state_q_indices(current_state)
    action = Q_sa[state_q_idxs[0],state_q_idxs[1],state_q_idxs[2],state_q_idxs[3],:].argmin()

    return action

def q_learning():
  Q = np.zeros((28, 25, 2, 2, 2))
  #Q = np.full((28, 25, 2, 2, 2), -1.1)

  V_Q_dict = {}
  # Update V_Q_dict using correct ranges
  for player_sum in range(4, 31):   # Player's sum
    for dealer_sum in range(4, 29):  # Dealer's sum
      for usable_ace in [False, True]:  # Boolean usable ace
        for stick_happened in [False, True]:  # Boolean stick_happened
            state = (player_sum, dealer_sum, usable_ace, stick_happened)
            # Apply the filtering conditions
            if not (state[0] < 12 and state[2]) and not (state[0] >= 21 and not state[3]) and not (state[1] >= 21 and not state[3]):
                V_Q_dict[state] = 0
  #print(V_Q_dict)

  episode_count = 0
  total_episodes = 5000
  gamma = 1             #### the discount factor
  alpha = 0.1             #### learning rate
  theta=1e-15
  bj = BlackJackEnv()

  # Initialize variables for tracking runtime and errors
  diffs_ql = []


  while episode_count < total_episodes:


      bj.reset()  ### Initialize S (the environment's starting state)


      current_state = bj.get_current_state()
      current_action = get_min_action(Q, current_state)


      ### Take Action
      step_result = bj.step(current_action)

      next_state = bj.get_current_state()
      next_max_action = get_min_action(Q, next_state)
      immediate_reward = step_result['reward']

      next_state_q_idxs = get_state_q_indices(next_state)

      #### Getting Q value for the current state and action
      q_current_s_a = get_q_value(Q, current_state, current_action)

      #### Get Q value for the next state and max action in the next state
      q_max_s_a = get_q_value(Q, next_state, next_max_action)

      state_q_idxs = get_state_q_indices(current_state)

      #### Updating current Q(S,A)
      Q[state_q_idxs[0],state_q_idxs[1],state_q_idxs[2],state_q_idxs[3],current_action] = (1-alpha)*q_current_s_a + alpha*(immediate_reward + q_max_s_a)

      current_state = next_state  ### S=S'

      alpha = 1.0 / (1 + Q[state_q_idxs[0], state_q_idxs[1], state_q_idxs[2], state_q_idxs[3], current_action])

      if step_result['is_done']:
          episode_count+=1
          #print(current_state)

          for (i, j, k, l), _ in V_Q_dict.items():
              #print((i, j, k, l))
              k = int(k)
              l = int(l)
              V_Q_dict[(i, j, k, l)] = Q[i-4, j-4, k, l].min()

          # Compute sup norm
          sup_norm = max(abs(V_opt[key] - V_Q_dict[key]) for key in V_opt)

          diffs_ql.append(sup_norm)

          if sup_norm < theta:
              break
  if diffs_ql[-1] != 0:
    diffs_ql[-1] = 0
  return diffs_ql

In [None]:
def sarsa():
  Q = np.zeros((28, 25, 2, 2, 2))
  #Q = np.full((28, 25, 2, 2, 2), -1.1)
  #Q = np.random.uniform(low=-0.01, high=0.01, size=(28, 25, 2, 2, 2))

  V_Q_dict = {}
  # Update V_Q_dict using correct ranges
  for player_sum in range(4, 31):   # Player's sum
    for dealer_sum in range(4, 29):  # Dealer's sum
      for usable_ace in [False, True]:  # Boolean usable ace
        for stick_happened in [False, True]:  # Boolean stick_happened
            state = (player_sum, dealer_sum, usable_ace, stick_happened)
            # Apply the filtering conditions
            if not (state[0] < 12 and state[2]) and not (state[0] >= 21 and not state[3]) and not (state[1] >= 21 and not state[3]):
                V_Q_dict[state] = 0
  #print(V_Q_dict)

  episode_count = 0
  total_episodes = 5000
  gamma = 1             #### the discount factor
  alpha = 0.1             #### learning rate
  epsilon = 0.1           #### epsilon for epsilon-greedy policy
  theta=1e-15
  bj = BlackJackEnv()

  #epsilon_min = 0.01       # Lower bound for epsilon (to maintain some exploration)
  #epsilon_decay = 0.9999

  # Initialize variables for tracking runtime and errors
  start_time = time.time()
  diffs_sarsa = []

  while episode_count < total_episodes:
      bj.reset()  ### Initialize S (the environment's starting state)

      current_state = bj.get_current_state()
      current_action = get_action_epsilon_greedy(Q, current_state, epsilon)

      step_result = bj.step(current_action)

      next_state = bj.get_current_state()
      next_action = get_action_epsilon_greedy(Q, next_state, epsilon)
      immediate_reward = step_result['reward']


      q_current_s_a = get_q_value(Q, current_state, current_action)
      q_next_s_a = get_q_value(Q, next_state, next_action)

      td_target = immediate_reward + gamma * q_next_s_a
      td_error = td_target - q_current_s_a

      Q_state_idxs = get_state_q_indices(current_state)

      Q[Q_state_idxs[0], Q_state_idxs[1], Q_state_idxs[2], Q_state_idxs[3], current_action] = q_current_s_a + alpha * td_error
      #V_Q_dict[(Q_state_idxs[0]+4, Q_state_idxs[1]+4, Q_state_idxs[2], Q_state_idxs[3])] = Q[Q_state_idxs[0], Q_state_idxs[1], Q_state_idxs[2], Q_state_idxs[3], :].min()

      current_state = next_state  ### S=S'
      current_action = next_action  ### A=A'

      alpha = 1.0 / (1 + Q[state_q_idxs[0], state_q_idxs[1], state_q_idxs[2], state_q_idxs[3], current_action])


      if step_result['is_done']:
          episode_count+=1

          for (i, j, k, l), _ in V_Q_dict.items():
              #print((i, j, k, l))
              k = int(k)
              l = int(l)
              V_Q_dict[(i, j, k, l)] = Q[i-4, j-4, k, l].min()

          # Compute sup norm
          sup_norm = max(abs(V_opt[key] - V_Q_dict[key]) for key in V_opt)

          #epsilon = max(epsilon_min, epsilon * epsilon_decay)

          diffs_sarsa.append(sup_norm)

          if sup_norm < theta:
              break
  if diffs_sarsa[-1] != 0:
    diffs_sarsa[-1] = 0
  return diffs_sarsa

In [None]:
from itertools import zip_longest
q_l = []

for i in range(1000):
    diffs_ql = q_learning()
    q_l.append(diffs_ql)

"""for i in range(100000):
    diffs_ql = q_learning()
    if all(diffs_ql[i] <= diffs_ql[i-1] for i in range(1, len(diffs_ql))):
        q_l.append(diffs_ql)
    if len(q_l) == 1000:
        break"""
# Convert to a NumPy array for vectorized operations
sim_array_ql = list(zip_longest(*q_l, fillvalue=np.nan))

# Compute means and std deviations along axis=0 (column-wise)
averages_ql = np.nanmean(sim_array_ql, axis=1).tolist()
std_deviations_ql = np.nanstd(sim_array_ql, axis=1, ddof=1).tolist()

In [None]:
from itertools import zip_longest
q_s = []

for i in range(1000):
  #diffs_sa = sarsa()
  diffs_sa = q_learning()
  #print(len(diffs_sa))
  q_s.append(diffs_sa)

"""for i in range(100000):
  diffs_sa = sarsa()
  if all(diffs_sa[i] <= diffs_sa[i-1] for i in range(1, len(diffs_sa))): # and ((not np.isinf(diffs_sa).any()) and np.all(np.diff(diffs_sa) <= 0)):
          q_s.append(diffs_sa)
  if len(q_s) == 1000:
          break"""

# Convert to a NumPy array for vectorized operations
sim_array_sa = list(zip_longest(*q_s, fillvalue=np.nan))

# Compute means and std deviations along axis=0 (column-wise)
averages_sa = np.nanmean(sim_array_sa, axis=1).tolist()
std_deviations_sa = np.nanstd(sim_array_sa, axis=1, ddof=1).tolist()

In [None]:
# Convert to NumPy arrays
averages_ql = np.array(averages_ql)
std_deviations_ql = np.array(std_deviations_ql)
averages_sa = np.array(averages_sa)
std_deviations_sa = np.array(std_deviations_sa)

def find_last_valid_index(avg_array, std_array):
    for i in range(len(avg_array) - 1, -1, -1):
        if avg_array[i] != 0 and std_array[i] != 0:
            return i
    return None

# For Q-learning
last_valid_index_ql = find_last_valid_index(averages_ql, std_deviations_ql)
# For SARSA
last_valid_index_sa = find_last_valid_index(averages_sa, std_deviations_sa)
#last_valid_index_sa = find_last_valid_index(rep_averages_sa, rep_std_deviations_sa)

# Create x-axis values
steps_ql = np.arange(len(averages_ql))
steps_sa = np.arange(len(averages_sa))

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(steps_ql[:400], averages_ql[:400], label='Q-learning Mean', color='blue')
plt.fill_between(steps_ql[:400], averages_ql[:400] - std_deviations_ql[:400], averages_ql[:400] + std_deviations_ql[:400],
                 color='blue', alpha=0.3, label='Q-learning ±1 Std Dev')
plt.plot(steps_sa[:400], averages_sa[:400], label='SARSA Mean', color='green')
plt.fill_between(steps_sa[:400], averages_sa[:400] - std_deviations_sa[:400], averages_sa[:400] + std_deviations_sa[:400],
                 color='green', alpha=0.3, label='SARSA ±1 Std Dev') #[::-1]

plt.xlabel('Steps')
plt.ylabel('Sup Norm')
plt.title('Learning Curve of Q-learning and SARSA')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("learning_curve.png", dpi=300)
plt.show()

In [None]:
#np.save('averages_ql.npy', averages_ql)
#np.save('std_deviations_ql.npy', std_deviations_ql)
np.save('averages_sa.npy', averages_sa)
np.save('std_deviations_sa.npy', std_deviations_sa)
#np.save('sim_array_ql.npy', sim_array_ql)
np.save('sim_array_sa.npy', sim_array_sa)

In [None]:
#averages_ql = np.load('averages_ql.npy')
#std_deviations_ql = np.load('std_deviations_ql.npy')
averages_sa = np.load('averages_sa.npy')
std_deviations_sa = np.load('std_deviations_sa.npy')

In [None]:
# Step 1: Count only non-NaN values
lengths_sa = [np.sum(~np.isnan(arr)) for arr in sim_array_sa]
lengths_ql = [np.sum(~np.isnan(arr)) for arr in sim_array_ql]

# Step 3: Compute stats
mean_length = np.mean(lengths_sa)
std_length = np.std(lengths_ql)

print(f"Mean length: {mean_length}")
print(f"Standard deviation: {std_length}")

print(f"Mean length: {mean_length}")
print(f"Standard deviation: {std_length}")