# Snake 10x10

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
import gymnasium as gym
import gym_snakegame
from scipy.signal import step
import pygame
import sys

# seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Pytorch device:", device)

Pytorch device: cuda


## Test the gymnasium env

In [6]:
PLAY_MANUAL = True  # set to False if u wanna see random moves
env = gym.make(
    "gym_snakegame/SnakeGame-v0",
    board_size=5,
    n_channel=1,
    n_target=1,
    render_mode="human"
)

obs, info = env.reset()

KEY_MAP = {
    pygame.K_UP: 2,
    pygame.K_RIGHT: 1,
    pygame.K_DOWN: 0,
    pygame.K_LEFT: 3
}

total_reward = 0
steps = 0

print("hit arrows to play! or close window to quit")
if not PLAY_MANUAL:
    print("auto mode: doing 250 random steps...")

try:
    while True:
        action = None

        if PLAY_MANUAL:
            # wait for key press
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    env.close()
                if event.type == pygame.KEYDOWN:
                    if event.key in KEY_MAP:
                        action = KEY_MAP[event.key]
            if action is None:
                continue  # wait till key pressed
        else:
            if steps >= 250:
                break
            action = env.action_space.sample()
            steps += 1

        # step
        obs, reward, terminated, truncated, info = env.step(action)
        env.render()
        total_reward += reward

        print(f"step {steps}/100, score: {total_reward} ({reward} | trunc:{truncated} | trmt:{terminated}) | {info} | {obs}")

        if terminated or truncated:
            print(f"GameOver!  total score: {total_reward}")
            # reset everything
            obs, info = env.reset()
            total_reward = 0
            steps = 0

except KeyboardInterrupt:
    pass
finally:
    env.close()

hit arrows to play! or close window to quit
step 0/100, score: 0 (0 | trunc:False | trmt:False) | {'snake_length': 3, 'prev_action': 1} | [[[ 0  0  0  0  0]
  [26  0  1  0  0]
  [ 0  3  2  0  0]
  [ 0  0  0  0  0]
  [ 0 26  0  0  0]]]
step 0/100, score: 0 (0 | trunc:False | trmt:False) | {'snake_length': 3, 'prev_action': 2} | [[[ 0  0  0  0  0]
  [26  1  2  0  0]
  [ 0  0  3  0  0]
  [ 0  0  0  0  0]
  [ 0 26  0  0  0]]]
step 0/100, score: 1 (1 | trunc:False | trmt:False) | {'snake_length': 4, 'prev_action': 3} | [[[ 0 26  0  0  0]
  [ 1  2  3  0  0]
  [ 0  0  4  0  0]
  [ 0  0  0  0  0]
  [ 0 26  0  0  0]]]
step 0/100, score: 1 (0 | trunc:False | trmt:False) | {'snake_length': 4, 'prev_action': 3} | [[[ 0 26  0  0  0]
  [ 2  3  4  0  0]
  [ 1  0  0  0  0]
  [ 0  0  0  0  0]
  [ 0 26  0  0  0]]]
step 0/100, score: 1 (0 | trunc:False | trmt:False) | {'snake_length': 4, 'prev_action': 0} | [[[ 0 26  0  0  0]
  [ 3  4  0  0  0]
  [ 2  0  0  0  0]
  [ 1  0  0  0  0]
  [ 0 26  0  0  0]]]
s

error: video system not initialized

# Approach 1

In [8]:
class SnakeV0(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=64):
        super(SnakeV0, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_size)
        )

    def forward(self, x):
        return self.net(x)


In [9]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

