<a href="https://colab.research.google.com/github/nomomon/drl-js/blob/main/snake/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np

In [None]:
# -1 - wall
#  1 - snake
#  2 - apple

def createBoard(width = 12, height = 12):
    board = np.zeros((width, height))
    board[0, :]  = -1
    board[:, 0]  = -1
    board[:, -1] = -1
    board[-1, :] = -1

    return board

def initBoard(board):
    # place apple
    empty = np.stack(np.where(board == 0))
    newPointIndex = np.random.randint(empty.shape[0], size = 1)[0]
    newPoint = empty[newPointIndex]
    board[newPoint[0]][newPoint[1]] = 2
    
    # place snake
    empty = np.stack(np.where(board == 0))
    newPointIndex = np.random.randint(empty.shape[0], size = 1)[0]
    newPoint = empty[newPointIndex]
    board[newPoint[0]][newPoint[1]] = 1

    snake = [newPoint]

    return board, snake

def showBoard(board):
    width, height = board.shape
    for i in range(width):
        row = ""
        for j in range(height):
            if board[i][j] == -1:
                row += "■ "
            elif board[i][j] == 0:
                row += "  "
            elif board[i][j] == 1:
                row += "▢ "
            elif board[i][j] == 2:
                row += " "
        print(row)

def play(board, snake, action):
    # 0 - up
    # 1 - right
    # 2 - down
    # 3 - left

    # snake[0]  - head
    # snake[-1] - tail

    if(action == 0):
        nextState = (snake[0][0] + 1, snake[0][1])
    elif(action == 1):
        nextState = (snake[0][0], snake[0][1] + 1)
    elif(action == 2):
        nextState = (snake[0][0] - 1, snake[0][1])
    elif(action == 3):
        nextState = (snake[0][0], snake[0][1] - 1)
    
    # hit a wall or it self
    if(board[nextState[0]][nextState[1]] == -1 or 
       board[nextState[0]][nextState[1]] == 1):
        return board, snake, 0
    
    # ate the apple
    if(board[nextState[0]][nextState[1]] == 2):
        # update the snake
        board[nextState[0]][nextState[1]] = 1
        snake = np.concatenate(nextState, snake)

        # new apple
        empty = np.stack(np.where(board == 0))
        newPointIndex = np.random.randint(empty.shape[0], size = 1)[0]
        newPoint = empty[newPointIndex]
        board[newPoint[0]][newPoint[1]] = 2

        return board, snake, 1
    
    # nothing happened
    if(board[nextState[0]][nextState[1]] == 0):
        board[nextState[0]][nextState[1]] = 1
        snake = np.concatenate(nextState, snake)[:-1]

        return board, snake, 1
                

In [None]:
showBoard(initBoard(createBoard())[0])

■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ 
■ ▢                  ■ 
■                     ■ 
■                     ■ 
■                     ■ 
■                     ■ 
■                     ■ 
■                     ■ 
■                     ■ 
■                     ■ 
■                     ■ 
■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ 


