In [2]:
# 1*. Implement A2C using gym mujoco Swimmer environment.
# 2. Implement domain randomization in Swimmer environment.
# 3*. Use normal distribution for continuous action space.
# 4. Implement n-step return TD error without lambda GAE.
# 5. Use transformer network.
# 6. Excellent episode returns.
# 7*. Demonstrate the agent with best weights.

### gymnasium = 0.29.1
### mujoco = 2.3.7

from __future__ import annotations
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.distributions as D
from tqdm import tqdm
import gymnasium as gym

#### Action Space: Box(-1 ~ 1, (6,), float32)
0: Torque applied on the thigh rotor, "thigh_joint"  
1: Torque applied on the leg rotor, "leg_joint"  
2: Torque applied on the foot rotor, "foot_joint"  
3: Torque applied on the left thigh rotor, "thigh_left_joint"  
4: Torque applied on the left leg rotor, "leg_left_joint"  
5: Torque applied on the left foot rotor, "foot_left_joint"  

#### Observation Space: Box(-Inf, Inf, (17,), float64)
X: x-coordinate of the torso, "rootx", slide, position(m)  
0: z-coordinate of the torso, "rootz", slide, position(m)  
1: angle of the torso, "rooty", hinge, angle(rad)  
2: angle of the thigh joint, "thigh_joint", hinge, angle(rad)  
3: angle of the leg joint, "leg_joint", hinge, angle(rad)  
4: angle of the foot joint, "foot_joint", hinge, angle(rad)  
5: angle of the left thigh joint, "thigh_left_joint", hinge, angle(rad)  
6: angle of the left leg joint, "leg_left_joint", hinge, angle(rad)  
7: angle of the left foot joint, "foot_left_joint", hinge, angle(rad)  
8: velocity of the x-coordinate of the torso, "rootx", slide, velocity(m/s)  
9: velocity of the z-coordinate(height) of the torso, "rootz", slide, velocity(m/s)  
10: angular velocity of the angle of the torso, "rooty", hinge, angular velocity(rad/s)  
11: angular velocity of the thigh joint, "thigh_joint", hinge, angular velocity(rad/s)  
12: angular velocity of the leg hinge, "leg_joint", hinge, angular velocity(rad/s)  
13: angular velocity of the foot hinge, "foot_joint", hinge, angular velocity(rad/s)  
14: angular velocity of the left thigh hinge, "thigh_left_joint", hinge, angular velocity(rad/s)  
15: angular velocity of the left leg hinge, "leg_left_joint", hinge, angular velocity(rad/s)  
16: angular velocity of the left foot hinge, "foot_left_joint", hinge, angular velocity(rad/s)  

#### Rewards
healthy_reward: Every timestep that the walker is alive, it receives a fixed reward of value  
forward_reward: This reward would be positive if the walker walks forward (positive x direction).  
                forward_reward_weight * (x-coordinate before action - x-coordinate after action)/dt  
ctrl_cost: A cost for penalising the walker if it takes actions that are too large.  
           ctrl_cost_weight * sum(action^2)  
reward = healthy_reward bonus + forward_reward - ctrl_cost  

