In [None]:
import pygame
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
from collections import deque
import matplotlib.pyplot as plt

# Pygame setup
pygame.init()
GRID_SIZE = 20
SCREEN_WIDTH = 400
SCREEN_HEIGHT = 450  # Extra space for score
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Snake DQN")
clock = pygame.time.Clock()
font = pygame.font.SysFont('Arial', 24)

# DQN parameters
STATE_SIZE = 9  # Simplified state
HIDDEN_SIZE = 128
ACTION_SPACE = 3  # left, straight, right
BATCH_SIZE = 128
MEMORY_SIZE = 100_000
LEARNING_RATE = 0.001

BG_COLOR = (0, 128, 0)  
SNAKE_COLOR = (52, 21, 57)  
FOOD_COLOR = (255, 0, 0)  
SCOREBOARD_COLOR = (0, 0, 0)  
GRID_LINE_COLOR = (50, 50, 50)  



  from pkg_resources import resource_stream, resource_exists


pygame 2.6.1 (SDL 2.28.4, Python 3.13.0)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class LinearModel(nn.Module):

    def __init__(self, STATE_SIZE, HIDDEN_SIZE, ACTION_SPACE):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(STATE_SIZE, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(HIDDEN_SIZE, ACTION_SPACE)
        )
    
    def forward(self, x):
        return self.net(x)   

    def save(self, file_name='model.pth'):
        model_folder_path = './model'
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)

        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)

class TrainerDQN:
    def __init__(self, model, lr, discount):
        self.lr = lr
        self.gamma = discount
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def train_step(self, state, action, reward, next_state, done):
        state = torch.tensor(state, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.long)
        reward = torch.tensor(reward, dtype=torch.float)

        if len(state.shape) == 1:
            state = torch.unsqueeze(state, 0)
            next_state = torch.unsqueeze(next_state, 0)
            action = torch.unsqueeze(action, 0)
            reward = torch.unsqueeze(reward, 0)
            done = (done, )
#this is done for long term shorterm memory, so we can use the same code for both single and batch training

        pred = self.model(state)

        dummy = pred.clone()

        for i in range(len(done)):
            Q_new = reward[i]
            if not done[i]:
                Q_new = reward[i] + self.gamma * torch.max(self.model(next_state[i]))


            dummy[i][action[i].item()] = Q_new
    
        self.optimizer.zero_grad()
        loss = self.criterion(dummy, pred)
        loss.backward()
        self.optimizer.step()
        #standard pytorch training step

In [3]:

