In [1]:
import gymnasium
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import matplotlib.pyplot as plt
import numpy as np
import imageio

In [2]:
env = gymnasium.make("InvertedPendulum-v5", render_mode="rgb_array")

In [3]:
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

In [4]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim)
        )
        self.log_std = nn.Parameter(torch.zeros(action_dim))

    def forward(self, state):
        mean = self.network(state)
        std_dev = torch.exp(self.log_std)
        return Normal(mean, std_dev)


class Critic(nn.Module):
    def __init__(self, state_dim):
        super(Critic, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, state):
        return self.network(state)


def collect_data(env, actor, max_steps):
    states, actions, rewards, masks, log_probs = [], [], [], [], []
    step_count = 0
    state, _ = env.reset()
    while step_count < max_steps:
        state_tensor = torch.tensor(state, dtype=torch.float32)
        dist = actor(state_tensor)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        next_state, reward, terminated, truncated, info = env.step(action.numpy())

        states.append(state_tensor)
        actions.append(action)
        rewards.append(reward)
        done = terminated or truncated
        masks.append(1 - done)
        log_probs.append(log_prob)
        state = next_state
        step_count += 1
        if done:
            state, _ = env.reset()

    return states, actions, log_probs, rewards, masks


def compute_rewards_to_go(rewards, masks, gamma=0.99):
    rtg = []
    discounted_sum = 0
    for reward, mask in zip(reversed(rewards), reversed(masks)):
        discounted_sum = reward + gamma * discounted_sum * mask
        rtg.insert(0, discounted_sum)
    return rtg


def compute_advantages(rewards, values, masks, gamma=0.99, lam=0.95):
    T = len(rewards)
    advantages = torch.zeros(T, dtype=torch.float32)
    gae = 0.0
    for t in reversed(range(T - 1)):
        td_error = rewards[t] + gamma * masks[t] * values[t + 1] - values[t]
        gae = td_error + gamma * lam * masks[t] * gae
        advantages[t] = gae
    return advantages


def update_policy(actor, critic, optimizer_actor, optimizer_critic, states, actions, log_probs_old, returns, advantages,
                  clip_param=0.2):
    states = torch.stack(states)
    actions = torch.stack(actions)
    log_probs_old = torch.stack(log_probs_old)
    returns = torch.tensor(returns, dtype=torch.float32)

    dists = actor(states)
    values = critic(states).squeeze(-1)
    log_probs = dists.log_prob(actions)
    ratios = torch.exp(log_probs - log_probs_old)

    surr1 = ratios * advantages
    surr2 = torch.clamp(ratios, 1.0 - clip_param, 1.0 + clip_param) * advantages
    actor_loss = -torch.min(surr1, surr2).mean()
    critic_loss = (returns - values).pow(2).mean()

    optimizer_actor.zero_grad()
    optimizer_critic.zero_grad()
    actor_loss.backward()
    critic_loss.backward()
    optimizer_actor.step()
    optimizer_critic.step()


actor = Actor(state_dim, action_dim)
critic = Critic(state_dim)
optimizer_actor = optim.Adam(actor.parameters(), lr=3e-4)
optimizer_critic = optim.Adam(critic.parameters(), lr=3e-4)

num_iterations = 500
horizon = 2048
episode_rewards = []

for i in range(num_iterations):
    states, actions, log_probs, rewards, masks = collect_data(env, actor, horizon)
    episode_reward = sum(rewards)
    episode_rewards.append(episode_reward)

    returns = compute_rewards_to_go(rewards, masks)
    values = critic(torch.stack(states)).detach().squeeze(-1)
    advantages = compute_advantages(rewards, values, masks)
    update_policy(actor, critic, optimizer_actor, optimizer_critic, states, actions, log_probs, returns, advantages)

    print(f"\rEpisode: {i}\tReward: {episode_reward}", end="")

Episode: 8	Reward: 1788

KeyboardInterrupt: 

In [None]:
plt.plot(episode_rewards)
plt.show()