a very simple example to show how to implement deep q networks (dqn) using pytorch, it only require gym and pytorch installed to run this notebook, no external files or other libraries is needed, everything needed to work is contained within this notebook. I believe codes written in this way is the most readable

-freddy chua

In [1]:
import gym
from gym import wrappers
import torch
import torch.nn as nn
import torch.nn.init
import torch.nn.functional as F
from collections import namedtuple
from torch.autograd import Variable
import random

In [2]:
env = gym.make('CartPole-v0')
env = wrappers.Monitor(env, 'cartpole', force=True)

[2017-07-08 10:20:48,262] Making new env: CartPole-v0
[2017-07-08 10:20:48,287] Clearing 22 monitor files from previous run (because force=True was provided)


In [3]:
# try implementing dqn

# the action reward value function can be represented by a mlp
class Mlp(nn.Module):
  def __init__(self, input_size, output_size):
    super(Mlp, self).__init__() # this statement is always needed
    
    self.fc1 = nn.Linear(input_size, 10) # matrix multiplication
    self.fc2 = nn.Linear(10, output_size) # matrix multiplication
    
    # == parameters initialization ==
    nn.init.xavier_normal(self.fc1.weight)
    nn.init.xavier_normal(self.fc2.weight)
    
    nn.init.normal(self.fc1.bias)
    nn.init.normal(self.fc2.bias)
    # =============================== 
    
  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x
  
  # no backward function needed, awesome!
# end class

In [4]:
# the memory


Event = namedtuple('Event', ['state', 'action', 'next_state', 'reward'])

class Memory(object):
  def __init__(self, capacity):
    self.capacity = capacity
    self.idx = 0
    self.mem = []

  def add_event(self, event):
    if len(self.mem) < self.capacity:
      self.mem.append(event)
    else:
      self.mem[self.idx] = event
    self.idx = (self.idx + 1) % self.capacity
  
  def sample(self, batch_size):
    return random.sample(self.mem, batch_size)

# end class

In [5]:
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
print('input_size = {0}, output_size = {1}'.format(input_size, output_size))

input_size = 4, output_size = 2


In [6]:
# create 2 Q-network

eval_Q   = Mlp(input_size, output_size)
target_Q = Mlp(input_size, output_size)
target_Q.load_state_dict(eval_Q.state_dict()) # set them to be similar

In [7]:
epsilon = 1.0 # the exploration decision parameter, will decay over time
batch_size = 100 # for batch processing, larger batch size -> faster computation
gamma = 0.99 # the parameter for discounting future rewards
decay = 0.999
C = 5 # the time delay in updating target_Q

In [8]:
optimizer = torch.optim.RMSprop(eval_Q.parameters()) # RMSprop for learning eval_Q parameters
criterion = nn.MSELoss() # mean squared error, similar to least squared error

In [9]:
replay_memory = Memory(10000) # create a replay memory of capacity 10
top_score = 0
c = 0
for i in range(100):
#   print('episode: {0}'.format(i+1))
  current_state = env.reset() # an array of 4 values
  done = False
  episode_reward = 0
  while not done:
    if random.random() < epsilon:
      # perform random action to explore the search space
      action = env.action_space.sample()
    else:
      # choose action with highest value
      state_tensor = torch.Tensor(current_state) # convert current_state into a torch tensor
      state_tensor = state_tensor.unsqueeze_(0) # unsqueeze to allow for batch processing
      # convert to a autograd Variable for automatic backpropagation
      state_tensor = Variable(state_tensor, volatile=True) # volatile is True for inference only
      action_values = eval_Q(state_tensor) # forward
      
      _, action = torch.max(action_values, 1)
      action = action.data[0,0]
    # end if
    next_state, reward, done, _ = env.step(action)
    episode_reward += reward
    if done:
      replay_memory.add_event(Event(current_state.copy(), action, None, reward))
    else:
      replay_memory.add_event(Event(current_state.copy(), action, next_state.copy(), reward))
    # end if
    current_state = next_state
    
    # train
    if len(replay_memory.mem) >= batch_size:
      # sample from replay memory
      mini_batch = replay_memory.sample(batch_size)
      mini_batch = Event(*zip(*mini_batch)) # do this for batch processing
      
      # calculate the estimated value
      estimated_value = eval_Q(Variable(torch.Tensor(mini_batch.state)))
      # select the value associated with the action taken
      estimated_value = estimated_value.gather(1, Variable(torch.LongTensor(mini_batch.action).unsqueeze_(1)))
      
      # calculate the actual value
      mask = torch.ByteTensor(tuple(map(lambda s: s is not None, mini_batch.next_state)))
      target_val = target_Q(Variable(torch.Tensor([
        next_state for next_state in mini_batch.next_state if next_state is not None])))
      target_val, _ = torch.max(target_val, 1)
      
      targetted_value = Variable(torch.zeros(batch_size, 1))
      targetted_value[mask] = gamma * target_val
      targetted_value += Variable(torch.Tensor(mini_batch.reward).unsqueeze_(1))
      
      # compute the loss between estimated value and actual value
      optimizer.zero_grad()
      loss = criterion(estimated_value, targetted_value.detach())      
      loss.backward()
      optimizer.step() # do a gradient descent on it
      
      c += 1
      if c == C:
        c = 0
        target_Q.load_state_dict(eval_Q.state_dict())
        epsilon = epsilon * decay
      # end if
    # end if
    
  # end while
  print('episode {0} reward = {1}, epsilon = {2:3g}'.format(i, episode_reward, epsilon))
  top_score = max(top_score, episode_reward)
