# SNAKE GAME

![SnakeGame](https://upload.wikimedia.org/wikipedia/en/0/04/Snake_trs-80.jpg)

The [Snake Game](https://en.wikipedia.org/wiki/Snake_(video_game_genre)) is a classic video game genre that has been popular since the 1970s. The player controls a snake that moves around the screen, eating food and growing longer while avoiding collisions with itself.

Try it out [here](https://snake.onl/).

In [None]:
#@title GAME CODE
from typing import Tuple
import random
import time
import copy
from collections import deque
from IPython.display import clear_output
import ipywidgets as widgets
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import ListedColormap
from IPython.display import HTML

try:
    from google.colab import output
    output.enable_custom_widget_manager()   # run once per session
except ImportError:
    # Not running in Google Colab, no need to enable custom widget manager
    pass

class SnakeMap:
    def __init__(self, n, head=None, body=None, fruit=None):
        self.n = n
        self.head = head
        self.body = body
        self.fruit = fruit

        if self.head is None:
            self.head = [0, 2]
        if self.body is None:
            self.body = [[0, 1], [0, 0]]
        if self.fruit is None:
            self.fruit = [self.n-1, self.n-1]

    def move_snake(self, direction: str, rng) -> Tuple[int, "SnakeMap"]:
        d = direction.upper()

        new_head = self.head.copy()
        new_body = self.body.copy()
        new_fruit = self.fruit.copy()
        result = 0

        # Move head (wrap-around)
        if d == "U":
            new_head[0] = (self.head[0] - 1) % self.n
        elif d == "D":
            new_head[0] = (self.head[0] + 1) % self.n
        elif d == "R":
            new_head[1] = (self.head[1] + 1) % self.n
        elif d == "L":
            new_head[1] = (self.head[1] - 1) % self.n
        else:
            raise Exception(f"{direction!r} is not valid (use U, D, L or R)")

        # Fruit?
        fruit_eaten = new_head == self.fruit
        if fruit_eaten:
            new_fruit = self.create_fruit(rng)
            result = 1

        # Update body
        new_body.insert(0, self.head)
        if not fruit_eaten:
            new_body.pop()

        # Collision
        if new_head in new_body:
            result = -1

        return result, SnakeMap(
            self.n,
            new_head,
            new_body,
            new_fruit
        )

    def _arrow_towards(self, src: list[int], dst: list[int]) -> str:
        r, c = src
        if (r - 1) % self.n == dst[0]:
            return "^"
        if (r + 1) % self.n == dst[0]:
            return "v"
        if (c - 1) % self.n == dst[1]:
            return "<"
        if (c + 1) % self.n == dst[1]:
            return ">"
        return "·"

    def to_array(self):
        grid = np.zeros((self.n, self.n), dtype=int)
        grid[self.head[0], self.head[1]] = 1
        for r, c in self.body:
            grid[r, c] = 2
        fr, fc = self.fruit
        grid[fr, fc] = 3
        return grid

    def _render(self) -> str:
        ngrid = self.to_array()
        # Convert to ASCII characters
        grid = [[" "] * self.n for _ in range(self.n)]
        for r in range(self.n):
            for c in range(self.n):
                if ngrid[r, c] == 1:
                    grid[r][c] = "O"
                elif ngrid[r, c] == 2:
                    grid[r][c] = "."
                elif ngrid[r, c] == 3:
                    grid[r][c] = "X"
                else:
                    grid[r][c] = " "

        # Body (arrows) – each segment points to the segment closer to head
        for idx, seg in enumerate(self.body):
            prev = self.head if idx == 0 else self.body[idx - 1]
            r, c = seg
            grid[r][c] = self._arrow_towards(seg, prev)

        # Assemble ASCII
        border = "+" + "-" * (2 * self.n) + "+"
        rows = [border]
        for row in grid:
            rows.append("|" + "".join(f"{ch} " for ch in row) + "|")
        rows.append(border)

        return "\n".join(rows)

    def __str__(self):
        return self._render()

    def create_fruit(self, rng):
        free = [
            (r, c)
            for r in range(self.n)
            for c in range(self.n)
            if [r, c] != self.head and [r, c] not in self.body
        ]
        if not free:
            raise Exception("Board is full! You win!")
        return list(rng.choice(free))



class SnakeFrame:

    def __init__(self, map: SnakeMap, step: int, score: int, movements: str = "", info: str = ""):
        self.map = map
        self.step = step
        self.score = score
        self.movements = movements
        self.info = info


class SnakeGame:
    def __init__(self, initial_map = None, seed = 42):
        self.initial_map = initial_map or SnakeMap(n=4, head=[0, 2], body=[[0, 1], [0, 0]], fruit=[3,3])
        self.seed = seed

    def str_score(self, score: int, step: int):
        """Print the current score and step count."""
        return f"Steps: {step}  Score: {score}"

    def str_separator(self, n=50):
        """Print a separator line."""
        return "=" * n

    # ------------------------------------------------------------------ #
    # Real-time loop (animated)                                          #
    # ------------------------------------------------------------------ #
    def game_loop(self, policy, max_steps: int = 50, delay: float = 0.5, history_size: int = 1):
        history = deque(maxlen=history_size)
        map = copy.deepcopy(self.initial_map)
        rng = random.Random(self.seed)
        score = 0
        step = 0

        try:
            while step < max_steps:
                move = policy(map)
                result, map = map.move_snake(move, rng)

                if result == 1:
                    score += 1

                elif result == -1:
                    print(self.str_separator())
                    print("Game over! Snake collided with itself.")
                    break

                step += 1

                clear_output(wait=True)
                print("\n\n".join(history))
                print(self.str_separator())

                render = map._render()
                print(move)
                print(render)
                print(self.str_score(score, step))

                history.append(render)

                time.sleep(delay)

            print(self.str_separator())
            print("Maximum steps achieved!")

        except Exception as exc:
            print(self.str_separator())
            print(exc)
            raise exc

        finally:
            print(self.str_score(score, step))



    # ------------------------------------------------------------------ #
    # Interactive slider loop                                            #
    # ------------------------------------------------------------------ #
    def interact_loop(self, policy, max_steps: int = 50):
        # ---------- simulation part (unchanged) ----------
        map_ = copy.deepcopy(self.initial_map)   # `map` is a Python builtin
        rng = random.Random(self.seed)
        score = 0
        step = 0
        movements = ""
        frames = [SnakeFrame(map_, step, score)]
        stop = False

        while step < max_steps:
            try:
                move = policy(map_)

                for l in move:
                    result, map_ = map_.move_snake(l, rng)
                    movements += l
                    step += 1

                    if result == 1:
                        score += 1


                    if result == -1:
                        frames.append(SnakeFrame(
                            map_, step, score, movements,
                            "Game over! Snake collided with itself."))
                        stop = True
                        break

                    frames.append(SnakeFrame(map_, step, score, movements))

                if stop:
                    break

                movements += " "

            except Exception as exc:
                frames.append(SnakeFrame(
                    map_, step, score, movements, f"{exc}"))
                break

        return frames


    def display_frames_ascii(frames, interval):

        slider = widgets.IntSlider(
            value=0, min=0, max=len(frames)-1, step=1,
            description="step", readout=True, layout=widgets.Layout(width="50%")
        )

        play = widgets.Play(
            value=0, min=0, max=len(frames)-1, step=1,
            interval=interval,              # milliseconds between frames
            description="Play",        # button text
            disabled=False
        )

        # keep the two widgets in sync
        widgets.jslink((play, 'value'), (slider, 'value'))

        # what happens when the index changes
        out = widgets.Output()

        def _show(change):
            with out:
                out.clear_output(wait=True)
                f = frames[change['new']]
                print(f"Steps: {f.step}  Score: {f.score}")
                print(f.movements)
                print(f.map._render())
                print(f.info)

        slider.observe(_show, names='value')      # respond to slider (and play)

        # display everything nicely
        controls = widgets.HBox([play, slider])
        display(widgets.VBox([controls, out]))



    def animate_frames(frames, interval=500):
        """
        Turn the recorded frames into a matplotlib.animation.FuncAnimation.
        Returns the Animation object so that Jupyter/Colab can display it.
        """
        # colour map: 0-empty, 1-head, 2-body, 3-fruit
        cmap = ListedColormap(["white",          # 0
                            "#1e90ff",        # 1 head: dodger blue
                            "#87cefa",        # 2 body: light sky-blue
                            "#63ff47"])       # 3 fruit: tomato

        fig, ax = plt.subplots(figsize=(4,4))
        im = ax.imshow(frames[0].map.to_array(), cmap=cmap,
                            vmin=0, vmax=3, interpolation="nearest")
        ax.set_xticks([]); ax.set_yticks([])

        def update(idx):
            f = frames[idx]

            mov = f.movements
            if len(mov) > 20:
                mov = "..." + mov[-30:]

            ax.set_title(f"Step {f.step}  Score {f.score}\n{mov}")
            ax.set_xlabel(f.info)

            im.set_data(f.map.to_array())
            return (im,)

        ani = animation.FuncAnimation(fig, update,
                                    frames=len(frames),
                                    interval=interval,
                                   blit=False, repeat=False)

        plt.close(fig)          # keep the raw figure from showing twice
        return ani

# CREATE YOUR OWN SNAKE STRATEGY

Your task is to create a Python function that plays the Snake Game for you.
Try to eat as many fruits as possible while avoiding eating yourself.

In the function `my_snake` you can implement your own snake strategy.
Receving a map state, you must decide between 4 possible actions:
- `U`P
- `D`OWN
- `L`EFT
- `R`IGHT

In [None]:
#@title MY STRATEGY
def my_snake(map):
    # IMPLEMENT YOUR CODE HERE
    return "U"  # Example: always move up

In [None]:
#@title PARAMETERS
MAX_STEPS = 50  # maximum number of steps in the game
INTERVAL = 500  # milliseconds between frames
N = 4           # size of the board
SEED = 42       # random seed for reproducibility

In [None]:
#@title RUNNING THE GAME
HTML(SnakeGame.animate_frames(SnakeGame(seed=SEED, initial_map=SnakeMap(n=N)).interact_loop(my_snake, max_steps=MAX_STEPS), interval=INTERVAL).to_jshtml())