<h1> Advantage Actor Critic on continuous actions </h1>


<h3> Import dependencies </h3>

In [7]:
import numpy as np
import matplotlib.pyplot as plt
import gymnasium
from tqdm import tqdm
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import torch
from torch import nn
import cloudpickle
import sklearn
from sklearn import preprocessing

<h3> Helper functions </h3>

In [8]:
def get_grads(model:torch.nn.Module):
    g = [param.grad.detach().flatten() for param in model.parameters() if param.grad is not None]
    g = torch.cat(g).norm()
    return g
def dump():
    with (open("./models/actor.pkl","wb") as f1, 
          open("./models/critic.pkl","wb") as f2, 
          open("./models/scaler.pkl","wb") as f3):
        cloudpickle.dump(actor,f1)
        cloudpickle.dump(critic,f2)
        cloudpickle.dump(scaler,f3)
def load():
    with (open("./models/actor.pkl","rb") as f1, 
          open("./models/critic.pkl","rb") as f2, 
          open("./models/scaler.pkl","rb") as f3):
        actor = cloudpickle.load(f1)
        critic = cloudpickle.load(f2)
        scaler = cloudpickle.load(f3)
        return actor,critic,scaler

<h3> Actor critic helper classes </h3>

In [9]:
class Actor:
    def __init__(self):
        self.pi = nn.Sequential(nn.Linear(2,50),
                  nn.ReLU(),
                  nn.Linear(50,2))
    def log_prob(self,state,action):
        dist =self._get_dist(state)
        return dist.log_prob(action)
    
    def sample(self,state):
        dist =self._get_dist(state)
        _sample = dist.sample((1,))
        if np.isnan(_sample):
            raise Exception("Action is nan")
    
    def _get_dist(self,state):
        mu,sigma = self.pi(state)
        sigma = torch.nn.Softplus()(sigma)
        dist = torch.distributions.Normal(mu,sigma+1e-5)
        return dist

class Critic:
    def __init(self):
        self.v = nn.Sequential(nn.Linear(2,100),
                  nn.ReLU(),
                  nn.Linear(100,1))
    
    def val(self,state):
        _val = self.v(state)
        return _val
    

<h3> Implementation of A2C </h3>

In [10]:
NUM_TRIALS = 1
NUM_EPISODES = 1000
ALPHA_V = 56e-5
ALPHA_PI = 1e-5
env = gymnasium.make("MountainCarContinuous-v0")
returns = np.zeros((NUM_TRIALS,NUM_EPISODES))
ep_lens = np.zeros((NUM_TRIALS,NUM_EPISODES))
grads = [[],[]]
losses = [[],[]]
for i in range(NUM_TRIALS):
    actor = Actor()
    critic = Critic()
    scaler = sklearn.preprocessing.StandardScaler()
    scaler.fit(np.array([env.observation_space.sample() for _ in range(10000)]))
    mse_loss = nn.MSELoss()
    pi_op = torch.optim.SGD(actor.pi.parameters(),lr=ALPHA_PI)
    v_op = torch.optim.SGD(critic.v.parameters(),lr=ALPHA_V)
    for j in tqdm(range(NUM_EPISODES)):
        ret = 0.
        ep_len=0.
        state,*_ = env.reset()
        state = np.squeeze(scaler.transform([state]))
        state = torch.from_numpy(state).detach().float()
        while True:
            with torch.no_grad():
                action = actor.sample(state)
            next_state,reward,terminated,truncated,*_ = env.step(action)
            next_state = np.squeeze(scaler.transform([next_state]))
            next_state = torch.from_numpy(next_state).detach().float()
            with torch.no_grad():
                target = torch.tensor([reward]) + (critic.val(next_state) if not terminated else 0)
            v_op.zero_grad()
            est_v = critic.val(state)
            loss = mse_loss(target,est_v)
            loss.backward()
            v_op.step()
            losses[0].append(loss.detach().item())
            grads[0].append(get_grads(critic.v))
            with torch.no_grad():
                critic_est = torch.tensor([reward]) + (critic.val(next_state) if not terminated else 0) - critic.val(state)
            pi_op.zero_grad()
            prob = actor.log_prob(state,action)
            loss = critic_est*prob*-1 #The -1 is needed coz pytorch always does SGD and not ascent
            loss.backward()
            pi_op.step()
            losses[1].append(loss.detach().item())
            grads[1].append(get_grads(actor.pi))
            state = next_state
            ret+=reward
            ep_len+=1
            if terminated or truncated:
                if terminated:
                    print("Achieved target in ",ep_len)
                break
        returns[i,j] = ret
        ep_lens[i,j] = ep_len
        

AttributeError: 'Critic' object has no attribute 'v'

<h3> Plot metrics </h3>

In [None]:
fig:Figure
ax1:Axes
ax2:Axes
ax3:Axes
ax4:Axes
ax5:Axes
ax6:Axes
fig,((ax1,ax2),(ax3,ax4),(ax5,ax6)) = plt.subplots(3,2)
fig.set_figwidth(20)
fig.set_figheight(20)
fig.tight_layout(pad=5.0)
ax1.plot(np.mean(returns[:][:j],axis=0))
ax1.set_ylabel(f"Mean return {NUM_TRIALS} trials")
ax1.set_xlabel("Episode")
ax1.set_title("Mean Return")
ax2.plot(np.mean(ep_lens[:][:j],axis=0))
ax2.set_ylabel(f"Mean episode length {NUM_TRIALS} trials")
ax2.set_xlabel("Episode")
ax2.set_title("Mean Episode Length")
ax3.plot(grads[0])
ax3.set_ylabel(f"Grad Norm")
ax3.set_xlabel("Grad Index")
ax3.set_title("Critic Grad Norm")
ax4.plot(losses[0])
ax4.set_ylabel(f"Loss")
ax4.set_xlabel("Loss Index")
ax4.set_title("Critic Loss")
ax5.plot(grads[1])
ax5.set_ylabel(f"Grad Norm")
ax5.set_xlabel("Grad Index")
ax5.set_title("Actor Grad Norm")
ax6.plot(losses[1])
ax6.set_ylabel(f"Loss")
ax6.set_xlabel("Loss Index")
ax6.set_title("Actor Loss")

In [None]:
list(critic.v.parameters())[0]

In [None]:
list(actor.pi.parameters())[0]

In [None]:
returns[:][:j]

In [None]:
dump()

<h3> Evaluate in human render mode </h3>

In [None]:
actor,critic,scaler = load()
env = gymnasium.make("MountainCarContinuous-v0",render_mode="human")
state,*_ = env.reset()
for _ in range(1):
    state = np.squeeze(scaler.transform([state]))
    state = torch.from_numpy(state).detach().float()
    with torch.no_grad():
        action = actor.sample(state)
    next_state,reward,terminated = env.step(action)
    state = next_state