In [1]:
import numpy as np
from matplotlib import pyplot as plt

import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
env = gym.make('CartPole-v0')



In [3]:
env.reset()

array([-0.00223198,  0.03423203, -0.04109657,  0.02724357])

In [4]:
env.action_space.n

2

In [5]:
action_list = [0, 1]
input_dim = 4 #env.observation_space.shape[0]

In [6]:
NUM_THREADS = 1

In [7]:
class A3C(nn.Module):
    def __init__(self, input_dim=2, action_dim=11):
        super(A3C, self).__init__()
        self.input_dim = input_dim
        self.action_dim = action_dim
        
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)
        self.fc4 = nn.Linear(128, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        Q = self.fc3(x)
        #policy = F.softmax(self.fc3(x))
        V = self.fc4(x)
        return Q, V
        #return policy, value
    '''
    def get_action(self, obs):
        poilcy, value = self.forward(obs)
        action = np.random.choice(self.action_dim, 1, p=policy.detach.data())
        return action
    
    def get_value(self, obs):
        poilcy, value = self.forward(obs)
        return value
    '''

In [8]:
T_max = 1000000
t_max = 5

beta = 0.01   # entropy regularization
gamma = 0.99
alpha = 0.99   # RMSProb decay factor
learning_rate = 1e-3

In [9]:
globalNet = A3C(input_dim, len(action_list))
optimizer = optim.Adam(globalNet.parameters(), lr=learning_rate)
T = 0

# for each thread
t = 0
done = False
ep_return = 0
log_episode_return = []

localNet = A3C(input_dim, len(action_list))
env = gym.make('Pendulum-v0')

buff_value = []
buff_q = []
buff_reward = []
buff_logp = []
buff_entropy = []

obs = env.reset()

while T < T_max:
    t_start = t

    localNet.load_state_dict(globalNet.state_dict())
    
    while t_start-t < t_max:
        Q, V = localNet(torch.tensor(obs.astype(np.float32)))
        prob = F.softmax(Q).data
        [a] = np.random.choice(localNet.action_dim, 1, p=prob.detach().numpy())
        log_prob = F.log_softmax(Q)
        
        obs, reward, done, _ = env.step(action_list[a])
        ep_return += reward
        entropy = -log_prob*prob.sum()
        
        buff_q.append(Q)
        buff_value.append(V)
        buff_reward.append(reward)
        buff_logp.append(log_prob[a])
        buff_entropy.append(entropy)
        t += 1
        T += 1
        if done:
            obs = env.reset()
            log_episode_return.append(ep_return)
            #print(ep_return)
            ep_return = 0
            break
    
    R = V if not done else 0
    policy_loss = 0
    value_loss = 0
    entropy_loss = 0
    for i in range(-1, -(t-t_start)-1, -1): #range(t-1, t_start-1, -1):
        R = buff_reward[i] + gamma*R
        TD = R - buff_value[i]
        policy_loss += buff_logp[i] * TD.detach()
        value_loss += torch.pow(TD, 2)
        entropy_loss += buff_entropy[i].sum()
    loss = - policy_loss + value_loss - beta*entropy_loss
        
    optimizer.zero_grad()
    loss.backward()
    for local_param, global_param in zip(localNet.parameters(), globalNet.parameters()):
        global_param.grad = local_param.grad
    optimizer.step()
    localNet.load_state_dict(globalNet.state_dict())
    if len(log_episode_return)%100==0:
        print(len(log_episode_return), 'episodes. (%d/%d steps)'%(T, T_max))
        print('Total loss:', loss.data.numpy()[0])
        print('Entropy:', entropy_loss.data.numpy())
        print('Policy:', policy_loss.data.numpy()[0])
        print('Value:', value_loss.data.numpy()[0])
        print(np.mean(log_episode_return[-10:]))
        print()



100 episodes. (2122/1000000 steps)
Total loss: 456.61334
Entropy: 16.710861
Policy: 38.798264
Value: 495.57874
20.8
tensor([-0.4306, -0.2738,  0.4443, -0.2644])
tensor([-0.4306, -0.2738,  0.4443, -0.2644])

200 episodes. (4860/1000000 steps)
Total loss: 1428.5696
Entropy: 32.449112
Policy: 81.95138
Value: 1510.8455
25.7
tensor([-0.4145, -0.2684,  0.4413, -0.2596])
tensor([-0.4145, -0.2684,  0.4413, -0.2596])

300 episodes. (8086/1000000 steps)
Total loss: 2588.4807
Entropy: 36.36179
Policy: 136.66702
Value: 2725.5112
15.5
tensor([-0.4315, -0.2862,  0.4582, -0.2361])
tensor([-0.4315, -0.2862,  0.4582, -0.2361])

400 episodes. (11312/1000000 steps)
Total loss: 3707.5715
Entropy: 37.84905
Policy: 172.09406
Value: 3880.044
51.9
tensor([-0.4007, -0.2725,  0.4244, -0.2538])
tensor([-0.4007, -0.2725,  0.4244, -0.2538])

500 episodes. (15951/1000000 steps)
Total loss: 11634.997
Entropy: 92.18403
Policy: 201.58022
Value: 11837.499
62.4
tensor([-0.3972, -0.2658,  0.4850, -0.1968])
tensor([-0.397

KeyboardInterrupt: 