In [1]:
import pygame
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.optimizers import Adam
from collections import deque

# Initialize Pygame
pygame.init()

# Set up display
width, height = 640, 480
window = pygame.display.set_mode((width, height))
pygame.display.set_caption('Snake Game')

# Define colors
black = (0, 0, 0)
white = (255, 255, 255)
red = (255, 0, 0)
green = (0, 255, 0)

# Set up snake parameters
block_size = 20

# Set up clock
clock = pygame.time.Clock()

font = pygame.font.SysFont(None, 35)

# Define DQN Agent
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95  # discount rate
        self.epsilon = 1.0  # exploration rate
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.model = self._build_model()

    def _build_model(self):
        model = Sequential()
        model.add(Flatten(input_shape=(self.state_size,)))
        model.add(Dense(24, activation='relu'))
        model.add(Dense(24, activation='relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(loss='mse', optimizer=Adam(learning_rate=0.001))
        return model

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        act_values = self.model.predict(state)
        return np.argmax(act_values[0])

    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                target = (reward + self.gamma * np.amax(self.model.predict(next_state)[0]))
            target_f = self.model.predict(state)
            target_f[0][action] = target
            self.model.fit(state, target_f, epochs=1, verbose=0)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

def message(msg, color, position):
    screen_text = font.render(msg, True, color)
    window.blit(screen_text, position)

def get_state(snake_list, foodx, foody, x1, y1):
    head_x, head_y = snake_list[-1] if snake_list else (x1, y1)
    state = [
        int(head_x == foodx),  # Food is to the left
        int(head_x < foodx),   # Food is to the right
        int(head_y == foody),  # Food is below
        int(head_y < foody),   # Food is above
        int(x1 == head_x),     # Moving left
        int(x1 > head_x),      # Moving right
        int(y1 == head_y),     # Moving up
        int(y1 > head_y)       # Moving down
    ]
    return np.array(state, dtype=int)

def game_loop():
    state_size = 8  # State size based on our state representation
    action_size = 4  # Four possible actions
    agent = DQNAgent(state_size, action_size)
    batch_size = 32
    num_episodes = 100

    high_score = 0

    for e in range(num_episodes):
        x1 = width / 2
        y1 = height / 2
        x1_change = 0
        y1_change = 0
        snake_List = []
        Length_of_snake = 1
        foodx = round(random.randrange(0, width - block_size) / block_size) * block_size
        foody = round(random.randrange(0, height - block_size) / block_size) * block_size
        snake_speed = 15
        state = get_state(snake_List, foodx, foody, x1, y1)
        state = np.reshape(state, [1, state_size])

        game_over = False
        game_close = False

        while not game_over:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    game_over = True
                    game_close = False

            while game_close:
                window.fill(black)
                message("You Lost! Press Q-Quit or C-Play Again", red, [width / 6, height / 3])
                pygame.display.update()

                for event in pygame.event.get():
                    if event.type == pygame.KEYDOWN:
                        if event.key == pygame.K_q:
                            game_over = True
                            game_close = False
                        if event.key == pygame.K_c:
                            game_loop()
                            return  # Important to exit the current loop

            action = agent.act(state)
            if action == 0:
                x1_change = -block_size
                y1_change = 0
            elif action == 1:
                x1_change = block_size
                y1_change = 0
            elif action == 2:
                y1_change = -block_size
                x1_change = 0
            elif action == 3:
                y1_change = block_size
                x1_change = 0

            if x1 >= width or x1 < 0 or y1 >= height or y1 < 0:
                game_close = True
            x1 += x1_change
            y1 += y1_change
            window.fill(black)
            pygame.draw.rect(window, green, [foodx, foody, block_size, block_size])
            snake_Head = [x1, y1]
            snake_List.append(snake_Head)
            if len(snake_List) > Length_of_snake:
                del snake_List[0]

            for x in snake_List[:-1]:
                if x == snake_Head:
                    game_close = True

            for x in snake_List:
                pygame.draw.rect(window, white, [x[0], x[1], block_size, block_size])

            if x1 == foodx and y1 == foody:
                foodx = round(random.randrange(0, width - block_size) / block_size) * block_size
                foody = round(random.randrange(0, height - block_size) / block_size) * block_size
                Length_of_snake += 1
                snake_speed += 5  # Increase the snake speed more significantly

            next_state = get_state(snake_List, foodx, foody, x1, y1)
            next_state = np.reshape(next_state, [1, state_size])
            reward = 1 if not game_close else -10
            agent.remember(state, action, reward, next_state, game_close)
            state = next_state

            # Display running score
            message(f"Score: {Length_of_snake - 1}", white, [0, 0])
            message(f"High Score: {high_score}", white, [width - 200, 0])

            if game_close:
                high_score = max(high_score, Length_of_snake - 1)
                print(f"Episode: {e+1}/{num_episodes}, Score: {Length_of_snake-1}, Epsilon: {agent.epsilon:.2}")
                break

            if len(agent.memory) > batch_size:
                agent.replay(batch_size)

            pygame.display.update()
            clock.tick(snake_speed)

    pygame.quit()
    quit()

# Run the game loop
game_loop()


pygame 2.6.0 (SDL 2.28.4, Python 3.12.4)
Hello from the pygame community. https://www.pygame.org/contribute.html


  super().__init__(**kwargs)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 50ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12