<a href="https://colab.research.google.com/github/newmantic/PPO/blob/main/PPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

In [2]:
class ActorCritic(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=64):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)

        # Actor network
        self.actor = nn.Linear(hidden_size, action_size)

        # Critic network
        self.critic = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        policy_logits = self.actor(x)
        state_value = self.critic(x)
        return policy_logits, state_value

    def act(self, state):
        policy_logits, _ = self.forward(state)
        dist = Categorical(logits=policy_logits)
        action = dist.sample()
        action_log_prob = dist.log_prob(action)
        return action.item(), action_log_prob

    def evaluate(self, state, action):
        policy_logits, state_value = self.forward(state)
        dist = Categorical(logits=policy_logits)
        action_log_probs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        return action_log_probs, torch.squeeze(state_value), dist_entropy

In [3]:
class PPOAgent:
    def __init__(self, state_size, action_size, hidden_size=64, lr=0.001, gamma=0.99, clip_epsilon=0.2, update_epochs=10, c1=0.5, c2=0.01):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = gamma
        self.clip_epsilon = clip_epsilon
        self.update_epochs = update_epochs
        self.c1 = c1  # Value function loss coefficient
        self.c2 = c2  # Entropy bonus coefficient

        self.model = ActorCritic(state_size, action_size, hidden_size)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

    def compute_returns(self, rewards, dones, next_value):
        returns = []
        R = next_value
        for reward, done in zip(reversed(rewards), reversed(dones)):
            R = reward + self.gamma * R * (1 - done)
            returns.insert(0, R)
        return returns

    def update(self, states, actions, log_probs, returns, advantages):
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        old_log_probs = torch.FloatTensor(log_probs)
        returns = torch.FloatTensor(returns)
        advantages = torch.FloatTensor(advantages)

        for _ in range(self.update_epochs):
            new_log_probs, state_values, entropy = self.model.evaluate(states, actions)
            ratios = torch.exp(new_log_probs - old_log_probs)

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()

            critic_loss = self.c1 * (returns - state_values).pow(2).mean()
            entropy_bonus = self.c2 * entropy.mean()

            loss = actor_loss + critic_loss - entropy_bonus

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))

In [4]:
class Simple1DEnv:
    def __init__(self, length=10, start=0, goal=9):
        self.length = length
        self.start = start
        self.goal = goal
        self.state = start

    def reset(self):
        self.state = self.start
        return np.array([self.state], dtype=np.float32)

    def step(self, action):
        if action == 0:  # move left
            self.state = max(0, self.state - 1)
        elif action == 1:  # move right
            self.state = min(self.length - 1, self.state + 1)

        reward = 1 if self.state == self.goal else -0.1
        done = self.state == self.goal
        return np.array([self.state], dtype=np.float32), reward, done

In [6]:
def train_ppo():
    env = Simple1DEnv()
    state_size = 1  # since the state is just the position in the 1D space
    action_size = 2  # two possible actions: move left or right

    agent = PPOAgent(state_size, action_size, hidden_size=64, lr=0.001, gamma=0.99, clip_epsilon=0.2, update_epochs=10, c1=0.5, c2=0.01)

    n_episodes = 500
    max_steps = 100

    for episode in range(n_episodes):
        state = env.reset()
        log_probs = []
        values = []
        rewards = []
        dones = []
        actions = []
        states = []

        for step in range(max_steps):
            state_tensor = torch.FloatTensor(state)
            action, log_prob = agent.model.act(state_tensor)
            _, value = agent.model(state_tensor)

            next_state, reward, done = env.step(action)

            states.append(state)
            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(reward)
            values.append(value)
            dones.append(done)

            state = next_state

            if done or step == max_steps - 1:
                next_state_tensor = torch.FloatTensor(next_state)
                _, next_value = agent.model(next_state_tensor)
                next_value = next_value.item()

                returns = agent.compute_returns(rewards, dones, next_value)
                advantages = [ret - val.item() for ret, val in zip(returns, values)]
                agent.update(states, actions, log_probs, returns, advantages)

                if episode % 10 == 0:
                    print(f"Episode {episode+1}/{n_episodes}, Total Reward: {sum(rewards)}")
                break

    return agent


# Train the PPO agent
ppo_agent = train_ppo()

# Save the trained model
ppo_agent.save("ppo_model.pth")

Episode 1/500, Total Reward: -0.7000000000000004
Episode 11/500, Total Reward: 0.20000000000000007
Episode 21/500, Total Reward: 0.20000000000000007
Episode 31/500, Total Reward: 1.1102230246251565e-16
Episode 41/500, Total Reward: 0.20000000000000007
Episode 51/500, Total Reward: 0.20000000000000007
Episode 61/500, Total Reward: 0.20000000000000007
Episode 71/500, Total Reward: 0.20000000000000007
Episode 81/500, Total Reward: 0.20000000000000007
Episode 91/500, Total Reward: 0.20000000000000007
Episode 101/500, Total Reward: 0.20000000000000007
Episode 111/500, Total Reward: 0.20000000000000007
Episode 121/500, Total Reward: 0.20000000000000007
Episode 131/500, Total Reward: 0.20000000000000007
Episode 141/500, Total Reward: 0.20000000000000007
Episode 151/500, Total Reward: 0.20000000000000007
Episode 161/500, Total Reward: 0.20000000000000007
Episode 171/500, Total Reward: 0.20000000000000007
Episode 181/500, Total Reward: 0.20000000000000007
Episode 191/500, Total Reward: 0.200000