In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gymnasium as gym
from collections import deque

GAMMA = 0.99
LAMBDA = 0.95
CLIP_EPS = 0.2
EPOCHS = 20
BATCH_SIZE = 64
ACTOR_LR = 5e-4
CRITIC_LR = 5e-4
HIDDEN = 256

class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, HIDDEN), nn.ReLU(),
            nn.Linear(HIDDEN, HIDDEN), nn.ReLU(),
            nn.Linear(HIDDEN, act_dim), nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.net(x)

class Critic(nn.Module):
    def __init__(self, obs_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, HIDDEN), nn.ReLU(),
            nn.Linear(HIDDEN, HIDDEN), nn.ReLU(),
            nn.Linear(HIDDEN, 1)
        )

    def forward(self, x):
        return self.net(x)

class PPO:
    def __init__(self, obs_dim, act_dim):
        self.actor = Actor(obs_dim, act_dim)
        self.critic = Critic(obs_dim)
        self.opt_actor = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
        self.opt_critic = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)

    def get_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        probs = self.actor(state)
        dist = Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action)

    def compute_adv(self, rewards, values, dones):
        adv, ret, gae, next_val = [], [], 0, 0
        for r, v, d in zip(reversed(rewards), reversed(values), reversed(dones)):
            delta = r + GAMMA * next_val * (1 - d) - v
            gae = delta + GAMMA * LAMBDA * gae * (1 - d)
            adv.insert(0, gae)
            ret.insert(0, gae + v)
            next_val = v
        return torch.tensor(adv, dtype=torch.float32), torch.tensor(ret, dtype=torch.float32)

    def update(self, states, actions, old_logps, advs, rets):
        states = torch.tensor(np.array(states), dtype=torch.float32)
        actions = torch.tensor(actions)
        old_logps = torch.tensor(old_logps)

        for _ in range(EPOCHS):
            for i in range(0, len(states), BATCH_SIZE):
                idx = slice(i, i + BATCH_SIZE)
                logits = self.actor(states[idx])
                dist = Categorical(logits)
                logps = dist.log_prob(actions[idx])
                ratio = torch.exp(logps - old_logps[idx])
                s1 = ratio * advs[idx]
                s2 = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * advs[idx]
                loss_actor = -torch.min(s1, s2).mean()

                vals = self.critic(states[idx]).squeeze()
                loss_critic = nn.MSELoss()(vals, rets[idx])

                self.opt_actor.zero_grad()
                loss_actor.backward()
                self.opt_actor.step()

                self.opt_critic.zero_grad()
                loss_critic.backward()
                self.opt_critic.step()

def normalize_state(state):
    # Normalize position (-1.2 to 0.6) and velocity (-0.07 to 0.07) to [-1, 1]
    position, velocity = state
    norm_position = (position - (-1.2)) / (0.6 - (-1.2)) * 2 - 1
    norm_velocity = (velocity - (-0.07)) / (0.07 - (-0.07)) * 2 - 1
    return np.array([norm_position, norm_velocity])

def train():
    env = gym.make("MountainCar-v0")
    agent = PPO(env.observation_space.shape[0], env.action_space.n)

    number_episodes = 2000
    max_timesteps = 1000
    scores_on_100_episodes = deque(maxlen=100)
    best_avg_score = -np.inf

    for ep in range(number_episodes):
        state, _ = env.reset()
        state = normalize_state(state)
        done = False
        logps, vals, rewards, states, actions, dones = [], [], [], [], [], []
        total = 0  # Tracks base reward for stopping condition

        for t in range(max_timesteps):
            action, logp = agent.get_action(state)
            value = agent.critic(torch.tensor(state, dtype=torch.float32).unsqueeze(0)).item()
            next_state, base_reward, done, _, _ = env.step(action)
            next_state = normalize_state(next_state)

            # Reward shaping for policy learning
            shaped_reward = base_reward + 5 * abs(next_state[1])  # Reduced bonus

            states.append(state)
            actions.append(action)
            rewards.append(shaped_reward)  # Use shaped reward for learning
            logps.append(logp)
            vals.append(value)
            dones.append(done)

            state = next_state
            total += base_reward  # Accumulate base reward for scoring

            if done or t == max_timesteps - 1:
                break

        advs, rets = agent.compute_adv(rewards, vals, dones)
        agent.update(states, actions, logps, advs, rets)

        scores_on_100_episodes.append(total)
        avg_score = np.mean(scores_on_100_episodes) if len(scores_on_100_episodes) > 0 else total

        # Print progress
        print(f'\rEpisode {ep}\tAverage Score: {avg_score:.2f}', end="")
        if ep % 100 == 0:
            print(f'\rEpisode {ep}\tAverage Score: {avg_score:.2f}')

        # Save best model (based on base reward)
        if len(scores_on_100_episodes) == 100 and avg_score > best_avg_score:
            best_avg_score = avg_score
            torch.save({
                'actor_state_dict': agent.actor.state_dict(),
                'critic_state_dict': agent.critic.state_dict(),
            }, 'ppo_mountaincar_v0_best.pth')

        # Check if solved (based on base reward)
        if len(scores_on_100_episodes) == 100 and avg_score >= -110:
            print(f'\nEnvironment solved in {ep - 100} episodes!\tAverage Score: {avg_score:.2f}')
            torch.save({
                'actor_state_dict': agent.actor.state_dict(),
                'critic_state_dict': agent.critic.state_dict(),
            }, 'ppo_mountaincar_v0_solved.pth')
            break

    # Save final model
    torch.save({
        'actor_state_dict': agent.actor.state_dict(),
        'critic_state_dict': agent.critic.state_dict(),
    }, 'ppo_mountaincar_v0_final.pth')

    env.close()
    return agent

if __name__ == "__main__":
    os.makedirs("models", exist_ok=True)
    agent = train()