In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import pandas as pd

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(torch.stack, zip(*batch))
        return state, action.squeeze(), reward, next_state, done.squeeze()

    def __len__(self):
        return len(self.buffer)

def compute_loss(model, target, gamma, optimizer, batch_size, replay_buffer):
    state, action, reward, next_state, done = replay_buffer.sample(batch_size)
    q_values = model(state).gather(1, action.unsqueeze(1)).squeeze(1)
    next_q_values = target(next_state).max(1)[0]
    expected_q_values = reward + gamma * next_q_values * (1 - done)
    loss = F.mse_loss(q_values, expected_q_values.detach())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def train(
        env_name='CartPole-v1',
        hidden_dim=64,
        gamma=0.99,
        eps_start=1.0,
        eps_end=0.01,
        eps_decay=0.995,
        lr=1e-3,
        capacity=100000,
        batch_size=64,
        target_update=10
        ):
    env = gym.make(env_name)
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    policy_net = DQN(obs_dim, action_dim, hidden_dim)
    target_net = DQN(obs_dim, action_dim, hidden_dim)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    replay_buffer = ReplayBuffer(capacity)

    eps = eps_start
    total_steps = 0
    for episode in range(500):
        obs = env.reset()[0]
        done = False
        episode_reward = 0
        while not done and episode_reward < 2_000:
            total_steps += 1
            if random.random() > eps:
                with torch.no_grad():
                    state = torch.FloatTensor(obs)
                    action = policy_net(state).max(0)[1].item()
            else:
                action = env.action_space.sample()
            next_obs, reward, done, info, _ = env.step(action)
            replay_buffer.push(torch.FloatTensor(obs),
                               torch.LongTensor([action]),
                               torch.FloatTensor([reward]),
                               torch.FloatTensor(next_obs),
                               torch.FloatTensor([done]))

            obs = next_obs
            episode_reward += reward

            if len(replay_buffer) > batch_size:
                eps = max(eps * eps_decay, eps_end)
                loss = compute_loss(policy_net, target_net, gamma, optimizer, batch_size, replay_buffer)

            if total_steps % target_update == 0:
                target_net.load_state_dict(policy_net.state_dict())

        print("Episode: {}, Episode reward: {:.2f}, Epsilon: {:.2f}".format(episode, episode_reward, eps))

    env.close()
    torch.save(policy_net, 'dqn_mc.pth')

def test(policy_net, env_name='Acrobot-v1', num_episodes=1000, log = False):
    env = gym.make(env_name)
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    episode_rewards = []

    for episode in range(num_episodes):
        obs = env.reset()[0]
        done = False
        episode_reward = 0
        while not done and episode_reward <= 4_000:
            with torch.no_grad():
                state = torch.FloatTensor(obs)
                action = policy_net(state).max(0)[1].item()
            next_obs, reward, done, info, _ = env.step(action)
            obs = next_obs
            episode_reward += reward

        episode_rewards.append(episode_reward)
        if log: print("Episode: {}, Episode reward: {:.2f}".format(episode, episode_reward))

    env.close()

    return episode_rewards

In [None]:
# CARTPOLE

def test_func(
        gamma=0.99,
        eps_end=0.01,
        eps_decay=0.995,
        lr=1e-3,
):
    agent = train(
        env_name='CartPole-v1',
        hidden_dim=64,
        gamma=gamma,
        eps_start=1.0,
        eps_end=eps_end,
        eps_decay=eps_decay,
        lr=lr,
        capacity=100000,
        batch_size=64,
        target_update=10
    )
    return test(agent, 'CartPole-v1')

In [None]:
test_func(gamma = 0.99, lr=0.001, eps_end=0.01, eps_decay=0.995)