<a href="https://colab.research.google.com/github/moodlep/MLC_A3C/blob/main/a3c.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install box2d-py
!pip3 install gym[Box_2D]

In [None]:
import os
import Box2D
import pyglet
import imageio
os.system("Xvfb :1 -screen 0 1024x768x24 &")
os.environ['DISPLAY'] = ':1'


In [None]:
import gym
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp


## Gym Env

In [None]:
# The env - quick test: 

env = gym.make("LunarLander-v2")

s = env.reset()

for _ in range(5): 
  a = env.action_space.sample()
  next_state, reward, done, info = env.step(a)
  print(next_state, reward, a)


In [None]:
class SharedAdam(torch.optim.Adam):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
                 weight_decay=0):
        super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        # State initialization
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['step'] = 0
                state['exp_avg'] = torch.zeros_like(p.data)
                state['exp_avg_sq'] = torch.zeros_like(p.data)

                # share in memory
                state['exp_avg'].share_memory_()
                state['exp_avg_sq'].share_memory_()

In [None]:
env.action_space

## Actor - policy NN and value NN 
### data collection -> batch
### train: calculate loss 

In [None]:

class Policy(nn.Module):
	def __init__(self, state_dim,action_dim,hidden=100):
		super(Policy, self).__init__()

		self.l1 = nn.Linear(state_dim, hidden)
		self.l2 = nn.Linear(hidden, hidden)
		self.l3 = nn.Linear(hidden,action_dim)

	def forward(self, state):
		print(state)
		q = F.leaky_relu(self.l1(state))
		q = F.leaky_relu(self.l2(q))
		return F.softmax(self.l3(q), dim = 1)
	
	def get_action(self,state):
		state = torch.tensor(state, dtype = torch.double)
		with torch.no_grad():
			pol = self.forward(state)
			dist = torch.distributions.Categorical(pol)
		return dist.sample() #returns a batch of values
	
	def log_prob(self, state, actions):
		  # Part of the loss term
			pol = self.forward(state)
			log_prob = torch.distributions.Categorical(pol).log_prob(actions)
			return log_prob
	
	def entropy(self, state):
			pol = self.forward(state)
			return torch.distributions.Categorical(pol).entropy()
    


In [None]:
# create batch of states 
batch_states = torch.rand(5, env.observation_space.shape[0])

policy = Policy(env.observation_space.shape[0], env.action_space.n)
policy(batch_states).data

In [None]:
#batch_actions = policy.get_action(batch_states)
#batch_actions

In [None]:
policy.log_prob(batch_states, batch_actions)

In [None]:
class Critic(nn.Module):
    def __init__(self, state_dim,hidden=100):
      super(Critic, self).__init__()
    
      self.l1 = nn.Linear(state_dim, hidden)
      self.l2 = nn.Linear(hidden, hidden)
      self.l3 = nn.Linear(hidden,1)

    def forward(self, state):
      q = F.leaky_relu(self.l1(state))
      q = F.leaky_relu(self.l2(q))
      return self.l3(q)



In [None]:
#testing the critic output
critic = Critic(env.observation_space.shape[0])
critic(batch_states)


## Worker

In [None]:
class ActorCriticWorker(mp.Process):
	def __init__(self,env_name,glb_critic,global_actor,opt,T,lock,gamma = 0.99,max_step=100):
		super(ActorCriticWorker, self).__init__()
		self.env = gym.make(env_name)
		self.t = 0
		self.max_step = max_step
		self.T = T
		self.lock = lock
		self.gamma = gamma
		self.opt = opt

		self.actor = Policy(self.env.observation_space.shape[0], self.env.action_space.n)
		self.critic = Critic(self.env.observation_space.shape[0])
		self.global_critic = global_critic
		self.global_actor = global_actor
	
	def run(self):

		# 1. Sync local from global - we need this for the actor: get_action()
		self.actor.load_state_dict(self.global_actor.state_dict())
		self.critic.load_state_dict(self.global_critic.state_dict())
	
		# 2. Create a rollout
		t_start = self.t		
		state   = self.env.reset() #giving us a state from the gym env.
		done    = False
		states  = []
		actions = []
		rewards = []
		returns = []
		while not done and (self.t - t_start+1)%self.max_step !=0:
				action = self.actor.get_action(state)
				next_state, reward,done, _info = self.env.step(action)
				rewards.append(reward)
				actions.append(action)
				states.append(state)
				state = next_state
				self.t  += 1					
				# lock memory
				with self.lock:
					self.T.value +=1

		# Calculate reward
		with torch.no_grad():
			if not done:			
				R = self.critic(torch.tensor(state,dtype = torch.double)).item() #calculating the value function
			else:
				R = 0.0
		
		for i in range(len(states)-1,-1,-1):  #Reverse because this is a bellman-type of calculation (you know all your rewards from t to the end)
			R = rewards[i] + self.gamma*R
			returns.append(R) 
		returns.reverse() # list of returns
		
		# 3. Calculating Loss
		states_t = torch.tensor(states, dtype = torch.double)
		actions_t = torch.tensor(actions, dtype = torch.int64)
		returns_t = torch.tensor(returns, dtype = torch.double)	

		td_error = returns_t - self.critic(states_t)	# n_batch x 1
		critic_loss = F.mse_loss(td_error) # 1 x 1
		actor_loss = -1.0*td_error.detach()*self.actor.log_prob(states_t, actions_t) # n_batch x 1
		# Take mean of the actor and critic loss 
		total_loss = (critic_loss + actor_loss.mean())

		# 4. Calculate grad and update optimiser
		self.opt.zero_grad()
		loss.backward()

		# align global grads to local grads
		for gp, lp in zip(self.global_critic.parameters(), self.critic.parameters()):
			gp._grad = lp.grad
		for gp, lp in zip(self.global_actor.parameters(), self.actor.parameters()):
			gp._grad = lp.grad

		# take a step!
		self.opt.step()



# T is a global counter
# Tmax is total steps overall
# t is the local counter per process
    
    


In [None]:
x = np.zeros((1,3))
list_x = [x,x,x]

x_t=torch.tensor(list_x)

In [None]:
x, list_x, x_t

In [None]:
# worker process
# Input: A2C network, env, no of steps, 

# 1. 
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n


global_critic = Critic(state_dim)
global_actor = Policy(state_dim,action_dim)
global_critic.share_memory()
global_actor.share_memory()
n_workers = 10


env_name = "LunarLander-v2"

global_opt = SharedAdam(list(list(global_critic.parameters()) + list(global_actor.parameters())))




global_ctr = mp.Value('i',0)
lock = mp.Lock()

#pr = [ ActorCriticWorker(env_name,global_critic,global_actor,global_opt,global_ctr,lock) for _ in range(n_workers)]
pr = []
for i in range(mp.cpu_count()):
  pr.append(ActorCriticWorker(env_name,global_critic,global_actor,global_opt,global_ctr,lock))

for p in pr:
    print(type(p))
    p.start()

    
for p in pr:
    p.join()