class SnakeGameDQN:
    def __init__(self):
        pygame.init()
        self.screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
        pygame.display.set_caption("Q-Learning Snake")
        self.clock = pygame.time.Clock()
        self.font = pygame.font.SysFont('Arial', 20) # Using 'Arial' for consistency
        self.done = False
        self.reset_game()


    

    def draw_game(self, screen, snake, food, score):
        screen.fill(BG_COLOR)
        
        # Scoreboard
        pygame.draw.rect(screen, SCOREBOARD_COLOR, (0, 0, SCREEN_WIDTH, 50))
        pygame.draw.line(screen, (0, 0, 0), (0, 50), (SCREEN_WIDTH, 50), 2)
        
        # Score text
        font = pygame.font.SysFont('freesanbold.ttf', 40)
        text = font.render(f"Score: {score}", True, (255, 255, 255))
        text_rect = text.get_rect(center=(SCREEN_WIDTH//2, 25))
        screen.blit(text, text_rect)
        
        # Grid
        for x in range(0, SCREEN_WIDTH, GRID_SIZE):
            pygame.draw.line(screen, GRID_LINE_COLOR, (x, 50), (x, SCREEN_HEIGHT))
        for y in range(50, SCREEN_HEIGHT, GRID_SIZE):
            pygame.draw.line(screen, GRID_LINE_COLOR, (0, y), (SCREEN_WIDTH, y))
        
        # Snake
        for segment in self.snake:
            rect = pygame.Rect(segment[0]*GRID_SIZE, segment[1]*GRID_SIZE + 50, 
                            GRID_SIZE, GRID_SIZE)
            pygame.draw.rect(screen, SNAKE_COLOR, rect, 0) 
        
        # Food
        food_rect = pygame.Rect(food[0]*GRID_SIZE, food[1]*GRID_SIZE + 50, 
                            GRID_SIZE, GRID_SIZE)
        pygame.draw.circle(screen, FOOD_COLOR, food_rect.center, 10)
        
        pygame.display.update()

    def get_state(self):
        head = self.snake[0]
        
        # Food relative position
        food_dx = np.sign(self.food[0] - head[0])
        food_dy = np.sign(self.food[1] - head[1])
        
        # Danger detection
        directions = {
            'left': (-1, 0) if self.direction == (0, -1) else 
                    (1, 0) if self.direction == (0, 1) else 
                    (0, 1) if self.direction == (1, 0) else 
                    (0, -1),
            'forward': self.direction,
            'right': (1, 0) if self.direction == (0, -1) else 
                    (-1, 0) if self.direction == (0, 1) else 
                    (0, -1) if self.direction == (1, 0) else 
                    (0, 1)
        }
        
        dangers = []
        for d in ['left', 'forward', 'right']:
            new_pos = (head[0] + directions[d][0], head[1] + directions[d][1])
            # Check boundaries and self-collision
            if (new_pos[0] < 0 or new_pos[0] >= GRID_SIZE or 
                new_pos[1] < 0 or new_pos[1] >= GRID_SIZE or 
                new_pos in self.snake):
                dangers.append(1)
            else:
                dangers.append(0)
        
        # Current direction (one-hot encoded)
        dir_one_hot = [0, 0, 0, 0]
        if self.direction == (1, 0): dir_one_hot[0] = 1    # right
        elif self.direction == (-1, 0): dir_one_hot[1] = 1 # left
        elif self.direction == (0, 1): dir_one_hot[2] = 1  # down
        else: dir_one_hot[3] = 1                           # up
        
        state = np.array([
            food_dx, food_dy,          # Food direction (2)
            *dangers,                  # Danger left, front, right (3)
            *dir_one_hot               # Current direction one-hot (4)
        ], dtype=np.float32)
        
        if state is None:
            raise ValueError("State is None in get_state(). Check game logic for state generation.")
        return state


    def get_direction_from_action(self, action, current_direction):
        # Action: 0=left, 1=forward, 2=right
        if current_direction == (1, 0):  
            if action == 0: return (0, -1)  
            elif action == 1: return (1, 0)  
            elif action == 2: return (0, 1)  
        elif current_direction == (-1, 0): 
            if action == 0: return (0, 1)  
            elif action == 1: return (-1, 0) 
            elif action == 2: return (0, -1)  
        elif current_direction == (0, 1):  
            if action == 0: return (1, 0) 
            elif action == 1: return (0, 1) 
            elif action == 2: return (-1, 0)  
        elif current_direction == (0, -1):  
            if action == 0: return (-1, 0) 
            elif action == 1: return (0, -1)  
            elif action == 2: return (1, 0)  
        return current_direction


    def reset_game(self):
        self.snake = [[10, 10], [9, 10], [8, 10]]
        self.food = [random.randint(0, GRID_SIZE-1), random.randint(0, GRID_SIZE-1)]
        while self.food in self.snake:
            self.food = [random.randint(0, GRID_SIZE-1), random.randint(0, GRID_SIZE-1)]
        
        self.direction = (1, 0)  # Reset to initial direction
        self.score = 0           # Reset score
        self.steps = 0           # Reset steps
        self.done = False        # Reset done flag

    def play_step(self, action):
        # Get new direction based on action
        self.direction = self.get_direction_from_action(action, self.direction)
        
        # Move snake without wrapping around
        head = self.snake[0]
        new_head = [head[0] + self.direction[0], head[1] + self.direction[1]]
        
        # Check for collisions with walls or self
        if (new_head[0] < 0 or new_head[0] >= GRID_SIZE or # Wall collision X
            new_head[1] < 0 or new_head[1] >= GRID_SIZE or # Wall collision Y
            new_head in self.snake): # Self-collision
            reward = -100
            self.done = True
        else:
            self.snake.insert(0, new_head)
            
            if new_head == self.food:
                reward = 20
                self.score += 10
                self.steps = 0
                self.food = [random.randint(0, GRID_SIZE-1), random.randint(0, GRID_SIZE-1)]
                while self.food in self.snake:
                    self.food = [random.randint(0, GRID_SIZE-1), random.randint(0, GRID_SIZE-1)]
            else:
                reward = 0  
                self.snake.pop()
            
            # Check if stuck (too many steps without eating)
            self.steps += 1
            if self.steps > 100 * len(self.snake):
                reward = -50
                self.done = True
        
        return reward, self.done, self.score
    



In [None]:
class QLearningAgent:
    def __init__(self):
        self.epsilon = 0.9
        self.epsilon_min = 0.1
        self.epsilon_decay = 0.9999
        self.learning_rate = LEARNING_RATE
        self.discount_factor = 0.99
        self.memory = deque(maxlen=MEMORY_SIZE)
        self.model = LinearModel(STATE_SIZE, HIDDEN_SIZE, ACTION_SPACE)
        self.trainer = TrainerDQN(self.model, LEARNING_RATE, self.discount_factor)
        self.n_games = 0
        # self.episodes = 10000
        # self.show_every = 1000
        # self.state_size = 8  # Simplified state representation
        # self.action_space = 3  # left, straight, right

    def get_state(self, game):
        return game.get_state()

    def get_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, 2)
        else:
            state = torch.tensor(state, dtype=torch.float)
            state = torch.unsqueeze(state, 0) 
            pred = self.model(state)
            action = torch.argmax(pred).item()
            return action  

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def train_long_memory(self):
        if len(self.memory) > BATCH_SIZE:
            mini_sample = random.sample(self.memory, BATCH_SIZE) # list of tuples
        else:
            mini_sample = self.memory

        states, actions, rewards, next_states, dones = zip(*mini_sample)
        self.trainer.train_step(states, actions, rewards, next_states, dones)

    def train_short_memory(self, state, action, reward, next_state, done):
        self.trainer.train_step(state, action, reward, next_state, done)


def train_DQN():
    pygame.init()
    screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
    pygame.display.set_caption("DQN Snake")
    clock = pygame.time.Clock()
    
    scores = []
    moving_avg_scores = []
    total_score = 0
    record = 0
    agent = QLearningAgent()
    game = SnakeGameDQN()
    
    SHOW_EVERY = 500  
    
    while True:
        game.reset_game()
        state_old= game.get_state()

        while not game.done:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    return
            

            action = agent.get_action(state_old)
            
            reward, done, score = game.play_step(action)
            state_new = game.get_state()
            
            agent.train_short_memory(state_old, action, reward, state_new, done)
            
            # Remember
            agent.remember(state_old, action, reward, state_new, done)
            
            state_old = state_new
            
            # Visualize
            if agent.n_games % SHOW_EVERY == 0:
                game.draw_game(screen, game.snake, game.food, score)
                pygame.display.update()
                clock.tick(60)  
        
       
        agent.train_long_memory()
        
        # Update game count and epsilon
        agent.n_games += 1
        agent.epsilon = max(agent.epsilon_min, agent.epsilon * agent.epsilon_decay)
        
        # Update scores
        scores.append(score)
        total_score += score
        mean_score = total_score / agent.n_games
        moving_avg_scores.append(mean_score)
        
        # Save model if new record
        if score > record:
            record = score
            agent.model.save()
        
        print('Game', agent.n_games, 'Score', score, 'Record:', record, 'Epsilon:', agent.epsilon)

In [None]:


train_DQN()
