In [52]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import gymnasium as gym
import pickle
import time

In [None]:
class TD3Agent:

    def __init__(self,
                 env_id,
                 env_hardcore,
                 critic_network1,
                 critic_network2,
                 actor_network,
                 critic_learning_rate=1e-3,
                 actor_learning_rate=1e-4,
                 discount_factor=0.99,
                 minibatch_size=256,
                 tau=0.005,
                 exploratory_noise=0.1,
                 exploratory_noise_clip=0.3,
                 policy_noise=0.1,
                 policy_noise_clip=0.5,
                 policy_delay=2,
                 warm_up=10_000,
                 gradient_clip=0.0,
                 updates_per_step=1,
                 max_buffer_length=250_000):
        # CPU or GPU?
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Environment.
        self.env_id = env_id
        self.env_hardcore = env_hardcore
        self.env = gym.make(id=env_id, hardcore=env_hardcore, render_mode=None)
        self.state_size = self.env.observation_space.shape[0]
        self.action_size = self.env.action_space.shape[0]
        self.action_min = torch.tensor(self.env.action_space.low, device=self.device)
        self.action_max = torch.tensor(self.env.action_space.high, device=self.device)
        self.max_reward = self.env.spec.reward_threshold

        # Initialize critic network 1 and target critic network 1.
        self.critic_network1 = copy.deepcopy(critic_network1)
        self.target_critic_network1 = copy.deepcopy(critic_network1)

        # Initialize critic network 2 and target critic network 2.
        self.critic_network2 = copy.deepcopy(critic_network2)
        self.target_critic_network2 = copy.deepcopy(critic_network2)

        # Initialize actor network and target actor network.
        self.actor_network = copy.deepcopy(actor_network)
        self.target_actor_network = copy.deepcopy(actor_network)

        # Move networks to correct device.
        self.actor_network.to(self.device)
        self.critic_network1.to(self.device)
        self.critic_network2.to(self.device)
        self.target_actor_network.to(self.device)
        self.target_critic_network1.to(self.device)
        self.target_critic_network2.to(self.device)

        # Initialize optimizers.
        self.critic1_optimizer = optim.Adam(self.critic_network1.parameters(), lr=critic_learning_rate)
        self.critic2_optimizer = optim.Adam(self.critic_network2.parameters(), lr=critic_learning_rate)
        self.actor_optimizer = optim.Adam(self.actor_network.parameters(), lr=actor_learning_rate)

        # Initialize hyperparameters.
        self.minibatch_size = minibatch_size
        self.discount_factor = discount_factor
        self.tau = tau
        self.warm_up = warm_up
        self.gradient_clip = gradient_clip
        self.exploratory_noise = exploratory_noise
        self.exploratory_noise_clip = exploratory_noise_clip
        self.policy_noise = policy_noise
        self.policy_noise_clip = policy_noise_clip
        self.policy_delay = policy_delay
        self.updates_per_step = updates_per_step

        # Initialize buffer.
        self.buffer_width = 2 * self.state_size + self.action_size + 2
        self.max_buffer_length = max_buffer_length
        self.buffer_write_idx = 0
        self.buffer_fullness = 0
        self.buffer_states = torch.zeros((self.max_buffer_length, self.state_size),
                                         dtype=torch.float32,
                                         device=self.device)
        self.buffer_actions = torch.zeros((self.max_buffer_length, self.action_size),
                                          dtype=torch.float32,
                                          device=self.device)
        self.buffer_rewards = torch.zeros((self.max_buffer_length, 1),
                                          dtype=torch.float32,
                                          device=self.device)
        self.buffer_next_states = torch.zeros((self.max_buffer_length, self.state_size),
                                              dtype=torch.float32,
                                              device=self.device)
        self.buffer_terminals = torch.zeros((self.max_buffer_length, 1),
                                            dtype=torch.float32,
                                            device=self.device)

    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action at the given state."""
        self.actor_network.eval()
        with torch.no_grad():
            # Forward pass.
            state_tensor = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
            action_tensor = self.actor_network(state_tensor).squeeze(0)
            # Add exploratory noise.
            noise = torch.randn(self.action_size, device=self.device) * self.exploratory_noise
            noise_clipped = torch.clamp(noise, min=-self.exploratory_noise_clip, max=self.exploratory_noise_clip)
            action_tensor = action_tensor + noise_clipped
            # Clip action.
            action_tensor = torch.clamp(action_tensor, min=self.action_min, max=self.action_max)
            return action_tensor.cpu().numpy()

    def save_transition(self, state: np.ndarray, action: np.ndarray, reward: float, new_state: np.ndarray, terminal: bool):
        """Save a transition."""
        # Save transition.
        self.buffer_states[self.buffer_write_idx] = torch.tensor(state, dtype=torch.float32, device=self.device)
        self.buffer_actions[self.buffer_write_idx] = torch.tensor(action, dtype=torch.float32, device=self.device)
        self.buffer_rewards[self.buffer_write_idx] = torch.tensor([reward], dtype=torch.float32, device=self.device)
        self.buffer_next_states[self.buffer_write_idx] = torch.tensor(new_state, dtype=torch.float32, device=self.device)
        self.buffer_terminals[self.buffer_write_idx] = torch.tensor([1.0 if terminal else 0.0],
                                                                    dtype=torch.float32, device=self.device)

        self.buffer_write_idx = (self.buffer_write_idx + 1) % self.max_buffer_length
        self.buffer_fullness = min(self.buffer_fullness + 1, self.max_buffer_length)

    def sample_minibatch(self):
        """Sample a minibatch from the replay buffer."""
        indices = torch.randint(0, self.buffer_fullness, (self.minibatch_size,), device=self.device)

        mb_states = self.buffer_states[indices]
        mb_actions = self.buffer_actions[indices]
        mb_rewards = self.buffer_rewards[indices]
        mb_next_states = self.buffer_next_states[indices]
        mb_terminals = self.buffer_terminals[indices]

        return mb_states, mb_actions, mb_rewards, mb_next_states, mb_terminals

    def update_critic_networks(self, minibatch: torch.Tensor):
        """Update critic networks"""
        mb_states, mb_actions, mb_rewards, mb_next_states, mb_terminals = minibatch
        mb_state_actions = torch.cat([mb_states, mb_actions], dim=1)

        self.target_actor_network.eval()
        self.target_critic_network1.eval()
        self.target_critic_network2.eval()
        self.critic_network1.train()
        self.critic_network2.train()

        with torch.no_grad():
            next_actions = self.target_actor_network(mb_next_states)
            noise = torch.randn(next_actions.shape, device=self.device) * self.policy_noise
            noise = torch.clamp(noise, min=-self.policy_noise_clip, max=self.policy_noise_clip)
            next_actions = torch.clamp(next_actions + noise, min=self.action_min, max=self.action_max)
            next_state_actions = torch.cat((mb_next_states, next_actions), dim=1)

            q1_next = self.target_critic_network1(next_state_actions)
            q2_next = self.target_critic_network2(next_state_actions)
            q_min_next = torch.min(q1_next, q2_next)
            q_target = mb_rewards + self.discount_factor * (1 - mb_terminals) * q_min_next

        # Critic network 1 update.
        with torch.set_grad_enabled(True):
            q_expected = self.critic_network1(mb_state_actions)
            critic1_loss = torch.mean((q_target - q_expected) ** 2)

        self.critic1_optimizer.zero_grad()
        critic1_loss.backward()
        nn.utils.clip_grad_norm_(self.critic_network1.parameters(), self.gradient_clip)
        self.critic1_optimizer.step()

        # Critic network 2 update.
        with torch.set_grad_enabled(True):
            q_expected = self.critic_network2(mb_state_actions)
            critic2_loss = torch.mean((q_target - q_expected) ** 2)

        self.critic2_optimizer.zero_grad()
        critic2_loss.backward()
        if self.gradient_clip > 0.0:
            nn.utils.clip_grad_norm_(self.critic_network2.parameters(), self.gradient_clip)
        self.critic2_optimizer.step()

    def update_actor_network(self, minibatch: torch.Tensor):
        """Update the actor network."""
        mb_states, *_ = minibatch

        self.actor_network.train()
        self.critic_network1.eval()

        with torch.set_grad_enabled(True):
            raw_actions = self.actor_network(mb_states)
            raw_state_actions = torch.cat((mb_states, raw_actions), dim=1)
            actor_loss = -self.critic_network1(raw_state_actions).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

    def soft_update_target_weights(self):
        """Soft update the target networks weights."""
        with torch.no_grad():
            for w_target, w_local in zip(self.target_actor_network.parameters(), self.actor_network.parameters()):
                w_target.data.copy_(self.tau * w_local.data + (1 - self.tau) * w_target.data)

            for w_target, w_local in zip(self.target_critic_network1.parameters(), self.critic_network1.parameters()):
                w_target.data.copy_(self.tau * w_local.data + (1 - self.tau) * w_target.data)

            for w_target, w_local in zip(self.target_critic_network2.parameters(), self.critic_network2.parameters()):
                w_target.data.copy_(self.tau * w_local.data + (1 - self.tau) * w_target.data)

    def save_outputs(self, episode_rewards, episode_step_counts, episode_run_times):
        """Save each network's weights and episode rewards."""
        folder = "TD3 Outputs (hardcore)" if self.env_hardcore else "TD3 Outputs (normal)"
        torch.save(self.critic_network1.state_dict(), folder + "/critic_network1.pth")
        torch.save(self.critic_network2.state_dict(), folder + "/critic_network2.pth")
        torch.save(self.actor_network.state_dict(), folder + "/actor_network.pth")
        with open(folder + "/episode_rewards.pkl", "wb") as fp:
            pickle.dump(episode_rewards, fp)
        with open(folder + "/episode_step_counts.pkl", "wb") as fp:
            pickle.dump(episode_step_counts, fp)
        with open(folder + "/episode_run_times.pkl", "wb") as fp:
            pickle.dump(episode_run_times, fp)

    def show_test_episode(self):
        """Do a visual test run."""
        print("\n========TEST RUN========")
        test_env = gym.make(id=self.env_id, hardcore=self.env_hardcore, render_mode="human")
        s, _ = test_env.reset()
        test_episode_reward = 0
        test_episode_step_count = 0
        test_start_time = time.time()

        while True:
            a = self.select_action(s)
            s_, r, terminated, truncated, _ = test_env.step(a)
            test_episode_reward += r
            test_episode_step_count += 1

            if terminated or truncated:
                break

            s = s_

        test_episode_end_time = time.time()
        test_episode_run_time = test_episode_end_time - test_start_time
        test_env.close()

        print(f"Reward: {test_episode_reward:.2f} - Step Count: {test_episode_step_count} - Run Time: {test_episode_run_time:.2f}s\n")

    def learn(self, n_episodes=2000, display_every=50, stop_after=None):

        episode_rewards = []
        episode_step_counts = []
        episode_run_times = []

        for n in range(n_episodes):
            # Print episode number.
            print(f"Running Episode {n + 1}...")

            start_time = time.time()

            # Reset environment.
            state, _ = self.env.reset()

            episode_reward = 0
            episode_step_count = 0

            while True:
                # Select action and take step.
                action = self.select_action(state)
                new_state, reward, terminal, truncated, _ = self.env.step(action)

                # Store transition.
                self.save_transition(state, action, reward, new_state, terminal)

                # Update episode reward
                episode_step_count += 1
                episode_reward += reward

                if self.buffer_fullness >= self.minibatch_size and self.buffer_fullness >= self.warm_up:

                    for _ in range(self.updates_per_step):

                        minibatch = self.sample_minibatch()
                        self.update_critic_networks(minibatch)

                        if episode_step_count % self.policy_delay == 0:

                            self.update_actor_network(minibatch)
                            self.soft_update_target_weights()

                if terminal or truncated:
                    break

                state = new_state

            end_time = time.time()
            episode_run_time = end_time - start_time

            # Print and save episode reward.
            print(f"Reward: {episode_reward:.2f} - Step Count: {episode_step_count} - Run Time: {episode_run_time:.2f}s")
            episode_rewards.append(episode_reward)
            episode_step_counts.append(episode_step_count)
            episode_run_times.append(episode_run_time)

            # Early stopping.
            if stop_after is not None and all(ep_rew >= self.max_reward for ep_rew in episode_rewards[-stop_after:]):
                break

            if n % display_every == 0:
                self.show_test_episode()

        self.save_outputs(episode_rewards, episode_step_counts, episode_run_times)

In [54]:
class Critic(nn.Module):

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.net(x)

In [55]:
class Actor(nn.Module):

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(24, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 4),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x)

In [56]:
# Train the robot.

actor = Actor()
critic1 = Critic()
critic2 = Critic()

agent = TD3Agent(env_id="BipedalWalker-v3",
                 env_hardcore=False,
                 critic_network1=critic1,
                 critic_network2=critic2,
                 actor_network=actor)

agent.learn(stop_after=1)

Episode: 1
Episode reward: -93.66564
Episode: 2
Episode reward: -93.320244
Episode: 3
Episode reward: -93.28207
Episode: 4
Episode reward: -92.92917
Episode: 5
Episode reward: -93.10777
Episode: 6
Episode reward: -94.00552
Episode: 7
Episode reward: -92.5566
Episode: 8
Episode reward: -92.56265
Episode: 9
Episode reward: -92.022194
Episode: 10
Episode reward: -92.90067
Episode: 11
Episode reward: -103.77015
Episode: 12


KeyboardInterrupt: 