In [1]:
import gym

env = gym.make("CartPole-v1")
env._max_episode_steps = 400

obs_n = env.observation_space.shape[0]
act_n = env.action_space.n

In [2]:
from collections import namedtuple

step = namedtuple('step', ('obs', 'act_v', 'act', 'next_obs', 'v_value', 'done'))

def get_episode(env, net, batch, proc, train_queue):
    obs = env.reset()
    ep = list()
    
    count = 0
    while True:
        with torch.no_grad():
            act_v, _ = net(torch.FloatTensor(obs))
        act = np.random.choice(act_n, p=F.softmax(act_v, dim=0).detach().numpy())
        act_v = act_v.detach().numpy()
        
        next_obs, rew, done, _ = env.step(act)
        with torch.no_grad():
            _, v_value = net(torch.FloatTensor(next_obs))
        count += 1
        ep.append(step(obs, act_v, act, next_obs, rew + (0 if done else GAMMA*v_value.item()), done))
        obs = next_obs
        
        if done:
            print(proc, count)
            count = 0
            obs = env.reset()
            
        if len(ep) != batch:
            continue

        steps = step(*zip(*ep))

        obs_ = torch.FloatTensor(steps.obs)
        act_ = torch.LongTensor(steps.act)
        act_v_ = torch.FloatTensor(steps.act_v)
        v_value_ = torch.FloatTensor(steps.v_value)

        opt.zero_grad()
        logit, v_pred = net(obs_)

        # critic update
        critic_loss = F.mse_loss(v_pred.squeeze(-1), v_value_)
        critic_loss.backward()

        # actor update
        prob = F.softmax(logit, dim=1)
        log_prob = F.log_softmax(logit, dim=1)

        adv = v_value_.unsqueeze(1) - v_pred.detach()
        # adv_pos = adv > 0
        # adv_neg = adv < 0
        # prob_old_g = F.softmax(act_v_, dim=1).gather(1, act_.unsqueeze(1))
        # prob_g = prob.gather(1, act_.unsqueeze(1))

        # ratio = prob_g/prob_old_g
        # ratio[adv_pos] = ratio[adv_pos].clamp_max(1+EPS)
        # ratio[adv_neg] = ratio[adv_neg].clamp_min(1-EPS)
        # policy_loss = (ratio * adv).mean()
        policy_loss = (-log_prob * adv).mean()

        # entropy loss
        entropy_loss = (prob * log_prob).sum(dim=1).mean()

        actor_loss = policy_loss + ALPHA * entropy_loss
        actor_loss.backward()

        #print("[%3d]"%proc + "policy : %.7f entrophy : %.7f critic : %.7f\n"%(policy_loss, entropy_loss, critic_loss))
        
        train_queue.put([p.grad.data.cpu() for p in net.parameters()])
        ep = list()

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp

class A2C(nn.Module):
    def __init__(self, in_, out_, hidden=512):
        super(A2C, self).__init__()
        
        self.actor = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_),
        )
        
        self.critic = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )
        
    def forward(self, x):
        return self.actor(x), self.critic(x)
    
GAMMA = 0.99
ALPHA = 0.01
LR = 0.0001
BATCH = 64
QUEUE_SIZE = 8
TGT_UP = 8
EPOCH = 100000
EPS = 0.2

net = A2C(obs_n, act_n)
opt = optim.Adam(net.parameters(), LR)

In [4]:
if __name__ == "__main__":
    net.share_memory()

    train_queue = mp.Queue(maxsize=QUEUE_SIZE)
    process_list = []
    for i in range(QUEUE_SIZE):
        env = gym.make("CartPole-v1")
        env.reset()
        env._max_episode_steps = 400

        obs_n = env.observation_space.shape[0]
        act_n = env.action_space.n

        process = mp.Process(target = get_episode, args=(env, net, BATCH, i, train_queue))
        process.start()
        process_list.append(process)
        
    grad_tgt = None
    for epoch in range(EPOCH):
        
        grad_ = train_queue.get()
        if grad_tgt is None:
            grad_tgt = grad_
        else:
            for tgt, g in zip(grad_tgt, grad_):
                tgt += g
                
        if epoch % TGT_UP == TGT_UP - 1:
            grad_sum = []
            for net_grad, tgt in zip(net.parameters(), grad_tgt):
                net_grad.grad = torch.FloatTensor(tgt)
                grad_sum.append(net_grad.grad.data.sum().item())
            opt.step()
            grad_tgt = None
            
            print("EPOCH %d]"%epoch, end= " ")
            for g in grad_sum:
                print("%.5f"%g, end=' ')
            print()
    
    for p in process_list:
        p.terminate()
        p.join()

BrokenPipeError: [Errno 32] Broken pipe