In [1]:
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from gym.wrappers import StepAPICompatibility
from collections import deque
import random
import pickle
import tqdm

## Environment Setup

In [2]:
acrobot_env_name = 'Acrobot-v1'
acrobot_env = gym.make(acrobot_env_name)
# acrobot_env = StepAPICompatibility(acrobot_env)
print("Action space:", acrobot_env.action_space)
print("State space:", acrobot_env.observation_space)

Action space: Discrete(3)
State space: Box([ -1.        -1.        -1.        -1.       -12.566371 -28.274334], [ 1.        1.        1.        1.       12.566371 28.274334], (6,), float32)


In [3]:
assault_env_name = 'ALE/Assault-ram-v5'
assault_env = gym.make(assault_env_name)
# assault_env = StepAPICompatibility(assault_env)
print("Action space:", assault_env.action_space)
print("State space:", assault_env.observation_space)

Action space: Discrete(7)
State space: Box(0, 255, (128,), uint8)


## Implementation

In [4]:
def softmax(x, temp):
    z = np.exp(x / temp - np.max(x / temp))
    return z / np.sum(z)


class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, device):
        super(QNetwork, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
        )

        self.mlp.apply(self.init_weights)

        self.device = device
        self.to(device)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.uniform_(m.weight, -0.001, 0.001)
            nn.init.uniform_(m.bias, -0.001, 0.001)

    def forward(self, x):
        return self.mlp(x)


class DeepValueLearning:
    def __init__(self, env, step_size, epsilon, algorithm, gamma=0.99):
        self.env = env
        self.step_size = step_size
        self.epsilon = epsilon
        self.gamma = gamma
        self.algorithm = algorithm
        self.n_actions = env.action_space.n
        self.state_dim = env.observation_space.shape[0]
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.Q = QNetwork(self.state_dim, self.n_actions, self.device)
        self.optimizer = optim.SGD(self.Q.parameters(), lr=self.step_size)
        self.loss_fn = nn.MSELoss()

    def select_action(self, s):
        if np.random.uniform() < self.epsilon:
            return np.random.choice(self.n_actions)
        else:
            state_input = torch.as_tensor(s).float().unsqueeze(0).to(self.device)
            return torch.argmax(self.Q(state_input))

    def update(
        self, state_batch, action_batch, reward_batch, next_state_batch, done_batch
    ):
        q_val_batch = self.Q(state_batch)
        q_val_batch = q_val_batch.gather(1, action_batch.unsqueeze(1))
        q_val_batch = q_val_batch.squeeze(1)

        with torch.no_grad():
            done_batch = 1.0 - done_batch
            next_q_val = self.Q(next_state_batch)
            greedy_next_q_val, _ = next_q_val.max(dim=1)
            if self.algorithm == "Q-Learning":
                target_batch = (
                    reward_batch + done_batch * self.gamma * greedy_next_q_val
                )
            else:
                random_next_q_val = next_q_val.mean(dim=1)
                exp_next_q_val = (
                    self.epsilon * random_next_q_val
                    + (1 - self.epsilon) * greedy_next_q_val
                )
                target_batch = reward_batch + done_batch * self.gamma * exp_next_q_val

        # Compute loss and update weights
        loss = self.loss_fn(q_val_batch, target_batch)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [5]:
class ReplayBuffer:
    def __init__(self, capacity, device):
        self.device = device
        self.buffer = deque(maxlen=capacity)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

In [6]:
def run_trial(epsilon, step_size, seed, env, algorithm, use_buffer, env_name):
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    agent = DeepValueLearning(env, step_size, epsilon, algorithm)
    if use_buffer:
        replay_buffer = ReplayBuffer(1_000_000, agent.device)
        replay_minibatch_size = 16
        
    max_steps_per_episode = 300
    
    torch.set_grad_enabled(True)
    
    episode_rewards = []
    for _ in tqdm.tqdm(range(1000)):
        state, _ = env.reset()
        if env_name == "ALE/Assault-ram-v5":
            state = state/255
        done = False
        total_reward = 0
        n_steps = 0
        while not done:
            action = agent.select_action(state)
            next_state, reward, done, truncated, _ = env.step(action)
             
            if env_name == "ALE/Assault-ram-v5":
                next_state = next_state/255
            done = done or truncated
            
            total_reward += reward
            
            if use_buffer: replay_buffer.push([torch.as_tensor(state).to(agent.device),
                                               torch.as_tensor(action).to(agent.device),
                                               torch.as_tensor(reward).to(agent.device),
                                               torch.as_tensor(next_state).to(agent.device),
                                               torch.as_tensor(done).float().to(agent.device)])
            
            if use_buffer:
                if n_steps % replay_minibatch_size and len(replay_buffer) > replay_minibatch_size:
                    transitions = replay_buffer.sample(replay_minibatch_size)
                    state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*transitions)
                    state_batch = torch.stack(state_batch)
                    action_batch = torch.stack(action_batch)
                    reward_batch = torch.stack(reward_batch)
                    next_state_batch = torch.stack(next_state_batch)
                    done_batch = torch.stack(done_batch)
                    agent.update(state_batch, action_batch, reward_batch, next_state_batch, done_batch)
            else:
                agent.update(torch.as_tensor(state).to(agent.device).unsqueeze(0),
                             torch.as_tensor(action).to(agent.device).unsqueeze(0),
                             torch.as_tensor(reward).to(agent.device).unsqueeze(0),
                             torch.as_tensor(next_state).to(agent.device).unsqueeze(0),
                             torch.as_tensor(done).float().to(agent.device).unsqueeze(0))
                
            state = next_state
            n_steps += 1
            if n_steps >= max_steps_per_episode:
                done = True
        
        episode_rewards.append(total_reward)
        
    return episode_rewards



## Experiment

