In [1]:
#| default_exp train_vpg

In [2]:
#| export
import torch
import torch.nn as nn
import numpy as np
import gym

import vpg.vpg_core as core
from vpg.vpg_buffer import VPGBuffer

In [3]:
# HYPERPARAMETERS

ENV_ID = "HalfCheetah-v2"
HIDDEN_SIZES = [64, 64]  # for both pi and v
PI_LR = 1e-3
V_LR = 1e-3
GAMMA = 0.99
GAE_LAMBDA = 0.98
ACTIVATION_FN = nn.Tanh
TRAIN_V_ITERS = 20
MAX_EPISODE_LENGTH = 1000
STEPS_PER_EPOCH = 4000
EPOCHS = 1 # 500

SEED = 0

In [4]:
# Set SEEDS

torch.manual_seed(SEED)
np.random.seed(SEED)

In [5]:
# Create the env
env = gym.make(ENV_ID)

# Create the Actor-Critic
ac = core.MLPActorCritic(env.observation_space, env.action_space, HIDDEN_SIZES, activation=nn.Tanh)

# Create optimizers for the actor and critic models
pi_optimizer = torch.optim.Adam(ac.pi.parameters(), lr=PI_LR)
vf_optimizer = torch.optim.Adam(ac.v.parameters(), lr=V_LR)

# Create the VPGBuffer Object
buf = VPGBuffer(env.observation_space.shape, env.action_space.shape, STEPS_PER_EPOCH, GAMMA, GAE_LAMBDA)

# Count Paramaters
var_counts = tuple(core.count_params(module) for module in [ac.pi, ac.v])
print("Number of Parameters. PI: {} V: {}".format(*var_counts))



Number of Parameters. PI: 5708 V: 5377


### Experiment

In [6]:
# Collect experience by running the agent in the environment for a number of steps.
# This is called an epoch. An epoch will to transverse N trajectories of agent-environment interactions.
# Per Epoch we'll run the agent for `STEPS_PER_EPOCH` number of steps

obs = env.reset()

ep_rets, ep_lens = [], []
ep_ret, ep_len = 0, 0

for step in range(STEPS_PER_EPOCH):
    # get action from the policy
    act, val, logp = ac.step(torch.as_tensor(obs, dtype=torch.float32))
    
    # step through the env with the action from the policy
    next_obs, rew, done, _ = env.step(act)
    ep_ret += rew
    ep_len += 1
    
    # store step in the VPG buffer
    buf.store(obs, act, rew, val, logp)
    
    # update the obs
    obs = next_obs
    
    # check for terminal or epoch end
    timeout = (ep_len == MAX_EPISODE_LENGTH)
    terminal = (done or timeout)
    epoch_end = (step == (STEPS_PER_EPOCH - 1))
    
    # if trajectory ends or epoch ends
    if terminal or epoch_end:
        # Log trajectory cut-off byb epoch end
        if epoch_end and not terminal:
            print(f"WARNING: Trajectory cut off by epoch end at step {ep_len}")
        
        # bootstrap value target if trajectory didn't reach terminal state
        if timeout or (epoch_end and not done):
            _, v, _ = ac.step(torch.as_tensor(obs, dtype=torch.float32))
        else:
            v = 0
        
        # Finish a trajectory
        buf.finish_path(v)
        
        # only save ep_rew and ep_len if trajectory finished
        if terminal:
            ep_lens.append(ep_len)
            ep_rets.append(ep_ret)
        ep_len, ep_ret = 0, 0
        obs = env.reset()

In [7]:
# At the end of an epoch perform one update step
# An update step will perform a step of gradient ascent on the policy performance of the PI (actor) network
# This is also equivalent to one step of gradient descent of the loss (-ve policy performance)
# An update step will also perform N gradient descent steps to fit the value network (V) on its MSE loss

# ------------------- Compute loss of PI network
# ------------------- loss = -Mean(logp*adv)

# get the observations, action taken, returns, advantages, ... for an epoch 
data = buf.get()

# obs, act, ret, adv, logp = data
obs, act, adv, logp_old = data["obs"], data["act"], data["adv"], data["logp"]

pi, logp = ac.pi(obs, act)
loss_pi = -(logp*adv).mean()

# Compute KL divergence and entropy
approx_kl = (logp_old - logp).mean().item()
ent = pi.entropy().mean().item()
pi_info = dict(kl=approx_kl, ent= ent)

logp_old, logp, loss_pi, -(logp_old*adv).mean(), approx_kl, ent

(tensor([-4.3119, -4.7736, -5.2322,  ..., -4.6005, -5.5176, -6.3941]),
 tensor([-4.3119, -4.7736, -5.2322,  ..., -4.6005, -5.5176, -6.3941],
        grad_fn=<SumBackward1>),
 tensor(0.0192, grad_fn=<NegBackward>),
 tensor(0.0192),
 -7.152557435219364e-10,
 0.9189439415931702)

