In [9]:
import gym
import collections
import numpy as np
def discretize(observation, bins):
    discrete_obs = [np.digitize(obs, bins[i]) for i, obs in enumerate(observation)]
    return tuple(discrete_obs)

def choose_action(Q, state, epsilon, env):
    if np.random.rand() < epsilon:
        return env.action_space.sample()
    else:
        return np.argmax(Q[state])


def update_Q_SARSA(Q, alpha, state, action, reward, gamma, next_state, next_action):
    state_q = Q[state+(action,)]
    next_state_q = Q[next_state+(next_action,)]
    Q[state+(action,)] = state_q + alpha * (reward + gamma * next_state_q-state_q)


def train(env, episodes=50000, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.05, alpha=0.1):
    num_bins = [6, 12, 12, 24]  # Define the number of bins for each observation
    bins = [
        np.linspace(-2.4, 2.4, num_bins[0]),
        np.linspace(-5, 5, num_bins[1]),
        np.linspace(-0.418, 0.418, num_bins[2]),
        np.linspace(-5, 5, num_bins[3])
    ]

    Q = np.random.uniform(low=0, high=1, size=(num_bins + [env.action_space.n]))
    rewards = []

    for episode in range(episodes):
        observation = env.reset()[0]
        state = discretize(tuple(observation), bins)
        total_reward = 0
        done = False
        action = choose_action(Q, state, epsilon, env)
        while not done:
            obs, reward, done, i, j = env.step(action)
            next_state = discretize(obs, bins)
            next_action = choose_action(Q, next_state, epsilon,env)
            update_Q_SARSA(Q, alpha, state, action, reward, gamma, next_state,next_action)
            state = next_state
            action = next_action
            total_reward += reward

        rewards.append(total_reward)

        epsilon = max(epsilon_min, epsilon * epsilon_decay)


        if episode >= 100 and np.mean(rewards[-100:]) >= 195.0:
            print(f"Solved in {episode + 1} episodes!")
            break

        if (episode + 1) % 100 == 0:
            print(f"Episode {episode + 1}, Average Reward (last 100 episodes): {np.mean(rewards[-100:]):.2f}")

env = gym.make('CartPole-v1')

train(env)

Episode 100, Average Reward (last 100 episodes): 27.47
Episode 200, Average Reward (last 100 episodes): 33.43
Episode 300, Average Reward (last 100 episodes): 35.00
Episode 400, Average Reward (last 100 episodes): 53.45
Episode 500, Average Reward (last 100 episodes): 91.76
Episode 600, Average Reward (last 100 episodes): 124.24
Episode 700, Average Reward (last 100 episodes): 115.31
Episode 800, Average Reward (last 100 episodes): 167.51
Solved in 828 episodes!
