# Simple DRQN implementaion in TF2
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jcformanek/Simple-TF2-DRQN/blob/main/tf2_drqn.ipynb)

In [1]:
!pip install dm-sonnet
!pip install trfl

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dm-sonnet
  Downloading dm_sonnet-2.0.0-py3-none-any.whl (254 kB)
[K     |████████████████████████████████| 254 kB 17.0 MB/s 
Installing collected packages: dm-sonnet
Successfully installed dm-sonnet-2.0.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting trfl
  Downloading trfl-1.2.0-py3-none-any.whl (104 kB)
[K     |████████████████████████████████| 104 kB 15.7 MB/s 
Installing collected packages: trfl
Successfully installed trfl-1.2.0


In [2]:
import copy
import random
import gym
import tensorflow as tf 
import sonnet as snt
import numpy as np
import trfl

In [3]:
class MaskedVelocityCartPole:

  def __init__(self):
    """A wrapper for the CartPole environment which masks out the velocity 
    components of the observations. This makes it neccessary for agents to have
    memory in order to solve the task."""

    self.env = gym.make("CartPole-v1")

  def reset(self):
    return self.env.reset() * np.array([1,0,1,0], "float32")

  def step(self, action):
    obs, rew, done, info = self.env.step(action)
    return obs * np.array([1,0,1,0], "float32"), rew, done, info

In [4]:
class SequenceReplayBuffer:

  def __init__(self, obs_size, max_size=2000, batch_size=32, max_sequence_len=20):
    """Replay buffer that stores sequences rather than just transitions."""
    
    self.obs = np.zeros((max_size, max_sequence_len, obs_size), "float32")
    self.rewards = np.zeros((max_size, max_sequence_len), "float32")
    self.act = np.zeros((max_size, max_sequence_len), "int64")
    self.dones = np.zeros((max_size, max_sequence_len), "float32")
    self.zero_mask = np.zeros((max_size, max_sequence_len), "float32")

    # Store prev obs
    self.prev_obs = None

    # Counters
    self.t = 0
    self.counter = 0

    # Sizes
    self.max_size = max_size
    self.batch_size = batch_size
    self.max_sequence_len = max_sequence_len

  def push_first(self, first_obs):
    self.prev_obs = first_obs

  def push(self, next_obs, action, reward, done):
    idx = self.counter % self.max_size
    self.obs[idx, self.t] = self.prev_obs
    self.act[idx, self.t] = action
    self.rewards[idx, self.t] = reward
    self.dones[idx, self.t] = done
    self.zero_mask[idx, self.t] = 1

    self.t += 1
    self.prev_obs = next_obs

    # End of the episode
    if done and self.t < self.max_sequence_len:
        self.zero_mask[idx, self.t:] *= 0.0
        self.t = self.max_sequence_len

    # Move to next sequence
    if self.t >= self.max_sequence_len:
        self.counter += 1
        self.zero_mask[self.counter % self.max_size, :] *= 0.0
        self.t = 0

  def is_ready(self):
    return self.counter >= self.batch_size

  def sample(self):
    max_idx = min(self.counter, self.max_size)
    idxs = np.random.randint(0, max_idx, self.batch_size)

    obs_batch = tf.convert_to_tensor(self.obs[idxs])
    act_batch = tf.convert_to_tensor(self.act[idxs])
    rew_batch = tf.convert_to_tensor(self.rewards[idxs])
    done_batch = tf.convert_to_tensor(self.dones[idxs])
    zero_mask_batch = tf.convert_to_tensor(self.zero_mask[idxs])

    return obs_batch, act_batch, rew_batch, done_batch, zero_mask_batch

In [5]:
class Agent:

  def __init__(self, obs_size, num_actions, lr=5e-4, gamma=0.99, target_update=200,
    eps_min=0.05, eps_decay_steps=20_000):
    
    # Parameters
    self.obs_size = obs_size
    self.num_actions = num_actions
    self.gamma = gamma
    self.target_update = target_update
    self.eps_decay_steps = eps_decay_steps
    self.eps_min = eps_min

    # Optimiser
    self.optimizer = snt.optimizers.Adam(lr)

    # Networks
    self.q_net = snt.DeepRNN([
      snt.Linear(20),
      tf.nn.relu,
      snt.LSTM(20),
      snt.Linear(num_actions)
    ])
    self.target_q_net = copy.deepcopy(self.q_net)

    # Initialise network variables
    dummy_obs = tf.zeros((1,obs_size), "float32")
    self.q_net(dummy_obs, self.q_net.initial_state(1))
    self.target_q_net(dummy_obs, self.target_q_net.initial_state(1))

    # Counters
    self.learn_step = tf.Variable(0, dtype="float32")
    self.act_step = tf.Variable(0, dtype="float32")

  def reset(self):
    self.hidden_state = self.q_net.initial_state(1)

  @tf.function
  def q_learning(self, obs, act, rew, done, mask):
    B, T = act.shape[:2] # get dims

    obs = tf.transpose(obs, perm=[1,0,2]) # make time major for sonnet

    # Unroll target network
    initial_hidden_state = self.target_q_net.initial_state(B)
    target_q_values, _ = snt.static_unroll(self.target_q_net, obs, initial_hidden_state)
    target_q_values = tf.transpose(target_q_values, perm=[1,0,2]) # make batch major again

    with tf.GradientTape() as tape:
      # Unroll online network
      initial_hidden_state = self.q_net.initial_state(B)
      q_values, _ = snt.static_unroll(self.q_net, obs, initial_hidden_state)
      q_values = tf.transpose(q_values, perm=[1,0,2]) # make batch major again

      # Extract q_value of chose action
      act_q_value = trfl.batched_index(q_values, act)

      # Double Q-learning
      target_act = tf.argmax(q_values, axis=-1)
      target_q_value = trfl.batched_index(target_q_values, target_act)

      # Extract the timesteps we want
      act_q_value = act_q_value[:,:-1] # chop off last timestep
      target_q_value = target_q_value[:,1:] # chop of first timestep
      act = act[:,:-1]
      rew = rew[:,:-1]
      done = done[:,:-1]
      mask = mask[:,:-1]

      # Bellman target
      target = tf.stop_gradient(rew + self.gamma * (1-done) * target_q_value)

      # Squared TD Error
      squared_td_error = (target - act_q_value) ** 2

      # Masked mean loss
      loss = tf.reduce_sum(squared_td_error * mask) / tf.reduce_sum(mask)

    # Comput and apply gradients
    variables = self.q_net.trainable_variables
    gradients = tape.gradient(loss, variables)
    self.optimizer.apply(gradients, variables)

    # Update target network
    if self.learn_step % self.target_update == 0:
      online_variables = self.q_net.variables
      target_variables = self.target_q_net.variables
      for src, dest in zip(online_variables, target_variables):
        dest.assign(src)

    # Increment counter
    self.learn_step.assign_add(1)

    return loss

  @tf.function
  def greedy_action_selection(self, obs, hidden_state):
    obs = tf.expand_dims(obs, axis=0) # dummy batch dim

    q_values, hidden_state = self.q_net(obs, hidden_state)

    # Greedy action
    action = tf.argmax(q_values, axis=-1)

    return action, hidden_state

  def epsilon_greedy_action_selection(self, obs, evaluation=False):

    # Greedy action
    tf_obs = tf.convert_to_tensor(obs, "float32")
    action, self.hidden_state = self.greedy_action_selection(tf_obs, self.hidden_state)
    action = action.numpy()[0]

    # Get Epsilon
    epsilon = 1.0 - self.act_step / self.eps_decay_steps
    epsilon = self.eps_min if epsilon < self.eps_min else epsilon

    # Random action
    if np.random.random() < epsilon and not evaluation:
      action = np.random.randint(0, self.num_actions, size=1)[0]

    # Increment counter
    self.act_step.assign_add(1)

    return action

In [None]:
num_episodes = 10_000
batch_size = 64

env = MaskedVelocityCartPole()

agent = Agent(4, 2)

mem = SequenceReplayBuffer(4, max_size=10_000, batch_size=batch_size)

ep_returns = []
for e in range(num_episodes):

  obs = env.reset()
  agent.reset()
  mem.push_first(obs)
  
  ep_return = 0
  done = False
  while not done:
    # Step actor
    action = agent.epsilon_greedy_action_selection(obs)

    # Step environment
    obs, rew, done, _ = env.step(action)

    # Push transition to memory
    mem.push(obs, action, rew, done)

    # Add reward to return
    ep_return += rew

    # Do some learning
    if mem.is_ready():
      # Sample memory
      train_obs, train_act, train_rew, train_done, train_mask = mem.sample()

      agent.q_learning(train_obs, train_act, train_rew, train_done, train_mask)
      
  # Add ep_return to list
  ep_returns.append(ep_return)

  # Logging
  if e % 100 == 0:
    epsilon = max(agent.eps_min, 1.0 - agent.act_step / agent.eps_decay_steps)
    print("Episode", e, "     Avg. Episode return:", np.mean(ep_returns[-50:]), "    Epsilon", round(float(epsilon), 3), "   Train steps:", int(agent.learn_step), "   Timesteps", int(agent.act_step))

print("Done")

Episode 0      Avg. Episode return: 21.0     Epsilon 0.999    Train steps: 0    Timesteps 21
Episode 100      Avg. Episode return: 18.66     Epsilon 0.897    Train steps: 1118    Timesteps 2060
Episode 200      Avg. Episode return: 21.16     Epsilon 0.798    Train steps: 3096    Timesteps 4038
Episode 300      Avg. Episode return: 23.12     Epsilon 0.678    Train steps: 5497    Timesteps 6439
Episode 400      Avg. Episode return: 29.8     Epsilon 0.545    Train steps: 8162    Timesteps 9104
Episode 500      Avg. Episode return: 26.38     Epsilon 0.404    Train steps: 10975    Timesteps 11917
Episode 600      Avg. Episode return: 103.04     Epsilon 0.058    Train steps: 17894    Timesteps 18836
Episode 700      Avg. Episode return: 369.14     Epsilon 0.05    Train steps: 43676    Timesteps 44618
Episode 800      Avg. Episode return: 309.72     Epsilon 0.05    Train steps: 78278    Timesteps 79220
