In [21]:
import torch
import torch.nn as nn
import numpy as np
import random
from collections import namedtuple, deque
import gym

In [20]:
# For visualization
from gym.wrappers.monitoring import video_recorder
from IPython.display import HTML
from IPython import display 
import glob

In [22]:
env = gym.make('LunarLander-v2')
env.seed(0)
print('State shape: ', env.observation_space.shape)
print('Number of actions: ', env.action_space.n)

DependencyNotInstalled: box2D is not installed, run `pip install gym[box2d]`

In [14]:
BUFFER_SIZE   = int(1e5)
BATCH_SIZE    = 64
GAMMA         = 0.99 # discount factor
TAU           = 1e-3 # soft update of target parameter
LEARNING_RATE = 5e-4
UPDATE_EVERY  = 4    # how often to update the target

In [15]:
class QNetwork(nn.Module):
    """ Agent Policy Network Model """
    def __init__(self, state_size, action_size):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)
        
    def forward(self, state):
        """ state -> action values """
        x = self.fc1(state)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        x = nn.functional.relu(x)
        x = self.fc3(x)
        return x


In [16]:
class Agent:
    def __init__(self, state_size, action_size, learning_rate):
        self.state_size = state_size
        self.action_size = action_size

        self.qnetwork_local = QNetwork(state_size, action_size)
        self.qnetwork_target = QNetwork(state_size, action_size)
        self.optimizer = torch.optim.Adam(
            self.qnetwork_local.parameters(), lr=learning_rate)
        # replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE)

        self.time_step = 0

    def step(self, state, action, reward, next_state, done):
        self.memory.add(state, action, reward, next_state, done)
        self.time_step(self.time_step+1) % UPDATE_EVERY
        if self.time_step == 0:
            if len(self.memory) > BATCH_SIZE:
                experience = self.memory.sample()
                self.learn(experience, GAMMA)

    def act(self, state, eps=0.0):
        state = torch.from_numpy(state).float().unsqueeze(0)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experience, gamma):
        """ Update parameters using batch of experience tuples """
        states, actions, rewards, next_states, dones = experience
        q_targets_next = self.qnetwork_target(
            next_states).detach().max(1)[0].unsqueeze(1)
        q_targets = rewards+gamma*q_targets_next*(1-dones)
        q_expected = self.qnetwork_local(states).gather(1, actions)
        # Compute the loss and gradient
        loss = torch.nn.functional.mse_loss(q_expected, q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)

    def soft_update(self, local_model, target_model, tau):
        """ θ_target = τ*θ_local + (1 - τ)*θ_target 
        copy the weights of the local model to the target model
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(
                tau*local_model.data+(1.0-tau)*target_param.data)

In [17]:
class ReplayBuffer:
    """ Fixed size buffer to store experience tuples """

    def __init__(self, action_size, buffer_size, batch_size):
        self.action_size = action_size
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.experience = namedtuple('experience', field_names=[
                                     'state', 'action', 'reward', 'next_state', 'done'])
        self.memory: deque[self.experience] = deque(maxlen=buffer_size)

    def add(self, state, action, reward, next_state, done):
        e = self.experience(state, action, reward, next_state, done)
        self.memory.append(e)

    def sample(self):
        """  """
        experiences = random.sample(self.memory, k=self.batch_size)
        states = torch.from_numpy(
            np.vstack([e.state for e in experiences if e is not None])).float()
        actions = torch.from_numpy(
            np.vstack([e.action for e in experiences if e is not None])).float()
        rewards = torch.from_numpy(
            np.vstack([e.reward for e in experiences if e is not None])).float()
        next_states = torch.from_numpy(
            np.vstack([e.next_state for e in experiences if e is not None])).float()
        dones = torch.from_numpy(np.vstack(
            [e.done for e in experiences if e is not None])).astype(np.uint8).float()
        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        return len(self.memory)

In [19]:
def dqn(agent: Agent, n_episodes, max_time_step, eps_start, eps_end, eps_decay):
    scores = []
    scores_window = deque(maxlen=100)
    eps = eps_start
    for episode in range(n_episodes):
        state = env.reset()
        score = 0
        for time_step in range(max_time_step):
            action = agent.act(state, eps)
            next_state, reward, done, _ = env.step(action)
            agent.step(state, action, reward, next_state, done)
            state = next_state
            score += reward
            if done:
                break
        scores_window.append(score)
        scores.append(score)
        eps = max(eps_end, eps-eps_decay)
        if episode % 100 == 0:
            print(episode, np.mean(scores_window))
            torch.save(agent.qnetwork_local.state_dict(), 'checkpoint.pt')
    return scores

In [None]:
class 

In [None]:
agent = Agent(state_size=8, action_size=4, learning_rate=LEARNING_RATE)
scores = dqn(agent, n_episodes=2000, max_time_step=1000, eps_start=1.0, eps_end=0.01, eps_decay=0.01)
