In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gymnasium as gym
from collections import deque
from gymnasium import register
import ale_py
try:
    import cv2
    USE_OPENCV = True
except ImportError:
    from PIL import Image
    USE_OPENCV = False

# Register the Atari Tennis environment
register(
    id='TennisDeterministic-v4',
    entry_point='ale_py.env:AtariEnv',
    kwargs={'game': 'tennis', 'mode': 0, 'difficulty': 0},
    max_episode_steps=10000,
    nondeterministic=False,
)

# === Hyperparameters ===
GAMMA = 0.99                # discount factor
LAMBDA = 0.9                # GAE parameter
CLIP_EPS = 0.2              # PPO clip parameter
EPOCHS = 3                  # reduced optimization epochs
MINIBATCH_SIZE = 128        # larger minibatch to reduce iterations
ACTOR_LR = 1e-4             # actor learning rate
CRITIC_LR = 1e-4            # critic learning rate
ENTROPY_COEF = 0.05         # entropy coefficient
NUM_ENVS = 4                # reduced parallel environments
NUM_STEPS = 256             # increased steps to maintain sample size (4*256=1024)
NUM_UPDATES = 50000         # total updates
SAVE_INTERVAL = 100         # save interval
FRAME_STACK = 4             # frames to stack
FRAME_SKIP = 2              # frame skip

# === Environment Wrapper ===
class AtariWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.env = env
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, FRAME_STACK), dtype=np.uint8)
        self.action_space = gym.spaces.Discrete(env.action_space.n)
        self.frames = deque(maxlen=FRAME_STACK)

    def reset(self, **kwargs):
        obs, _ = self.env.reset(**kwargs)
        obs = self._preprocess(obs)
        for _ in range(FRAME_STACK):
            self.frames.append(obs)
        return np.stack(self.frames, axis=-1), {}

    def step(self, action):
        total_reward = 0.0
        done = False
        for _ in range(FRAME_SKIP):
            obs, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            done = terminated or truncated
            if done:
                break
        obs = self._preprocess(obs)
        self.frames.append(obs)
        return np.stack(self.frames, axis=-1), total_reward, done, False, info

    def _preprocess(self, frame):
        # Convert to grayscale and crop
        frame = frame.mean(axis=2).astype(np.uint8)
        frame = frame[34:34+160, :160]  # crop
        if USE_OPENCV:
            frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)
        else:
            frame = np.array(Image.fromarray(frame).resize((84, 84), Image.BILINEAR))
        return frame