In [8]:
(logp_old - logp).sum()

tensor(-2.8610e-06, grad_fn=<SumBackward0>)

In [9]:
# --------------- Compute the loss of the value function 
obs, ret = data["obs"], data["ret"]
loss_v = ((ac.v(obs) - ret)**2).mean() # the MSE of the V prediction and the actual reward-to-go
loss_v

tensor(993.5066, grad_fn=<MeanBackward0>)

Define a function to compute loss of the pi network

In [10]:
#| export
def compute_loss_pi(data, ac):
    # obs, act, ret, adv, logp = data
    obs, act, adv, logp_old = data["obs"], data["act"], data["adv"], data["logp"]

    pi, logp = ac.pi(obs, act)
    loss_pi = -(logp*adv).mean()

    # Compute KL divergence and entropy
    approx_kl = (logp_old - logp).mean().item()
    ent = pi.entropy().mean().item()
    pi_info = dict(kl=approx_kl, ent=ent)
    
    return loss_pi, pi_info

Define a function to compute the loss of the value network

In [11]:
#| export
def compute_loss_v(data, ac):
    # obs, act, ret, adv, logp = data
    obs, ret = data["obs"], data["ret"]
    return ((ac.v(obs) - ret)**2).mean()

In [12]:
# Perform gradient descent on the loss of the PI and Value networks

pi_optimizer.zero_grad()
loss_p, pi_info = compute_loss_pi(data, ac)
loss_p.backward()
pi_optimizer.step()
# the loss of the PI network is only valid for one gradient descent step, on the data generated by that 
# policy (the previous parameters). The value of the loss cannot be use to track the performance of the
# policy

print(f"Old Value function loss: {compute_loss_v(data, ac)}")
for _  in range(TRAIN_V_ITERS):
    vf_optimizer.zero_grad()
    loss_v = compute_loss_v(data, ac)
    loss_v.backward()
    vf_optimizer.step()
print(f"Final Value function loss: {compute_loss_v(data, ac)}")

Old Value function loss: 993.506591796875
Final Value function loss: 963.981689453125


Define a function to update the policy and the value function at the end of each epoch

In [13]:
#| export
def update(data, ac, pi_optimizer, vf_optimizer, train_v_iters):
    # Get loss and info values before update
    loss_pi_old, pi_info_old = compute_loss_pi(data, ac)
    loss_pi_old = loss_pi_old.item()
    loss_v_old = compute_loss_v(data, ac).item()
    
    # Train the policy with a single step of gradient descent
    pi_optimizer.zero_grad()
    loss_pi, pi_info = compute_loss_pi(data, ac)
    loss_pi.backward()
    pi_optimizer.step()

    # Fit value function
    for _  in range(train_v_iters):
        vf_optimizer.zero_grad()
        loss_v = compute_loss_v(data, ac)
        loss_v.backward()
        vf_optimizer.step()
        
    # Log changes from update
    kl, ent = pi_info["kl"], pi_info_old["ent"]
    return dict(LossPi=loss_pi_old, LossV=loss_v_old, KL=kl, Entropy=ent, 
                DeltaLossPi=loss_pi.item() - loss_pi_old, DeltaLossV=loss_v.item() - loss_v_old)

In [14]:
update(data, ac, pi_optimizer, vf_optimizer, TRAIN_V_ITERS)

{'LossPi': 0.003244533436372876,
 'LossV': 963.981689453125,
 'KL': 0.010360285639762878,
 'Entropy': 0.9192747473716736,
 'DeltaLossPi': 0.0,
 'DeltaLossV': -57.24359130859375}

### Training Loop

Define the train_vpg function

In [15]:
#| export
from torch.utils.tensorboard import SummaryWriter

