<a href="https://colab.research.google.com/github/nomomon/drl-js/blob/main/snake/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import tensorflow as tf
import numpy as np

In [86]:
class snakeEnvironment:
    def __init__(self, boardWidth = 20, boardHeight = 20):
        self.width = boardWidth + 2
        self.height = boardHeight + 2

        self.reset()

    def reset(self):
        self.stepsLeft = self.width * self.height

        self.board = np.zeros((self.width, self.height))
        self.board[0, :]  = -1
        self.board[:, 0]  = -1
        self.board[:, -1] = -1
        self.board[-1, :] = -1
        
        self.snake = []
        self.snakeDirection = 0
        self.dead = 0

        self.initBoard()

    def initBoard(self):
        # place apple
        empty = np.stack(np.where(self.board == 0))
        emptyPointIndex = np.random.randint(empty.shape[0], size = 1)[0]
        emptyPoint = empty[emptyPointIndex]
        self.apple = emptyPoint
        self.board[emptyPoint[0]][emptyPoint[1]] = 2
        
        # place snake
        empty = np.stack(np.where(self.board == 0))
        emptyPointIndex = np.random.randint(empty.shape[0], size = 1)[0]
        emptyPoint = np.array(empty[emptyPointIndex])
        self.board[emptyPoint[0]][emptyPoint[1]] = 1

        self.snakeHead = emptyPoint
        self.snake = np.array([emptyPoint])

        print(self.snake)

    def getObservations(self):

        apple_x, apple_y = self.apple[0], self.apple[1]
        snake_x, snake_y = self.snakeHead[0], self.snakeHead[1]

        danger_straight = 0
        danger_right = 0
        danger_left =  0

        for action, danger in enumerate([danger_left, danger_straight, danger_right]):
            newSnakeDirection = (4 + self.snakeDirection + self.getActions()[action]) % 4

            if(newSnakeDirection == 0):
                nextState = (self.snake[0][0] + 1, self.snake[0][1])    # up
            elif(newSnakeDirection == 1):
                nextState = (self.snake[0][0], self.snake[0][1] + 1)    # right
            elif(newSnakeDirection == 2):
                nextState = (self.snake[0][0] - 1, self.snake[0][1])    # down
            elif(newSnakeDirection == 3):
                nextState = (self.snake[0][0], self.snake[0][1] - 1)    # left

            if(self.board[nextState[0]][nextState[1]] not in [0, 2]):
                danger = 1

        moving_up =    (self.snakeDirection == 0) + 0
        moving_right = (self.snakeDirection == 1) + 0
        moving_left =  (self.snakeDirection == 3) + 0
        moving_down =  (self.snakeDirection == 3) + 0
        food_left =    (snake_x > apple_x) + 0
        food_right =   (snake_x < apple_x) + 0
        food_up =      (snake_y > apple_y) + 0
        food_down =    (snake_y < apple_y) + 0

        return [danger_straight, danger_right, danger_left, moving_left, moving_right, moving_up, moving_down, food_left, food_right, food_up, food_down]
    
    def getActions(self):
        return [-1, 0, 1]

    def isDone(self) -> bool:
        return (self.stepsLeft == 0) or self.dead

    def executeAction(self, action):
        if(self.isDone()):
            raise Exception("Game is over, however tried to execute an action")
        
        self.stepsLeft -= 1

        self.rewardForAction = 0
        self.dead = 0

        # -1 - turn left
        #  0 - continue same direction
        #  1 - turn right

        # snake[0]  - head
        # snake[-1] - tail

        newSnakeDirection = (4 + self.snakeDirection + self.getActions()[action]) % 4

        if(newSnakeDirection == 0):
            nextState = (self.snake[0][0] + 1, self.snake[0][1])    # up
        elif(newSnakeDirection == 1):
            nextState = (self.snake[0][0], self.snake[0][1] + 1)    # right
        elif(newSnakeDirection == 2):
            nextState = (self.snake[0][0] - 1, self.snake[0][1])    # down
        elif(newSnakeDirection == 3):
            nextState = (self.snake[0][0], self.snake[0][1] - 1)    # left
        
        self.snakeDirection = newSnakeDirection

        # hit a wall or it self
        if(self.board[nextState[0]][nextState[1]] == -1 or 
           self.board[nextState[0]][nextState[1]] == 1):

            self.dead = 1
            self.rewardForAction = -1
        
        # ate the apple
        if(self.board[nextState[0]][nextState[1]] == 2):
            # update the snake
            self.snakeHead = np.array(nextState)
            self.board[nextState[0]][nextState[1]] = 1
            self.snake = np.concatenate((np.array(nextState, ndmin = 2), self.snake))

            # new apple
            empty = np.stack(np.where(self.board == 0))
            newPointIndex = np.random.randint(empty.shape[0], size = 1)[0]
            newPoint = empty[newPointIndex]
            self.apple = newPoint
            self.board[newPoint[0]][newPoint[1]] = 2

            self.rewardForAction = 1
        
        # nothing happened
        if(self.board[nextState[0]][nextState[1]] == 0):
            self.snakeHead = np.array(nextState)
            self.board[nextState[0]][nextState[1]] = 1

            self.snake = np.concatenate((np.array(nextState, ndmin = 2), self.snake))[:-1]

    def step(self, action):
        self.executeAction(action)

        return self.getObservations(), self.rewardForAction, self.isDone()

