In [1]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random


In [2]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, action_dim)
        self.max_action = max_action

    def forward(self, state):
        x = torch.relu(self.l1(state))
        x = torch.relu(self.l2(x))
        x = torch.tanh(self.l3(x))
        return self.max_action * x

In [3]:
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, 1)

    def forward(self, state, action):
        x = torch.relu(self.l1(torch.cat([state, action], 1)))
        x = torch.relu(self.l2(x))
        x = self.l3(x)
        return x


In [4]:
class ReplayBuffer:
    def __init__(self, max_size):
        self.buffer = deque(maxlen=max_size)

    def add(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states), np.array(actions), np.array(rewards).reshape(-1, 1),
                np.array(next_states), np.array(dones).reshape(-1, 1))

    def size(self):
        return len(self.buffer)


In [5]:
class TD3:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-3)

        self.critic_1 = Critic(state_dim, action_dim).to(device)
        self.critic_2 = Critic(state_dim, action_dim).to(device)
        self.critic_target_1 = Critic(state_dim, action_dim).to(device)
        self.critic_target_2 = Critic(state_dim, action_dim).to(device)
        self.critic_target_1.load_state_dict(self.critic_1.state_dict())
        self.critic_target_2.load_state_dict(self.critic_2.state_dict())
        self.critic_optimizer = optim.Adam(list(self.critic_1.parameters()) + list(self.critic_2.parameters()), lr=1e-3)

        self.replay_buffer = ReplayBuffer(max_size=1_000_000)
        self.max_action = max_action
        self.discount = 0.99
        self.tau = 0.005
        self.policy_noise = 0.2
        self.noise_clip = 0.5
        self.policy_freq = 2
        self.total_it = 0

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor(state).cpu().data.numpy().flatten()

    def train(self, batch_size=100):
        self.total_it += 1

        # Sample replay buffer
        state, action, reward, next_state, done = self.replay_buffer.sample(batch_size)
        state = torch.FloatTensor(state).to(device)
        action = torch.FloatTensor(action).to(device)
        reward = torch.FloatTensor(reward).to(device)
        next_state = torch.FloatTensor(next_state).to(device)
        done = torch.FloatTensor(done).to(device)

        # Select action according to policy and add clipped noise
        noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
        next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)

        # Compute the target Q value
        target_q1 = self.critic_target_1(next_state, next_action)
        target_q2 = self.critic_target_2(next_state, next_action)
        target_q = reward + (1 - done) * self.discount * torch.min(target_q1, target_q2)

        # Get current Q estimates
        current_q1 = self.critic_1(state, action)
        current_q2 = self.critic_2(state, action)

        # Compute critic loss
        critic_loss = nn.MSELoss()(current_q1, target_q.detach()) + nn.MSELoss()(current_q2, target_q.detach())

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:
            # Compute actor loss
            actor_loss = -self.critic_1(state, self.actor(state)).mean()

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic_1.parameters(), self.critic_target_1.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.critic_2.parameters(), self.critic_target_2.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def save(self, filename):
        torch.save(self.actor.state_dict(), filename + "_actor.pth")
        torch.save(self.critic_1.state_dict(), filename + "_critic1.pth")
        torch.save(self.critic_2.state_dict(), filename + "_critic2.pth")

    def load(self, filename):
        self.actor.load_state_dict(torch.load(filename + "_actor.pth"))
        self.critic_1.load_state_dict(torch.load(filename + "_critic1.pth"))
        self.critic_2.load_state_dict(torch.load(filename + "_critic2.pth"))


In [6]:
env = gym.make("BipedalWalker-v3")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

td3 = TD3(state_dim, action_dim, max_action)

episodes = 1000
batch_size = 100

  deprecation(
  deprecation(


In [7]:
episode_rewards = []

In [8]:
for episode in range(episodes):
    state = env.reset()
    episode_reward = 0

    for t in range(2000):
        action = td3.select_action(np.array(state))
        next_state, reward, done, _ = env.step(action)
        td3.replay_buffer.add((state, action, reward, next_state, float(done)))
        state = next_state
        episode_reward += reward

        if td3.replay_buffer.size() > batch_size:
            td3.train(batch_size)

        if done:
                break
        episode_rewards.append(episode_reward)

    print(f"Episode {episode+1}, Reward: {episode_reward}")

    if (episode + 1) % 10 == 0:
        td3.save(f"models/td3_bipedal_")


  if not isinstance(terminated, (bool, np.bool8)):


Episode 1, Reward: -92.49592465158769
Episode 2, Reward: -112.57023937397916
Episode 3, Reward: -112.40137849643982
Episode 4, Reward: -117.06204470005198
Episode 5, Reward: -161.70714885553383
Episode 6, Reward: -113.29521534495117
Episode 7, Reward: -111.5597152640475
Episode 8, Reward: -116.95648149309741
Episode 9, Reward: -103.67498456785796
Episode 10, Reward: -114.75361160729749
Episode 11, Reward: -109.95334705100271
Episode 12, Reward: -124.52026673044713
Episode 13, Reward: -120.34094985260084
Episode 14, Reward: -131.0564814728065
Episode 15, Reward: -116.52354605286786
Episode 16, Reward: -117.33954139348049
Episode 17, Reward: -125.16829287693105
Episode 18, Reward: -117.4775504859152
Episode 19, Reward: -123.7216497857233
Episode 20, Reward: -105.68786284397419
Episode 21, Reward: -202.29751357059277
Episode 22, Reward: -144.37612675505858
Episode 23, Reward: -152.3177740434995
Episode 24, Reward: -119.34675116563083
Episode 25, Reward: -130.6652879583659
Episode 26, Rewa