In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import pygame
from pygame.math import Vector3
from OpenGL.GL import *
from OpenGL.GLU import *
from OpenGL.GLUT import *
import matplotlib.pyplot as plt

# Constants
GRID_SIZE = 10
DIMENSIONS = 6
N_ACTIONS = 2 * DIMENSIONS
N_AGENTS = 3
MEMORY_SIZE = 10000
BATCH_SIZE = 64
GAMMA = 0.99
EPSILON_START = 1.0
EPSILON_END = 0.1
EPSILON_DECAY = 0.995
TARGET_UPDATE = 100
LEARNING_RATE = 0.0005



pygame 2.6.0 (SDL 2.28.4, Python 3.11.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
# Define the DQN model
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)



In [3]:
# Define the Snake environment
class Snake6D:
    def __init__(self):
        self.reset()

    def reset(self):
        self.snakes = [deque([[random.randint(0, GRID_SIZE-1) for _ in range(DIMENSIONS)]]) for _ in range(N_AGENTS)]
        self.food = self.generate_food()
        self.scores = [0] * N_AGENTS
        return self.get_state()

    def generate_food(self):
        while True:
            food = [random.randint(0, GRID_SIZE-1) for _ in range(DIMENSIONS)]
            if not any(food in snake for snake in self.snakes):
                return food

    def get_state(self):
        state = []
        for snake in self.snakes:
            head = snake[0]
            state.extend(head)
        state.extend(self.food)
        return np.array(state, dtype=np.float32)

    def step(self, actions):
        rewards = [0] * N_AGENTS
        done = [False] * N_AGENTS  # Done condition per agent

        for i, action in enumerate(actions):
            direction = [0] * DIMENSIONS
            direction[action // 2] = 1 if action % 2 == 0 else -1
            
            new_head = [
                (self.snakes[i][0][j] + direction[j]) % GRID_SIZE
                for j in range(DIMENSIONS)
            ]

            # Check if new head is in the snake's body (self-collision)
            if new_head in self.snakes[i]:
                rewards[i] = -10  # Higher penalty for self-collision
                done[i] = True
            else:
                self.snakes[i].appendleft(new_head)
                if new_head == self.food:
                    self.scores[i] += 1
                    rewards[i] = 10  # Reward for eating food
                    self.food = self.generate_food()
                else:
                    self.snakes[i].pop()

            # Boundary check (handle collisions properly with wrapping)
            if any(coord == 0 or coord == GRID_SIZE-1 for coord in new_head):
                rewards[i] = -5  # Penalty for hitting boundaries (if applicable)

        # If any agent reaches a score of 50, end the game
        if max(self.scores) >= 50:
            done = [True] * N_AGENTS

        return self.get_state(), rewards, done



In [4]:
# Define the Agent
class Agent:
    def __init__(self, state_dim, action_dim):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = DQN(state_dim, action_dim).to(self.device)
        self.target_net = DQN(state_dim, action_dim).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
        self.memory = deque(maxlen=MEMORY_SIZE)
        self.epsilon = EPSILON_START

    def select_action(self, state):
        if random.random() > self.epsilon:
            with torch.no_grad():
                return self.policy_net(torch.tensor(state, device=self.device)).max(0)[1].item()
        else:
            return random.randrange(N_ACTIONS)

    def update_epsilon(self):
        self.epsilon = max(EPSILON_END, self.epsilon * EPSILON_DECAY)

    def store_transition(self, state, action, reward, next_state):
        self.memory.append((state, action, reward, next_state))

    def update_model(self):
        if len(self.memory) < BATCH_SIZE:
            return

        batch = random.sample(self.memory, BATCH_SIZE)
        states, actions, rewards, next_states = zip(*batch)

        states = torch.tensor(states, device=self.device)
        actions = torch.tensor(actions, device=self.device).unsqueeze(1)
        rewards = torch.tensor(rewards, device=self.device)
        next_states = torch.tensor(next_states, device=self.device)

        current_q = self.policy_net(states).gather(1, actions)
        next_q = self.target_net(next_states).max(1)[0].detach()
        target_q = rewards + (GAMMA * next_q)

        loss = nn.functional.smooth_l1_loss(current_q, target_q.unsqueeze(1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())



In [5]:
# Visualization class
class Visualizer:
    def __init__(self):
        pygame.init()
        self.display = (800, 600)
        pygame.display.set_mode(self.display, pygame.DOUBLEBUF | pygame.OPENGL)
        gluPerspective(45, (self.display[0] / self.display[1]), 0.1, 50.0)
        glTranslatef(0.0, 0.0, -30)
        self.colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]

    def draw_cube(self, position, color):
        x, y, z = position[:3]
        glPushMatrix()
        glTranslatef(x - GRID_SIZE/2, y - GRID_SIZE/2, z - GRID_SIZE/2)
        glScale(0.5, 0.5, 0.5)
        glColor3fv(color)
        
        vertices = [
            (1, -1, -1), (1, 1, -1), (-1, 1, -1), (-1, -1, -1),
            (1, -1, 1), (1, 1, 1), (-1, -1, 1), (-1, 1, 1)
        ]
        edges = [
            (0,1), (0,3), (0,4), (2,1), (2,3), (2,7),
            (6,3), (6,4), (6,7), (5,1), (5,4), (5,7)
        ]
        
        glBegin(GL_LINES)
        for edge in edges:
            for vertex in edge:
                glVertex3fv(vertices[vertex])
        glEnd()
        
        glPopMatrix()

    def draw_enclosed_cube(self):
        glColor3f(1, 1, 1)  # White color for the enclosing cube
        glPushMatrix()
        glScale(GRID_SIZE/2, GRID_SIZE/2, GRID_SIZE/2)
        glutWireCube(2)
        glPopMatrix()

    def draw_scene(self, env):
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
        glRotatef(1, 3, 1, 1)

        self.draw_enclosed_cube()

        # Draw snakes
        for i, snake in enumerate(env.snakes):
            for segment in snake:
                self.draw_cube(segment, self.colors[i])

        # Draw food
        self.draw_cube(env.food, (1, 1, 1))  # White color for food

        pygame.display.flip()



In [6]:
# Training loop
def train(episodes):
    env = Snake6D()
    state_dim = DIMENSIONS * (N_AGENTS + 1)  # State includes all snake heads and food position
    agents = [Agent(state_dim, N_ACTIONS) for _ in range(N_AGENTS)]
    visualizer = Visualizer()

    scores_history = [[] for _ in range(N_AGENTS)]
    
    for episode in range(episodes):
        state = env.reset()
        total_rewards = [0] * N_AGENTS
        done = [False] * N_AGENTS

        while not all(done):
            actions = [agents[i].select_action(state) for i in range(N_AGENTS)]
            next_state, rewards, done = env.step(actions)

            for i in range(N_AGENTS):
                agents[i].store_transition(state, actions[i], rewards[i], next_state)
                agents[i].update_model()
                total_rewards[i] += rewards[i]

            state = next_state
            visualizer.draw_scene(env)
        
        for i in range(N_AGENTS):
            scores_history[i].append(total_rewards[i])
            agents[i].update_epsilon()

        if episode % TARGET_UPDATE == 0:
            for agent in agents:
                agent.update_target_network()

        print(f"Episode {episode}: {total_rewards}")

    return agents, scores_history



In [7]:
# Testing loop
def test_agents(agents, episodes=10):
    env = Snake6D()
    visualizer = Visualizer()
    test_scores_history = [[] for _ in range(N_AGENTS)]

    for episode in range(episodes):
        state = env.reset()
        total_rewards = [0] * N_AGENTS
        done = [False] * N_AGENTS

        while not all(done):
            # No exploration, only use the learned policy
            actions = [agents[i].select_action(state) for i in range(N_AGENTS)]
            next_state, rewards, done = env.step(actions)

            for i in range(N_AGENTS):
                total_rewards[i] += rewards[i]

            state = next_state
            visualizer.draw_scene(env)

        for i in range(N_AGENTS):
            test_scores_history[i].append(total_rewards[i])

        print(f"Test Episode {episode}: {total_rewards}")

    return test_scores_history



In [8]:
# Plot the scores
def plot_scores(scores_history, test_scores_history=None):
    plt.figure(figsize=(10, 6))
    for i in range(N_AGENTS):
        plt.plot(scores_history[i], label=f"Agent {i+1} (Training)")
    if test_scores_history:
        for i in range(N_AGENTS):
            plt.plot(test_scores_history[i], label=f"Agent {i+1} (Testing)", linestyle='--')
    plt.xlabel("Episode")
    plt.ylabel("Score")
    plt.title("6D Snake Game - Agent Scores Over Time")
    plt.legend()

    # Adjust plot limits to zoom in on the relevant score range
    plt.ylim([-20, 60])  # Adjust based on observed scores

    plt.savefig("training_results.png")
    plt.show()

# Train the agents
episodes = 1000
agents, scores_history = train(episodes)

# Test the agents after training
test_scores_history = test_agents(agents, episodes=10)

# Plot both training and testing scores
plot_scores(scores_history, test_scores_history)


2024-12-04 19:19:47.642 Python[74598:11137409] +[IMKClient subclass]: chose IMKClient_Modern
2024-12-04 19:19:47.642 Python[74598:11137409] +[IMKInputSession subclass]: chose IMKInputSession_Modern
  states = torch.tensor(states, device=self.device)
