In [2]:
import os
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [None]:
LR = 1e-4
MAX_GRAD_NORM = 0.5
ENT_WEIGHT = 0.01 
CLIP_VAL = 0.2
N_EPOCHS = 4
BATCH_SIZE = 64
GAMMA = 0.99  
LAMBDA = 0.95  
MAX_STEPS = 2048
MAX_EPISODES = 1500
SAVE_INTERVAL = 200 
SAVE_DIR = "save"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Used devic: {device}")

In [4]:
class FixedNormal(torch.distributions.Normal):
    def log_probs(self, x):
        return super().log_prob(x).sum(-1)

    def entropy(self):
        return super().entropy().sum(-1)

    def mode(self):
        return self.mean


In [5]:
class AddBias(nn.Module):
    def __init__(self, bias):
        super().__init__()
        self.bias = nn.Parameter(bias.unsqueeze(1))

    def forward(self, x):
        b = self.bias.t().view(1, -1)
        return x + b

In [6]:
class DiagGaussian(nn.Module):
    def __init__(self, inp_dim, out_dim):
        super().__init__()
        self.fc_mean = nn.Linear(inp_dim, out_dim)
        self.log_std = AddBias(torch.zeros(out_dim))

    def forward(self, x):
        mean = self.fc_mean(x)
        logstd = self.log_std(torch.zeros_like(mean))
        return FixedNormal(mean, logstd.exp())

In [7]:
class PolicyNet(nn.Module):
    def __init__(self, s_dim, a_dim):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(s_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU()
        )
        self.dist = DiagGaussian(128, a_dim)

    def forward(self, state, deterministic=False):
        features = self.main(state)
        dist = self.dist(features)
        action = dist.mode() if deterministic else dist.sample()
        return action, dist.log_probs(action)

    def choose_action(self, state, deterministic=False):
        with torch.no_grad():
            action, _ = self.forward(state, deterministic)
        return action

    def evaluate(self, state, action):
        features = self.main(state)
        dist = self.dist(features)
        return dist.log_probs(action), dist.entropy()

In [8]:
class ValueNet(nn.Module):
    def __init__(self, s_dim):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(s_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, state):
        return self.main(state).squeeze(-1)

In [9]:
class EnvRunner:
    def __init__(self, s_dim, a_dim, max_step=MAX_STEPS, gamma=GAMMA, device=device):
        self.s_dim = s_dim
        self.a_dim = a_dim
        self.max_step = max_step
        self.gamma = gamma
        self.device = device

        # Prepare buffers for a single rollout
        self.mb_states = np.zeros((max_step, s_dim), dtype=np.float32)
        self.mb_actions = np.zeros((max_step, a_dim), dtype=np.float32)
        self.mb_values = np.zeros((max_step,), dtype=np.float32)
        self.mb_rewards = np.zeros((max_step,), dtype=np.float32)
        self.mb_a_logps = np.zeros((max_step,), dtype=np.float32)

    def compute_discounted_return(self, rewards, last_value):
        returns = np.zeros_like(rewards)
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                returns[t] = rewards[t] + self.gamma * last_value
            else:
                returns[t] = rewards[t] + self.gamma * returns[t + 1]
        return returns

    def run(self, env, policy_net, value_net):
        state, _ = env.reset()
        episode_len = self.max_step

        for step in range(self.max_step):
            state_t = torch.tensor(state[None], dtype=torch.float32, device=self.device)
            action, a_logp = policy_net(state_t)
            value = value_net(state_t)

            # Move data back to CPU numpy
            action_np = action.cpu().numpy()[0]
            self.mb_states[step] = state
            self.mb_actions[step] = action_np
            self.mb_a_logps[step] = a_logp.cpu().detach().numpy()
            self.mb_values[step] = value.cpu().detach().numpy()

            next_state, reward, terminated, truncated, _ = env.step(action_np)
            done = terminated or truncated
            self.mb_rewards[step] = reward

            if done:
                episode_len = step + 1
                break
            state = next_state

        # Compute returns using the last state value
        last_value = (
            value_net(
                torch.tensor(next_state[None], dtype=torch.float32, device=self.device)
            )
            .cpu()
            .numpy()
            if done is False
            else np.array([0.0])
        )

        mb_returns = self.compute_discounted_return(
            self.mb_rewards[:episode_len], last_value
        )

        return (
            self.mb_states[:episode_len],
            self.mb_actions[:episode_len],
            self.mb_a_logps[:episode_len],
            self.mb_values[:episode_len],
            mb_returns,
            self.mb_rewards[:episode_len],
        )

