In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [14]:
# Simple neural network for policy
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_dim, 20),
            nn.ReLU(),
            nn.Linear(20, 20),
            nn.ReLU(),
            nn.Linear(20, output_dim),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        return self.layer(x)

In [15]:
# Training settings
learning_rate = 0.01
gamma = 0.99  # Discount factor
episodes = 2_000

# Snake game
input_dim = 12
output_dim = 4

# initialize policy network
policy_network = PolicyNetwork(input_dim, output_dim)
optimizer = optim.Adam(policy_network.parameters(), lr=learning_rate)


In [16]:
# state = [danger + dir_one_hot + food_dir]
from game import SnakeGame

game = SnakeGame(render=False)

In [17]:
learning_rate = 0.001

In [18]:
# Training loop
for episode in range(episodes):
    state = game.reset()
    rewards = []
    log_probs = []
    
    # Collect experience
    done = False
    while not done:
        # Convert state to tensor
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        
        # Get action probabilities
        action_probs = policy_network(state_tensor)
        
        # Sample action from probability distribution
        action_distribution = torch.distributions.Categorical(action_probs)
        action = action_distribution.sample()
        
        # Take action in environment
        next_state, reward, done, _ = game.step(action.item())
        
        # Store reward and log probability
        rewards.append(reward)
        log_probs.append(action_distribution.log_prob(action))
        
        # Update state
        state = next_state
        
        if done:
            break
    
    # Calculate returns (discounted rewards)
    returns = []
    R = 0

    # Calculate the total return over all time steps. We do reversed because we value the most things closer to time t more (they get multiplied by gamma less)
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    returns = torch.FloatTensor(returns)
    
    # Normalize returns (optional but helps with training stability)
    returns = (returns - returns.mean()) / (returns.std() + 1e-9)
    
    # Calculate loss
    policy_loss = []
    for log_prob, R in zip(log_probs, returns):
        policy_loss.append(-log_prob * R)  # Negative because we're doing gradient ascent
    
    policy_loss = torch.cat(policy_loss).sum()
    
    # Update policy
    optimizer.zero_grad()
    policy_loss.backward()
    optimizer.step()
    
    # Print episode results
    total_reward = sum(rewards)
    if episode % 10 == 0:
        print(f"Episode {episode}, Total Reward: {total_reward}")

print("Training completed!")

Episode 0, Total Reward: -10.19
Episode 10, Total Reward: -10.3
Episode 20, Total Reward: -10.73
Episode 30, Total Reward: -10.15
Episode 40, Total Reward: -10.38
Episode 50, Total Reward: -10.19
Episode 60, Total Reward: -11.14
Episode 70, Total Reward: -10.93
Episode 80, Total Reward: -11.010000000000002
Episode 90, Total Reward: -10.23
Episode 100, Total Reward: -10.940000000000001
Episode 110, Total Reward: -10.51
Episode 120, Total Reward: -10.780000000000001
Episode 130, Total Reward: -10.26
Episode 140, Total Reward: -11.15
Episode 150, Total Reward: -10.700000000000001
Episode 160, Total Reward: -10.33
Episode 170, Total Reward: -10.52
Episode 180, Total Reward: -12.0
Episode 190, Total Reward: -10.59
Episode 200, Total Reward: 9.29
Episode 210, Total Reward: -10.450000000000001
Episode 220, Total Reward: -10.73
Episode 230, Total Reward: -1.7999999999999992
Episode 240, Total Reward: -0.6400000000000002
Episode 250, Total Reward: -1.1499999999999997
Episode 260, Total Reward: 

In [16]:
# Save the trained model weights
torch.save(policy_network.state_dict(), 'snake_policy_weights.pth')
print("Model weights saved to 'snake_policy_weights.pth'")


Model weights saved to 'snake_policy_weights.pth'


In [7]:
# Test the trained policy
visual_game = SnakeGame(render=True)

def test_policy(policy, visual_game, episodes=1):
    for episode in range(episodes):
        state = visual_game.reset()
        done = False
        total_reward = 0
        
        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action_probs = policy(state_tensor)
            action = torch.argmax(action_probs, dim=1).item()
            
            state, reward, done, _ = visual_game.step(action)
            total_reward += reward
            
        print(f"Test Episode {episode}, Total Reward: {total_reward}")

test_policy(policy_network, visual_game)

2025-03-11 16:22:09.599 python[96806:1655281] +[IMKClient subclass]: chose IMKClient_Modern
2025-03-11 16:22:09.599 python[96806:1655281] +[IMKInputSession subclass]: chose IMKInputSession_Modern


Test Episode 0, Total Reward: 254.2800000000014
