In [9]:
!pip -q install "gymnasium[accept-rom-license]"
!pip -q install "gymnasium[atari]"

In [10]:
!nvidia-smi

Sun Oct  6 09:06:50 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   38C    P8              9W /   70W |       1MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla T4                       Off |   00

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Categorical
import numpy as np
import gymnasium as gym
from gymnasium.wrappers import GrayScaleObservation, FrameStack
import wandb

In [12]:
# Get ROM
from ale_py import ALEInterface
ale = ALEInterface()

from ale_py.roms import BattleZone
ale.loadROM(BattleZone)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
Game console created:
  ROM file:  /opt/conda/lib/python3.10/site-packages/AutoROM/roms/battle_zone.bin
  Cart Name: Battlezone (1983) (Atari) [!]
  Cart MD5:  41f252a66c6301f1e8ab3612c19bc5d4
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        8192
  Bankswitch Type: AUTO-DETECT ==> F8

Running ROM file...
Random seed is 1728205611


In [13]:
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

wandb_api = user_secrets.get_secret("WANDB_API_KEY") 

In [14]:
class PPOPolicy(nn.Module):
    def __init__(self, env):
        super().__init__()

        self.roi1_conv1 = nn.Conv2d(4, 32, 8, stride=4)
        self.roi1_conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.roi1_conv3 = nn.Conv2d(64, 64, 4, stride=2)

        self.roi2_conv1 = nn.Conv2d(4, 16, 4, stride=2)
        self.roi2_conv2 = nn.Conv2d(16, 32, 4, stride=2)

        self.linear1 = nn.Linear(1792, 512)
        self.actor = nn.Linear(512, 18)  # env.action_space
        self.critic = nn.Linear(512, 1)  # Value output

    def forward(self, x):
        x = x / 255
        x1, x2 = self.ROI(x)

        # process ROI1
        x1 = F.relu(self.roi1_conv1(x1))
        x1 = F.relu(self.roi1_conv2(x1))
        x1 = F.relu(self.roi1_conv3(x1))
        x1 = x1.view(x1.shape[0], -1)

        # process ROI2
        x2 = F.relu(self.roi2_conv1(x2))
        x2 = F.relu(self.roi2_conv2(x2))
        x2 = x2.view(x2.shape[0], -1)

        x3 = torch.cat((x1, x2), dim=1)
        x3 = F.relu(self.linear1(x3))
        return x3

    def get_action(self, x):
        x3 = self.forward(x)
        logits = self.actor(x3)
        probs = F.softmax(logits, dim=1)
        return probs

    def get_value(self, x):
        x3 = self.forward(x)
        value = self.critic(x3)
        return value

    def ROI(self, img):
        """
        :return: roi1 - contains horizontal zone with tanks
                 roi2 - contains radar
        """
        height, width = img.shape[2], img.shape[3]

        roi1 = img[:, :, int(height * 0.4) : int(height * 0.75), :]
        roi2 = img[
            :,
            :,
            int(height * 0.02) : int(height * 0.17),
            int(width * 0.465) : int(width * 0.6),
        ]

        return roi1, roi2

In [15]:
def make_env(env_id, seed):
    env = gym.make(env_id)
    env = GrayScaleObservation(env)
    env = FrameStack(env, 4)
    env.action_space.seed(seed)
    return env


def train(
    env,
    policy,
    optimizer,
    num_episodes=1000,
    gamma=0.99,
    clip_epsilon=0.2,
    ppo_epochs=4,
    wandb_name=wandb.util.generate_id()
):  
    wandb.init(
        project="ppo-training-battlezone-rl", 
        config={
            "num_episodes": num_episodes,
            "gamma": gamma,
            "clip_epsilon": clip_epsilon,
            "ppo_epochs": ppo_epochs,
            "learning_rate": optimizer.param_groups[0]['lr'],
        },
        name=wandb_name
    )
    policy.train()
    for episode in range(num_episodes):
        obs, _ = env.reset()
        log_probs = []
        values = []
        rewards = []
        dones = []
        states = []

        done = False
        total_reward = 0
        while not done:
            obs_tensor = torch.tensor(np.asarray(obs)).unsqueeze(0).float()
            states.append(obs_tensor)

            # Get action and value
            probs = policy.get_action(obs_tensor)
            value = policy.get_value(obs_tensor)
            cat = Categorical(probs)
            action = cat.sample()

            # Step in the environment
            obs, reward, done, _, _ = env.step(action.item())
            total_reward += reward

            # Store log probability, value, and reward
            log_prob = cat.log_prob(action)
            log_probs.append(log_prob)
            values.append(value)
            rewards.append(reward)
            dones.append(done)

        # Compute returns and advantages
        returns = []
        advantages = []
        G = 0
        for reward in reversed(rewards):
            G = reward + gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns, dtype=torch.float32)
        values = torch.cat(values).squeeze()
        advantages = returns - values.detach()

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-9)

        # PPO Update
        loss = None  # Initialize loss to prevent unbound variable error
        log_probs = torch.cat(
            log_probs
        )  # Concatenate log_probs outside of loop to avoid in-place modifications
        values = values.squeeze()  # Make sure values are properly squeezed

        for _ in range(ppo_epochs):
            new_log_probs = []
            new_values = []
            for state in states:
                probs = policy.get_action(state)
                value = policy.get_value(state)
                cat = Categorical(probs)
                action = cat.sample()
                new_log_prob = cat.log_prob(action)
                new_log_probs.append(new_log_prob)
                new_values.append(value)

            new_log_probs = torch.cat(new_log_probs)
            new_values = torch.cat(new_values).squeeze()
            ratios = torch.exp(
                new_log_probs - log_probs.detach()
            )  # Detach to avoid in-place modifications
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = F.mse_loss(new_values, returns)
            loss = actor_loss + 0.5 * critic_loss

            # Update policy
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Log episode results to wandb
        if loss is not None:
            wandb.log({
                "episode": episode + 1,
                "loss": loss.item(),
                "total_reward": total_reward
            })
            print(f"Episode {episode + 1}/{num_episodes}, Loss: {loss.item()}, Total Reward: {total_reward}")

In [16]:
# Initialize wandb
wandb.login(key=wandb_api)

[34m[1mwandb[0m: Currently logged in as: [33melskow[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
env = make_env("ALE/BattleZone-v5", 0)
policy = PPOPolicy(env)
optimizer = Adam(policy.parameters(), lr=1e-4)

# Train the policy
train(env, policy, optimizer)

# Save the trained model
torch.save(policy.state_dict(), "trained_ppo_policy.pth")

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113773366666161, max=1.0…

Episode 1/1000, Loss: 62277.3515625, Total Reward: 3000.0
Episode 2/1000, Loss: 93489.296875, Total Reward: 3000.0
Episode 3/1000, Loss: 22033.025390625, Total Reward: 1000.0
Episode 4/1000, Loss: 22119.015625, Total Reward: 1000.0
Episode 5/1000, Loss: 4.024988174438477, Total Reward: 0.0
Episode 6/1000, Loss: 55883.03515625, Total Reward: 3000.0


In [None]:
wandb.save("/kaggle/working/trained_ppo_policy.pth")

In [None]:
wandb.finish()