def train_vpg(env_fn, actor_critic, ac_kwargs, pi_lr, vf_lr,
              epochs, steps_per_epoch, gamma, gae_lambda, train_v_iters,
              max_ep_len, log_freq=10, seed=0, exp_name="vpg"):
    # set seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Create the training environment
    env = env_fn()
    
    # Create the actor-critic
    ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs)
    param_counts = tuple(core.count_params(module) for module in [ac.pi, ac.v])
    
    
    # Create optimizers for the policy and value function
    pi_optimizer = torch.optim.Adam(ac.pi.parameters(), lr=pi_lr)
    vf_optimizer = torch.optim.Adam(ac.v.parameters(), lr=vf_lr)
    
    # Create the VPG Buffer
    buffer = VPGBuffer(env.observation_space.shape, env.action_space.shape, 
                       steps_per_epoch, gamma, gae_lambda)
    
    # Tensorboard Writer
    writer = SummaryWriter(f"./logs/tensorboard/{exp_name}")
    
    # Run `epochs` number of epochs
    obs, ep_ret, ep_len = env.reset(), 0, 0
    for epoch in range(epochs):
        epoch_rets, epoch_lens = [], []
        for step in range(steps_per_epoch):
            # get action from the policy
            act, val, logp = ac.step(torch.as_tensor(obs, dtype=torch.float32))

            # step through the env with the action from the policy
            next_obs, rew, done, _ = env.step(act)
            ep_ret += rew
            ep_len += 1

            # store step in the VPG buffer
            buffer.store(obs, act, rew, val, logp)

            # update the obs
            obs = next_obs

            # check for terminal or epoch end
            timeout = (ep_len == max_ep_len)
            terminal = (done or timeout)
            epoch_end = (step == (steps_per_epoch - 1))

            # if trajectory ends or epoch ends
            if terminal or epoch_end:
                # Log trajectory cut-off byb epoch end
                if epoch_end and not terminal:
                    pass
                    # print(f"WARNING: Trajectory cut off by epoch end at step {ep_len}")

                # bootstrap value target if trajectory didn't reach terminal state
                if timeout or epoch_end: # change to if not done
                    _, v, _ = ac.step(torch.as_tensor(obs, dtype=torch.float32))
                else:
                    v = 0

                # Finish a trajectory
                buffer.finish_path(v)

                # only save ep_rew and ep_len if trajectory finished
                if terminal:
                    epoch_lens.append(ep_len)
                    epoch_rets.append(ep_ret)
                obs, ep_len, ep_ret = env.reset(), 0, 0
            
        # Perform VPG update
        data = buffer.get()
        res = update(data, ac, pi_optimizer, vf_optimizer, train_v_iters)
        
        # Log Result
        if (epoch % log_freq == 0) or (epoch == epochs - 1):
            print(f"Epoch: {epoch} Mean Reward: {np.mean(epoch_rets):.2f}, Mean Length: {np.mean(epoch_lens):.1f} LossV: {res['LossV']:.3f}")
            
        writer.add_scalar("Mean Return", np.mean(epoch_rets), global_step=epoch)
        writer.add_scalar("Mean Length", np.mean(epoch_lens), global_step=epoch)
        writer.add_scalar("Value Loss", res['LossV'], global_step=epoch)
    
    return ac

In [16]:
import time

train_kwargs = {"env_fn": lambda: gym.make(ENV_ID), 
                "actor_critic": core.MLPActorCritic,
                "ac_kwargs": {"hidden_sizes": HIDDEN_SIZES, "activation": ACTIVATION_FN},
                "pi_lr": PI_LR,
                "vf_lr": V_LR,
                "epochs": EPOCHS,
                "steps_per_epoch": STEPS_PER_EPOCH,
                "gamma": GAMMA,
                "gae_lambda": GAE_LAMBDA,
                "train_v_iters": TRAIN_V_ITERS,
                "max_ep_len": MAX_EPISODE_LENGTH, 
                "log_freq": 10, 
                "seed": SEED, 
                "exp_name": f"vpg_{ENV_ID}_{time.time()}"}

In [17]:
model = train_vpg(**train_kwargs)

Epoch: 0 Mean Reward: -293.87, Mean Length: 1000.0 LossV: 939.325


In [18]:
# SAVE THE MODEL
torch.save(model.state_dict(), f"./logs/checkpoints/{ENV_ID}_{time.time()}.pth")

In [19]:
# LOAD A SAVED
ac_kwargs = {"hidden_sizes": HIDDEN_SIZES, "activation": ACTIVATION_FN}
model = core.MLPActorCritic(env.observation_space, env.action_space, **ac_kwargs)

model.load_state_dict(torch.load(f"./logs/checkpoints/{ENV_ID}.pth"))

<All keys matched successfully>

In [20]:
# # RECORD AGENT IN ENVIRONMENT
# from colabgymrender.recorder import Recorder

# eval_env = gym.make(ENV_ID)
# env = Recorder(env, "./logs/videos", fps=30)

# NUM_EPISODES = 10
# done = False
# obs = env.reset()

# for _ in range(NUM_EPISODES):
#     while not done:
#         act = model.act(torch.as_tensor(obs, dtype=torch.float32))
#         obs, rew, done, _ = env.step(act)
#     done = False
#     obs = env.reset()
# # env.close()
# env.play()

In [21]:
# EXPORT VIDEO
# video_path attempt to get the last recorded video.
# If video_path is not correct, provide the correct path
# video_path = sorted(os.listdir("./logs/videos"), reverse=True)[0]

# !ffmpeg -i {video_path} -vcodec h264 replay.mp4
# !rm {video_path}

In [22]:
#| hide
# import nbdev

# nbdev.nbdev_export()