<a href="https://colab.research.google.com/github/moodlep/MLC_A3C/blob/main/a3c.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install box2d-py
!pip3 install gym[Box_2D]

Collecting box2d-py
  Downloading box2d_py-2.3.8-cp37-cp37m-manylinux1_x86_64.whl (448 kB)
[?25l[K     |▊                               | 10 kB 20.3 MB/s eta 0:00:01[K     |█▌                              | 20 kB 24.6 MB/s eta 0:00:01[K     |██▏                             | 30 kB 29.0 MB/s eta 0:00:01[K     |███                             | 40 kB 24.1 MB/s eta 0:00:01[K     |███▋                            | 51 kB 7.8 MB/s eta 0:00:01[K     |████▍                           | 61 kB 8.3 MB/s eta 0:00:01[K     |█████▏                          | 71 kB 7.2 MB/s eta 0:00:01[K     |█████▉                          | 81 kB 8.0 MB/s eta 0:00:01[K     |██████▋                         | 92 kB 8.3 MB/s eta 0:00:01[K     |███████▎                        | 102 kB 7.2 MB/s eta 0:00:01[K     |████████                        | 112 kB 7.2 MB/s eta 0:00:01[K     |████████▊                       | 122 kB 7.2 MB/s eta 0:00:01[K     |█████████▌                      | 133 kB 7.2 M

In [None]:
import os
import Box2D
import pyglet
import imageio
os.system("Xvfb :1 -screen 0 1024x768x24 &")
os.environ['DISPLAY'] = ':1'


In [1]:
import gym
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp


In [None]:
# The env - quick test: 

env = gym.make("LunarLander-v2")

s = env.reset()

for _ in range(5): 
  a = env.action_space.sample()
  next_state, reward, done, info = env.step(a)
  print(next_state, reward, a)


[-0.00978136  1.4278034  -0.48951548  0.36232132  0.00955719  0.07687708
  0.          0.        ] 1.066382897039149 3
[-0.01455765  1.4353637  -0.47948155  0.33599737  0.01138271  0.03651392
  0.          0.        ] 1.3802776703599886 3
[-0.01942892  1.4438708  -0.48857823  0.37808454  0.01280896  0.02852786
  0.          0.        ] -4.528536738815819 2
[-0.02430039  1.4517776  -0.48858204  0.35140312  0.01423492  0.02852177
  0.          0.        ] 0.6550461056471022 0
[-0.02908678  1.4590878  -0.477918    0.32490155  0.01352056 -0.0142885
  0.          0.        ] 1.6946848958115208 3


In [None]:
class SharedAdam(torch.optim.Adam):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
                 weight_decay=0):
        super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        # State initialization
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['exp_avg'] = torch.zeros_like(p.data)
                state['exp_avg_sq'] = torch.zeros_like(p.data)

                # share in memory
                state['exp_avg'].share_memory_()
                state['exp_avg_sq'].share_memory_()

In [None]:
env.action_space

Discrete(4)

### Actor - policy NN and value NN 
### data collection -> batch
### train: calculate loss 

In [None]:

class Policy(nn.Module):
	def __init__(self, state_dim,action_dim,hidden=100):
		super(Policy, self).__init__()

		self.l1 = nn.Linear(state_dim, hidden)
		self.l2 = nn.Linear(hidden, hidden)
		self.l3 = nn.Linear(hidden,action_dim)

	def forward(self, state):
		q = F.leaky_relu(self.l1(state))
		q = F.leaky_relu(self.l2(q))
		return F.softmax(self.l3(q), dim = 1)
	
	def get_action(self,state):
		with torch.no_grad():
			pol = self.forward(state)
			dist = torch.distributions.Categorical(pol)
		return dist.sample() #returns a batch of values
	
	def log_prob(self, state, actions):
		  # Part of the loss term
			pol = self.forward(state)
			log_prob = torch.distributions.Categorical(pol).log_prob(actions)
			return log_prob
	
	def entropy(self, state):
			pol = self.forward(state)
			return torch.distributions.Categorical(pol).entropy()
    


In [None]:
# create batch of states 
batch_states = torch.rand(5, env.observation_space.shape[0])

policy = Policy(env.observation_space.shape[0], env.action_space.n)
policy(batch_states).data

tensor([[0.2348, 0.2314, 0.2917, 0.2421],
        [0.2365, 0.2435, 0.2770, 0.2430],
        [0.2343, 0.2392, 0.2871, 0.2394],
        [0.2327, 0.2390, 0.2851, 0.2433],
        [0.2407, 0.2420, 0.2838, 0.2335]])

In [None]:
batch_actions = policy.get_action(batch_states)
batch_actions

tensor([2, 0, 2, 3, 3])

In [None]:
policy.log_prob(batch_states, batch_actions)

tensor([-1.2319, -1.4420, -1.2478, -1.4137, -1.4546],
       grad_fn=<SqueezeBackward1>)

In [None]:
class Critic(nn.Module):
    def __init__(self, state_dim,hidden=100):
      super(Critic, self).__init__()
    
      self.l1 = nn.Linear(state_dim, hidden)
      self.l2 = nn.Linear(hidden, hidden)
      self.l3 = nn.Linear(hidden,1)

    def forward(self, state):
      q = F.leaky_relu(self.l1(state))
      q = F.leaky_relu(self.l2(q))
      return self.l3(q)



In [None]:
#testing the critic output
critic = Critic(env.observation_space.shape[0])
critic(batch_states)


tensor([[0.1183],
        [0.1011],
        [0.0787],
        [0.0686],
        [0.0949]], grad_fn=<AddmmBackward>)

In [None]:
class ActorCriticWorker(mp.Process):
	def __init__(self,env_name,glb_critic,glb_policy,opt_crt,opt_pol,T,lock,gamma = 0.99,max_step=100):
		self.env = gym.make(env_name)
		self.t = 0
		self.max_step = max_step
		self.T = T
		self.lock = lock
		self.gamma = gamma

		self.actor = Policy(self.env.observation_space.shape[0], self.env.action_space.n)
		self.critic = Critic(self.env.observation_space.shape[0])
		self.global_critic = global_critic
		self.global_policy = global_policy
	
	def run(self):

		# 1. Sync local from global
		self.actor.load_state_dict(self.global_policy.state_dict())
		self.critic.load_state_dict(self.global_critic.state_dict())
	
		# 2. Create a rollout
		t_start = self.t		
		state   = self.env.reset() #giving us a state from the gym env.
	  done    = False
		states  = []
		actions = []
		rewards = []
		returns = []
		while not done and (self.t - t_start+1)%self.max_step !=0:
			    action = self.actor.get_action(state)
			    next_state, reward,done, _info = self.env.step(action)
			    rewards.append(reward)
			    actions.append(action)
			 		states.append(state)
			    state = next_state
					self.t  += 1					
					# lock memory
					with self.lock:
						self.T.value +=1

		# Calculate reward
		with torch.no_grad():
			if not done:			
				R = self.critic(torch.tensor(state,dtype = torch.float64)).item() #calculating the value function
			else:
				R = 0.0
		
		for i in range(len(states)-1,-1,-1):  #Reverse because this is a bellman-type of calculation (you know all your rewards from t to the end)
			  R = rewards[i] + self.gamma*R
				returns.append(R)  #TODO: test whether container should be tensor
		returns.reverse() # list of returns
		
		#Calculating gradients
		states_t = torch.tensor(states, dtypes = torch.float64)
		actions_t = torch.tensor(actions, dtypes = torch.float64)
		returns_t = torch.tensor(returns, dtypes = torch.float64)
	
		td_error = returns_t - self.critic(states_t)	
		critic_loss = F.mse_loss(td_error)
		actor_loss = -1*td_error.detach()*self.actor.log_prob(states_t, actions_t)
		## Do MEAN check! 
		total_loss = critic_loss + actor_loss

		# 3. Calculate loss 
    


# T is a global counter
# Tmax is total steps overall
# t is the local counter per process
    
    


IndentationError: ignored

In [3]:
x = np.zeros((1,3))
list_x = [x,x,x]

x_t=torch.tensor(list_x)

In [5]:
x, list_x, x_t

(array([[0., 0., 0.]]),
 [array([[0., 0., 0.]]), array([[0., 0., 0.]]), array([[0., 0., 0.]])],
 tensor([[[0., 0., 0.]],
 
         [[0., 0., 0.]],
 
         [[0., 0., 0.]]], dtype=torch.float64))

In [None]:
lst = [torch.tensor([0]),torch.tensor([0])].data.numpy
torch.tensor(lst,dtype = torch.float64)

ValueError: ignored

In [None]:
# worker process
# Input: A2C network, env, no of steps, 

# 1. 
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n


global_critic = Critic(state_dim)
global_policy = Policy(state_dim,action_dim)
global_critic.share_memory()
global_policy.share_memory()

global_opt_crt = SharedAdam(global_critic.parameters())
global_opt_pol = SharedAdam(global_policy.parameters())


global_ctr = mp.Value('i',0)
lock = mp.Lock()

pr = [mp.Process(target=test,args=(a,)) for _ in range(5)]

for p in pr:
    p.start()

    
for p in pr:
    p.join()