# === PPO Network Architecture ===
class AtariCNN(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(FRAME_STACK, 16, kernel_size=8, stride=4),  # reduced filters
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten()
        )

        self.actor = nn.Sequential(
            nn.Linear(1568, 256),  # adjusted for smaller feature map
            nn.ReLU(),
            nn.Linear(256, num_actions),
            nn.Softmax(dim=-1)
        )

        self.critic = nn.Sequential(
            nn.Linear(1568, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        x = x.float() / 255.0
        features = self.features(x)
        return self.actor(features), self.critic(features)

# === PPO Agent ===
class PPO:
    def __init__(self, num_actions, device):
        self.net = AtariCNN(num_actions).to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=ACTOR_LR, eps=1e-5)
        self.device = device

    def get_action(self, x):
        with torch.no_grad():
            x = torch.FloatTensor(x).permute(0, 3, 1, 2).to(self.device)
            probs, value = self.net(x)
            dist = Categorical(probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            entropy = dist.entropy()
        return action.cpu().numpy(), log_prob.cpu().numpy(), value.cpu().numpy(), entropy.cpu().numpy()

    def update(self, samples):
        obs, actions, old_log_probs, returns, advantages = samples

        # Convert to tensors with pinned memory for GPU
        obs = torch.FloatTensor(obs).permute(0, 3, 1, 2).to(self.device, non_blocking=True)
        actions = torch.LongTensor(actions).to(self.device, non_blocking=True)
        old_log_probs = torch.FloatTensor(old_log_probs).to(self.device, non_blocking=True)
        returns = torch.FloatTensor(returns).unsqueeze(1).to(self.device, non_blocking=True)
        advantages = torch.FloatTensor(advantages).unsqueeze(1).to(self.device, non_blocking=True)

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        for _ in range(EPOCHS):
            indices = torch.randperm(NUM_ENVS * NUM_STEPS)
            for start in range(0, len(indices), MINIBATCH_SIZE):
                end = start + MINIBATCH_SIZE
                idx = indices[start:end]

                mb_obs = obs[idx]
                mb_actions = actions[idx]
                mb_old_log_probs = old_log_probs[idx]
                mb_returns = returns[idx]
                mb_advantages = advantages[idx]

                probs, values = self.net(mb_obs)
                dist = Categorical(probs)
                log_probs = dist.log_prob(mb_actions)
                entropy = dist.entropy().mean()

                ratios = torch.exp(log_probs - mb_old_log_probs)
                surr1 = ratios * mb_advantages
                surr2 = torch.clamp(ratios, 1.0 - CLIP_EPS, 1.0 + CLIP_EPS) * mb_advantages
                actor_loss = -torch.min(surr1, surr2).mean() - ENTROPY_COEF * entropy
                critic_loss = 0.5 * (mb_returns - values).pow(2).mean()
                loss = actor_loss + 0.5 * critic_loss

                self.optimizer.zero_grad(set_to_none=True)  # faster zero_grad
                loss.backward()
                nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
                self.optimizer.step()

# === Training Loop ===
def train():
    # Create environments
    envs = gym.vector.SyncVectorEnv([
        lambda: AtariWrapper(gym.make("TennisDeterministic-v4", render_mode="rgb_array"))
        for _ in range(NUM_ENVS)
    ])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    agent = PPO(envs.single_action_space.n, device)

    # Create directories
    os.makedirs("", exist_ok=True)
    os.makedirs("../videos", exist_ok=True)

    # Training statistics
    episode_rewards = np.zeros(NUM_ENVS)
    final_episode_rewards = []
    best_mean_reward = -np.inf

    # Evaluation function
    def evaluate(agent, env, num_episodes=5):  # reduced episodes
        eval_env = AtariWrapper(gym.make("TennisDeterministic-v4", render_mode="rgb_array"))
        rewards = []
        for _ in range(num_episodes):
            obs, _ = eval_env.reset()
            done = False
            episode_reward = 0
            while not done:
                action, _, _, _ = agent.get_action(np.expand_dims(obs, 0))
                obs, reward, done, _, _ = eval_env.step(action[0])
                episode_reward += reward
            rewards.append(episode_reward)
        eval_env.close()
        return np.mean(rewards), np.std(rewards)

    # Initialize environments
    obs, _ = envs.reset()

    for update in range(1, NUM_UPDATES + 1):
        # Storage
        mb_obs = np.zeros((NUM_STEPS, NUM_ENVS, 84, 84, FRAME_STACK), dtype=np.uint8)
        mb_actions = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.int32)
        mb_log_probs = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.float32)
        mb_values = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.float32)
        mb_rewards = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.float32)
        mb_dones = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.bool_)
        mb_entropy = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.float32)

        # Collect rollout
        for step in range(NUM_STEPS):
            mb_obs[step] = obs
            actions, log_probs, values, entropy = agent.get_action(obs)
            mb_actions[step] = actions
            mb_log_probs[step] = log_probs
            mb_values[step] = values.squeeze()
            mb_entropy[step] = entropy

            step_output = envs.step(actions)
            obs, rewards, dones, _, infos = step_output
            mb_rewards[step] = rewards
            mb_dones[step] = dones

            episode_rewards += rewards
            for i, done in enumerate(dones):
                if done:
                    final_episode_rewards.append(episode_rewards[i])
                    episode_rewards[i] = 0

        # Calculate advantages and returns
        with torch.no_grad():
            last_values = agent.net(torch.FloatTensor(obs).permute(0, 3, 1, 2).to(device))[1].cpu().numpy()

        mb_advantages = np.zeros_like(mb_rewards)
        mb_returns = np.zeros_like(mb_rewards)
        gae = 0
        for t in reversed(range(NUM_STEPS)):
            if t == NUM_STEPS - 1:
                next_non_terminal = 1.0 - mb_dones[t]
                next_values = last_values.squeeze()
            else:
                next_non_terminal = 1.0 - mb_dones[t+1]
                next_values = mb_values[t+1]
            delta = mb_rewards[t] + GAMMA * next_values * next_non_terminal - mb_values[t]
            gae = delta + GAMMA * LAMBDA * next_non_terminal * gae
            mb_advantages[t] = gae
            mb_returns[t] = mb_advantages[t] + mb_values[t]

        # Flatten the batch
        mb_obs = mb_obs.reshape((-1, 84, 84, FRAME_STACK))
        mb_actions = mb_actions.flatten()
        mb_log_probs = mb_log_probs.flatten()
        mb_returns = mb_returns.flatten()
        mb_advantages = mb_advantages.flatten()

        # Update policy
        agent.update((mb_obs, mb_actions, mb_log_probs, mb_returns, mb_advantages))

        # Calculate statistics
        mean_reward = np.mean(final_episode_rewards[-100:]) if final_episode_rewards else 0

        # Print progress
        if update % 10 == 0:
            print(f"Update: {update}/{NUM_UPDATES} | "
                  f"Mean Reward: {mean_reward:.2f} | "
                  f"Mean Entropy: {np.mean(mb_entropy):.4f} | "
                  f"Minibatch Size: {len(mb_obs)} | "
                  f"Total Frames: {update * NUM_STEPS * NUM_ENVS}")

        # Evaluate policy
        if update % 100 == 0:
            eval_mean, eval_std = evaluate(agent, envs)
            print(f"Eval Mean Reward: {eval_mean:.2f} ± {eval_std:.2f}")

        # Record video (less frequent to save time)
        if update % 1000 == 0:  # changed from 500
            from gymnasium.wrappers import RecordVideo
            eval_env = RecordVideo(
                AtariWrapper(gym.make("TennisDeterministic-v4", render_mode="rgb_array")),
                video_folder="videos",
                name_prefix=f"tennis_update_{update}"
            )
            obs, _ = eval_env.reset()
            done = False
            while not done:
                action, _, _, _ = agent.get_action(np.expand_dims(obs, 0))
                obs, _, done, _, _ = eval_env.step(action[0])
            eval_env.close()

        # Save model
        if update % SAVE_INTERVAL == 0:
            if mean_reward > best_mean_reward:
                best_mean_reward = mean_reward
                torch.save(agent.net.state_dict(), "tennis_ppo_best.pth")
                print(f"New best model saved with mean reward: {best_mean_reward:.2f}")
            torch.save(agent.net.state_dict(), f"models/tennis_ppo_{update}.pth")

    # Final save
    torch.save(agent.net.state_dict(), "models/tennis_ppo_final.pth")
    envs.close()