In [10]:
class PPO:
    def __init__(
        self,
        policy_net,
        value_net,
        lr=LR,
        max_grad_norm=MAX_GRAD_NORM,
        ent_weight=ENT_WEIGHT,
        clip_val=CLIP_VAL,
        sample_n_epoch=N_EPOCHS,
        sample_mb_size=BATCH_SIZE,
        device=device,
    ):
        self.policy_net = policy_net
        self.value_net = value_net
        self.max_grad_norm = max_grad_norm
        self.ent_weight = ent_weight
        self.clip_val = clip_val
        self.sample_n_epoch = sample_n_epoch
        self.sample_mb_size = sample_mb_size
        self.device = device

        self.opt_policy = torch.optim.Adam(policy_net.parameters(), lr)
        self.opt_value = torch.optim.Adam(value_net.parameters(), lr)

    def train(
        self, mb_states, mb_actions, mb_old_values, mb_advs, mb_returns, mb_old_a_logps
    ):
        mb_states_t = torch.tensor(mb_states, dtype=torch.float32, device=self.device)
        mb_actions_t = torch.tensor(mb_actions, dtype=torch.float32, device=self.device)
        mb_old_values_t = torch.tensor(
            mb_old_values, dtype=torch.float32, device=self.device
        )
        mb_advs_t = torch.tensor(mb_advs, dtype=torch.float32, device=self.device)
        mb_returns_t = torch.tensor(mb_returns, dtype=torch.float32, device=self.device)
        mb_old_a_logps_t = torch.tensor(
            mb_old_a_logps, dtype=torch.float32, device=self.device
        )

        episode_length = len(mb_states_t)
        indices = np.arange(episode_length)
        mini_batch_count = max(1, episode_length // self.sample_mb_size)

        for _ in range(self.sample_n_epoch):
            np.random.shuffle(indices)
            for i in range(mini_batch_count):
                batch_indices = indices[
                    i * self.sample_mb_size : (i + 1) * self.sample_mb_size
                ]
                s_batch = mb_states_t[batch_indices]
                a_batch = mb_actions_t[batch_indices]
                old_v_batch = mb_old_values_t[batch_indices]
                adv_batch = mb_advs_t[batch_indices]
                ret_batch = mb_returns_t[batch_indices]
                old_logp_batch = mb_old_a_logps_t[batch_indices]

                # Evaluate current policy on the mini-batch
                curr_logp, entropy = self.policy_net.evaluate(s_batch, a_batch)
                curr_values = self.value_net(s_batch)

                # Value loss
                v_clipped = old_v_batch + torch.clamp(
                    curr_values - old_v_batch, -self.clip_val, self.clip_val
                )
                v_loss_1 = (ret_batch - curr_values).pow(2)
                v_loss_2 = (ret_batch - v_clipped).pow(2)
                value_loss = torch.max(v_loss_1, v_loss_2).mean()

                # Policy loss
                ratio = torch.exp(curr_logp - old_logp_batch)
                pg_loss_1 = -adv_batch * ratio
                pg_loss_2 = -adv_batch * torch.clamp(
                    ratio, 1.0 - self.clip_val, 1.0 + self.clip_val
                )
                policy_loss = torch.max(pg_loss_1, pg_loss_2).mean()
                policy_loss -= self.ent_weight * entropy.mean()

                # Update policy
                self.opt_policy.zero_grad()
                policy_loss.backward()
                nn.utils.clip_grad_norm_(
                    self.policy_net.parameters(), self.max_grad_norm
                )
                self.opt_policy.step()

                # Update value
                self.opt_value.zero_grad()
                value_loss.backward()
                nn.utils.clip_grad_norm_(
                    self.value_net.parameters(), self.max_grad_norm
                )
                self.opt_value.step()

        return policy_loss.item(), value_loss.item(), entropy.mean().item()

In [11]:
def play(policy_net, env, device=device):
    state, _ = env.reset()
    total_reward = 0
    steps = 0

    while True:
        env.render()
        st = torch.tensor(state[None], dtype=torch.float32, device=device)
        action = policy_net.choose_action(st, deterministic=True).cpu().numpy()

        next_state, reward, terminated, truncated, _ = env.step(action[0])
        total_reward += reward
        steps += 1

        if terminated or truncated:
            print(f"[Evaluation] Total reward: {total_reward:.2f}, steps: {steps}")
            break
        state = next_state

In [12]:
def train_agent(
    env_train, policy_net, value_net, runner, agent, max_episodes=MAX_EPISODES
):
    rewards_history = []
    lengths_history = []

    sum_reward, sum_length = 0.0, 0

    os.makedirs(SAVE_DIR, exist_ok=True)

    for ep in range(1, max_episodes + 1):
        with torch.no_grad():
            s_mb, a_mb, old_logp_mb, v_mb, ret_mb, r_mb = runner.run(
                env_train, policy_net, value_net
            )

            adv_mb = ret_mb - v_mb
            adv_mb = (adv_mb - adv_mb.mean()) / (adv_mb.std() + 1e-8)

        pg_loss, v_loss, entropy = agent.train(
            s_mb, a_mb, v_mb, adv_mb, ret_mb, old_logp_mb
        )

        episode_reward = r_mb.sum()
        episode_length = len(s_mb)
        rewards_history.append(episode_reward)
        lengths_history.append(episode_length)
        sum_reward += episode_reward
        sum_length += episode_length

        print(
            f"[Episode {ep:4d}] Reward: {episode_reward:.2f}  Steps: {episode_length}"
        )

        if ep % SAVE_INTERVAL == 0:
            avg_reward = sum_reward / SAVE_INTERVAL
            avg_length = sum_length / SAVE_INTERVAL
            print(f"\nEpisode {ep}/{max_episodes} - Saving model...")
            print("----------------------------------")
            print(f"Actor Loss   = {pg_loss:.6f}")
            print(f"Critic Loss  = {v_loss:.6f}")
            print(f"Entropy      = {entropy:.6f}")
            print(f"Avg Reward   = {avg_reward:.2f}")
            print(f"Avg Length   = {avg_length:.2f}")

            torch.save(
                {
                    "episode": ep,
                    "policy_net": policy_net.state_dict(),
                    "value_net": value_net.state_dict(),
                },
                os.path.join(SAVE_DIR, "model.pt"),
            )

            sum_reward, sum_length = 0.0, 0

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(rewards_history)
    plt.title("Total Reward per Episode")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")

    plt.subplot(1, 2, 2)
    plt.plot(lengths_history)
    plt.title("Episode Length per Episode")
    plt.xlabel("Episode")
    plt.ylabel("Length")

    plt.tight_layout()
    plt.show()
    print("leaving train agent")


In [None]:
env_train = gym.make("BipedalWalker-v3", render_mode="rgb_array")
s_dim = env_train.observation_space.shape[0]
a_dim = env_train.action_space.shape[0]
print(f"State Dimension: {s_dim}, Action Dimension: {a_dim}")

policy = PolicyNet(s_dim, a_dim).to(device)
value = ValueNet(s_dim).to(device)
runner = EnvRunner(s_dim, a_dim, max_step=MAX_STEPS, gamma=GAMMA, device=device)
agent = PPO(
    policy,
    value,
    lr=LR,
    max_grad_norm=MAX_GRAD_NORM,
    ent_weight=ENT_WEIGHT,
    clip_val=CLIP_VAL,
    sample_n_epoch=N_EPOCHS,
    sample_mb_size=BATCH_SIZE,
    device=device,
)

print("\nEvaluating the untrained agent:")
_ = runner.run(env_train, policy, value) 

print("\nStarting training...")
train_agent(env_train, policy, value, runner, agent, max_episodes=MAX_EPISODES)
print("\nTraining complete")

env_eval = gym.make("BipedalWalker-v3", render_mode="human")
print("\nFinal evaluation with rendering:")
play(policy, env_eval, device=device)

env_train.close()
env_eval.close()