All necessary imports

In [7]:
import numpy as np
import torch
import torch.nn as nn
import pygame
import sys
import random
import cv2
from collections import deque
import time
import matplotlib.pyplot as plt
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Writing down the game, specifying the entities that will take place and defining the reward system

In [4]:
player_moves = {
    'L': np.array([-1.,0.]),
    'R': np.array([1.,0.]),
    'U': np.array([0.,-1.]),
    'D': np.array([0.,1.])
}
initial_playersize = 4

class snakeclass(object):
    def __init__(self, gridsize):
        self.pos = np.array([gridsize//2, gridsize//2]).astype('float')
        self.dir = np.array([1.,0.])
        self.len = initial_playersize
        self.prevpos = [np.array([gridsize//2, gridsize//2]).astype('float')]
        self.gridsize = gridsize

    def move(self):
        self.pos += self.dir
        self.prevpos.append(self.pos.copy())
        self.prevpos = self.prevpos[-self.len-1:]

    def checkdead(self, pos):
        if pos[0] <= -1 or pos[0] >= self.gridsize:
            return True
        elif pos[1] <= -1 or pos[1] >= self.gridsize:
            return True
        elif list(pos) in [list(item) for item in self.prevpos[:-1]]:
            return True
        else:
            return False

    def getproximity(self):
        L = self.pos - np.array([1,0])
        R = self.pos + np.array([1,0])
        U = self.pos - np.array([0,1])
        D = self.pos + np.array([0,1])
        possdirections = [L, R, U, D]
        proximity = [int(self.checkdead(x)) for x in possdirections]
        return proximity

    def __len__(self):
        return self.len + 1

class appleclass(object):
    def __init__(self, gridsize):
        self.pos = np.random.randint(1,gridsize,2)
        self.score = 0
        self.gridsize = gridsize

    def eaten(self):
        self.pos = np.random.randint(1,self.gridsize,2)
        self.score += 1

class GameEnvironment(object):
    def __init__(self, gridsize, nothing, dead, apple):
        self.snake = snakeclass(gridsize)
        self.apple = appleclass(gridsize)
        self.game_over = False
        self.gridsize = gridsize
        self.reward_nothing = nothing
        self.reward_dead = dead
        self.reward_apple = apple
        self.time_since_apple = 0

    def resetgame(self):
        self.apple.pos = np.random.randint(1, self.gridsize, 2).astype('float')
        self.apple.score = 0
        self.snake.pos = np.random.randint(1, self.gridsize, 2).astype('float')
        self.snake.prevpos = [self.snake.pos.copy().astype('float')]
        self.snake.len = initial_playersize
        self.game_over = False

    def get_boardstate(self):
        return [self.snake.pos, self.snake.dir, self.snake.prevpos, self.apple.pos, self.apple.score, self.game_over]

    def update_boardstate(self, move):
        reward = self.reward_nothing
        Done = False

        if move == 0:
            if not (self.snake.dir == player_moves['R']).all():
                self.snake.dir = player_moves['L']
        if move == 1:
            if not (self.snake.dir == player_moves['L']).all():
                self.snake.dir = player_moves['R']
        if move == 2:
            if not (self.snake.dir == player_moves['D']).all():
                self.snake.dir = player_moves['U']
        if move == 3:
            if not (self.snake.dir == player_moves['U']).all():
                self.snake.dir = player_moves['D']

        self.snake.move()
        self.time_since_apple += 1

        if self.time_since_apple == 100:
            self.game_over = True
            reward = self.reward_dead
            self.time_since_apple = 0


        if self.snake.checkdead(self.snake.pos) == True:
            self.game_over = True
            reward = self.reward_dead
            self.time_since_apple = 0
            Done = True

        elif (self.snake.pos == self.apple.pos).all():
            self.apple.eaten()
            self.snake.len += 1
            self.time_since_apple = 0
            reward = self.reward_apple

        len_of_snake = len(self.snake)

        return reward, Done, len_of_snake

Defining the model: a feedforward network used as a Q-network.
The input tensor is composed of:

*   apple.pos: Position of the apple.
*   player.dir: Direction in which the snake is facing.

*   proximity: Proximity pixels around the snake.
*   player.pos: Position of the snake.







In [5]:
class QNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        l1 = self.relu(self.fc1(x.float()))
        l2 = self.relu(self.fc2(l1))
        l3 = self.relu(self.fc3(l2))
        l4 = self.fc4(l3)
        return l4

def get_network_input(player, apple):
    proximity = player.getproximity()
    x = torch.cat([torch.from_numpy(player.pos).double(), torch.from_numpy(apple.pos).double(),
                   torch.from_numpy(player.dir).double(), torch.tensor(proximity).double()])
    return x

Storing and managing transitions

In [6]:
class ReplayMemory(object):
    def __init__(self, max_size):
        self.max_size = max_size
        self.buffer = []

    def push(self, state, action, reward, next_state, done):
        experience = (state, action, reward, next_state, done)
        self.buffer.append(experience)

    def sample(self, batch_size):
        state_batch = []
        action_batch = []
        reward_batch = []
        next_state_batch = []
        done_batch = []

        batch = random.sample(self.buffer, batch_size)

        for experience in batch:
            state, action, reward, next_state, done = experience
            state_batch.append(state)
            action_batch.append(action)
            reward_batch.append(reward)
            next_state_batch.append(next_state)
            done_batch.append(done)

        return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)

    def truncate(self):
        self.buffer = self.buffer[-self.max_size:]

    def __len__(self):
        return len(self.buffer)

Model training

In [None]:
save_path = '/content/drive/MyDrive/checkpoints/'

model = QNetwork(input_dim=10, hidden_dim=20, output_dim=5)
epsilon = 0.1
gridsize = 15
GAMMA = 0.9

board = GameEnvironment(gridsize, nothing=0, dead=-1, apple=1)
memory = ReplayMemory(1000)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-5)

def run_episode(num_games):
    run = True
    move=0
    games_played = 0
    total_reward = 0
    episode_games = 0
    len_array = []

    while run:
        state = get_network_input(board.snake, board.apple)
        action_0 = model(state)
        rand = np.random.uniform(0,1)
        if rand > epsilon:
            action = torch.argmax(action_0)
        else:
            action = np.random.randint(0,5)

        reward, done, len_of_snake = board.update_boardstate(action)
        next_state = get_network_input(board.snake, board.apple)

        memory.push(state, action, reward, next_state, done)

        total_reward += reward

        episode_games += 1

        if board.game_over == True:
            games_played += 1
            len_array.append(len_of_snake)
            board.resetgame()

            if num_games == games_played:
                run = False

    avg_len_of_snake = np.mean(len_array)
    max_len_of_snake = np.max(len_array)
    return total_reward, avg_len_of_snake, max_len_of_snake
MSE = nn.MSELoss()
def learn(num_updates, batch_size):

    total_loss = 0

    for i in range(num_updates):

        optimizer.zero_grad()
        sample = memory.sample(batch_size)

        states, actions, rewards, next_states, dones = sample
        states = torch.cat([x.unsqueeze(0) for x in states], dim=0)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.cat([x.unsqueeze(0) for x in next_states])
        dones = torch.FloatTensor(dones)

        q_local = model.forward(states)
        next_q_value = model.forward(next_states)

        Q_expected  = q_local.gather(1, actions.unsqueeze(0).transpose(0,1)).transpose(0,1).squeeze(0)

        Q_targets_next  = torch.max(next_q_value, 1)[0]*(torch.ones(dones.size()) - dones)

        Q_targets  = rewards + GAMMA * Q_targets_next

        loss = MSE(Q_expected, Q_targets)

        total_loss += loss
        loss.backward()
        optimizer.step()

    return total_loss
num_episodes = 60000
num_updates = 500
print_every = 10
games_in_episode = 30
batch_size = 20


def train():

    scores_deque = deque(maxlen=100)
    scores_array = []
    avg_scores_array = []

    avg_len_array = []
    avg_max_len_array = []

    time_start = time.time()


    for i_episode in range(num_episodes+1):

        score, avg_len, max_len = run_episode(games_in_episode)

        scores_deque.append(score)
        scores_array.append(score)
        avg_len_array.append(avg_len)
        avg_max_len_array.append(max_len)


        avg_score = np.mean(scores_deque)
        avg_scores_array.append(avg_score)

        total_loss = learn(num_updates, batch_size)

        dt = (int)(time.time() - time_start)

        if i_episode % print_every == 0 and i_episode > 0:
            print('Ep.: {:6}, Loss: {:.3f}, Avg.Score: {:.2f}, Avg.LenOfSnake: {:.2f}, Max.LenOfSnake:  {:.2f} Time: {:02}:{:02}:{:02} '.\
                  format(i_episode, total_loss, score, avg_len, max_len, dt//3600, dt%3600//60, dt%60))

        memory.truncate()

        if i_episode % 250 == 0 and i_episode > 0:
            torch.save(model.state_dict(), save_path + 'Snake_{}'.format(i_episode))


    return scores_array, avg_scores_array, avg_len_array, avg_max_len_array

scores, avg_scores, avg_len_of_snake, max_len_of_snake = train()

Ep.:     10, Loss: 30.767, Avg.Score: -29.00, Avg.LenOfSnake: 5.03, Max.LenOfSnake:  6.00 Time: 00:00:18 
Ep.:     20, Loss: 11.130, Avg.Score: -27.00, Avg.LenOfSnake: 5.10, Max.LenOfSnake:  6.00 Time: 00:00:34 
Ep.:     30, Loss: 12.436, Avg.Score: -27.00, Avg.LenOfSnake: 5.10, Max.LenOfSnake:  6.00 Time: 00:00:52 
Ep.:     40, Loss: 7.188, Avg.Score: -30.00, Avg.LenOfSnake: 5.00, Max.LenOfSnake:  5.00 Time: 00:01:08 
Ep.:     50, Loss: 5.708, Avg.Score: -29.00, Avg.LenOfSnake: 5.03, Max.LenOfSnake:  6.00 Time: 00:01:25 
Ep.:     60, Loss: 6.213, Avg.Score: -28.00, Avg.LenOfSnake: 5.07, Max.LenOfSnake:  6.00 Time: 00:01:42 
Ep.:     70, Loss: 7.099, Avg.Score: -28.00, Avg.LenOfSnake: 5.07, Max.LenOfSnake:  6.00 Time: 00:01:58 
Ep.:     80, Loss: 6.804, Avg.Score: -27.00, Avg.LenOfSnake: 5.10, Max.LenOfSnake:  7.00 Time: 00:02:16 
Ep.:     90, Loss: 5.390, Avg.Score: -28.00, Avg.LenOfSnake: 5.07, Max.LenOfSnake:  6.00 Time: 00:02:33 
Ep.:    100, Loss: 4.781, Avg.Score: -30.00, Avg.Len

Performance visualization

In [None]:
%matplotlib inline

print('length of scores: ', len(scores), ', len of avg_scores: ', len(avg_scores))

fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(np.arange(1, len(scores)+1), scores, label="Score")
plt.plot(np.arange(1, len(avg_scores)+1), avg_scores, label="Avg score on 100 episodes")
plt.legend(bbox_to_anchor=(1.05, 1))
plt.ylabel('Score')
plt.xlabel('Episodes #')
plt.show()
ax1 = fig.add_subplot(121)
plt.plot(np.arange(1, len(avg_len_of_snake)+1), avg_len_of_snake, label="Avg Len of Snake")
plt.plot(np.arange(1, len(max_len_of_snake)+1), max_len_of_snake, label="Max Len of Snake")
plt.legend(bbox_to_anchor=(1.05, 1))
plt.ylabel('Length of Snake')
plt.xlabel('Episodes #')
plt.show()
n, bins, patches = plt.hist(max_len_of_snake, 45, density=1, facecolor='green', alpha=0.75)
l = plt.plot(np.arange(1, len(bins) + 1), 'r--', linewidth=1)
mu = round(np.mean(max_len_of_snake), 2)
sigma = round(np.std(max_len_of_snake), 2)
median = round(np.median(max_len_of_snake), 2)
print('mu: ', mu, ', sigma: ', sigma, ', median: ', median)
plt.xlabel('Max.Lengths, mu = {:.2f}, sigma={:.2f},  median: {:.2f}'.format(mu, sigma, median))
plt.ylabel('Probability')
plt.title('Histogram of Max.Lengths')
plt.axis([4, 44, 0, 0.15])
plt.grid(True)

plt.show()

Defining the watch agent with a video visualization of the played game

In [None]:
gridsize = 23
framerate = 10
block_size = 30

snake_name = 'Snake_59000'

model = QNetwork(input_dim=10, hidden_dim=20, output_dim=5)
model.load_state_dict(torch.load('./dir_chk_lr0.00001/' + snake_name))

board = GameEnvironment(gridsize, nothing=0, dead=-1, apple=1)
windowwidth = gridsize*block_size*2
windowheight = gridsize*block_size

pygame.init()
win = pygame.display.set_mode((windowwidth, windowheight))
pygame.display.set_caption("snake")
font = pygame.font.SysFont('Helvetica', 14)
clock = pygame.time.Clock()


VIDEO = []

def drawboard(snake, apple):
    win.fill((0,0,0))
    for pos in snake.prevpos:
        pygame.draw.rect(win, (0,255,0), (pos[0]*block_size, pos[1]*block_size, block_size, block_size))
    pygame.draw.rect(win, (255, 0, 0), (apple.pos[0]*block_size, apple.pos[1]*block_size, block_size, block_size))

runGame = True

prev_len_of_snake = 0

while runGame:
    clock.tick(framerate)

    state_0 = get_network_input(board.snake, board.apple)
    state = model(state_0)

    action = torch.argmax(state)

    reward, done, len_of_snake = board.update_boardstate(action)
    drawboard(board.snake, board.apple)

    lensnaketext     = font.render('          LEN OF SNAKE: ' + str(len_of_snake), False, (255, 255, 255))
    prevlensnaketext = font.render('          LEN OF PREVIOUS SNAKE: ' + str(prev_len_of_snake), False, (255, 255, 255))

    x_pos= (int)(0.75*windowwidth)
    win.blit(lensnaketext, (x_pos, 40))
    win.blit(prevlensnaketext, (x_pos, 80))

    VIDEO.append(pygame.image.tostring(win, 'RGB', False))

    for event in pygame.event.get():
        if event.type==pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
            runGame = False

    keys = pygame.key.get_pressed()
    if keys[pygame.K_r]:
        paused = True
        while paused == True:
            clock.tick(10)
            pygame.event.pump()
            for event in pygame.event.get():
                if event.type == pygame.KEYDOWN:
                    paused = False

    pygame.display.update()

    if board.game_over == True:
        prev_len_of_snake = len_of_snake
        board.resetgame()

fourcc = cv2.VideoWriter_fourcc(*'MPV4')
output_name = 'output_' + snake_name + '.mp4'
video_mp4 = cv2.VideoWriter(output_name,fourcc, 20.0, (windowwidth,windowheight))

for image in VIDEO:

    image = np.frombuffer(image, np.uint8).reshape(windowheight, windowwidth, 3)
    image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
    video_mp4.write(image)

cv2.destroyAllWindows()
video_mp4.release()

pygame.quit()