if __name__ == "__main__":
    # Enable mixed precision if GPU is available
    if torch.cuda.is_available():
        torch.cuda.amp.autocast(enabled=True)
    train()

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gymnasium as gym
from collections import deque
import time
from gymnasium import register
import ale_py
from PIL import Image

# Register the Atari Tennis environment
register(
    id='TennisDeterministic-v4',
    entry_point='ale_py.env:AtariEnv',
    kwargs={'game': 'tennis', 'mode': 0, 'difficulty': 0},
    max_episode_steps=10000,
    nondeterministic=False,
)

# === Hyperparameters ===
GAMMA = 0.99                # discount factor
LAMBDA = 0.95               # GAE parameter
CLIP_EPS = 0.1              # PPO clip parameter
EPOCHS = 4                  # number of optimization epochs per update
MINIBATCH_SIZE = 64         # mini-batch size
ACTOR_LR = 2.5e-4           # actor learning rate
CRITIC_LR = 2.5e-4          # critic learning rate
ENTROPY_COEF = 0.01         # entropy coefficient
NUM_ENVS = 8                # number of parallel environments
NUM_STEPS = 128             # steps per environment per update
NUM_UPDATES = 10000         # total number of updates
SAVE_INTERVAL = 100         # save interval (updates)
FRAME_STACK = 4             # number of frames to stack
FRAME_SKIP = 4              # frame skip (action repeat)