In [3]:
class A2C(nn.Module):
    """
    (Synchronous) Advantage Actor-Critic agent class

    Args:
        n_features: The number of features of the input state.
        n_actions: The number of actions the agent can take.
        device: The device to run the computations on (running on a GPU might be quicker for larger Neural Nets,
                for this code CPU is totally fine).
        critic_lr: The learning rate for the critic network (should usually be larger than the actor_lr).
        actor_lr: The learning rate for the actor network.
        n_envs: The number of environments that run in parallel (on multiple CPUs) to collect experiences.
    """

    def __init__(
        self,
        n_features: int,
        n_actions: int,
        device: torch.device,
        critic_lr: float,
        actor_lr: float,
        n_envs: int,
    ) -> None:
        """Initializes the actor and critic networks and their respective optimizers."""
        super().__init__()
        self.device = device
        self.n_envs = n_envs

        critic_layers = [
            nn.Linear(n_features, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 1),  # estimate V(s)
        ]

        actor_layers = [
            nn.Linear(n_features, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(
                32, n_actions
            ),  # estimate action logits (will be fed into a softmax later)
        ]

        # define actor and critic networks
        self.critic = nn.Sequential(*critic_layers).to(self.device)
        self.actor = nn.Sequential(*actor_layers).to(self.device)

        # define optimizers for actor and critic
        self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr)
        self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr)

    def forward(self, x: np.ndarray) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the networks.

        Args:
            x: A batched vector of states.

        Returns:
            state_values: A tensor with the state values, with shape [n_envs,].
            action_logits_vec: A tensor with the action logits, with shape [n_envs, n_actions].
        """
        x = torch.Tensor(x).to(self.device)
        state_values = self.critic(x)  # shape: [n_envs,]
        action_logits_vec = self.actor(x)  # shape: [n_envs, n_actions]
        return (state_values, action_logits_vec)

    def select_action(
        self, x: np.ndarray
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns a tuple of the chosen actions and the log-probs of those actions.

        Args:
            x: A batched vector of states.

        Returns:
            actions: A tensor with the actions, with shape [n_steps_per_update, n_envs].
            action_log_probs: A tensor with the log-probs of the actions, with shape [n_steps_per_update, n_envs].
            state_values: A tensor with the state values, with shape [n_steps_per_update, n_envs].
        """
        
        state_values, action_logits = self.forward(x)
        mu = torch.tensor([0.0])
        sigma = torch.tensor([1.0])
        normal_dist = D.Normal(mu, sigma)# implicitly uses softmax
        sampled_value = normal_dist.rsample()
        action_log_probs = normal_dist.log_prob(sampled_value)
        entropy = normal_dist.entropy()
        return (sampled_value, action_log_probs, state_values, entropy)

    def get_losses(
        self,
        rewards: torch.Tensor,
        action_log_probs: torch.Tensor,
        value_preds: torch.Tensor,
        entropy: torch.Tensor,
        masks: torch.Tensor,
        gamma: float,
        lam: float,
        ent_coef: float,
        device: torch.device,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the loss of a minibatch (transitions collected in one sampling phase) for actor and critic
        using Generalized Advantage Estimation (GAE) to compute the advantages (https://arxiv.org/abs/1506.02438).

        Args:
            rewards: A tensor with the rewards for each time step in the episode, with shape [n_steps_per_update, n_envs].
            action_log_probs: A tensor with the log-probs of the actions taken at each time step in the episode, with shape [n_steps_per_update, n_envs].
            value_preds: A tensor with the state value predictions for each time step in the episode, with shape [n_steps_per_update, n_envs].
            masks: A tensor with the masks for each time step in the episode, with shape [n_steps_per_update, n_envs].
            gamma: The discount factor.
            lam: The GAE hyperparameter. (lam=1 corresponds to Monte-Carlo sampling with high variance and no bias,
                                          and lam=0 corresponds to normal TD-Learning that has a low variance but is biased
                                          because the estimates are generated by a Neural Net).
            device: The device to run the computations on (e.g. CPU or GPU).

        Returns:
            critic_loss: The critic loss for the minibatch.
            actor_loss: The actor loss for the minibatch.
        """
        T = len(rewards)
        advantages = torch.zeros(T, self.n_envs, device=device)

        # compute the advantages using GAE
        gae = 0.0
        for t in reversed(range(T - 1)):
            td_error = (
                rewards[t] + gamma * masks[t] * value_preds[t + 1] - value_preds[t]
            )
            gae = td_error + gamma * lam * masks[t] * gae
            advantages[t] = gae

        # calculate the loss of the minibatch for actor and critic
        critic_loss = advantages.pow(2).mean()

        # give a bonus for higher entropy to encourage exploration
        actor_loss = (
            -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
        )
        return (critic_loss, actor_loss)

    def update_parameters(
        self, critic_loss: torch.Tensor, actor_loss: torch.Tensor
    ) -> None:
        """
        Updates the parameters of the actor and critic networks.

        Args:
            critic_loss: The critic loss.
            actor_loss: The actor loss.
        """
        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()


In [4]:
class ContinuousA2C(nn.Module):
    def __init__(self, n_features, n_actions, device, critic_lr, actor_lr, n_envs):
        super().__init__()
        self.device = device
        self.n_envs = n_envs

        # Critic network
        self.critic = nn.Sequential(
            nn.Linear(n_features, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, 1)  # Estimates V(s)
        ).to(device)

        # Actor network now predicts mean and standard deviation for each action
        self.actor = nn.Sequential(
            nn.Linear(n_features, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, n_actions),  # Outputs mean of the action distribution
            nn.Tanh()  # Optional: use Tanh to bound the actions
        ).to(device)

        # Additional layer to predict the standard deviation
        self.std_dev = nn.Sequential(
            nn.Linear(32, n_actions),
            nn.Softplus()  # Ensures that the standard deviation is positive
        ).to(device)

        # Optimizers
        self.critic_optim = optim.AdamW(self.critic.parameters(), lr=critic_lr)
        self.actor_optim = optim.AdamW(list(self.actor.parameters()) + list(self.std_dev.parameters()), lr=actor_lr)

    def forward(self, x):
        x = torch.Tensor(x).to(self.device)
        state_values = self.critic(x)  # shape: [n_envs,]
        action_means = self.actor(x)   # shape: [n_envs, n_actions]
        action_std_devs = self.std_dev(x)  # shape: [n_envs, n_actions]
        return state_values, action_means, action_std_devs

    def select_action(self, x):
        state_values, action_means, action_std_devs = self.forward(x)
        
        # Create normal distributions and sample actions using reparameterization trick
        dists = D.Normal(action_means, action_std_devs)
        actions = dists.rsample()  # Reparameterized sample
        log_probs = dists.log_prob(actions).sum(axis=-1)  # Sum log probs for multi-action cases
        entropy = dists.entropy().sum(axis=-1)
        
        return actions, log_probs, state_values, entropy

    def get_losses(self, rewards, log_probs, state_values, entropy, masks, gamma, lam, ent_coef):
        # Implementation of the loss computation goes here.
        # It would be similar to the discrete version but needs to account for the continuous nature of actions.
        # Use log_probs and entropy from the new select_action method.
        pass

    def update_parameters(self, critic_loss, actor_loss):
        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()



In [5]:
def train_A2C(envs, 
              n_updates = 1000, 
              n_steps_per_update = 128, 
              gamma = 0.999, 
              lam = 0.95, 
              ent_coef = 0.01, 
              actor_lr = 0.001, 
              critic_lr = 0.005, 
              n_envs = 10
              ):
    
    obs_shape = envs.single_observation_space.shape[0]
    action_shape = envs.single_action_space.shape[0]

    # set the device
    use_cuda = True
    if use_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    # init the agent
    agent = ContinuousA2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)
    # create a wrapper environment to save episode returns and episode lengths
    envs_wrapper = gym.wrappers.RecordEpisodeStatistics(envs, deque_size=n_envs * n_updates)

    # use tqdm to get a progress bar for training
    for sample_phase in tqdm(range(n_updates)):
        # we don't have to reset the envs, they just continue playing
        # until the episode is over and then reset automatically

        # reset lists that collect experiences of an episode (sample phase)
        ep_value_preds = torch.zeros(n_steps_per_update, n_envs, device=device)
        ep_rewards = torch.zeros(n_steps_per_update, n_envs, device=device)
        ep_action_log_probs = torch.zeros(n_steps_per_update, n_envs, device=device)
        masks = torch.zeros(n_steps_per_update, n_envs, device=device)

        # at the start of training reset all envs to get an initial state
        if sample_phase == 0:
            states, info = envs_wrapper.reset(seed=42)

        # play n steps in our parallel environments to collect data
        for step in range(n_steps_per_update):
            # select an action A_{t} using S_{t} as input for the agent
            actions, action_log_probs, state_value_preds, entropy = agent.select_action(
                states
            )

            # perform the action A_{t} in the environment to get S_{t+1} and R_{t+1}
            states, rewards, terminated, truncated, infos = envs_wrapper.step(
                actions.cpu().numpy()
            )

            ep_value_preds[step] = torch.squeeze(state_value_preds)
            ep_rewards[step] = torch.tensor(rewards, device=device)
            ep_action_log_probs[step] = action_log_probs

            # add a mask (for the return calculation later);
            # for each env the mask is 1 if the episode is ongoing and 0 if it is terminated (not by truncation!)
            masks[step] = torch.tensor([not term for term in terminated])

        # calculate the losses for actor and critic
        critic_loss, actor_loss = agent.get_losses(
            ep_rewards,
            ep_action_log_probs,
            ep_value_preds,
            entropy,
            masks,
            gamma,
            lam,
            ent_coef,
            device,
        )

        # update the actor and critic networks
        agent.update_parameters(critic_loss, actor_loss)

        # log the losses and entropy
        critic_losses.append(critic_loss.detach().cpu().numpy())
        actor_losses.append(actor_loss.detach().cpu().numpy())
        entropies.append(entropy.detach().mean().cpu().numpy())
    return agent, envs_wrapper


In [6]:
def plot_result():

    """ plot the results """

    # %matplotlib inline

    rolling_length = 20
    fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 5))
    fig.suptitle(
        f"Training plots for {agent.__class__.__name__} in the Swimmer-v4 environment \n \
                (n_envs={n_envs}, n_steps_per_update={n_steps_per_update}, randomize_domain={randomize_domain})"
    )

    # episode return
    axs[0][0].set_title("Episode Returns")
    episode_returns_moving_average = (
        np.convolve(
            np.array(envs_wrapper.return_queue).flatten(),
            np.ones(rolling_length),
            mode="valid",
        )
        / rolling_length
    )
    axs[0][0].plot(
        np.arange(len(episode_returns_moving_average)) / n_envs,
        episode_returns_moving_average,
    )
    axs[0][0].set_xlabel("Number of episodes")

    # entropy
    axs[1][0].set_title("Entropy")
    entropy_moving_average = (
        np.convolve(np.array(entropies), np.ones(rolling_length), mode="valid")
        / rolling_length
    )
    axs[1][0].plot(entropy_moving_average)
    axs[1][0].set_xlabel("Number of updates")


    # critic loss
    axs[0][1].set_title("Critic Loss")
    critic_losses_moving_average = (
        np.convolve(
            np.array(critic_losses).flatten(), np.ones(rolling_length), mode="valid"
        )
        / rolling_length
    )
    axs[0][1].plot(critic_losses_moving_average)
    axs[0][1].set_xlabel("Number of updates")


    # actor loss
    axs[1][1].set_title("Actor Loss")
    actor_losses_moving_average = (
        np.convolve(np.array(actor_losses).flatten(), np.ones(rolling_length), mode="valid")
        / rolling_length
    )
    axs[1][1].plot(actor_losses_moving_average)
    axs[1][1].set_xlabel("Number of updates")

    plt.tight_layout()
    plt.show()


In [7]:
def save_weights():

    if not os.path.exists("weights"):
        os.mkdir("weights")
    """ save network weights """
    if save_weights:
        torch.save(agent.actor.state_dict(), actor_weights_path)
        torch.save(agent.critic.state_dict(), critic_weights_path)


In [8]:
def load_weights():
    """ load network weights """
    if load_weights:
        agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr)
        agent.actor.load_state_dict(torch.load(actor_weights_path))
        agent.critic.load_state_dict(torch.load(critic_weights_path))
        agent.actor.eval()
        agent.critic.eval()


In [25]:
n_envs = 1
n_updates = 10
n_steps_per_update = 128
envs = gym.vector.make("Walker2d-v4", num_envs=n_envs)

  gym.logger.warn(


In [26]:
randomize_domain = False
if randomize_domain:
    envs = gym.vector.AsyncVectorEnv(
        [
            lambda: gym.make(
                "Walker2d-v4",
                forward_reward_weight=1.0,
                ctrl_cost_weight=1e-3,
                healthy_reward=1.0,
                terminate_when_unhealthy=True,
                healthy_z_range=(0.8, 2),
                healthy_angle_range=(-1, 1),
                reset_noise_scale=5e-3,
                exclude_current_positions_from_observation=True,
                max_episode_steps=600,
            )
            for i in range(n_envs)
        ]
    )

In [27]:
critic_losses = []
actor_losses = []
entropies = []
agent, envs_wrapper = train_A2C(envs, n_updates, n_steps_per_update, n_envs=n_envs)
plot_result()

save_weights = True
load_weights = False
actor_weights_path = "weights/actor_weights.h5"
critic_weights_path = "weights/critic_weights.h5"

save_weights()
load_weights()

  0%|          | 0/10 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x17 and 32x6)

In [28]:
env = gym.make("Walker2d-v4")

In [33]:
n_steps_per_update = 128
gamma = 0.999
lam = 0.95
ent_coef = 0.01
actor_lr = 0.001
critic_lr = 0.005
obs_shape = envs.single_observation_space.shape[0]
action_shape = envs.single_action_space.shape[0]
print(obs_shape, action_shape)
device = torch.device("cpu")
agent = ContinuousA2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)
envs_wrapper = gym.wrappers.RecordEpisodeStatistics(envs, deque_size=n_envs * n_updates)
sample_phase = 0
print(f"Update {sample_phase}/{n_updates}")
ep_value_preds = torch.zeros(n_steps_per_update, n_envs, device=device)
ep_rewards = torch.zeros(n_steps_per_update, n_envs, device=device)
ep_action_log_probs = torch.zeros(n_steps_per_update, n_envs, device=device)
masks = torch.zeros(n_steps_per_update, n_envs, device=device)
if sample_phase == 0:
    states, info = envs_wrapper.reset(seed=42)
step = 0

17 6
Update 0/10