In [10]:
class DQNAgent:
    def __init__(self, modelclass, state_size, action_size, lr=1e-3, gamma=0.99,
                 buffer_size=10000, batch_size=64, target_update=100):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = gamma
        self.batch_size = batch_size
        self.target_update = target_update

        self.q_net = modelclass(state_size, action_size).to(device)
        self.target_net = modelclass(state_size, action_size).to(device)
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)

        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()

        self.memory = ReplayBuffer(buffer_size)
        self.step_count = 0

    def act(self, state, epsilon=0.0):
        if random.random() < epsilon:
            return random.randrange(self.action_size)
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = self.q_net(state)
        return q_values.argmax().item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.push(state, action, reward, next_state, done)

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)
        states = torch.FloatTensor(states).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(next_states).to(device)
        dones = torch.BoolTensor(dones).to(device)

        current_q_values = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

        with torch.no_grad():
            next_q_values = self.target_net(next_states).max(1)[0]
            target_q_values = rewards + (self.gamma * next_q_values * (~dones))

        loss = nn.MSELoss()(current_q_values, target_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.step_count += 1
        if self.step_count % self.target_update == 0:
            self.target_net.load_state_dict(self.q_net.state_dict())

## utils

In [11]:
def compute_reward(
    raw_reward,
    terminated,
    snake_length,
    board_size,
    steps_since_last_food,
    max_steps_without_food
):
    """
    Unified reward function with shaped incentives.

    Args:
        raw_reward (float): Reward from environment (typically +1 for food).
        terminated (bool): True if snake died.
        snake_length (int): Current length of the snake.
        board_size (int): Side length of square board.
        steps_since_last_food (int): Number of steps since last food eaten.
        max_steps_without_food (int): Threshold for looping penalty (set to board_size**2).

    Returns:
        float: Shaped reward for the current step.
    """
    # Full board completion
    if snake_length == board_size * board_size:
        return 100.0

    # Death penalty
    if terminated:
        return -10.0

    reward = 0.0

    # Food reward
    if raw_reward > 0:
        reward += 1.0

    # Small step penalty to encourage efficiency
    reward -= 0.01

    # Looping penalty: no food in too long
    if steps_since_last_food >= max_steps_without_food:
        reward -= 0.1

    return reward


def train_dqn(
    modelclass,
    board_size=10,
    env_name="gym_snakegame/SnakeGame-v0",
    episodes=20000,
    max_steps=500
):
    env = gym.make(
        env_name,
        board_size=board_size,
        n_channel=1,
        n_target=1,
        render_mode=None
    )
    state_size = np.prod(env.observation_space.shape)
    action_size = env.action_space.n
    agent = DQNAgent(modelclass=modelclass, state_size=state_size, action_size=action_size)

    scores = deque(maxlen=100)
    epsilon = 1.0
    epsilon_decay = 0.995
    epsilon_end = 0.01
    max_steps_without_food = board_size * board_size

    print("Starting training...")
    for episode in range(episodes):
        obs, info = env.reset()
        state = obs.flatten().astype(np.float32)
        total_reward = 0.0
        steps_since_last_food = 0
        initial_length = info.get('snake_length', 1)

        for t in range(max_steps):
            action = agent.act(state, epsilon)
            next_obs, raw_reward, terminated, truncated, next_info = env.step(action)
            done = terminated or truncated

            current_length = next_info.get('snake_length', initial_length)
            ate_food = (raw_reward > 0) or (current_length > initial_length + steps_since_last_food)

            if ate_food:
                steps_since_last_food = 0
            else:
                steps_since_last_food += 1

            reward = compute_reward(
                raw_reward=raw_reward,
                terminated=terminated,
                snake_length=current_length,
                board_size=board_size,
                steps_since_last_food=steps_since_last_food,
                max_steps_without_food=max_steps_without_food
            )

            next_state = next_obs.flatten().astype(np.float32)
            agent.remember(state, action, reward, next_state, done)
            agent.replay()

            state = next_state
            total_reward += reward
            initial_length = current_length

            if done:
                break

        scores.append(total_reward)
        epsilon = max(epsilon_end, epsilon_decay * epsilon)

        if episode % 100 == 0:
            avg_score = np.mean(scores)
            print(f"Episode {episode}, Avg Reward (last 100): {avg_score:.3f}, Epsilon: {epsilon:.3f}")

    env.close()
    return agent


def evaluate_agent(agent, board_size=10, episodes=100, max_steps=500):
    env = gym.make(
        "gym_snakegame/SnakeGame-v0",
        board_size=board_size,
        n_channel=1,
        n_target=1,
        render_mode="human"
    )
    max_steps_without_food = board_size * board_size
    total_reward = 0.0

    for _ in range(episodes):
        obs, info = env.reset()
        state = obs.flatten().astype(np.float32)
        steps_since_last_food = 0
        initial_length = info.get('snake_length', 1)
        episode_reward = 0.0

        for _ in range(max_steps):
            action = agent.act(state, epsilon=0.0)
            next_obs, raw_reward, terminated, truncated, next_info = env.step(action)
            done = terminated or truncated

            current_length = next_info.get('snake_length', initial_length)
            ate_food = (raw_reward > 0) or (current_length > initial_length + steps_since_last_food)

            if ate_food:
                steps_since_last_food = 0
            else:
                steps_since_last_food += 1

            reward = compute_reward(
                raw_reward=raw_reward,
                terminated=terminated,
                snake_length=current_length,
                board_size=board_size,
                steps_since_last_food=steps_since_last_food,
                max_steps_without_food=max_steps_without_food
            )

            state = next_obs.flatten().astype(np.float32)
            episode_reward += reward
            initial_length = current_length

            if done:
                break

        total_reward += episode_reward

    avg_reward = total_reward / episodes
    print(f"Evaluation over {episodes} episodes: Average Reward = {avg_reward:.2f}")
    env.close()

## Training loop

In [12]:
trained_agent = train_dqn(modelclass=SnakeV0, board_size=10, episodes=20_000)

Starting training...
Episode 0, Avg Reward (last 100): -10.060, Epsilon: 0.995
Episode 100, Avg Reward (last 100): -9.977, Epsilon: 0.603
Episode 200, Avg Reward (last 100): -10.030, Epsilon: 0.365
Episode 300, Avg Reward (last 100): -9.994, Epsilon: 0.221
Episode 400, Avg Reward (last 100): -10.035, Epsilon: 0.134
Episode 500, Avg Reward (last 100): -10.011, Epsilon: 0.081
Episode 600, Avg Reward (last 100): -10.030, Epsilon: 0.049
Episode 700, Avg Reward (last 100): -10.023, Epsilon: 0.030
Episode 800, Avg Reward (last 100): -10.009, Epsilon: 0.018
Episode 900, Avg Reward (last 100): -9.991, Epsilon: 0.011
Episode 1000, Avg Reward (last 100): -9.995, Epsilon: 0.010
Episode 1100, Avg Reward (last 100): -9.996, Epsilon: 0.010
Episode 1200, Avg Reward (last 100): -9.883, Epsilon: 0.010
Episode 1300, Avg Reward (last 100): -9.997, Epsilon: 0.010
Episode 1400, Avg Reward (last 100): -9.956, Epsilon: 0.010
Episode 1500, Avg Reward (last 100): -10.005, Epsilon: 0.010
Episode 1600, Avg Rewar

In [41]:
# evaluate_agent(trained_agent, board_size=10, episodes=10)
eval_env = gym.make("gym_snakegame/SnakeGame-v0",
                    board_size=10,
                    n_channel=1,
                    n_target=1,
                    render_mode="human")
evaluate_agent(agent, eval_env, episodes=200)
eval_env.close()


Success rate over 200 episodes: -200/200 (-100.0%)
