In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import gym

import numpy as np

import random
import math

import collections
from collections import namedtuple

In [2]:
step = namedtuple("step", ("state", "action", "next_state", "reward", "done"))

class Replay:
    def __init__(self, size):
        self.memory = collections.deque(maxlen = size)
        
    def push(self, data):
        self.memory.append(data)
        
    def prepare(self, env):
        pass
        
    def sample(self, size):
        if len(self.memory) >= size:
            return random.sample(self.memory, size)

In [3]:
class Actor(nn.Module):
    def __init__(self, state_n, action_n, hidden = 256):
        super(Actor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_n, hidden),
            nn.ReLU(),
            nn.Linear(hidden, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), action_n),
            nn.Tanh()
        )
        
    def forward(self,x):
        return self.net(x)
    
class Critic(nn.Module):
    def __init__(self, state_n, action_n, hidden =256):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_n, hidden),
            nn.ReLU(),
        )
        self.out = nn.Sequential(
            nn.Linear(hidden+action_n, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), 1)
        )
        
    def forward(self, state, act):
        temp = self.net(state)
        return self.out(torch.cat([temp, act], dim=1))

In [4]:
EPOCH = 1000
GAME_NAME = "CartPole-v1"

env = gym.make(GAME_NAME)
obs_n = env.observation_space.shape[0]
act_n = env.action_space.n

LR = 0.001
TAU = 0.005
GAMMA = 0.99

actor = Actor(obs_n, act_n)
actor_optim = optim.Adam(actor.parameters(), lr = LR)
actor_tgt = Actor(obs_n, act_n)
actor_tgt.load_state_dict(actor.state_dict())

critic = Critic(obs_n, act_n)
critic_optim = optim.Adam(critic.parameters(), lr = LR)
critic_tgt = Critic(obs_n, act_n)
critic_tgt.load_state_dict(critic.state_dict())

MAX_MEMORY = 20000
MEM_INIT = 2000
BATCH = 256
storage = Replay(MAX_MEMORY)

In [None]:
for epoch in range(EPOCH):
    obs = env.reset()
    env.render()
    
    count = 0
    while True:
        with torch.no_grad():
            act_v = actor(torch.FloatTensor(obs)).numpy()
            noise = np.random.random(act_n)/(epoch+1)
            act_v += noise
            act = act_v.argmax().item()
            
        next_obs, rew, done, _ = env.step(act)
        env.render()
        count += 1
        
        storage.push(step(obs, act_v, next_obs, rew, done))
        obs = next_obs
        
        sample = storage.sample(BATCH)
        if sample:
            sample = step(*zip(*sample))
            
            states = torch.FloatTensor(sample.state)
            actions = torch.FloatTensor(sample.action)
            next_states = torch.FloatTensor(sample.next_state)
            rewards = torch.FloatTensor(sample.reward).unsqueeze(-1)
            dones = torch.BoolTensor(sample.done).unsqueeze(-1)
            
            # critic learning
            critic_optim.zero_grad()
            q_pred = critic(states, actions)
            
            next_action_v = actor_tgt(next_states)
            q_next = critic_tgt(next_states, next_action_v)
            q_next[dones] = 0
            q_target = rewards + GAMMA * q_next
            
            critic_loss = F.mse_loss(q_pred, q_target.detach())
            critic_loss.backward()
            critic_optim.step()
            
            # actor learning
            actor_optim.zero_grad()
            actor_loss = -critic(states, actor(states))
            actor_loss = actor_loss.mean()
            actor_loss.backward()
            actor_optim.step()
            
            # tgt soft update
            for tgt, real  in zip(actor_tgt.parameters(), actor.parameters()):
                tgt.data.copy_(TAU*real.data + (1-TAU)*tgt.data)
                
            for tgt, real  in zip(critic_tgt.parameters(),critic.parameters()):
                tgt.data.copy_(TAU*real.data + (1-TAU)*tgt.data)
            
        if done:
            break
    print("epoch %d count %d"%(epoch, count))

epoch 0 count 500
epoch 1 count 500
epoch 2 count 500
epoch 3 count 500
epoch 4 count 274
epoch 5 count 312
epoch 6 count 97
epoch 7 count 155
epoch 8 count 141
epoch 9 count 112
epoch 10 count 128
epoch 11 count 118
epoch 12 count 107
epoch 13 count 190
epoch 14 count 131
epoch 15 count 194
epoch 16 count 148