# === Environment Wrapper ===
class AtariWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.env = env
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, FRAME_STACK), dtype=np.uint8)
        self.action_space = gym.spaces.Discrete(env.action_space.n)
        self.frames = deque(maxlen=FRAME_STACK)

    def reset(self, **kwargs):
        obs, _ = self.env.reset(**kwargs)
        obs = self._preprocess(obs)
        for _ in range(FRAME_STACK):
            self.frames.append(obs)
        return np.stack(self.frames, axis=-1), {}

    def step(self, action):
        total_reward = 0.0
        done = False
        for _ in range(FRAME_SKIP):
            obs, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            done = terminated or truncated
            if done:
                break

        obs = self._preprocess(obs)
        self.frames.append(obs)
        stacked_frames = np.stack(self.frames, axis=-1)
        return stacked_frames, total_reward, done, False, info

    def _preprocess(self, frame):
        # Convert to grayscale and resize
        frame = np.array(Image.fromarray(frame).convert('L').resize((84, 84), Image.BILINEAR))
        return frame

# === PPO Network Architecture ===
class AtariCNN(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(FRAME_STACK, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten()
        )

        self.actor = nn.Sequential(
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions),
            nn.Softmax(dim=-1)
        )

        self.critic = nn.Sequential(
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, x):
        x = x.float() / 255.0
        features = self.features(x)
        return self.actor(features), self.critic(features)

