In [22]:
# import pygame
import random
from enum import Enum
from collections import namedtuple
import numpy as np

# pygame.init()
# font = pygame.font.Font('arial.ttf', 25)

class Direction(Enum):
    RIGHT = 1
    LEFT = 2
    UP = 3
    DOWN = 4

Point = namedtuple('Point', 'x, y')

# game graphic
HEAD = "@"
WALL = "#"
BODY = "O"
FOOD = "*"

class SnakeGame:
    def __init__(self, width, height):
        self.map = np.zeros((width, height))
        self.width = width
        self.height = height
        # init display
        # self.display = pygame.display.set_mode((self.w, self.h))
        # pygame.display.set_caption('Snake')
        # self.clock = pygame.time.Clock()
        self.reset()

    def reset(self):
        self.map[:, :] = 0
        self.direction = Direction.RIGHT
        self.head = Point(self.width // 2, self.height // 2)
        self.snake = [
            self.head,
            Point(self.head.x - 1, self.head.y),
            Point(self.head.x - 2, self.head.y)
        ]
        self.map[self.head.x, self.head.y] = 1
        for body in self.snake[1:]:
            self.map[body.x, body.y] = 2
        self.score = 0
        self.food = None
        self._place_food()
        self.frame_iteration = 0

    def _place_food(self):
        x = random.randint(0, self.width - 1) 
        y = random.randint(0, self.height - 1)
        self.food = Point(x, y)
        if self.food in self.snake:
            self._place_food()
        else:
            self.map[self.food.x, self.food.y] = 3

    def play_step(self, action):
        self.frame_iteration += 1

        # 2. move
        self._move(action) # update the head
        self.snake.insert(0, self.head)

        # 3. check if game over
        reward = 0
        game_over = False
        if self.is_collision() or self.frame_iteration > 100*len(self.snake):
            game_over = True
            reward = -10
            return reward, game_over, self.score
        
        self.map[self.head.x, self.head.y] = 1
        
        # 4. place new food or just move
        if self.head == self.food:
            self.score += 1
            reward = 10
            self._place_food()
        else:
            tail = self.snake.pop()
            self.map[tail.x, tail.y] = 0

        # 6. return game over and score
        return reward, game_over, self.score

    def is_collision(self, pt=None):
        if pt is None:
            pt = self.head
        # hits boundary
        if pt.x >= self.width or pt.x < 0 or pt.y >= self.height or pt.y < 0:
            return True
        # hits itself
        if pt in self.snake[1:]:
            return True

        return False
    
    def render(self):
        for h in range(self.height + 2):
            for w in range(self.width + 2):
                p = Point(w - 1, h - 1)
                if w == 0 or w == self.width + 1 or h == 0 or h == self.height + 1:
                    print(WALL, end = "\n" if w == self.width + 1 else "")
                elif self.map[p.x, p.y] == 1:
                    print(HEAD, end="")
                elif self.map[p.x, p.y] == 2:
                    print(BODY, end="")
                elif self.map[p.x, p.y] == 3:
                    print(FOOD, end="")
                else:
                    print(" ", end="")

    def _move(self, action):
        # [straight, right, left]

        clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
        idx = clock_wise.index(self.direction)

        if np.array_equal(action, [1, 0, 0]):
            new_dir = clock_wise[idx] # no change
        elif np.array_equal(action, [0, 1, 0]):
            next_idx = (idx + 1) % 4
            new_dir = clock_wise[next_idx] # right turn r -> d -> l -> u
        else: # [0, 0, 1]
            next_idx = (idx - 1) % 4
            new_dir = clock_wise[next_idx] # left turn r -> u -> l -> d

        self.direction = new_dir

        x = self.head.x
        y = self.head.y
        self.map[x, y] = 2
        if self.direction == Direction.RIGHT:
            x += 1
        elif self.direction == Direction.LEFT:
            x -= 1
        elif self.direction == Direction.DOWN:
            y += 1
        elif self.direction == Direction.UP:
            y -= 1

        self.head = Point(x, y)


In [23]:
import torch
from collections import deque
from model import Linear_QNet, QTrainer
from helper import plot

MAX_MEMORY = 100_000
BATCH_SIZE = 1000
LR = 0.001

class Agent:

    def __init__(self, game):
        self.n_games = 0
        self.epsilon = 0 # randomness
        self.gamma = 0.9 # discount rate
        self.memory = deque(maxlen=MAX_MEMORY) # popleft()
        self.model = Linear_QNet(game.width * game.height, 256, 3)
        self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)

    def get_state(self, game):
        return game.map.flatten()
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done)) # popleft if MAX_MEMORY is reached

    def train_long_memory(self):
        if len(self.memory) > BATCH_SIZE:
            mini_sample = random.sample(self.memory, BATCH_SIZE) # list of tuples
        else:
            mini_sample = self.memory

        states, actions, rewards, next_states, dones = zip(*mini_sample)
        self.trainer.train_step(states, actions, rewards, next_states, dones)
        #for state, action, reward, nexrt_state, done in mini_sample:
        #    self.trainer.train_step(state, action, reward, next_state, done)

    def train_short_memory(self, state, action, reward, next_state, done):
        self.trainer.train_step(state, action, reward, next_state, done)

    def get_action(self, state):
        # random moves: tradeoff exploration / exploitation
        self.epsilon = 80 - self.n_games
        final_move = [0,0,0]
        if random.randint(0, 200) < self.epsilon:
            move = random.randint(0, 2)
            final_move[move] = 1
        else:
            state0 = torch.tensor(state, dtype=torch.float)
            prediction = self.model(state0)
            move = torch.argmax(prediction).item()
            final_move[move] = 1

        return final_move

