In [None]:
%pip install gym
%pip install numpy

# A game of Snake

Playing snake in a command line friendly format.

In [2]:
from copy import deepcopy
from gym import Env, spaces
import numpy as np
import random
from IPython.display import clear_output

class  SnakeEnv(Env):
    APPLE = "A"
    HEAD = "H"
    BODY = "S"
    EMPTY = "O"
    
    UP = 'UP'
    DOWN = 'DOWN'
    RIGHT = 'RIGHT'
    LEFT = 'LEFT'
    
    def __init__(self, tiles:int):
        super(SnakeEnv, self).__init__()
        self.TILES = tiles

        self.actions = [SnakeEnv.UP, SnakeEnv.DOWN, SnakeEnv.LEFT, SnakeEnv.RIGHT]
        self.action_space = spaces.Discrete(len(self.actions))
        self.observation_space = spaces.Discrete(self.TILES * self.TILES)
        
        # Inititialze game
        self.reset()
        
    def reset(self):
        # Initialize empty grid
        self.grid = np.array([[SnakeEnv.EMPTY for x in range(self.TILES)] for y in range(self.TILES)])
        
        # Add snake head to the grid at random position
        start_y = random.randint(0 , self.TILES - 1)
        start_x = random.randint(0 , self.TILES - 1)
        self.grid[start_y][start_x] = SnakeEnv.HEAD
        self.snake = np.array([[start_y, start_x]])
        
        # Add apple at random position
        apple_y, apple_x = self.__generate_apple()
        self.grid[apple_y][apple_x] = SnakeEnv.APPLE
        self.apple = (apple_y, apple_x)
        
        # Initialize game variables & last_action memory
        self.score = 0
        self.terminal = False
        self.last_action = None
        
        return self.grid
        
    def step(self, action):
        # Move the snake
        self.__move(direction=action)
        
        observation = self.grid
        done = self.terminal
        self.score = len(self.snake)
        score = self.score
        info = {"direction": action}
        
        return observation, score, done, info
    
    def get_copy(self):
        # Returns a full copy of the game
        instance = SnakeEnv(self.TILES)
        instance.reset()
        instance.grid = deepcopy(self.grid)
        instance.snake = deepcopy(self.snake)
        instance.apple = self.apple
        instance.score = self.score
        instance.terminal = self.terminal
        instance.last_action = self.last_action
        return instance

    def __generate_apple(self):
        all_grid_positions = {(y,x) for x in range(0, self.TILES) for y in  range(0, self.TILES)}
        allowed_grid_positions = list(all_grid_positions - set([tuple(e) for e in self.snake]))
        if allowed_grid_positions:
            return random.choice(allowed_grid_positions)
        else:
            # In case the board is full of snake the game will end anyway
            return random.choice(list(all_grid_positions))
        
        
    def __move(self, direction):
        head_y, head_x = self.snake[-1]
        match direction:
            case SnakeEnv.UP:
                new_head = (head_y -1, head_x)

            case SnakeEnv.DOWN:
                new_head = (head_y + 1, head_x)

            case SnakeEnv.RIGHT:
                new_head = (head_y, head_x + 1)

            case SnakeEnv.LEFT:
                new_head = (head_y, head_x - 1)
                

        if self.__did_wall_crash(new_head):
            self.terminal = True
            return
        
        if self.__did_self_crash(new_head):
            self.terminal = True
            return
        
        self.snake = np.append(self.snake, [new_head], axis=0)
        self.grid[head_y][head_x] = SnakeEnv.BODY
        head_y, head_x = self.snake[-1]
        self.grid[head_y][head_x] = SnakeEnv.HEAD
        
        if self.__did_eat():
            apple_y, apple_x = self.__generate_apple()
            self.grid[apple_y][apple_x] = SnakeEnv.APPLE
            self.apple = (apple_y, apple_x)   
        else:
            tail_y, tail_x = self.snake[0]
            self.snake = self.snake[1:]
            self.grid[tail_y][tail_x] = SnakeEnv.EMPTY
        
        self.last_action = direction
    
    def __did_eat(self):
        eaten = False
        head_y, head_x = self.snake[-1]
        apple_y, apple_x = self.apple
        if (apple_y == head_y) & (apple_x == head_x):
            eaten = True
        return eaten
    
    def __did_wall_crash(self, new_head):
        head_y, head_x = new_head
        if head_y >= self.TILES:
            return True
        if head_y < 0:
            return True
        if head_x >= self.TILES:
            return True
        if head_x < 0:
            return True
        return False
    
    def __did_self_crash(self, new_head):
        head_y, head_x = new_head
        for body_y, body_x in self.snake:
            if (body_y == head_y) & (body_x == head_x):
                return True
        return False
        
    def render(self, clear:bool = True):
        if clear:
            clear_output(wait=True)
        grid = ""
        for x in range(self.TILES):
            for y in range(self.TILES):
                grid += str(self.grid[x][y]) + " "
            grid += "\n"
        print(grid)

