In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from collections import deque
import random
from snakeenv import *

pygame 2.1.0 (SDL 2.0.16, Python 3.9.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
def train(env, policy_net, num_episodes=1000, max_steps=200, gamma=0.99, lr=0.001, eps=1e-8):
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    replay_buffer = deque(maxlen=10000)

    for episode in range(num_episodes):
        state = env.reset()
        episode_reward = 0

        for step in range(max_steps):
            state_tensor = torch.tensor(state, dtype=torch.float32)
            logits = policy_net(state_tensor)

            # Replace nan and infinity values in logits
            logits = torch.where(torch.isnan(
                logits), torch.zeros_like(logits), logits)
            logits = torch.where(torch.isinf(
                logits), torch.zeros_like(logits), logits)

            log_probs = torch.log_softmax(logits, dim=0)
            action_dist = Categorical(logits=log_probs)
            action = action_dist.sample()

            next_state, reward, done, _ = env.step(action.item())
            replay_buffer.append((state, action, reward, next_state, done))
            state = next_state
            episode_reward += reward

            if done:
                break

        if episode % 100 == 0:
            print(f"Episode: {episode}, Reward: {episode_reward}")

        optimize_model(policy_net, optimizer, replay_buffer, gamma)

    env.close()
    return policy_net


def optimize_model(policy_net, optimizer, replay_buffer, gamma):
    if len(replay_buffer) < 1000:
        return

    transitions = random.sample(replay_buffer, 1000)
    batch = list(zip(*transitions))  # Convert the iterator to a list

    state_batch = torch.tensor(batch[0], dtype=torch.float32)
    action_batch = torch.tensor(batch[1], dtype=torch.int64)
    reward_batch = torch.tensor(batch[2], dtype=torch.float32)
    next_state_batch = torch.tensor(batch[3], dtype=torch.float32)
    done_batch = torch.tensor(
        [float(done) for done in batch[4]], dtype=torch.float32)  # Convert bool to float

    logits = policy_net(state_batch)
    action_probs = torch.softmax(logits, dim=1)
    log_probs = torch.log(action_probs.gather(1, action_batch.unsqueeze(1)))

    next_state_logits = policy_net(next_state_batch)
    next_state_values = next_state_logits.max(1).values
    expected_values = (reward_batch + gamma *
                       next_state_values * (1 - done_batch)).detach()

    loss = -(log_probs * (expected_values - logits.gather(1,
             action_batch.unsqueeze(1)).squeeze())).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [11]:
import time


def evaluate(env, policy_net, num_episodes=10, render=True):
    total_rewards = []
    env.render_mode = 'human' if render else None

    for episode in range(num_episodes):
        state = env.reset()
        episode_reward = 0
        done = False

        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32)
            logits = policy_net(state_tensor)

            # Replace nan values in logits with zeros
            logits = torch.where(torch.isnan(
                logits), torch.zeros_like(logits), logits)

            # Clip logits to a valid range
            logits = torch.clamp(logits, -10, 10)

            log_probs = torch.log_softmax(logits, dim=0)
            action_dist = Categorical(logits=log_probs)
            action = action_dist.sample()

            next_state, reward, done, _ = env.step(action.item())
            state = next_state
            episode_reward += reward

            if render:
                env.render()
                time.sleep(0.05)

            if done:
                total_rewards.append(episode_reward)
                break

    avg_reward = sum(total_rewards) / num_episodes
    print(f"Average reward over {num_episodes} episodes: {avg_reward}")

    env.close()
    return avg_reward

In [14]:
env = SnakeEnv()
policy_net = PolicyNetwork(env.observation_space.shape[0], env.action_space.n)
trained_policy_net = train(env, policy_net, num_episodes=50000)
torch.save(trained_policy_net.state_dict(), 'snake_policy.pth')

Episode: 0, Reward: -111
Episode: 100, Reward: -107
Episode: 200, Reward: -107
Episode: 300, Reward: -113
Episode: 400, Reward: -103
Episode: 500, Reward: -111
Episode: 600, Reward: -105
Episode: 700, Reward: -105
Episode: 800, Reward: -107
Episode: 900, Reward: -103
Episode: 1000, Reward: -109
Episode: 1100, Reward: -107
Episode: 1200, Reward: -103
Episode: 1300, Reward: -105
Episode: 1400, Reward: -103
Episode: 1500, Reward: -111
Episode: 1600, Reward: -105
Episode: 1700, Reward: -103
Episode: 1800, Reward: -103
Episode: 1900, Reward: -103
Episode: 2000, Reward: -105
Episode: 2100, Reward: -103
Episode: 2200, Reward: -103
Episode: 2300, Reward: -109
Episode: 2400, Reward: -105
Episode: 2500, Reward: -103
Episode: 2600, Reward: -115
Episode: 2700, Reward: -105
Episode: 2800, Reward: -113
Episode: 2900, Reward: -103
Episode: 3000, Reward: -107
Episode: 3100, Reward: -92
Episode: 3200, Reward: -103
Episode: 3300, Reward: -103
Episode: 3400, Reward: -111
Episode: 3500, Reward: -103
Episo

KeyboardInterrupt: 

In [13]:
avg_reward = evaluate(env, trained_policy_net, render=True)

Average reward over 10 episodes: -102.5