env = snakeEnvironment()

[[ 1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  2  2  2  2  2
   2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  3  3  3  3  3  3  3  3  3
   3  3  3  3  3  3  3  3  3  3  3  4  4  4  4  4  4  4  4  4  4  4  4  4
   4  4  4  4  4  4  4  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5
   5  5  5  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  7
   7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  8  8  8  8  8
   8  8  8  8  8  8  8  8  8  8  8  8  8  8  8  9  9  9  9  9  9  9  9  9
   9  9  9  9  9  9  9  9  9  9  9 10 10 10 10 10 10 10 10 10 10 10 10 10
  10 10 10 10 10 10 10 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
  11 11 11 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 12 13
  13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 14 14 14 14 14
  14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 15 15 15 15 15 15 15 15 15
  15 15 15 15 15 15 15 15 15 15 15 16 16 16 16 16 16 16 16 16 16 16 16 16
  16 16 16 16 16 16 16 17 17 17 17 17 

In [51]:
n_observations = len(env.getObservations())
print(f"Observation space is size {n_observations}")

Observation space is size 11


In [52]:
n_actions = len(env.getActions())
print(f"Action space is size {n_actions}")

Action space is size 3


# model


In [37]:
### Agent ###

def createPolicy():
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer((n_observations,)),
        tf.keras.layers.Dense(100, activation = "relu"),
        tf.keras.layers.Dense(100, activation = "relu"),
        tf.keras.layers.Dense(n_actions, activation = None)
    ])

    return model

In [33]:
def choose_action(model, observation, single=True):
    observation = np.expand_dims(observation, axis=0) if single else observation
    
    logits = model.predict(observation)
    action = tf.random.categorical(logits, num_samples=1)

    action = action.numpy().flatten()

    return action[0] if single else action

In [40]:
### Agent Memory ###

class Memory:
    def __init__(self): 
        self.clear()

    def clear(self): 
        self.observations = []
        self.actions = []
        self.rewards = []

    def add_to_memory(self, new_observation, new_action, new_reward): 
        self.observations.append(new_observation)
        self.actions.append(new_action)
        self.rewards.append(new_reward)

    def aggregate_memories(memories):
        batch_memory = Memory()
        
        for memory in memories:
            for step in zip(memory.observations, memory.actions, memory.rewards):
                batch_memory.add_to_memory(*step)

        return batch_memory

memory = Memory()

In [64]:
### Reward function ###

def normalize(x):
    x -= np.mean(x)
    x /= np.std(x)
    return x.astype(np.float32)

def discount_rewards(rewards, gamma=0.95): 
    discounted_rewards = np.zeros_like(rewards)
    R = 0
    for t in reversed(range(0, len(rewards))):
        R = R * gamma + rewards[t]
        discounted_rewards[t] = R

    print(discounted_rewards)
        
    return normalize(discounted_rewards)

In [34]:
### Loss function ###

def compute_loss(logits, actions, rewards):
  neg_logprob = tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits=logits, labels=actions)
  loss = tf.reduce_mean( neg_logprob * rewards )
  return loss

In [35]:
### Training step (forward and backpropagation) ###

def train_step(model, optimizer, observations, actions, discounted_rewards):
    with tf.GradientTape() as tape:
        logits = model(observations)

        loss = compute_loss(logits, actions, discounted_rewards)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

In [87]:
### Snake training! ###

learning_rate = 1e-3
optimizer = tf.keras.optimizers.Adam(learning_rate)

snake_model = createPolicy()

for i_episode in range(500):

    env.reset()
    observation = env.getObservations()
    memory.clear()

    while True:
        action = choose_action(snake_model, observation)
        next_observation, reward, done = env.step(action)
        
        memory.add_to_memory(observation, action, reward)
        
        if done:
            total_reward = sum(memory.rewards)
            
            train_step(snake_model, optimizer, 
                        observations=np.vstack(memory.observations),
                        actions=np.array(memory.actions),
                        discounted_rewards = discount_rewards(memory.rewards))
            
            memory.clear()
            break
        
        observation = next_observation

[[ 2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20  1  2  3  4  5
   6  7  8  9 10 11 12 13 14 15 16 17 18 19 20  1  2  3  4  5  6  7  8  9
  10 11 12 13 14 15 16 17 18 19 20  1  2  3  4  5  6  7  8  9 10 11 12 13
  14 15 16 17 18 19 20  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17
  18 19 20  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20  1
   2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20  1  2  3  4  5
   6  7  8  9 10 11 12 13 14 15 16 17 18 19 20  1  2  3  4  5  6  7  8  9
  10 11 12 13 14 15 16 17 18 19 20  1  2  3  4  5  6  7  8  9 10 11 12 13
  14 15 16 17 18 19 20  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17
  18 19 20  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20  1
   2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20  1  2  3  4  5
   6  7  8  9 10 11 12 13 14 15 16 17 18 19 20  1  2  3  4  5  6  7  8  9
  10 11 12 13 14 15 16 17 18 19 20  1  2  3  4  5  6  7  8  9 10 11 12 13
  14 15 16 17 18 19 20  1  2  3  4  5 

ValueError: ignored