In [3]:
# Testing it out

import time

game = SnakeEnv(5)
game.render()
while not game.terminal:
    action = random.choice(game.actions)
    game.step(action)
    game.render()
    print(action)
    time.sleep(1)

O O O O O 
O O O O O 
O O O O O 
A O O O O 
O H O O O 

DOWN


# The Monte Carlo Tree Search (MCTS) algorithm

Monte Carlo Tree Search (MCTS) is a search algorithm used in decision-making and planning.The search tree consits of nodes (snapshot of the current state of the environment) and edges (actions taken to get from one state/node to the next another).

MCTS consits of 4 phases:

1. **Selection**: Starting from the root node, MCTS traveses through the search tree selecting the best child of a node until it arrives at a leaf node.

2. **Expansion**: Once a leaf node (an unexplored or not fully expanded node) is reached, MCTS expands it by adding a child node representing possible actions or moves.

3. **Simulation**: To assess an inital value, MCTS performs simulations from the newly expanded node by randomly selecting actions until a terminal state is reached.

4. **Backpropagation**: The result of the simulation is backpropagated up the tree to update the statistics of the nodes in the path from the expanded node to the root.

**Making a decision**: Steps 1 to 4 are repeated for a specified number of iterations or until a time limit is reached.
Then MCTS decides on the best move by selecting the action leading to the child node with the highest estimated value.

The key formula of MCTS is the Upper Confidence Bound (UCB) which is used to decide which child to take during the **Selection** phase.
Here's the formula and its components:

**Upper Confidence Bound (UCB) Formula:**
   - The UCB score for a child node `i` of a parent node `j` is calculated as follows:

     ```
     UCB(i, j) = Value(i) + C * sqrt(ln(Visits(j)) / Visits(i))
     ```

   - `Value(i)`: The estimated value of node `i`, typically the average of simulation results.
   - `Visits(i)`: The number of times node `i` has been visited.
   - `Visits(j)`: The number of times the parent node `j` has been visited.
   - `C`: A constant

## Tasks:
- What does C do?
- Now with the formula and the algorithm, walk us through how this will lead to good decisions?
- Implement the algorithm

## Solution:

- C is used to control the exploration / exploitation tradeoff. The higher C, the more the algorithm will explore instead of continously choosing the highest value node.
- During selection current best option (according to UCB) is selected and expanded. The new node receives a value by simulating random actions until a terminal state (assesing the potential of the state). This result is backpropagatged to the current state. Therefore with each iteration the algorithm learns more about the expected values of each action that can be taken from the current state. Given that the number of iterations is choosen high enough, good decisions can be taken.

In [4]:
# TODO: Implement the Monte Carlo Tree Search algorithm
from __future__ import annotations
import math

class Node():
    def __init__(self, state: SnakeEnv, parent: Node):
        self.state = state
        self.is_terminal = state.terminal
        self.parent = parent
        self.depth = 0 if parent is None else parent.depth + 1
        
        self.num_visits = 0
        self.value = 0
        self.children = {}
        self.ucb = float("-inf")

