In [23]:
import numpy as np
from matplotlib import pyplot as plt

import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [39]:
action_list = [4/10*i-2 for i in range(11)]
input_dim = 3 #env.observation_space.shape[0]

In [25]:
NUM_THREADS = 4

In [26]:
class A3C(nn.Module):
    def __init__(self, input_dim=2, action_dim=11):
        super(A3C, self).__init__()
        self.input_dim = input_dim
        self.action_dim = action_dim
        
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)
        self.fc4 = nn.Linear(128, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        Q = self.fc3(x)
        #policy = F.softmax(self.fc3(x))
        V = self.fc4(x)
        return Q, V
        #return policy, value
    '''
    def get_action(self, obs):
        poilcy, value = self.forward(obs)
        action = np.random.choice(self.action_dim, 1, p=policy.detach.data())
        return action
    
    def get_value(self, obs):
        poilcy, value = self.forward(obs)
        return value
    '''

In [83]:
T_max = 100000
t_max = 5

beta = 0.01   # entropy regularization
gamma = 0.99
alpha = 0.99   # RMSProb decay factor
learning_rate = 1e-4

In [85]:
globalNet = A3C(input_dim, len(action_list))
optimizer = optim.Adam(globalNet.parameters(), lr=learning_rate)
T = 0

# for each thread
t = 0
done = False
ep_return = 0
log_episode_return = []

buff_value = []
buff_q = []
buff_reward = []
buff_logp = []
buff_entropy = []

env = gym.make('Pendulum-v0')
obs = env.reset()

while T < T_max:
    t_start = t
    localNet = globalNet #A3C(input_dim, len(action_list))
    localNet.load_state_dict(globalNet.state_dict())
    
    while t_start-t < t_max:
        Q, V = localNet(torch.tensor(obs.astype(np.float32)))
        prob = F.softmax(Q).data
        [a] = np.random.choice(localNet.action_dim, 1, p=prob.detach().numpy())
        log_prob = F.log_softmax(Q)
        
        obs, reward, done, _ = env.step([action_list[a]])
        ep_return += reward
        entropy = -log_prob*prob.sum()
        
        buff_q.append(Q)
        buff_value.append(V)
        buff_reward.append(reward)
        buff_logp.append(log_prob[a])
        buff_entropy.append(entropy)
        t += 1
        T += 1
        if done:
            obs = env.reset()
            log_episode_return.append(ep_return)
            print(ep_return)
            ep_return = 0
            break
    
    R = V if not done else 0
    loss = 0
    for i in range(-1, -(t-t_start)-1, -1): #range(t-1, t_start-1, -1):
        R = buff_reward[i] + gamma*R
        TD = R - buff_value[i]
        policy_loss = buff_logp[i] * TD.detach()
        value_loss = torch.pow(TD, 2)
        entropy_loss = buff_entropy[i].sum()
        loss = -(policy_loss + value_loss + beta*entropy_loss)
        
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
#     if T%1000==0:
#         print('Total loss:', loss.data.numpy()[0])
#         print('Entropy:', entropy_loss.data.numpy())
#         print('Policy:', policy_loss.data.numpy()[0])
#         print('Value:', value_loss.data.numpy()[0])
#         print()

KeyboardInterrupt: 