# Policy search
Policy search RL algorithm for playing the game of Space Invaders.

In [14]:
# Import the environment
import ale_py
import gymnasium as gym
gym.register_envs(ale_py) # needed to run atari games

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical

In [None]:
def preprocess_frame(frame):
    frame = frame[34:194] 
    frame = frame[::2, ::2, 0] 
    frame[frame == 144] = 0
    frame[frame == 109] = 0
    frame[frame != 0] = 1
    return np.expand_dims(frame.astype(np.float32), axis=0)

In [15]:
class PolicyNetwork(nn.Module):
    def __init__(self, action_space):
        super(PolicyNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048, 256),  # Updated input size after CNN
            nn.ReLU(),
            nn.Linear(256, action_space)
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return nn.Softmax(dim=-1)(x)

def compute_discounted_rewards(rewards, gamma):
    discounted_rewards = np.zeros_like(rewards)
    cumulative = 0
    for t in reversed(range(len(rewards))):
        cumulative = rewards[t] + gamma * cumulative
        discounted_rewards[t] = cumulative
    return discounted_rewards

env = gym.make('ALE/SpaceInvaders-v5', render_mode=None)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
policy = PolicyNetwork(env.action_space.n).to(device)
optimizer = optim.Adam(policy.parameters(), lr=1e-3)
gamma = 0.99

num_episodes = 500
episode_rewards = []

for episode in range(num_episodes):
    state, _ = env.reset()
    state = preprocess_frame(state)
    state = torch.tensor(state, dtype=torch.float32, device=device)

    log_probs = []
    rewards = []

    done = False
    while not done:
        action_probs = policy(state.unsqueeze(0))
        dist = Categorical(action_probs)
        action = dist.sample()

        next_state, reward, done, _, _ = env.step(action.item())
        next_state = preprocess_frame(next_state)
        next_state = torch.tensor(next_state, dtype=torch.float32, device=device)

        log_probs.append(dist.log_prob(action))
        rewards.append(reward)

        state = next_state

    discounted_rewards = compute_discounted_rewards(rewards, gamma)
    discounted_rewards = torch.tensor(discounted_rewards, dtype=torch.float32, device=device)
    
    discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-8)

    loss = -torch.sum(torch.stack(log_probs) * discounted_rewards)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    episode_rewards.append(sum(rewards))
    print(f"Episode {episode + 1}: Total Reward = {sum(rewards)}")

env.close()


Using device: mps
Episode 1: Total Reward = 75.0
Episode 2: Total Reward = 210.0
Episode 3: Total Reward = 300.0
Episode 4: Total Reward = 90.0
Episode 5: Total Reward = 240.0
Episode 6: Total Reward = 75.0
Episode 7: Total Reward = 155.0
Episode 8: Total Reward = 355.0
Episode 9: Total Reward = 55.0


KeyboardInterrupt: 