class MCTS():
    def __init__(self, iterations:int = 1000, exploration_constant:float = math.sqrt(2), discount:float = 0.995, step_cost:float = 0.1):        
        self.iterations = iterations
        self.C = exploration_constant
        
        # Bonus: it is good practice to add a discount to diminish value of actions further out in the future
        self.discount = discount
        # Bonus: adding step cost will ensure that the algorithm does not prefer driving around surviving over risking to eat
        self.step_cost = step_cost

    def find_best_action(self, current_state: SnakeEnv) -> str:
        # Initialize root with current state
        root = Node(current_state, None)
        
        for _ in range(self.iterations):
            # Execute 4 steps of one iteration to update the tree
            selected_node = self.__selection(root)
            added_node = self.__expand(selected_node)
            reward = self.__simulation(added_node)
            self.__backpropogation(added_node, reward)

        # After iteration budget is used, select best action from root node
        action = max([(action, node.value) for action, node in root.children.items()],key=lambda x:x[1])[0]
        return action
        

    def __selection(self, node: Node) -> Node:
        search_node = node
        
        is_expandable = len(search_node.children) < len(search_node.state.actions)
        # if the selected node can be expanded we stop the search
        while not (search_node.is_terminal or is_expandable):
            best_ucb = float("-inf")
            # track multiple nodes in case of equal UCB values
            best_nodes = []
            for _, child in search_node.children.items():
                # calculate UCB for each child node
                child.ucb = child.value / child.num_visits + self.C * math.sqrt(2 * math.log(search_node.num_visits) / child.num_visits)
                
                if child.ucb > best_ucb:
                    best_ucb = child.ucb
                    best_nodes = [child]
                elif child.ucb == best_ucb:
                    best_nodes.append(child)
            
            # in case of multiple best nodes, randomly choose one
            search_node = random.choice(best_nodes)
            # check if node is expandable (has unexplored options)
            is_expandable = len(search_node.children) < len(search_node.state.actions)
        return search_node

    def __expand(self, node: Node) ->  Node:
        # terminal nodes can not be expanded
        if node.is_terminal:
            return node
        # randomly choose an unexplored action to expand the tree
        action = random.choice([a for a in node.state.actions if not a in node.children.keys()])
        # create new node and add to the tree
        state_copy = node.state.get_copy()
        state_copy.step(action)
        new_node = Node(state_copy, node)
        node.children[action] = new_node
        
        return new_node
            
    def __simulation(self, node: Node) -> float:
        # if state is terminal no new reward to backpropagate
        if node.is_terminal:
            return 0
        else:
            # copy environment to run simulations
            state = node.state.get_copy()
            reward = state.score
            steps = 0

            while not state.terminal:
                # Bonus: required to remember previous score to see if action led to an improvement
                old_score = state.score
                
                # randomly take an action
                action = random.choice(state.actions)
                state.step(action)
                
                # Bonus: punish reward for taking action to avoid pure survival over improving score
                step_reward = state.score - old_score - self.step_cost
                # Bonus: Discount rewards that our further out in the future
                reward += step_reward * (self.discount ** steps)
                
                # Simple solution (without bonus):
                # reward = state.score
                
                steps += 1

            return reward

    def __backpropogation(self, node: Node, reward: float) -> None:
        # backpropagate from node to root and update visits and value of node
        while node is not None:
            node.num_visits += 1
            node.value += reward
            node = node.parent

## Test the algorithm

With the set parameters it will be able to mostly play until no space for improvement is left.  
If the bonus was done, larger environemnts will work with larger iterations.

In [5]:
game = SnakeEnv(5)
searcher = MCTS(iterations=700, exploration_constant=math.sqrt(2), discount=0.995, step_cost=0.1)
i = 0

print("Game reset")
game.reset()
initial_state = game.grid
action_history = []
while not game.terminal:
    action = searcher.find_best_action(current_state=game.get_copy())
    action_history.append(action)
    game.step(action)
    game.render()
    if game.score == (game.TILES ** 2) - 2:
        break
time.sleep(1)
print(f"Final score: {game.score}")

S S S S S 
S S S S S 
S S S S S 
S S S H S 
S S S A O 

Final score: 23
