In [1]:
import torch
from torch import nn
from torch import optim
import numpy as np
from torch.nn import functional as F
import gym
import torch.multiprocessing as mp #A
from multiprocessing import set_start_method
set_start_method("fork")
mp.get_start_method()

'fork'

In [2]:
from torch.utils.tensorboard import SummaryWriter

In [3]:
writer = SummaryWriter('runs/test1')

In [4]:
class ActorCritic(nn.Module): #B
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.l1 = nn.Linear(4,25)
        self.l2 = nn.Linear(25,50)
        self.actor_lin1 = nn.Linear(50,2)
        self.l3 = nn.Linear(50,25)
        self.critic_lin1 = nn.Linear(25,1)
    def forward(self,x):
        x = F.normalize(x,dim=0)
        y = F.relu(self.l1(x))
        y = F.relu(self.l2(y))
        actor = F.log_softmax(self.actor_lin1(y),dim=0) #C
        c = F.relu(self.l3(y.detach()))
        critic = torch.tanh(self.critic_lin1(c)) #D
        return actor, critic #E

In [5]:
def worker(t, worker_model, counter, params):
    worker_env = gym.make("CartPole-v1")
    worker_env.reset()
    worker_opt = optim.Adam(lr=1e-4,params=worker_model.parameters()) #A
    worker_opt.zero_grad()
    for i in range(params['epochs']):
        worker_opt.zero_grad()
        values, logprobs, rewards = run_episode(worker_env,worker_model) #B 
        actor_loss,critic_loss,eplen = update_params(worker_opt,values,logprobs,rewards) #C
        
        counter.value = counter.value + 1 #D
        writer.add_scalar('Episode_length', eplen, counter.value)

In [6]:
def run_episode(worker_env, worker_model):
    state = torch.from_numpy(worker_env.env.state).float() #A
    values, logprobs, rewards = [],[],[] #B
    done = False
    j=0
    while (done == False): #C
        j+=1
        policy, value = worker_model(state) #D
        values.append(value)
        logits = policy.view(-1)
        action_dist = torch.distributions.Categorical(logits=logits)
        action = action_dist.sample() #E
        logprob_ = policy.view(-1)[action]
        logprobs.append(logprob_)
        state_, _, done, info = worker_env.step(action.detach().numpy())
        state = torch.from_numpy(state_).float()
        if done: #F
            reward = -10
            worker_env.reset()
        else:
            reward = 1.0
        rewards.append(reward)
    return values, logprobs, rewards

In [7]:
def update_params(worker_opt,values,logprobs,rewards,clc=0.1,gamma=0.95):
        rewards = torch.Tensor(rewards).flip(dims=(0,)).view(-1) #A
        logprobs = torch.stack(logprobs).flip(dims=(0,)).view(-1)
        values = torch.stack(values).flip(dims=(0,)).view(-1)
        Returns = []
        ret_ = torch.Tensor([0])
        for r in range(rewards.shape[0]): #B
            ret_ = rewards[r] + gamma * ret_
            Returns.append(ret_)
        Returns = torch.stack(Returns).view(-1)
        Returns = F.normalize(Returns,dim=0)
        actor_loss = -1*logprobs * (Returns - values.detach()) #C
        critic_loss = torch.pow(values - Returns,2) #D
        loss = actor_loss.sum() + clc*critic_loss.sum() #E
        loss.backward()
        worker_opt.step()
        return actor_loss, critic_loss, len(rewards)

In [9]:
MasterNode = ActorCritic() #A
MasterNode.share_memory() #B
processes = [] #C
params = {
    'epochs':1000,
    'n_workers':4,
}
counter = mp.Value('i',0) #D
for i in range(params['n_workers']):
    p = mp.Process(target=worker, args=(i,MasterNode,counter,params)) #E
    p.start() 
    processes.append(p)
for p in processes: #F
    p.join()
for p in processes: #G
    p.terminate()
    
print(counter.value,processes[1].exitcode) #H

Process Process-11:
Process Process-8:
Process Process-9:
Process Process-10:
Traceback (most recent call last):


KeyboardInterrupt: 

Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line

In [None]:
env = gym.make("CartPole-v1")
env.reset()

i_initial = 0
for i in range(5000):
    state_ = np.array(env.env.state)
    state = torch.from_numpy(state_).float()
    logits,value = MasterNode(state)
    action_dist = torch.distributions.Categorical(logits=logits)
    action = action_dist.sample()
    state2, reward, done, info = env.step(action.detach().numpy())
    
    if done:
        print("Lost")
        env.reset()
        n_steps = i - i_initial
        i_initial = i
        print('Number of steps {}'.format(n_steps))
    state_ = np.array(env.env.state)
    state = torch.from_numpy(state_).float()
    env.render()


