In [6]:
import gymnasium as gym
from tqdm import tqdm
import numpy as np

In [81]:
class BlackJackAgent:

    def __init__(self,env,num_iter):
        self.V = {}
        self.Q = {}
        self.env = env
        self.state_space = [(i,j,k) for i in range(32) for j in range(11) for k in range(2)]
        self.pi = {(state): self.initial_policy(state) for state in self.state_space}

        self.train(num_iter)

    def initial_policy(self, state):
        
        if state[0] == 20 or state[0] == 21:
            return 0
        else:
            return 1

    def generate_episode(self):
        episode = []
        state,_ = self.env.reset()
        i = 0
        while True:
            # Exploring starts
            if i == 0:
                action = self.env.action_space.sample()
            else:
                action = self.pi[state]
            next_state, reward, terminated, _, _ = self.env.step(action)
            episode.append((state, action, reward))
            state = next_state
            if terminated:
                break
            i = i + 1
        return episode
    

    def mc_prediction_state(self, num_episodes):
        returns = {}
        for _ in tqdm(range(num_episodes)):
            episode = self.generate_episode()
            G = 0
            for i, (state, _, reward) in enumerate(episode[::-1]):
                G = reward + G
                if state not in [x[0] for x in episode[::-1][len(episode)-i:]]:
                    if state not in self.V:
                        returns[state] = [G]
                    else:
                        returns[state].append(G)
                    self.V[state] = sum(returns[state])/len(returns[state])
    
    def train(self,num_episodes):
        returns = {}
        for _ in tqdm(range(num_episodes)):
            episode = self.generate_episode()
            G = 0
            for i, (state, action, reward) in enumerate(episode[::-1]):
                G =  G + reward
                if (state, action) not in [(x[0], x[1]) for x in episode[::-1][len(episode)-i:]]:
                    if (state, action) not in returns:
                        returns[(state, action)] = [G]
                    else:
                        returns[(state, action)].append(G)
                    self.Q[(state, action)] = sum(returns[(state, action)])/len(returns[(state, action)])
                    q_values = [self.Q.get((state, a), 0) for a in range(2)]
                    self.pi[state] = np.argmax(q_values)
    
    def evaluate(self,num_steps):
        rewards = []
        for _ in range(num_steps):
            state,_ = self.env.reset()
            terminated = False
            reward = 0
            while not terminated:
                action = self.pi[state]
                next_state, r, terminated, _, _ = self.env.step(action)
                reward += r
                state = next_state
            rewards.append(reward)
        return np.sum(rewards)

        

In [82]:
class RandomAgent:
    
        def __init__(self,env):
            self.env = env
    
        def evaluate(self,num_steps):
            rewards = []
            for _ in range(num_steps):
                state,_ = self.env.reset()
                terminated = False
                reward = 0
                while not terminated:
                    action = self.env.action_space.sample()
                    next_state, r, terminated, _, _ = self.env.step(action)
                    reward += r
                    state = next_state
                rewards.append(reward)
            return np.sum(rewards)

In [83]:
env = gym.make('Blackjack-v1')
mc_agent = BlackJackAgent(env,1000000)
random_agent = RandomAgent(env)

print("Random Agent: ", random_agent.evaluate(500))
print("MC Agent: ", mc_agent.evaluate(500))


  0%|          | 0/1000000 [00:00<?, ?it/s]

100%|██████████| 1000000/1000000 [05:33<00:00, 2998.11it/s]


Random Agent:  -180.0
MC Agent:  -44.0


In [91]:
print("Random Agent: ", random_agent.evaluate(500))
print("MC Agent: ", mc_agent.evaluate(500))

Random Agent:  -219.0
MC Agent:  -53.0