In [7]:
epsilons = [0.01, 0.1, 0.5]
step_sizes = [1/4, 1/8, 1/16]
seeds = range(10)
envs = [acrobot_env, assault_env]
algorithms = ["Expected-SARSA", "Q-Learning"]

total_trials = len(envs) * len(seeds) * len(epsilons) * len(step_sizes) * len(algorithms) * 2


# load pickle file for results
# check if file exists
try:
    with open('results.pkl', 'rb') as f:
        results = pickle.load(f)
        trials_completed = len(results)
except:
    results = {}
    trials_completed = 0

for env in envs:
    env_name = env.env.spec.id
    for epsilon in epsilons:
        for step_size in step_sizes:
            for algorithm in algorithms:
                for use_buffer in [True, False]:
                    for seed in seeds:
                        print("Starting trial #", trials_completed + 1, "/", total_trials)
                        
                        episode_rewards = run_trial(epsilon, step_size, seed, env, algorithm, use_buffer, env_name)
                        results[(env_name, seed, epsilon, step_size, algorithm, use_buffer)] = episode_rewards
                        with open('results.pkl', 'wb') as f:
                            pickle.dump(results, f)
                        trials_completed += 1
                        print(f"Completed {trials_completed}/{total_trials} trials")

Starting trial # 6 / 720


100%|██████████| 1000/1000 [10:19<00:00,  1.61it/s]


Completed 6/720 trials
Starting trial # 7 / 720


100%|██████████| 1000/1000 [10:09<00:00,  1.64it/s]


Completed 7/720 trials
Starting trial # 8 / 720


100%|██████████| 1000/1000 [10:47<00:00,  1.54it/s]


Completed 8/720 trials
Starting trial # 9 / 720


100%|██████████| 1000/1000 [09:55<00:00,  1.68it/s]


Completed 9/720 trials
Starting trial # 10 / 720


100%|██████████| 1000/1000 [09:52<00:00,  1.69it/s]


Completed 10/720 trials
Starting trial # 11 / 720


100%|██████████| 1000/1000 [09:24<00:00,  1.77it/s]


Completed 11/720 trials
Starting trial # 12 / 720


100%|██████████| 1000/1000 [09:23<00:00,  1.77it/s]


Completed 12/720 trials
Starting trial # 13 / 720


100%|██████████| 1000/1000 [09:16<00:00,  1.80it/s]


Completed 13/720 trials
Starting trial # 14 / 720


100%|██████████| 1000/1000 [08:39<00:00,  1.93it/s]


Completed 14/720 trials
Starting trial # 15 / 720


100%|██████████| 1000/1000 [10:10<00:00,  1.64it/s]


Completed 15/720 trials
Starting trial # 16 / 720


100%|██████████| 1000/1000 [08:22<00:00,  1.99it/s]


Completed 16/720 trials
Starting trial # 17 / 720


100%|██████████| 1000/1000 [15:40:35<00:00, 56.44s/it]       


Completed 17/720 trials
Starting trial # 18 / 720


100%|██████████| 1000/1000 [15:32<00:00,  1.07it/s]


Completed 18/720 trials
Starting trial # 19 / 720


100%|██████████| 1000/1000 [16:21<00:00,  1.02it/s]


Completed 19/720 trials
Starting trial # 20 / 720


100%|██████████| 1000/1000 [12:54<00:00,  1.29it/s]


Completed 20/720 trials
Starting trial # 21 / 720


100%|██████████| 1000/1000 [08:35<00:00,  1.94it/s]


Completed 21/720 trials
Starting trial # 22 / 720


100%|██████████| 1000/1000 [09:07<00:00,  1.83it/s]


Completed 22/720 trials
Starting trial # 23 / 720


100%|██████████| 1000/1000 [09:05<00:00,  1.83it/s]


Completed 23/720 trials
Starting trial # 24 / 720


100%|██████████| 1000/1000 [08:36<00:00,  1.94it/s]


Completed 24/720 trials
Starting trial # 25 / 720


100%|██████████| 1000/1000 [09:33<00:00,  1.74it/s]


Completed 25/720 trials
Starting trial # 26 / 720


100%|██████████| 1000/1000 [09:24<00:00,  1.77it/s]


Completed 26/720 trials
Starting trial # 27 / 720


100%|██████████| 1000/1000 [09:18<00:00,  1.79it/s]


Completed 27/720 trials
Starting trial # 28 / 720


100%|██████████| 1000/1000 [09:49<00:00,  1.70it/s]


Completed 28/720 trials
Starting trial # 29 / 720


100%|██████████| 1000/1000 [09:11<00:00,  1.81it/s]


Completed 29/720 trials
Starting trial # 30 / 720


100%|██████████| 1000/1000 [09:04<00:00,  1.84it/s]


Completed 30/720 trials
Starting trial # 31 / 720


100%|██████████| 1000/1000 [08:52<00:00,  1.88it/s]


Completed 31/720 trials
Starting trial # 32 / 720


100%|██████████| 1000/1000 [09:50<00:00,  1.69it/s]


Completed 32/720 trials
Starting trial # 33 / 720


100%|██████████| 1000/1000 [09:54<00:00,  1.68it/s]


Completed 33/720 trials
Starting trial # 34 / 720


100%|██████████| 1000/1000 [09:04<00:00,  1.84it/s]


Completed 34/720 trials
Starting trial # 35 / 720


100%|██████████| 1000/1000 [09:26<00:00,  1.76it/s]


Completed 35/720 trials
Starting trial # 36 / 720


100%|██████████| 1000/1000 [08:53<00:00,  1.87it/s]


Completed 36/720 trials
Starting trial # 37 / 720


 17%|█▋        | 168/1000 [01:46<08:46,  1.58it/s]


KeyboardInterrupt: 