In [1]:
from snake import Game
import random
import pygame
import math
import os

# torch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

pygame 2.5.2 (SDL 2.28.3, Python 3.11.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(n_observations, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [12]:
policy = DQN(28, 4)

In [13]:
policy = torch.load(os.path.join('policies', 'policy1.pth'))

In [14]:
policy.eval()

DQN(
  (fc1): Linear(in_features=28, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=4, bias=True)
)

In [15]:
def dist(x1, y1, x2, y2):
    return math.sqrt((x2 - x1)**2 + (y2 - y1)**2)

In [16]:
def observation(game):
    red = pygame.Color(255, 0, 0)
    orange = pygame.Color(255, 127, 0)
    yellow =pygame.Color(255, 255, 0)
    green = pygame.Color(0, 255, 0)
    blue = pygame.Color(0, 0, 255)
    indigo = pygame.Color(75, 0, 130)
    violet = pygame.Color(148, 0, 211)
    rainbow = [red, orange, yellow, green, blue, indigo, violet]

    # [x_snake, y_snake, up, down, left, right, 
    # distance to top wall, d to bot wall, d to left wall, d to r wall
    # closest seg up, c seg down, c seg left, c seg right
    # red fruit x, red fruit y, o fruit x, o fruit y, y fruit x, y fruit y
    # g fruit x, g fruit y, ...]

    obs_vector = torch.zeros(28)

    snake_body = game.snake_body

    # x and y snake head pos
    obs_vector[0] = game.snake_position[0]
    obs_vector[1] = game.snake_position[1]

    # up down left right
    if game.direction == 'UP':
        obs_vector[2] = 1
    if game.direction == 'DOWN':
        obs_vector[3] = 1
    if game.direction == 'LEFT':
        obs_vector[4] = 1
    if game.direction == 'RIGHT':
        obs_vector[5] = 1

    obs_vector[6] = game.snake_position[0] - 0
    obs_vector[7] = game.window_x - game.snake_position[0]
    obs_vector[8] = game.snake_position[1] - 0
    obs_vector[9] = game.window_y - game.snake_position[1]

    # closest seg up, c seg down, c seg left, c seg right
    gx, gy = game.snake_position

    up_candidates = [p for p in snake_body if p[0] == gx and p[1] > gy]
    down_candidates = [p for p in snake_body if p[0] == gx and p[1] < gy]
    left_candidates = [p for p in snake_body if p[1] == gy and p[0] < gx]
    right_candidates = [p for p in snake_body if p[1] == gy and p[0] > gx]

    # print(up_candidates, down_candidates, left_candidates, right_candidates)

    if len(up_candidates) > 0:
        closest_up = up_candidates[0]
        for pos in up_candidates:
            if dist(gx, gy, closest_up[0], closest_up[1]) < dist(gx, gy, pos[0], pos[1]):
                closest_up = pos
        obs_vector[10] = dist(gx, gy, closest_up[0], closest_up[1])

    if len(down_candidates) > 0:
        closest_down = down_candidates[0]
        for pos in down_candidates:
            if dist(gx, gy, closest_down[0], closest_down[1]) < dist(gx, gy, pos[0], pos[1]):
                closest_down = pos
        obs_vector[11] = dist(gx, gy, closest_down[0], closest_down[1])

    if len(left_candidates) > 0:
        closest_left = left_candidates[0]
        for pos in left_candidates:
            if dist(gx, gy, closest_left[0], closest_left[1]) < dist(gx, gy, pos[0], pos[1]):
                closest_left = pos
        obs_vector[12] = dist(gx, gy, closest_left[0], closest_left[1])

    if len(right_candidates) > 0:
        closest_right = right_candidates[0]
        for pos in right_candidates:
            if dist(gx, gy, closest_right[0], closest_right[1]) < dist(gx, gy, pos[0], pos[1]):
                closest_right = pos
        obs_vector[13] = dist(gx, gy, closest_right[0], closest_right[1])

    for pos, ind in zip(game.fruit_position, range(14, 28, 2)):
        obs_vector[ind] = pos[0]
        obs_vector[ind + 1] = pos[1]
    
    return obs_vector

In [25]:
x_window_size = 720
y_window_size = 480
x_axis_size = x_window_size // 10
y_axis_size = y_window_size //10
episodes_done = 0

game = Game(x_window_size, y_window_size, episodes_done)
obs = observation(game)

print(obs)

while True:

    # select action
    action = torch.argmax(policy(obs)).item()

    # game step
    terminated = game.step(action)

    if terminated:
        break

tensor([100., 240.,   0.,   0.,   0.,   1., 100., 620., 240., 240.,   0.,   0.,
         30.,   0., 490., 160., 700., 100., 350., 370.,  20., 350., 640.,  60.,
        390., 450.,  20., 170.])


In [22]:
[0] * 14

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]