In [None]:
from IPython import display
from networks import PolicyNetwork, ValueNetwork
from torch import nn, optim
from torch.distributions import Normal
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from utils import check_nan, GAE, get_env, get_weights, norm_adv, test_agent, StateNormalizer
import gymnasium as gym
import numpy as np
import os
import torch

In [2]:
VIDS = './vids'
WEIGHTS = "./weights"
os.makedirs(VIDS, exist_ok=True)
os.makedirs(WEIGHTS, exist_ok=True)

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

# VPG with GAE and Actor-Critic

In [None]:
def VPG(env:gym.Env, device:torch.device, actor:PolicyNetwork, critic:ValueNetwork, 
        α:float=1e-4, γ:float=0.99, λ:float=0.95, epochs:int=400, batch_size:float=5_000, 
        mini_batch_size:int=32, save_freq:int=100) -> None:
    """Implementation of a modfied version of OpenAI's VPG that uses GAE
    and actor-critic framework to train and optimize an agent to learn
    a specified environment.

    Parameters:
        - env (gym.Env) : the environment to simulate.
        - device (torch.device) : the device to put tensors on.
        - actor (PolicyNetwork) : the neural net that learns a policy.
        - critic (ValueNetwork) : the neural net that learns the value 
                                  function to evaluate a learned policy.
        - α (float) : the learning-rate for the Adam optimizer. Default
                      is 1e-4.
        - γ (float) : the discount factor for GAE. Default is 0.99.
        - λ (float) : the bias-variance tradeoff weight for GAE. Default
                      is 0.95.
        - epochs (int) : the total number of epochs to simulate. Default is 
                         400.
        - batch_size (int) : the total number (state, action, reward, done)
                             tuples to collect for estimating the policy
                             gradient and value function. Default is 5_000.
        - mini_batch_size (int) : the number of (state, action, reward, done)
                                  tuples to collect for optimizing the actor
                                  and critic weights/parameters. Default is
                                  32.
        - save_freq (int) : the number of epochs between successive savings
                            of the actor and critic networks. Default is 100.

    Returns:
        - None
    """

    actor.baseMLP.register_forward_hook(check_nan)
    
    recent_weights = get_weights(dir=WEIGHTS, device=device)
    if recent_weights is None:
        print("No recent weights found. Starting from scratch.")
        display.clear_output(wait=True)
    else:
        print("Loading most recent weights")
        display.clear_output(wait=True)
        actor.load_state_dict(recent_weights['actor'])
        critic.load_state_dict(recent_weights['critic'])

    gae = GAE(gamma=γ, lamb=λ)
    state_normalizer = StateNormalizer(state_dim=env.observation_space.shape[0])
    π_opt = optim.Adam(actor.parameters(), lr=α)                # Actor/Policy optimizer
    V_opt = optim.Adam(critic.parameters(), lr=α)               # Critic/State Value function optimizer
    MSE = nn.MSELoss()

    pbar = tqdm(iterable=range(1, epochs+1), desc='Epochs', position=0)

    for epoch in range(1, epochs+1):

        states, actions, rewards, state_vals, dones = [], [], [], [], []
        num_states = 0
        obs, info = env.reset()                                                             # s_i
        state = state_normalizer(obs)
        done = False
        
        # Collect a trajectory of length K batch_size
        while num_states < batch_size :         
            state_tensor = (torch.tensor(state, dtype=torch.float32)).to(device)

            with torch.no_grad():
                action, log_prob = actor.act(state_tensor)                                    # a_i
                state_val = (critic(state_tensor)).detach().item()

            next_obs, reward, done, trunc, info = env.step(action.detach().cpu().numpy())     #s_{i+1}, r_i

            states += [state]
            actions += [action.detach().cpu().numpy()]
            rewards += [reward]
            state_vals += [state_val]
            dones += [done]
            num_states += 1

            if done:
                obs, info = env.reset()
                state = state_normalizer(obs)
                done = False
            else:
                state = state_normalizer(next_obs) 

        # Add value of state s_K
        if done:
            values += [0.0]
        else:
            with torch.no_grad():
                next_state = state_normalizer(next_obs)
                last_state_tensor = torch.tensor(next_state, dtype=torch.float32).to(device)
                state_vals += [(critic(last_state_tensor)).detach().item()]

        advantages = gae(rewards=rewards, state_vals=state_vals, dones=dones)
        returns = [advantage + state_vals[i] for i, advantage in enumerate(advantages)]

        states_tensor = torch.tensor(np.array(states), dtype=torch.float32).to(device)
        actions_tensor = torch.tensor(np.array(actions), dtype=torch.float32).to(device)
        returns_tensor = torch.tensor(np.array(returns), dtype=torch.float32).to(device)
        advantages_tensor = norm_adv(advantages, return_tensor=True, device=device)
        trajectories = TensorDataset(states_tensor, actions_tensor, advantages_tensor, returns_tensor)
        trajectory_loader = DataLoader(dataset=trajectories, batch_size=mini_batch_size, shuffle=True)
        num_mini_batches = len(trajectory_loader)

        # Actor update
        pseudolosses = np.zeros(shape=num_mini_batches)
        batch_bar = tqdm(iterable=range(1, num_mini_batches+1), desc='Actor Mini Batches', position=0)
        for i, (state_batch, action_batch, advantage_batch, _) in enumerate(trajectory_loader):
            μs, σs = actor(state_batch)
            distrib = Normal(loc=μs, scale=σs)
            log_prob_batch = (distrib.log_prob(action_batch)).sum(dim=-1)

            pseudoloss = -((log_prob_batch * advantage_batch).mean())
            pseudolosses[i] = pseudoloss.detach().cpu().item()

            π_opt.zero_grad()
            pseudoloss.backward()
            π_opt.step()

            batch_bar.set_postfix_str(f"Actor Pseudoloss:{pseudolosses[i]:.5e}")
            batch_bar.update()
        display.clear_output(wait=True)

        # Critic update
        losses = np.zeros(shape=len(trajectory_loader))
        batch_bar = tqdm(iterable=range(1, num_mini_batches+1), desc='Critic Mini Batches', position=0)
        for i, (state_batch, _, _, returns_batch) in enumerate(trajectory_loader):
            state_val_batch = (critic(state_batch)).squeeze(dim=-1)
            loss = MSE(state_val_batch, returns_batch)
            losses[i] = loss.detach().cpu().item()

            V_opt.zero_grad()
            loss.backward()
            V_opt.step()

            batch_bar.set_postfix_str(f"Critic MSE Loss:{losses[i]:.5e}")
            batch_bar.update()
        display.clear_output(wait=True)

        pbar.set_postfix_str(f"Averages - Actor Pseudoloss:{pseudolosses.mean():.5e} Critic MSE Loss:{losses.mean():.5e}")
        pbar.update()
        display.clear_output(wait=True)

        if epoch % save_freq == 0 or (epoch == epochs):
            actor_filename = '/actor_weights_vpg' + f"_epoch{epoch}.pth"
            critic_filename = '/critic_weights_vpg' + f"_epoch{epoch}.pth"
            torch.save(actor.state_dict(), WEIGHTS+actor_filename)
            torch.save(critic.state_dict(), WEIGHTS+critic_filename)
        
    env.close()

In [None]:
robot = get_env("Humanoid-v5", vid_dir=VIDS)
policy_net = (PolicyNetwork(robot.observation_space.shape[0], robot.action_space.shape[0])).to(device)
value_net = (ValueNetwork(robot.observation_space.shape[0], hidden_size=100)).to(device)

  logger.warn(


In [None]:
VPG(env=robot, device=device, actor=policy_net, critic=value_net)

Epochs: 100%|██████████| 5/5 [00:01<00:00,  2.90it/s, Averages - Actor Pseudoloss:-7.66850e-01 Critic MSE Loss:3.03769e+02]
Actor Mini Batches: 100%|██████████| 4/4 [00:00<00:00, 20.59it/s, Actor Pseudoloss:-2.71356e+00] 
Critic Mini Batches: 100%|██████████| 4/4 [00:00<00:00, 23.76it/s, Critic MSE Loss:3.24883e+02] 
