In [11]:
import numpy as np
import gym
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torch.nn.utils as torch_utils
import torch.multiprocessing as mp

In [12]:
class Actor_Net(nn.Module):
    def __init__(self,n_in,n_out,action_bound):
        super(Actor_Net,self).__init__()
        self.fc1=nn.Linear(n_in,64)
        self.fc2=nn.Linear(64,32)
        self.fc3=nn.Linear(32,16)
        
        self.fc_mu=nn.Linear(16,n_out)
        self.fc_std=nn.Linear(16,n_out)
        self.action_bound=action_bound
    
    def forward(self,x):
        h1=F.relu(self.fc1(x))
        h2=F.relu(self.fc2(h1))
        h3=F.relu(self.fc3(h2))
        
        mu=F.tanh(self.fc_mu(h3))
        std=F.softplus(self.fc_std(h3))
        
        return mu*self.action_bound,std

In [13]:
class Critic_Net(nn.Module):
    def __init__(self,n_in,n_out):
        super(Critic_Net,self).__init__()
        self.fc1=nn.Linear(n_in,64)
        self.fc2=nn.Linear(64,32)
        self.fc3=nn.Linear(32,16)
        self.fc_value=nn.Linear(16,n_out)
    
    def forward(self,x):
        h1=F.relu(self.fc1(x))
        h2=F.relu(self.fc2(h1))
        h3=F.relu(self.fc3(h2))
        out=self.fc_value(h3)
        return out

In [14]:
class Worker_Actor(nn.Module):
    def __init__(self,state_dim,action_dim,action_bound):
        super(Worker_Actor,self).__init__()
        
        self.state_dim=state_dim
        self.std_bound=[1e-2,1.]
        self.network=Actor_Net(state_dim,action_dim,action_bound)
    
    def get_action(self,state):
        self.network.eval()
        with torch.no_grad():
            mu_a,std_a=self.network(state.view(1,self.state_dim))
            mu_a,std_a=mu_a[0],std_a[0]
            sta_a=torch.clamp(std_a,self.std_bound[0],self.std_bound[1])
            action=torch.normal(mu_a,std_a)
        return action

In [15]:
class Global_Actor(nn.Module):
    def __init__(self,state_dim,action_dim,action_bound,lr_rate,entropy_beta):
        super(Global_Actor,self).__init__()
        
        self.state_dim=state_dim
        self.std_bound=[1e-2,1.]
        self.entropy_beta = entropy_beta
        self.network=Actor_Net(state_dim,action_dim,action_bound)
        self.optimizer=optim.Adam(self.network.parameters(),lr=lr_rate)
        
    def log_pdf(self,mu,std,action):
        std=torch.clamp(std,min=self.std_bound[0],max=self.std_bound[1])
        var=std**2
        log_policy_pdf=-0.5*(action-mu)**2/var-0.5*torch.log(var*2*np.pi)
        return torch.sum(log_policy_pdf,dim=1,keepdim=True)
    
    def update(self,states,actions,advantages):
        self.network.train()
        mu_a,std_a=self.actor_network(states)
        log_policy_pdf=self.log_pdf(mu_a,std_a,actions)
        loss=torch.sum(-log_policy_pdf*advantages.detach())
        
        self.optimizer.zero_grad()
        loss.backward()
        torch_utils.clip_grad_norm_(self.network.parameters(), 40.0)
        self.optimizer.step()

In [16]:
class Worker_Critic(nn.Module):
    def __init__(self,state_dim):
        super(Worker_Critic,self).__init__()
        
        self.network=Critic_Net(state_dim,1)

In [17]:
class Global_Critic(nn.Module):
    def __init__(self,state_dim,lr_rate):
        super(Global_Critic,self).__init__()
        
        self.network=Critic_Net(state_dim,1)
        self.optimizer=optim.Adam(self.network.parameters(),lr=lr_rate)
        
    def get_value(self,states):
        self.network.eval()
        value=self.network(states)
        return value
    
    def update(self,states,targets):
        self.network.train()
        values=self.network(states)
        loss=F.mse_loss(values,targets.detach())
        
        self.optimizer.zero_grad()
        loss.backward()
        torch_utils.clip_grad_norm_(self.network.parameters(), 40.0)
        self.optimizer.step()