In [3]:
import time
from IPython.display import clear_output

def train():
    plot_scores = []
    plot_mean_scores = []
    total_score = 0
    record = 0
    game = SnakeGame(16, 9)
    agent = Agent(game)
    while True:
        # get old state
        state_old = agent.get_state(game)

        # get move
        final_move = agent.get_action(state_old)

        # perform move and get new state
        reward, done, score = game.play_step(final_move)

        # clear_output(wait=True)
        # game.render()
        # time.sleep(0.01)

        state_new = agent.get_state(game)

        # train short memory
        agent.train_short_memory(state_old, final_move, reward, state_new, done)

        # remember
        agent.remember(state_old, final_move, reward, state_new, done)

        if done:
            # train long memory, plot result
            game.reset()
            agent.n_games += 1
            agent.train_long_memory()

            if score > record:
                record = score
                agent.model.save()

            print('Game', agent.n_games, 'Score', score, 'Record:', record)

            plot_scores.append(score)
            total_score += score
            # mean_score = total_score / agent.n_games
            mean_score = average_of_last_n_items(plot_scores, 20)
            plot_mean_scores.append(mean_score)
            # plot(plot_scores, plot_mean_scores)
        

def average_of_last_n_items(lst, n):
    # 边界情况：当n为0或负数时，返回None
    if n <= 0:
        return None
    
    # 边界情况：当列表为空时，返回None
    if not lst:
        return None

    # 如果n大于列表的长度，使用整个列表
    n = min(n, len(lst))
    
    # 使用切片获取末尾n项，并计算平均值
    return sum(lst[-n:]) / n

In [26]:
import pygame
import time
WHITE = (255, 255, 255)
RED = (200,0,0)
BLUE1 = (0, 0, 255)
BLUE2 = (0, 100, 255)
BLACK = (0,0,0)

BLOCK_SIZE = 20

plot_scores = []
plot_mean_scores = []
total_score = 0
record = 0
game = SnakeGame(16, 9)
agent = Agent(game)

pygame.init()
font = pygame.font.Font('arial.ttf', 25)
display = pygame.display.set_mode((game.width * BLOCK_SIZE, game.height * BLOCK_SIZE))
pygame.display.set_caption('Snake')



def update_ui(game):
    display.fill(BLACK)

    for pt in game.snake:
        pygame.draw.rect(display, BLUE1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))
        pygame.draw.rect(display, BLUE2, pygame.Rect(pt.x+4, pt.y+4, 12, 12))

    pygame.draw.rect(display, RED, pygame.Rect(game.food.x, game.food.y, BLOCK_SIZE, BLOCK_SIZE))

    text = font.render("Score: " + str(game.score), True, WHITE)
    display.blit(text, [0, 0])
    pygame.display.flip()

def average_of_last_n_items(lst, n):
    # 边界情况：当n为0或负数时，返回None
    if n <= 0:
        return None
    
    # 边界情况：当列表为空时，返回None
    if not lst:
        return None

    # 如果n大于列表的长度，使用整个列表
    n = min(n, len(lst))
    
    # 使用切片获取末尾n项，并计算平均值
    return sum(lst[-n:]) / n

while True:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            pygame.quit()

    update_ui(game)
    time.sleep(1)
            
    # get old state
    state_old = agent.get_state(game)

    # get move
    final_move = agent.get_action(state_old)

    # perform move and get new state
    reward, done, score = game.play_step(final_move)

    # clear_output(wait=True)
    # game.render()
    # time.sleep(0.01)

    state_new = agent.get_state(game)

    # train short memory
    agent.train_short_memory(state_old, final_move, reward, state_new, done)

    # remember
    agent.remember(state_old, final_move, reward, state_new, done)

    if done:
        # train long memory, plot result
        game.reset()
        agent.n_games += 1
        agent.train_long_memory()

        if score > record:
            record = score
            agent.model.save()

        print('Game', agent.n_games, 'Score', score, 'Record:', record)

        plot_scores.append(score)
        total_score += score
        # mean_score = total_score / agent.n_games
        mean_score = average_of_last_n_items(plot_scores, 20)
        plot_mean_scores.append(mean_score)
        plot(plot_scores, plot_mean_scores)

error: display Surface quit