In [1]:
import os
import numpy as np
import threading
from matplotlib import pyplot as plt

import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
gym.logger.set_level(40)

In [3]:
ENV_NAME = 'CartPole-v0'
input_dim = 4
action_list = [0,1]

In [4]:
class A3C(nn.Module):
    def __init__(self, input_dim, action_dim, max_ep=0, is_global=False):
        super(A3C, self).__init__()
        self.input_dim = input_dim
        self.action_dim = action_dim
        self.max_ep = max_ep
        
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)
        self.fc4 = nn.Linear(128, 1)
        
        self.ep_counter = 0
        self.ep_returns = []
        self.average_returns = []
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        Q = self.fc3(x)
        V = self.fc4(x)
        return Q, V
    
    def log_episode(self, ep_return):
        c = self.ep_counter
        self.ep_returns.append(ep_return)
        self.average_returns.append(np.mean(self.ep_returns[max(0, c-99):c+1]))
        self.ep_counter += 1
        return self.ep_counter

In [5]:
NUM_THREADS = 8

#T_max = 10000
MAX_EP = 20000
t_max = 5
print_freq = 1000

beta = 0.01   # entropy regularization
gamma = 0.99
alpha = 0.99   # RMSProb decay factor
learning_rate = 1e-4
decay_rate = 0.996

In [6]:
def train(lock, globalNet, optimizer, tmax, tid):
    t = 0
    done = False
    ep_return = 0
    log_episode_return = []
    cur_ep = 0
    
    localNet = A3C(input_dim, len(action_list))
    localNet.load_state_dict(globalNet.state_dict())
    env = gym.make(ENV_NAME)
    obs = env.reset()
    
    while globalNet.ep_counter < MAX_EP:
        t_start = t
        buff_value = []
        buff_q = []
        buff_reward = []
        buff_logp = []
        buff_entropy = []

        while t_start-t < t_max:
            Q, V = localNet(torch.tensor(obs.astype(np.float32)))
            prob = F.softmax(Q, dim=0).data
            [a] = np.random.choice(localNet.action_dim, 1, p=prob.detach().numpy())
            log_prob = F.log_softmax(Q, dim=0)

            obs, reward, done, _ = env.step(action_list[a])
            ep_return += reward
            entropy = -log_prob*prob.sum()

            buff_q.append(Q)
            buff_value.append(V)
            buff_reward.append(reward)
            buff_logp.append(log_prob[a])
            buff_entropy.append(entropy)
            t += 1
            
            if done:
                cur_ep = globalNet.log_episode(ep_return)
                obs = env.reset()
                ep_return = 0
                break

        R = V if not done else 0
        policy_loss = 0
        value_loss = 0
        entropy_loss = 0
        for i in range(-1, -(t-t_start)-1, -1): #range(t-1, t_start-1, -1):
            R = buff_reward[i] + gamma*R
            TD = R - buff_value[i]
            policy_loss += buff_logp[i] * TD.detach()
            value_loss += torch.pow(TD, 2)
            entropy_loss += buff_entropy[i].sum()
        loss = - policy_loss + value_loss - beta*entropy_loss
        
        optimizer.zero_grad()
        loss.backward()
        lock.acquire()
        try:
            for local_param, global_param in zip(localNet.parameters(), globalNet.parameters()):
                global_param.grad = local_param.grad
            optimizer.step()
        finally:
            lock.release()
        localNet.load_state_dict(globalNet.state_dict())
        
        if cur_ep%print_freq==0: #globalNet.ep_counter.value%100==0:
            print('[%d] Thread'%tid)
            print('%d/%d episodes. (%.2f%%)'%(cur_ep, MAX_EP, cur_ep/MAX_EP*100))
            #print(globalNet.ep_counter.value-1, 'episodes.')
            print('Total loss:\t', loss.data.numpy()[0])
            print('Entropy\t\tPolicy\t\tValue')
            print('%.2f\t\t%.2f\t\t%.2f'%(entropy_loss.data.numpy(), policy_loss.data.numpy()[0], \
                  value_loss.data.numpy()[0]))
            print('Epside Return: [%.1f]'%globalNet.average_returns[globalNet.ep_counter-1])
            print()

In [7]:
globalNet = A3C(input_dim, len(action_list), MAX_EP, is_global=True)
#globalNet.share_memory()
optimizer = optim.Adam(globalNet.parameters(), lr=learning_rate)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate)
lock = threading.Lock()

threads = []
for t_idx in range(NUM_THREADS):
    t = threading.Thread(target=train, args=(lock, globalNet, optimizer, t_max, t_idx))
    t.start()
    threads.append(t)
for t in threads:
    t.join()

1583591895.1710887
10.155836582183838

[2] Thread
1000/20000 episodes. (5.00%)
Total loss:	 1481.8181
Entropy		Policy		Value
35.52		55.96		1538.13
Epside Return: [32.8]

1583591905.3269253
14.60524034500122

[2] Thread
2000/20000 episodes. (10.00%)
Total loss:	 21027.176
Entropy		Policy		Value
122.52		-247.05		20781.35
Epside Return: [53.1]

1583591919.9321656
19.61437153816223

[1] Thread
3000/20000 episodes. (15.00%)
Total loss:	 22294.602
Entropy		Policy		Value
161.96		-376.18		21920.04
Epside Return: [56.7]

1583591939.5465372
33.5636625289917

[3] Thread
4000/20000 episodes. (20.00%)
Total loss:	 108015.484
Entropy		Policy		Value
329.35		-977.58		107041.20
Epside Return: [124.9]

1583591973.1101997
50.31988191604614

[7] Thread
5000/20000 episodes. (25.00%)
Total loss:	 54194.82
Entropy		Policy		Value
342.72		-1182.71		53015.54
Epside Return: [138.2]

1583592023.4300816
56.36374855041504

[4] Thread
6000/20000 episodes. (30.00%)
Total loss:	 49855.07
Entropy		Policy		Value
340.81	

KeyboardInterrupt: 

## No Learning Rate Decay

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(globalNet.ep_returns[:], color='orange')
plt.plot(globalNet.average_returns[:], color='red')
fignum = len([f for f in os.listdir() if 'CartPole' in f and 'png' in f])
plt.savefig('A3C_CartPole_threads_%d.png'%fignum)