In [None]:
import numpy as np
import random
import tensorflow as tf
from collections import deque
import pygame

# Define SnakeGame environment (simplified)
class SnakeGame:
    def __init__(self):
        self.width, self.height = 400, 400
        self.snake_pos = [100, 50]  # Initial position of the snake's head
        self.food_pos = [200, 200]   # Initial position of the food
        self.game_over = False       # Flag indicating if the game is over
    
    def step(self, action):
        # Perform action, update snake and food positions
        self._move_snake(action)
        self._check_collision()
        
        # Calculate reward based on game state
        reward = 1 if self.snake_pos == self.food_pos else -0.1
        
        # Check if the game is over due to collision
        if self.game_over:
            reward = -1
        
        return reward, self.game_over
    
    def reset(self):
        # Reset game to initial state
        self.snake_pos = [100, 50]
        self.food_pos = [200, 200]
        self.game_over = False
        
    def _move_snake(self, action):
        # Update snake's position based on the chosen action
        if action == 0:  # Up
            self.snake_pos[1] -= 10
        elif action == 1:  # Down
            self.snake_pos[1] += 10
        elif action == 2:  # Left
            self.snake_pos[0] -= 10
        elif action == 3:  # Right
            self.snake_pos[0] += 10
    
    def _check_collision(self):
        # Check for collision with walls
        if self.snake_pos[0] < 0 or self.snake_pos[0] >= self.width \
            or self.snake_pos[1] < 0 or self.snake_pos[1] >= self.height:
            self.game_over = True
        
        # Check for collision with food
        if self.snake_pos == self.food_pos:
            self._generate_new_food()
    
    def _generate_new_food(self):
        # Generate new food position
        self.food_pos = [random.randrange(0, self.width, 10), random.randrange(0, self.height, 10)]


# Define DQN agent
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.learning_rate = 0.001
        self.model = self._build_model()

    def _build_model(self):
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(24, input_dim=self.state_size, activation='relu'),
            tf.keras.layers.Dense(24, activation='relu'),
            tf.keras.layers.Dense(self.action_size, activation='linear')
        ])
        model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(lr=self.learning_rate))
        return model

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        return np.argmax(self.model.predict(state)[0])

    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                target = reward + self.gamma * np.amax(self.model.predict(next_state)[0])
            target_f = self.model.predict(state)
            target_f[0][action] = target
            self.model.fit(state, target_f, epochs=1, verbose=0)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

# Initialize Pygame
pygame.init()

# Set up game variables
width, height = 400, 400
screen = pygame.display.set_mode((width, height))
clock = pygame.time.Clock()

# Initialize SnakeGame environment
env = SnakeGame()

# Initialize DQNAgent
state_size = 4  # Example: snake x, snake y, food x, food y
action_size = 4  # Example: up, down, left, right
agent = DQNAgent(state_size, action_size)

# Main DQN training loop
batch_size = 32
num_episodes = 1000

for episode in range(num_episodes):
    state = np.array([env.snake_pos[0], env.snake_pos[1], env.food_pos[0], env.food_pos[1]])
    state = state.reshape(1, state_size)
    done = False
    total_reward = 0

    while not done:
        action = agent.act(state)
        reward, done = env.step(action)
        total_reward += reward
        next_state = np.array([env.snake_pos[0], env.snake_pos[1], env.food_pos[0], env.food_pos[1]])
        next_state = next_state.reshape(1, state_size)
        agent.remember(state, action, reward, next_state, done)
        state = next_state

        if len(agent.memory) > batch_size:
            agent.replay(batch_size)

    print(f"Episode: {episode + 1}, Reward: {total_reward}")
    env.reset()

pygame.quit()


pygame 2.5.0 (SDL 2.28.0, Python 3.11.3)
Hello from the pygame community. https://www.pygame.org/contribute.html




