# Snake Game

This is text

In [2]:
# standard library
import sys;
import os;
import time;
import asyncio;

# scientific
import numpy as np
%matplotlib inline
from matplotlib import pyplot as plt
import IPython
import ipywidgets

# machine learning
import gymnasium as gym

# add project files to path
%load_ext autoreload
%autoreload 2
sys.path.append(os.path.relpath('..'))
import gym_snakegame

  from pkg_resources import resource_stream, resource_exists


### Setup

In [17]:
# register custom snake environment
env = gym.make("gym_snakegame/SnakeGame-v0", board_size=10, n_channel=1, n_target=1, render_mode='human')

# check environment validity (optional)
# https://gymnasium.farama.org/introduction/create_custom_env/#check-environment-validity
from gymnasium.utils.env_checker import check_env
try:
    check_env(env)
    print("Environment passes all checks!")
except Exception as e:
    print(f"Environment has issues: {e}")

  logger.warn(


Environment passes all checks!


  logger.warn(


In [31]:
env = gym.make("gym_snakegame/SnakeGame-v0", board_size=10, n_channel=1, n_target=1, render_mode='ansi')
out = ipywidgets.Output(layout={'border': '1px solid black'})
display(out)

with out:
    obs, info = env.reset()
    for i in range(10000):
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        
        frame = env.render()
        print(frame)
        time.sleep(0.08)
        IPython.display.clear_output(wait=True)

        if terminated or truncated:
            obs, info = env.reset()
    env.close()

Output(layout=Layout(border_bottom='1px solid black', border_left='1px solid black', border_right='1px solid b…

## `SnakeAgent`

In [4]:
# based on https://gymnasium.farama.org/introduction/train_agent/#../tutorials/training_agents

import collections

class SnakeAgent:
  def __init__(
      self,
      env: gym.Env,
      learning_rate: float,
      initial_epsilon: float,
      epsilon_decay: float,
      epsilon_minimum: float,
      discount_factor: float = 0.95,
  ):
    """
    Initialize a Q-Learning agent.

    Args:
      `env`:             The training environment
      `learning_rate`:   How quickly to update Q-values (0-1)
      `initial_epsilon`: Starting exploration rate (usually 1.0)
      `epsilon_decay`:   How much to reduce epsilon each episode
      `epsilon_minimum`: Minimum exploration rate (usually 0.1)
      `discount_factor`: How much to value future rewards (0-1)
    """

    # keep a reference to the training environment
    self.env = env;

    # learning / exploration rates
    self.learning_rate = learning_rate;
    self.discount_factor = discount_factor;

    self.epsilon = initial_epsilon;
    self.epsilon_decay = epsilon_decay;
    self.epsilon_minimum = epsilon_minimum;

    # the Q-Table maps (state, action) pairs to expected reward
    # defaultdict automatically creates entries with zeros for new states
    num_actions = env.action_space.n;
    self.q_values = collections.defaultdict(lambda: np.zeros(num_actions));

    # track learning progress
    self.training_error = []
  
  def get_action(self, obs: tuple[int, int, bool]) -> int:
    """
    Choose an action using an epsilon-greedy strategy.

    Returns:
      `action`: Left / Right / Up / Down
    """

    if np.random.random() < self.epsilon:
      # EXPLORE, with probability epsilon
      return self.env.action_space.sample()
    else:
      # otherwise EXPLOIT!
      return int(np.argmax(self.q_values[obs]))
  
  def update(
    self,
    obs: tuple[int, int, bool],
    action: int,
    reward: float,
    terminated: bool,
    next_obs: tuple[int, int, bool] # what is next_obs?
  ):
    """
    Update Q-value based on experience.

    Args
      (`obs`, `action`): The current state and chosen action.
      `reward`: The reward received after taking the `action`.
      `terminated`: Whether the action caused termination.
      `next_obs`: Next observation after taking `action`.
    """

    # estimate our best expected reward from the next state
    if terminated:
      # no future rewards possible if we're terminated!
      future_q_value = 0;
    else:
      # look for the maximum possible reward from this state according to q-function
      future_q_value = np.max(self.q_values[next_obs])

    target_q_value = reward + self.discount_factor * future_q_value;

    # temporal difference
    temporal_difference = target_q_value - self.q_values[obs][action]

    # update our estimate in the direction of the error
    # learning rate controls step size
    self.q_values[obs][action] = (
      self.q_values[obs][action] + self.learning_rate * temporal_difference
    )

    # track learning progress (useful for debugging)
    self.training_error.append(temporal_difference);

In [8]:
# Training hyperparameters
learning_rate = 0.01        # How fast to learn (higher = faster but less stable)
n_episodes = 100_000        # Number of hands to practice
start_epsilon = 1.0         # Start with 100% random actions
epsilon_decay = start_epsilon / (n_episodes / 2)  # Reduce exploration over time
epsilon_minimum = 0.1         # Always keep some exploration

# Create environment and agent
env = gym.make("gym_snakegame/SnakeGame-v0", board_size=6, n_channel=1, n_target=1, render_mode='rgb_array')
env = gym.wrappers.RecordEpisodeStatistics(env, buffer_length=n_episodes)

agent = SnakeAgent(
    env=env,
    learning_rate=learning_rate,
    initial_epsilon=start_epsilon,
    epsilon_decay=epsilon_decay,
    epsilon_minimum=epsilon_minimum,
)

from tqdm import tqdm  # Progress bar

for episode in tqdm(range(n_episodes)):
    # Start a new hand
    obs, info = env.reset()
    done = False

    # Play one complete hand
    while not done:
        # Agent chooses action (initially random, gradually more intelligent)
        action = agent.get_action(obs)

        # Take action and observe result
        next_obs, reward, terminated, truncated, info = env.step(action)

        print(next_obs)

        # Learn from this experience
        agent.update(obs, action, reward, terminated, next_obs)

        # Move to next state
        done = terminated or truncated
        obs = next_obs

    # Reduce exploration rate (agent becomes less random over time)
    agent.decay_epsilon()

  0%|          | 0/100000 [00:00<?, ?it/s]

[[[ 0  0 37  0  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  1  0  0]
  [ 0  0  3  2  0  0]
  [ 0  0  0  0  0  0]
  [ 0  0  0  0  0  0]]]





TypeError: unhashable type: 'numpy.ndarray'

## Extra

In [33]:
# %matplotlib widget

# env = gym.make("gym_snakegame/SnakeGame-v0", board_size=10, n_channel=1, n_target=1, render_mode='rgb_array')

# fig,ax = plt.subplots(1,1)
# hdisplay = IPython.display.display("", display_id=True)

# obs, info = env.reset()
# for i in range(10000):
#     time.sleep(0.01)

#     action = env.action_space.sample()
#     obs, reward, terminated, truncated, info = env.step(action)
#     frame = env.render()

#     ax.imshow(frame);
#     hdisplay.update(fig);

#     if terminated or truncated:
#         obs, info = env.reset()

# env.close()
# plt.close(fig)

In [34]:
# %matplotlib widget

# from IPython import display

# # https://stackoverflow.com/a/65400882
# def pltsin(ax, *,hdisplay, colors=['b']):
#     x = np.linspace(0,1,100)
#     if ax.lines:
#         for line in ax.lines:
#             line.set_xdata(x)
#             y = np.random.random(size=(100,1))
#             line.set_ydata(y)
#     else:
#         for color in colors:
#             y = np.random.random(size=(100,1))
#             ax.plot(x, y, color)
#     hdisplay.update(fig)


# fig,ax = plt.subplots(1,1)
# hdisplay = IPython.display.display("", display_id=True)

# ax.set_xlabel('X')
# ax.set_ylabel('Y')
# ax.set_xlim(0,1)
# ax.set_ylim(0,1)
# for f in range(5):
#     pltsin(ax, colors=['b', 'r'], hdisplay=hdisplay)
#     time.sleep(1)
    
# plt.close(fig)

In [35]:
# import matplotlib.pyplot as plt
# from matplotlib import animation
# from IPython.display import HTML

# # (Your code to create the figure, axes, and initial plot objects)
# # ...

# # Define the animation function that updates the plot for each frame
# def animate(i):
#     # Update plot objects (e.g., line data, text) based on frame 'i'
#     # ...
#     return (line1, line2, txt_title) # Return the objects that were modified

# # Create the animation object
# anim = animation.FuncAnimation(
#     fig, animate, frames=100, interval=20, blit=True
# )

# # Display the animation as an HTML5 video
# HTML(anim.to_jshtml())

In [36]:
# %matplotlib widget
# import matplotlib.pyplot as plt
# import matplotlib.animation as animation
# import numpy as np
# from IPython.display import HTML
# plt.rcParams["animation.html"] = "jshtml"
# plt.ioff() #needed so the second time you run it you get only single plot

# fig, ax = plt.subplots()

# x = np.arange(0, 2*np.pi, 0.1)
# line, = ax.plot(x, np.sin(x))
# z = x.size

# def animate(i):
#     line.set_ydata(np.sin(x - 2*np.pi*i / z)) 
#     return line,

# ani = animation.FuncAnimation(
#     fig, animate,
#     frames = z,
#     blit=True)
# ani