In [33]:
import gym
import numpy as np
import matplotlib.pyplot as plt

# Train for Acrobot with speedy Q-learning(SQL)

In [34]:
# Create the Acrobot environment
env = gym.make('Blackjack-v1')

In [39]:
print(env.observation_space)

Tuple(Discrete(32), Discrete(11), Discrete(2))


In [48]:
class SpeedyQLearning:
    def __init__(self, action_space, alpha=0.1, gamma=0.99, epsilon=0.1, beta=0.5):
        self.action_space = action_space
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.beta = beta
        self.Q = {}  # Dictionary to store Q-values
        self.Q_prev = {}  # Dictionary to store previous Q-values
        
    def choose_action(self, state):
        state_key = str(state)  # Convert state tuple to string
        if np.random.uniform(0, 1) < self.epsilon:
            return np.random.choice(self.action_space.n)
        else:
            # Initialize Q-values for unseen states
            self.Q.setdefault(state_key, np.zeros(self.action_space.n))
            return np.argmax(self.Q[state_key])
    
    def learn(self, state, action, reward, next_state):
        state_key = str(state)  # Convert state tuple to string
        next_state_key = str(next_state)  # Convert next state tuple to string
        # Initialize Q-values for unseen states
        self.Q.setdefault(state_key, np.zeros(self.action_space.n))
        self.Q_prev.setdefault(state_key, np.zeros(self.action_space.n))
        self.Q.setdefault(next_state_key, np.zeros(self.action_space.n))
        predict = self.Q[state_key][action]
        target = reward + self.gamma * np.max(self.Q[next_state_key])
        self.Q[state_key][action] = predict + self.alpha * (target - predict) + self.beta * (predict - self.Q_prev[state_key][action])
        self.Q_prev[state_key][action] = predict

# Initialize the SQL agent
action_space = env.action_space
sql_agent = SpeedyQLearning(action_space=action_space)


In [55]:
# Training Loop
num_episodes = 1000
total_rewards = []

for episode in range(num_episodes):
    state = env.reset()[0]
    done = False
    total_reward = 0
    while not done:
        action = sql_agent.choose_action(state)
        next_state, reward, done, _, __ = env.step(action)
        sql_agent.learn(state, action, reward, next_state)
        state = next_state
        total_reward += reward
        
    total_rewards.append(total_reward)

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(total_rewards)
plt.ylabel('Total Reward')
plt.xlabel('Episode')
plt.title('Total Reward per Episode')
plt.show()

# Closing the environment
env.close()

ValueError: attempt to get argmax of an empty sequence