In [18]:
global_episode_reward = []
class A3Cagent():
    def __init__(self,env_name,n_workers):
        self.env_name=env_name
        self.n_workers=n_workers
        self.actor_learning_rate=0.0001
        self.critic_learning_rate=0.001
        self.entropy_beta=0.01
        
        env=gym.make(env_name)
        state_dim=env.observation_space.shape[0]
        action_dim=env.action_space.shape[0]
        action_bound=env.action_space.high[0]
        
        self.global_actor=Global_Actor(state_dim,action_dim,action_bound,self.actor_learning_rate,self.entropy_beta)
        self.global_critic=Global_Critic(state_dim,self.critic_learning_rate)
        
        self.global_actor.share_memory()
        self.global_critic.share_memory()
        
        self.global_episode_count=mp.Value('i',0)
        self.global_step= mp.Value('i', 0)
    
    def train(self,max_episode_num):
        workers=[]
        for i in range(self.n_workers):
            worker_name='worker%i' %i
            workers.append(Worker(worker_name,self.env_name,self.global_actor,self.global_critic,max_episode_num,self.global_episode_count,self.global_step))
        for worker in workers:
            worker.start()
        for worker in workers:
            worker.join()
    def plot_result(self):
        plt.plot(self.global_episode_count)
        plt.show()

In [27]:
class Worker(mp.Process):
    def __init__(self,worker_name,env_name,global_actor,global_critic,max_episode_num,global_episode_count,global_step):
        super(Worker,self).__init__()
        self.worker_name=worker_name
        self.env=gym.make(env_name)
        self.state_dim=self.env.observation_space.shape[0]
        self.action_dim=self.env.action_space.shape[0]
        self.action_bound=self.env.action_space.high[0]
        
        self.gamma=0.95
        self.t_max=4
        self.max_episode=max_episode_num
        
        self.global_actor=global_actor
        self.global_critic=global_critic
        
        self.worker_actor=Worker_Actor(self.state_dim,self.action_dim,self.action_bound)
        self.worker_critic=Worker_Critic(self.state_dim)
        
        #initial transfer global network parameters
        
        self.worker_actor.network.load_state_dict(self.global_actor.network.state_dict())
        self.worker_critic.network.load_state_dict(self.global_critic.network.state_dict())
        
        self.global_episode_count= global_episode_count
        self.global_step= global_step
    def n_step_td_target(self,rewards,next_v_value,done):
        td_targets=torch.zeros(rewards.size())
        cumulative=0
        if not done:
            cumulative=next_v_value
        for k in reversed(range(0,len(rewards))):
            cumulative=self.gamma*cumulative+rewards[k]
            td_targets[k]=cumulative
        return td_targets
    
    def unpack_batch(self,batch):
        unpack=[]
        for idx in range(len(batch)):
            unpack.append(batch[idx])
        unpack=torch.cat(unpack,axis=0)
        return unpack
    
    def run(self):
        global global_episode_reward
        print(self.worker_name, "starts ---")
        while self.global_episode_count.value<=int(self.max_episode):
            batch_state,batch_action,batch_reward=[],[],[]
            time, episode_reward, done = 0, 0, False
            state=self.env.reset()
            state=torch.from_numpy(state).type(torch.FloatTensor)
            
            while not done:
                action=self.worker_actor.get_action(state)
                action=np.array([action.item()])
                action=np.clip(action,-self.action_bound,self.action_bound)
                
                next_state,reward,done,_=self.env.step(action)
                
                next_state=torch.from_numpy(next_state).type(torch.FloatTensor)
                action=torch.from_numpy(action).type(torch.FloatTensor)
                reward=torch.FloatTensor([reward])
                
                state= state.view(1, self.state_dim)
                next_state= next_state.view(1, self.state_dim)
                action= action.view(1, self.action_dim)
                reward= reward.view(1, 1)
                
                train_reward=(reward+8)/8
                
                batch_state.append(state)
                batch_action.append(action)
                batch_reward.append(train_reward)
                
                state=next_state[0]
                episode_reward+=reward[0]
                time+=1
                if len(batch_state)==self.t_max or done:
                    states=self.unpack_batch(batch_state)
                    actions=self.unpack_batch(batch_action)
                    rewards=self.unpack_batch(batch_rewad)
                
                    batch_state, batch_action, batch_reward = [],[],[]
                    
                    v_values=self.global_critic.get_value(states)
                    next_v_value=self.global_critic.get_value(next_state)
                    n_step_td_targets=self.n_step_td_target(rewards,next_v_value,done)
                    advantages=n_step_td_targets-v_values
                    
                    self.global_critic.update(states,n_step_td_targets)
                    self.global_actor.update(states,actions,advantages)
                    
                    self.worker_actor.network.load_state_dict(self.global_actor.network.state_dict())
                    self.worker_critic.network.load_state_dict(self.global_critic.network.state_dict())
                    
                    self.global_step.value+=1
            if done:
                self.global_episode_count.value+=1
                global_episode_reward.append(episode_reward.item())
                print('Worker:', self.worker_name,
                            ', Episode:',self.global_episode_count,
                            ', Step:', time, ', Reward:', np.mean(global_episode_reward))

In [28]:
max_episode_num=1000
env_name='Pendulum-v0'
n_workers=4

agent=A3Cagent(env_name,n_workers)
agent.train(max_episode_num)
agent.plot_result()

aa
aa
aa
aa


BrokenPipeError: [Errno 32] Broken pipe