# end for
print('top_score = {0}'.format(top_score))

[2017-07-08 10:20:55,919] Starting new video recorder writing to /Users/fchua/Documents/torch_projects/pytorch_tutorials/pytorch-deep-rl/cartpole/openaigym.video.0.83386.video000000.mp4
[2017-07-08 10:20:56,808] Starting new video recorder writing to /Users/fchua/Documents/torch_projects/pytorch_tutorials/pytorch-deep-rl/cartpole/openaigym.video.0.83386.video000001.mp4


episode 0 reward = 12.0, epsilon =   1


[2017-07-08 10:20:57,148] Starting new video recorder writing to /Users/fchua/Documents/torch_projects/pytorch_tutorials/pytorch-deep-rl/cartpole/openaigym.video.0.83386.video000008.mp4


episode 1 reward = 12.0, epsilon =   1
episode 2 reward = 17.0, epsilon =   1
episode 3 reward = 13.0, epsilon =   1
episode 4 reward = 22.0, epsilon =   1
episode 5 reward = 21.0, epsilon =   1
episode 6 reward = 22.0, epsilon = 0.996006
episode 7 reward = 14.0, epsilon = 0.994015
episode 8 reward = 14.0, epsilon = 0.991036
episode 9 reward = 23.0, epsilon = 0.986091
episode 10 reward = 15.0, epsilon = 0.983135
episode 11 reward = 13.0, epsilon = 0.98117
episode 12 reward = 15.0, epsilon = 0.978229
episode 13 reward = 13.0, epsilon = 0.975298
episode 14 reward = 35.0, epsilon = 0.968491
episode 15 reward = 22.0, epsilon = 0.964623
episode 16 reward = 38.0, epsilon = 0.956933
episode 17 reward = 11.0, epsilon = 0.95502
episode 18 reward = 43.0, epsilon = 0.946459
episode 19 reward = 15.0, epsilon = 0.943623


[2017-07-08 10:20:57,777] Starting new video recorder writing to /Users/fchua/Documents/torch_projects/pytorch_tutorials/pytorch-deep-rl/cartpole/openaigym.video.0.83386.video000027.mp4


episode 20 reward = 18.0, epsilon = 0.940795
episode 21 reward = 23.0, epsilon = 0.9361
episode 22 reward = 29.0, epsilon = 0.930497
episode 23 reward = 30.0, epsilon = 0.924928
episode 24 reward = 24.0, epsilon = 0.920313
episode 25 reward = 43.0, epsilon = 0.912976
episode 26 reward = 12.0, epsilon = 0.91024
episode 27 reward = 14.0, epsilon = 0.90842
episode 28 reward = 18.0, epsilon = 0.904792
episode 29 reward = 31.0, epsilon = 0.899377
episode 30 reward = 14.0, epsilon = 0.896682
episode 31 reward = 17.0, epsilon = 0.893994
episode 32 reward = 18.0, epsilon = 0.890424
episode 33 reward = 12.0, epsilon = 0.888644
episode 34 reward = 23.0, epsilon = 0.884209
episode 35 reward = 53.0, epsilon = 0.874531
episode 36 reward = 21.0, epsilon = 0.871039
episode 37 reward = 33.0, epsilon = 0.865825
episode 38 reward = 27.0, epsilon = 0.860643
episode 39 reward = 28.0, epsilon = 0.856349
episode 40 reward = 30.0, epsilon = 0.851223
episode 41 reward = 14.0, epsilon = 0.848672
episode 42 rew