In [10]:
def update_params(worker_opt,values,logprobs,rewards,G,clc=0.1,gamma=0.95):
    rewards = torch.Tensor(rewards).flip(dims=(0,)).view(-1) #A
    logprobs = torch.stack(logprobs).flip(dims=(0,)).view(-1)
    values = torch.stack(values).flip(dims=(0,)).view(-1)
    Returns = []
    ret_ = G
    for r in range(rewards.shape[0]): #B
        ret_ = rewards[r] + gamma * ret_
        Returns.append(ret_)
    Returns = torch.stack(Returns).view(-1)
    Returns = F.normalize(Returns,dim=0)
    actor_loss = -1*logprobs * (Returns - values.detach()) #C
    critic_loss = torch.pow(values - Returns,2) #D
    loss = actor_loss.sum() + clc*critic_loss.sum() #E
    loss.backward()
    worker_opt.step()
    return actor_loss, critic_loss, len(rewards)

In [11]:
def run_episode(worker_env, worker_model, N_steps=20):
    raw_state = np.array(worker_env.env.state)
    state = torch.from_numpy(raw_state).float()
    values, logprobs, rewards = [], [], []
    done = False
    j = 0 
    G = torch.Tensor([0])
    while (j < N_steps and done == False):
        j+=1
        policy, value = worker_model(state)
        values.append(value)
        logits = policy.flatten()
        action_dist = torch.distributions.Categorical(logits=logits)
        action = action_dist.sample()
        logprob_ = logits[action]
        logprobs.append(logprob_)
        state_, _, done, info = worker_env.step(action.detach().numpy())
        state = torch.from_numpy(state_).float()
        if done:
            reward = -10
            worker_env.reset()
        else:
            reward = 1.0
            G = value.detach()
        rewards.append(reward)
    return values, logprobs, rewards, G 


In [12]:
def worker(t, worker_model, counter, params):
    worker_env = gym.make("CartPole-v1")
    worker_env.reset()
    worker_opt = optim.Adam(lr=1e-4,params=worker_model.parameters()) # Each process runs its own isolated environment and optimizer but shares the model.!!
    worker_opt.zero_grad()
    for i in range(params['epochs']):
        worker_opt.zero_grad()
        values, logprobs, rewards, G = run_episode(worker_env,worker_model) #B 
        actor_loss,critic_loss,eplen = update_params(worker_opt,values,logprobs,rewards,G) #C
        counter.value = counter.value + 1 #D
        writer.add_scalar('Episode_length', eplen, counter.value)

In [16]:
MasterNode = ActorCritic()      #  Creates a global, shared actor-critic model
MasterNode.share_memory()       #  Allow the parameters of the model to be shared across processes rather than being copied
processes = []
params = {                      
    'epochs': 1000,
    'n_workers': 7,
}
counter = mp.Value('i',0)   #  A Shared global counter using multiprocessing's built-in shared object. The i parameter indicates the type is integer.

for i in range(params['n_workers']):
    
    p = mp.Process(target=worker, args=(i, MasterNode, counter, params)) #  Starts a new process that runs the worker function
    p.start()
    processes.append(p)

for p in processes:     #  "Joins" each process to wait for it to finish before returning to the main process
    p.join()

for p in processes:     #  Makes sure each process is terminated
    p.terminate()

print(counter.value,processes[1].exitcode)    

Process Process-21:
Process Process-24:
Process Process-19:
Process Process-23:
Process Process-25:


KeyboardInterrupt: 

Process Process-20:
Process Process-22:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/multiprocessing/process.py"

  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/site-packages/tensorboard/summary/writer/event_file_writer.py", line 113, in add_event
    self._async_writer.write(event.SerializeToString())
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/site-packages/tensorboard/summary/writer/event_file_writer.py", line 113, in add_event
    self._async_writer.write(event.SerializeToString())
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/site-packages/tensorboard/summary/writer/event_file_writer.py", line 166, in write
    self._byte_queue.put(bytestring)
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/site-packages/tensorboard/summary/writer/event_file_writer.py", line 166, in write
    self._byte_queue.put(bytestring)
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/site-packages/tensorboard/summary/writer/event_file_writer.py", line 166, in write
    self._byte_queue.put(bytestring)
  File "/Users/madiaz/miniforge3/envs/rllib/lib/python3.8/site-p

In [18]:
counter.value

77