# === PPO Agent ===
class PPO:
    def __init__(self, num_actions, device):
        self.net = AtariCNN(num_actions).to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=ACTOR_LR, eps=1e-5)
        self.device = device

    def get_action(self, x):
        with torch.no_grad():
            x = torch.FloatTensor(x).permute(0, 3, 1, 2).to(self.device)
            probs, value = self.net(x)
            dist = Categorical(probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            entropy = dist.entropy()
        return action.cpu().numpy(), log_prob.cpu().numpy(), value.cpu().numpy(), entropy.cpu().numpy()

    def update(self, samples):
        obs, actions, old_log_probs, returns, advantages = samples

        # Convert to tensors
        obs = torch.FloatTensor(np.asarray(obs)).permute(0, 3, 1, 2).to(self.device)
        actions = torch.LongTensor(np.asarray(actions)).to(self.device)
        old_log_probs = torch.FloatTensor(np.asarray(old_log_probs)).to(self.device)
        returns = torch.FloatTensor(np.asarray(returns)).unsqueeze(1).to(self.device)
        advantages = torch.FloatTensor(np.asarray(advantages)).unsqueeze(1).to(self.device)

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        for _ in range(EPOCHS):
            # Shuffle indices
            indices = torch.randperm(len(obs))

            for start in range(0, len(indices), MINIBATCH_SIZE):
                end = start + MINIBATCH_SIZE
                idx = indices[start:end]

                # Get minibatch
                mb_obs = obs[idx]
                mb_actions = actions[idx]
                mb_old_log_probs = old_log_probs[idx]
                mb_returns = returns[idx]
                mb_advantages = advantages[idx]

                # Get current policy
                probs, values = self.net(mb_obs)
                dist = Categorical(probs)
                log_probs = dist.log_prob(mb_actions)
                entropy = dist.entropy().mean()

                # Calculate ratios
                ratios = torch.exp(log_probs - mb_old_log_probs)

                # Actor loss
                surr1 = ratios * mb_advantages
                surr2 = torch.clamp(ratios, 1.0 - CLIP_EPS, 1.0 + CLIP_EPS) * mb_advantages
                actor_loss = -torch.min(surr1, surr2).mean() - ENTROPY_COEF * entropy

                # Critic loss
                critic_loss = 0.5 * (mb_returns - values).pow(2).mean()

                # Total loss
                loss = actor_loss + 0.5 * critic_loss

                # Update
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
                self.optimizer.step()

# === Training Loop ===
def train():
    # Create environments
    envs = gym.vector.SyncVectorEnv([
        lambda: AtariWrapper(gym.make("TennisDeterministic-v4", render_mode="rgb_array"))
        for _ in range(NUM_ENVS)
    ])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    agent = PPO(envs.single_action_space.n, device)

    # Create directory for saving models
    os.makedirs("", exist_ok=True)

    # Training statistics
    episode_rewards = np.zeros(NUM_ENVS)
    final_episode_rewards = []
    best_mean_reward = -np.inf

    # Initialize environments
    obs, _ = envs.reset()

    for update in range(1, NUM_UPDATES + 1):
        # Storage
        mb_obs = np.zeros((NUM_STEPS, NUM_ENVS, 84, 84, FRAME_STACK), dtype=np.uint8)
        mb_actions = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.int32)
        mb_log_probs = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.float32)
        mb_values = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.float32)
        mb_rewards = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.float32)
        mb_dones = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.bool_)
        mb_entropy = np.zeros((NUM_STEPS, NUM_ENVS), dtype=np.float32)

        # Collect rollout
        for step in range(NUM_STEPS):
            mb_obs[step] = obs
            actions, log_probs, values, entropy = agent.get_action(obs)
            mb_actions[step] = actions
            mb_log_probs[step] = log_probs
            mb_values[step] = values.squeeze()
            mb_entropy[step] = entropy

            obs, rewards, dones, _, infos = envs.step(actions)
            mb_rewards[step] = rewards
            mb_dones[step] = dones

            # Track episode rewards
            episode_rewards += rewards
            for i, done in enumerate(dones):
                if done:
                    final_episode_rewards.append(episode_rewards[i])
                    episode_rewards[i] = 0

        # Calculate advantages and returns
        with torch.no_grad():
            last_values = agent.net(torch.FloatTensor(obs).permute(0, 3, 1, 2).to(device))[1].cpu().numpy()

        mb_advantages = np.zeros_like(mb_rewards)
        mb_returns = np.zeros_like(mb_rewards)
        gae = 0

        for t in reversed(range(NUM_STEPS)):
            if t == NUM_STEPS - 1:
                next_non_terminal = 1.0 - mb_dones[t]
                next_values = last_values.squeeze()
            else:
                next_non_terminal = 1.0 - mb_dones[t+1]
                next_values = mb_values[t+1]

            delta = mb_rewards[t] + GAMMA * next_values * next_non_terminal - mb_values[t]
            gae = delta + GAMMA * LAMBDA * next_non_terminal * gae
            mb_advantages[t] = gae
            mb_returns[t] = mb_advantages[t] + mb_values[t]

        # Flatten the batch
        mb_obs = mb_obs.reshape((-1, 84, 84, FRAME_STACK))
        mb_actions = mb_actions.flatten()
        mb_log_probs = mb_log_probs.flatten()
        mb_returns = mb_returns.flatten()
        mb_advantages = mb_advantages.flatten()

        # Update policy
        agent.update((mb_obs, mb_actions, mb_log_probs, mb_returns, mb_advantages))

        # Calculate statistics
        mean_reward = np.mean(final_episode_rewards[-100:]) if final_episode_rewards else 0

        # Print progress
        if update % 10 == 0:
            print(f"Update: {update}/{NUM_UPDATES} | "
                  f"Mean Reward: {mean_reward:.2f} | "
                  f"Minibatch Size: {len(mb_obs)} | "
                  f"Total Frames: {update * NUM_STEPS * NUM_ENVS}")

        # Save model
        if update % SAVE_INTERVAL == 0:
            if mean_reward > best_mean_reward:
                best_mean_reward = mean_reward
                torch.save(agent.net.state_dict(), "tennis_ppo_best.pth")
                print(f"New best model saved with mean reward: {best_mean_reward:.2f}")

            torch.save(agent.net.state_dict(), f"models/tennis_ppo_{update}.pth")

    # Final save
    torch.save(agent.net.state_dict(), "models/tennis_ppo_final.pth")
    envs.close()

if __name__ == "__main__":
    train()