<h1> Advantage Actor Critic on continuous actions </h1>


<h3> Import dependencies </h3>

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import gymnasium
from tqdm import tqdm
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import torch
from torch import nn
import cloudpickle
import sklearn
from sklearn import preprocessing
torch.set_grad_enabled(True) 

<torch.autograd.grad_mode.set_grad_enabled at 0x137998610>

<h3> Helper functions </h3>

In [2]:
def get_grads(model:torch.nn.Module):
    g = [param.grad.flatten() for param in model.parameters() if param.grad is not None]
    g = torch.cat(g).norm()
    return g

In [3]:
def dump():
    with (open("./models/actor.pkl","wb") as f1, 
          open("./models/critic.pkl","wb") as f2):
        cloudpickle.dump(actor,f1)
        cloudpickle.dump(critic,f2)

In [4]:
def load():
    with (open("./models/actor.pkl","rb") as f1, 
          open("./models/critic.pkl","rb") as f2):
        actor = cloudpickle.load(f1)
        critic = cloudpickle.load(f2)
        return actor,critic

<h3> Actor critic helper classes </h3>

In [5]:
class Actor:
    def __init__(self):
        self.pi = nn.Sequential(nn.Linear(2,50),
                  nn.ReLU(),
                  nn.Linear(50,2))
        
    def log_prob(self,state,action):
        dist =self._get_dist(state)
        return dist.log_prob(action)
    
    def sample(self,state):
        dist =self._get_dist(state)
        _sample = dist.sample((1,))
        return _sample
    
    def _get_dist(self,state):
        mu,sigma = self.pi(state)
        sigma = torch.nn.Softplus()(sigma) + 1e-1
        dist = torch.distributions.Normal(mu,sigma)
        return dist

In [6]:
class Critic:
    def __init__(self):
        self.v = nn.Sequential(nn.Linear(2,100),
                  nn.ReLU(),
                  nn.Linear(100,1))
    
    def val(self,state):
        _val = self.v(state)
        return _val
    

<h3> Implementation of A2C </h3>

In [5]:
NUM_EPISODES = 1000
ALPHA_V = ALPHA_PI = 0.1
GAMMA = 1
env = gymnasium.make("MountainCarContinuous-v0")
mse_loss = nn.MSELoss()
actor = Actor()
critic = Critic()
pi_op = torch.optim.AdamW(actor.pi.parameters(),lr=ALPHA_PI)
v_op = torch.optim.AdamW(critic.v.parameters(),lr=ALPHA_V)
returns = []
ep_lens = []
for j in range(NUM_EPISODES):
    log_probs = []
    values = []
    rewards = []
    state,*_ = env.reset()
    state = torch.as_tensor(state).float()
    while True:
        action = actor.sample(state)
        action = torch.clamp(action,env.action_space.low[0],env.action_space.high[0])
        next_state,reward,terminated,truncated,*_ = env.step(action)
        next_state = torch.as_tensor(next_state).float()
        log_prob = actor.log_prob(state,action)
        value = critic.val(state)
        log_probs.append(log_prob)
        rewards.append(reward)
        values.append(value)
        state = next_state
        if terminated or truncated:
            ret = np.sum(rewards)
            if terminated:
                print("Achieved target in",len(values),"steps at ep",j,"with ret",ret)
            break
    
    returns.append(np.sum(rewards))
    ep_lens.append(ep_len)
    target = torch.tensor([reward],dtype=torch.float) + GAMMA * (critic.val(next_state) if not terminated else 0)
    v_op.zero_grad()
    est_v = critic.val(state)
    loss = mse_loss(target.detach(),est_v)
    loss.backward()
    v_op.step()
    critic_est = torch.tensor([reward],dtype=torch.float) + GAMMA*(critic.val(next_state) if not terminated else 0) - critic.val(state)
    pi_op.zero_grad()
    prob = actor.log_prob(state,action)
    loss = critic_est.detach()*prob[0]*-1 
    loss.backward()
    pi_op.step()
        

KeyboardInterrupt: 

In [None]:
fig:Figure
ax1:Axes
ax2:Axes
ax3:Axes
ax4:Axes
ax5:Axes
ax6:Axes
fig,((ax1,ax2),(ax3,ax4),(ax5,ax6),(ax7,ax8)) = plt.subplots(4,2)
fig.set_figwidth(20)
fig.set_figheight(20)
fig.tight_layout(pad=5.0)
ax1.plot(returns[:j])
ax1.set_ylabel(f"Mean return {NUM_TRIALS} trials")
ax1.set_xlabel("Episode")
ax1.set_title("Mean Return")
ax2.plot(ep_lens[:j])
ax2.set_ylabel(f"Mean episode length {NUM_TRIALS} trials")
ax2.set_xlabel("Episode")
ax2.set_title("Mean Episode Length")
ax3.plot(grads[0])
ax3.set_ylabel(f"Grad Norm")
ax3.set_xlabel("Grad Index")
ax3.set_title("Critic Grad Norm")
ax4.plot(losses[0])
ax4.set_ylabel(f"Loss")
ax4.set_xlabel("Loss Index")
ax4.set_title("Critic Loss")
ax5.plot(grads[1])
ax5.set_ylabel(f"Grad Norm")
ax5.set_xlabel("Grad Index")
ax5.set_title("Actor Grad Norm")
ax6.plot(losses[1])
ax6.set_ylabel(f"Loss")
ax6.set_xlabel("Loss Index")
ax6.set_title("Actor Loss");
ax7.plot(entropies)
ax7.set_ylabel(f"Entropy")
ax7.set_xlabel("Index")
ax7.set_title("Entropy Loss");
ax8.plot(rewards)
ax8.set_ylabel(f"Rewards")
ax8.set_xlabel("Index")
ax8.set_title("Rewards");

In [None]:
list(critic.v.parameters())[0]

In [None]:
list(actor.pi.parameters())[0]

In [None]:
returns[:,:j]

In [None]:
dump()

<h3> Evaluate in human render mode </h3>

In [None]:
actor,critic,scaler = load()
env = gymnasium.make("MountainCarContinuous-v0",render_mode="human")
for _ in range(10):
    state,*_ = env.reset()
    while True:
        state = np.squeeze(scaler.transform([state]))
        state = torch.from_numpy(state).detach().float()
        with torch.no_grad():
            action = actor.sample(state)
        next_state,reward,terminated,truncated,*_ = env.step(action)
        if terminated or terminated:
            break
        state = next_state