In [6]:
class snakeEnvironment:
    def __init__(self, boardWidth = 20, boardHeight = 20):
        self.width = boardWidth
        self.height = boardHeight

        self.stepsLeft = self.width * self.height

        self.board = np.zeros((self.width, self.height))
        self.snake = []
        self.snakeDirection = 0
        self.dead = 0

        self.initBoard()

    def initBoard(self):
        # place apple
        empty = np.stack(np.where(self.board == 0))
        emptyPointIndex = np.random.randint(empty.shape[0], size = 1)[0]
        emptyPoint = empty[emptyPointIndex]
        self.apple = emptyPoint
        self.board[emptyPoint[0]][emptyPoint[1]] = 2
        
        # place snake
        empty = np.stack(np.where(self.board == 0))
        emptyPointIndex = np.random.randint(empty.shape[0], size = 1)[0]
        emptyPoint = empty[emptyPointIndex]
        self.board[emptyPoint[0]][emptyPoint[1]] = 1

        self.snakeHead = emptyPoint
        self.snake = [emptyPoint]

    def getObservations(self):

        apple_x, apple_y = self.apple
        snake_x, snake_y = self.snakeHead

        danger_straight = 0
        danger_right = 0
        danger_left =  0

        for action, danger in enumerate([danger_left, danger_straight, danger_right]):
            newSnakeDirection = (4 + self.snakeDirection + self.getActions()[action]) % 4

            if(newSnakeDirection == 0):
                nextState = (self.snake[0][0] + 1, self.snake[0][1])    # up
            elif(newSnakeDirection == 1):
                nextState = (self.snake[0][0], self.snake[0][1] + 1)    # right
            elif(newSnakeDirection == 2):
                nextState = (self.snake[0][0] - 1, self.snake[0][1])    # down
            elif(newSnakeDirection == 3):
                nextState = (self.snake[0][0], self.snake[0][1] - 1)    # left

            if(self.board(nextState) not in [0, 2]):
                danger = 1

        moving_up =    (self.snakeDirection == 0) + 0
        moving_right = (self.snakeDirection == 1) + 0
        moving_left =  (self.snakeDirection == 3) + 0
        moving_down =  (self.snakeDirection == 3) + 0
        food_left =    (snake_x > apple_x) + 0
        food_right =   (snake_x < apple_x) + 0
        food_up =      (snake_y > apple_y) + 0
        food_down =    (snake_y < apple_y) + 0

        return [danger_straight, danger_right, danger_left, moving_left, moving_right, moving_up, moving_down, food_left, food_right, food_up, food_down]
    
    def getActions(self):
        return [-1, 0, 1]

    def isDone(self) -> bool:
        return (self.stepsLeft == 0) or self.dead

    def executeAction(self, action):
        if(self.isDone()):
            raise Exception("Game is over, however tried to execute an action")
        
        self.stepsLeft -= 1

        self.rewardForAction = 0
        self.dead = 0

        # -1 - turn left
        #  0 - continue same direction
        #  1 - turn right

        # snake[0]  - head
        # snake[-1] - tail

        newSnakeDirection = (4 + self.snakeDirection + self.getActions()[action]) % 4

        if(newSnakeDirection == 0):
            nextState = (self.snake[0][0] + 1, self.snake[0][1])    # up
        elif(newSnakeDirection == 1):
            nextState = (self.snake[0][0], self.snake[0][1] + 1)    # right
        elif(newSnakeDirection == 2):
            nextState = (self.snake[0][0] - 1, self.snake[0][1])    # down
        elif(newSnakeDirection == 3):
            nextState = (self.snake[0][0], self.snake[0][1] - 1)    # left
        
        self.snakeDirection = newSnakeDirection

        # hit a wall or it self
        if(self.board[nextState[0]][nextState[1]] == -1 or 
           self.board[nextState[0]][nextState[1]] == 1):

            self.dead = 1
            self.rewardForAction = -1
        
        # ate the apple
        if(self.board[nextState[0]][nextState[1]] == 2):
            # update the snake
            self.snakeHead = nextState
            self.board[nextState[0]][nextState[1]] = 1
            self.snake = np.concatenate(nextState, self.snake)

            # new apple
            empty = np.stack(np.where(self.board == 0))
            newPointIndex = np.random.randint(empty.shape[0], size = 1)[0]
            newPoint = empty[newPointIndex]
            self.apple = newPoint
            self.board[newPoint[0]][newPoint[1]] = 2

            self.rewardForAction = 1
        
        # nothing happened
        if(self.board[nextState[0]][nextState[1]] == 0):
            self.snakeHead = nextState
            self.board[nextState[0]][nextState[1]] = 1
            self.snake = np.concatenate(nextState, self.snake)[:-1]

    def step(self, action):
        self.executeAction(action)

        return self.getObservations(), self.rewardForAction, self.isDone()