[2017-07-08 10:20:59,098] Starting new video recorder writing to /Users/fchua/Documents/torch_projects/pytorch_tutorials/pytorch-deep-rl/cartpole/openaigym.video.0.83386.video000064.mp4


episode 58 reward = 69.0, epsilon = 0.723133
episode 59 reward = 48.0, epsilon = 0.715934
episode 60 reward = 56.0, epsilon = 0.708098
episode 61 reward = 68.0, epsilon = 0.698249
episode 62 reward = 44.0, epsilon = 0.69199
episode 63 reward = 11.0, epsilon = 0.690606
episode 64 reward = 14.0, epsilon = 0.688537
episode 65 reward = 115.0, epsilon = 0.672873
episode 66 reward = 16.0, epsilon = 0.670857
episode 67 reward = 58.0, epsilon = 0.663514
episode 68 reward = 22.0, epsilon = 0.660203
episode 69 reward = 100.0, epsilon = 0.647124
episode 70 reward = 13.0, epsilon = 0.64583
episode 71 reward = 75.0, epsilon = 0.63621
episode 72 reward = 87.0, epsilon = 0.624855
episode 73 reward = 18.0, epsilon = 0.622983
episode 74 reward = 45.0, epsilon = 0.617398
episode 75 reward = 63.0, epsilon = 0.60942
episode 76 reward = 200.0, epsilon = 0.585513
episode 77 reward = 84.0, epsilon = 0.575638
episode 78 reward = 23.0, epsilon = 0.573339
episode 79 reward = 58.0, epsilon = 0.566497
episode 80 

In [10]:
env.render(close=True)
env.close()

[2017-07-08 10:21:10,572] Finished writing results. You can upload them to the scoreboard via gym.upload('/Users/fchua/Documents/torch_projects/pytorch_tutorials/pytorch-deep-rl/cartpole')


In [11]:
mini_batch = replay_memory.sample(batch_size)

In [17]:
mini_batch = Event(*zip(*mini_batch)) # do this for batch processing

TypeError: __new__() takes 5 positional arguments but 101 were given

In [21]:
torch.Tensor(mini_batch.state)


-0.0868 -0.0374  0.0296  0.0092
 0.0254  0.0070 -0.0596 -0.0306
-0.3525 -0.3222 -0.0686  0.0349
-0.3865 -0.7546 -0.0371 -0.0355
-0.2143 -0.3172  0.0003  0.1146
-0.0758 -0.9735  0.1966  1.8280
-0.0357  0.1839 -0.1127 -0.8558
-0.4763 -0.7006 -0.0410  0.3576
 0.7877  1.5115  0.1855 -0.7046
 0.1615  0.0447 -0.2000 -0.5020
 0.0764  0.7824 -0.0922 -1.2270
-0.0255 -0.0475  0.0842  0.2408
-0.1200  0.1908 -0.0288 -0.8128
-0.0091  0.2321 -0.0420 -0.5012
 0.0019  0.2054 -0.0639 -0.5254
 0.0128  0.2171 -0.0254 -0.2923
-0.0186 -0.3655  0.1260  0.8130
 0.0624 -0.0375  0.0235  0.2471
-0.1513 -0.0071  0.0152 -0.1423
 0.0269 -0.1505 -0.0254 -0.0131
-0.4416 -0.7626 -0.0319  0.1966
 0.6420  0.3702  0.1589  0.4212
-0.0660 -0.4253  0.0903  0.6744
-0.0597 -0.0117  0.0526  0.0823
 0.0106 -0.4159 -0.1215  0.0879
-0.0069 -0.2170 -0.0158  0.3112
-0.1803 -0.1719  0.0004 -0.0966
-0.0844 -0.0382  0.0243  0.0279
 0.0004  0.2194 -0.0263 -0.3436
 0.4961  1.5143  0.1498 -0.8800
 0.0123  0.1645 -0.0962 -0.6823
-0.0500