In [26]:
import math
import random
import time
from jupyterthemes import jtplot
jtplot.style()

import gym
gym.logger.set_level(40)
import numpy as np


import nnabla as nn
import nnabla.logger as logger
import nnabla.functions as F
import nnabla.parametric_functions as PF
import nnabla.solver as S
from nnabla.contrib.context import extension_context
from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed

In [27]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [28]:
env = gym.make("CartPole-v0")


In [29]:
epsilon_start = 1.0
epsilon_final = 0.01
epsilon_decay = 500
epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * math.exp(-1. * frame_idx / epsilon_decay)

In [30]:
from collections import deque
class ReplayBuffer(object):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        state      = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done
    
    def __len__(self):
        return len(self.buffer)

In [35]:
hidden_size = 128
max_frames = 50
state_size = 1
batch_size = 32
replay_buffer_size = 1000

In [40]:
class DQN:
    def __init__(self, num_states, num_actions):
        self.state_dim = num_states
        self.action_dim = num_actions 
        self.state = nn.Variable([state_size, self.state_dim])
    
    def forward(self, x):
        with nn.parameter_scope("DQN"):
            with nn.parameter_scope("affine1"):
                h = F.relu(PF.affine(x, hidden_size))
            with nn.parameter_scope("affine2"):
                h = F.relu(PF.affine(h, hidden_size))
            with nn.parameter_scope("affine3"):
                y = PF.affine(h, self.action_dim)
        return y
    
    def act(self, state, epsilon):
        if random.random() > epsilon:
            self.state.d = state
            q_value = self.forward(self.state)
            print(q_value.d)
            action  = np.argmax(q_value.d)
            print("action:",action)
        else:
            action = random.randrange(self.action_dim)
        return action

In [41]:
print("state_num ",env.observation_space.shape[0])
print("action_num ",env.action_space.n)

model = DQN(env.observation_space.shape[0], env.action_space.n)
replay_buffer = ReplayBuffer(replay_buffer_size)

losses = []
overall_rewards = []
episode_reward = 0

state_num  4
action_num  2


In [42]:
state = env.reset()
for frame in range(1, max_frames):
    epsilon = epsilon_by_frame(frame)
    print(state,frame)
    action = model.act(state, epsilon)
    next_state, reward, done, _ = env.step(action)
    replay_buffer.push(state, action, reward, next_state, done)
    state = next_state
    episode_reward += reward   
    if done:
        print("done")
        state = env.reset()
        overall_rewards.append(episode_reward)
        episode_reward = 0
        
    if len(replay_buffer) > batch_size:
        loss = compute_td_loss(batch_size)
        losses.append(loss.data[0])
        
    if frame % 200 == 0:
        plot(frame, all_rewards, losses)


[-0.0429882  -0.03804648 -0.04936629 -0.04296004] 1
[-0.04374913 -0.23242705 -0.05022549  0.23374781] 2
[-0.04839767 -0.42679674 -0.04555053  0.51017459] 3
[-0.05693361 -0.23106368 -0.03534704  0.20349232] 4
[-0.06155488 -0.42566277 -0.03127719  0.48481876] 5
[-0.07006813 -0.23011372 -0.02158082  0.18244464] 6
[-0.07467041 -0.03468972 -0.01793192 -0.11696736] 7
[-0.0753642  -0.22955021 -0.02027127  0.17000468] 8
[-0.07995521 -0.03414407 -0.01687118 -0.12900366] 9
[-0.08063809  0.16121545 -0.01945125 -0.42696113] 10
[-0.07741378 -0.03362568 -0.02799047 -0.14047303] 11
[-0.07808629  0.16188575 -0.03079993 -0.44185341] 12
[-0.07484858  0.3574297  -0.039637   -0.74408398] 13
[-0.06769998  0.55307562 -0.05451868 -1.04897259] 14
[-0.05663847  0.74887687 -0.07549813 -1.35825863] 15
[-0.04166093  0.5547786  -0.10266331 -1.09011597] 16
[-0.03056536  0.36114866 -0.12446563 -0.83133034] 17
[-0.02334239  0.55773186 -0.14109223 -1.16042252] 18
[-0.01218775  0.75438124 -0.16430068 -1.49380906] 19
[ 

NameError: name 'compute_td_loss' is not defined