<a href="https://colab.research.google.com/github/dcownden/PerennialProblemsOfLifeWithABrain/blob/gw-refactor/utils/base_gridworld.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The following is part of a test for an upcoming text book on computational neuroscience from an optimization and learning perspective. The book will start with evolution because ultimately, all aspects of the brain are shaped by evolution and, as we will see, evolution can also be seen as an optimization algorithm. We are sharing it now to get feedback on what works and what does not and the developments we should do.

___
# **Sequence 1.1.1: Gridworld Introduction**

### Objective: In this sequence, we will create a simple environment-organism system to demonstrate how an organism's **behaviour**, within an **environment**, can be evaluated using **rewards**. We will also see how intelligent behaviour can lead to better outcomes and how **randomness** can make evaluation of behaviour more difficult.

# Setup
Run the following cell to setup and install the various dependencies and helper functions for this sequence.

In [None]:
# @title Dependencies, Imports and Setup
# @markdown You don't need to worry about how this code works – but you do need to **run the cell**

!pip install ipympl vibecheck datatops jupyterquiz > /dev/null 2> /dev/null #google.colab

import requests
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import ipywidgets as widgets
import time
import logging
import random
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from scipy.spatial.distance import cdist
from tabulate import tabulate
from IPython.display import display, clear_output, Markdown
from jupyterquiz import display_quiz
from vibecheck import DatatopsContentReviewContainer
from pathlib import Path




# random seed settings and
# getting torch to use gpu if it's there


def set_seed(seed=None, seed_torch=True):
  """
  Function that controls randomness. NumPy and random modules must be imported.

  Args:
    seed : Integer
      A non-negative integer that defines the random state. Default is `None`.
    seed_torch : Boolean
      If `True` sets the random seed for pytorch tensors, so pytorch module
      must be imported. Default is `True`.

  Returns:
    Nothing.
  """
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)
  if seed_torch:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

  print(f'Random seed {seed} has been set.')


def seed_worker(worker_id):
  """
  DataLoader will reseed workers following randomness in
  multi-process data loading algorithm.

  Args:
    worker_id: integer
      ID of subprocess to seed. 0 means that
      the data will be loaded in the main process
      Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details

  Returns:
    Nothing
  """
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)


def set_device():
  """
  Set the device. CUDA if available, CPU otherwise

  Args:
    None

  Returns:
    Nothing
  """
  device = "cuda" if torch.cuda.is_available() else "cpu"
  if device != "cuda":
    print("WARNING: For this notebook to perform best, "
        "if possible, in the menu under `Runtime` -> "
        "`Change runtime type.`  select `GPU` ")
  else:
    print("GPU is enabled in this notebook.")

  return device


SEED = 2021
set_seed(seed=SEED)
DEVICE = set_device()

def printmd(string):
  display(Markdown(string))


# the different utility .py files used in this notebook
filenames = ['gw_plotting.py', 'gw_board.py', 'gw_game.py', 'gw_widgets.py',
             'gw_NN_RL.py']
# A directory to store these utility .py files
Path('utils').mkdir(parents=True, exist_ok=True)
# Get the IPython interactive shell to run/load the utility files
ipython = get_ipython()

for filename in filenames:
  url = f'https://raw.githubusercontent.com/dcownden/PerennialProblemsOfLifeWithABrain/main/utils/{filename}'
  response = requests.get(url)
  # Check that we got a valid response
  if response.status_code == 200:
    code = response.content.decode()
    exec(code)
  else:
    print(f'Failed to download {url}')

# environment contingent imports
try:
  print('Running in colab')
  from google.colab import output
  output.enable_custom_widget_manager()
  #from google.colab import output as colab_output
  #colab_output.enable_custom_widget_manager()
  IN_COLAB = True
except:
  IN_COLAB = False
  print('Not running in colab')

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%matplotlib widget
plt.style.use("https://raw.githubusercontent.com/dcownden/PerennialProblemsOfLifeWithABrain/main/pplb.mplstyle")
plt.ioff() #need to use plt.show() or display explicitly
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)


def content_review(notebook_section: str):
  return DatatopsContentReviewContainer(
    "",  # No text prompt
    notebook_section,
    {
      "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
      "name": "neuro_book",
      "user_key": "xuk960xj",
    },
  ).render()
feedback_prefix = "P1C1_S1"

In [None]:
# @title plotting functions
#################################################
# More plotting functions
#################################################


def plot_directions(fig, ax, loc_prob_dict, critter, deterministic=False,
                    name=None):
  """
  Plot vector field indicating critter direction probabilities.

  Args:
    fig, ax (matplotlib objects): Figure and axes objects for plotting.
    loc_prob_dict (dict): Dictionary with keys as (row, col) location tuples
      and values as lists of direction probabilities corresponding to the
      directions ['right', 'down', 'left', 'up'].
    critter (int): Identifier for which critter directions are associated with.
    deterministic (bool, optional): If True, the probabilities array is
      converted to 1-hot, and the arrows are plotted at the center of the cell
      and are larger. Defaults to False.
  """

  #looks like direction ignores inverted axis
  direction_vectors = {'right': (1, 0), 'down': (0, -1),
                       'left': (-1, 0), 'up': (0, 1)}
  # but offsets need to be aware of inverted
  direction_offsets = {'right': (0.1, 0), 'down': (0, 0.1),
                       'left': (-0.1, 0), 'up': (0, -0.1)}
  # Offsets for each critter type 1 and 2 to be used together, 0 by itself
  critter_offsets = {0: (0, 0), 1: (-0.05, -0.05), 2: (0.05, 0.05)}
  # same logic for colors
  critter_colors = {0: 'black', 1: 'red', 2: 'blue'}
  # Get the offset and color for this critter
  critter_offset = critter_offsets[critter]
  critter_color = critter_colors[critter]

  # Add legend only if critter is not 0
  custom_leg_handles = []
  if critter != 0:
    if name is None:
      name = f'Critter {critter}'
    legend_patch = mpatches.Patch(color=critter_color, label=name)
    # Add the legend for this critter
    custom_leg_handles.append(legend_patch)

  C, R, U, V, A = [], [], [], [], []

  for loc in loc_prob_dict.keys():
    row, col = loc
    probs = loc_prob_dict[loc]
    for dir_key, prob in probs.items():
      C.append(col + critter_offset[0] + direction_offsets[dir_key][0])
      R.append(row + critter_offset[1] + direction_offsets[dir_key][1])
      U.append(direction_vectors[dir_key][0])
      V.append(direction_vectors[dir_key][1])

      if deterministic:
        A.append(1 if prob == max(probs.values()) else 0)
      else:
        A.append(prob)

  linewidth = 1.5 if deterministic else 0.5
  scale = 15 if deterministic else 30

  ax.quiver(C, R, U, V, alpha=A, color=critter_color,
            scale=scale, linewidth=linewidth)
  return fig, ax, custom_leg_handles

def make_grid(num_rows, num_cols, figsize=(7,6), title=None):
  """Plots an n_rows by n_cols grid with cells centered on integer indices and
  returns fig and ax handles for further use
  Args:
    num_rows (int): number of rows in the grid (vertical dimension)
    num_cols (int): number of cols in the grid (horizontal dimension)

  Returns:
    fig (matplotlib.figure.Figure): figure handle for the grid
    ax: (matplotlib.axes._axes.Axes): axes handle for the grid
  """
  # Create a new figure and axes with given figsize
  fig, ax = plt.subplots(figsize=figsize, layout='constrained')
  # Set width and height padding, remove horizontal and vertical spacing
  fig.get_layout_engine().set(w_pad=4 / 72, h_pad=4 / 72, hspace=0, wspace=0)
  # Show right and top borders (spines) of the plot
  ax.spines[['right', 'top']].set_visible(True)
  # Set major ticks (where grid lines will be) on x and y axes
  ax.set_xticks(np.arange(0, num_cols, 1))
  ax.set_yticks(np.arange(0, num_rows, 1))
  # Set labels for major ticks with font size of 8
  ax.set_xticklabels(np.arange(0, num_cols, 1),fontsize=8)
  ax.set_yticklabels(np.arange(0, num_rows, 1),fontsize=8)
  # Set minor ticks (no grid lines here) to be between major ticks
  ax.set_xticks(np.arange(0.5, num_cols-0.5, 1), minor=True)
  ax.set_yticks(np.arange(0.5, num_rows-0.5, 1), minor=True)
  # Move x-axis ticks to the top of the plot
  ax.xaxis.tick_top()
  # Set grid lines based on minor ticks, make them grey, dashed, and half transparent
  ax.grid(which='minor', color='grey', linestyle='-', linewidth=2, alpha=0.5)
  # Remove minor ticks (not the grid lines)
  ax.tick_params(which='minor', bottom=False, left=False)
  # Set limits of x and y axes
  ax.set_xlim(( -0.5, num_cols-0.5))
  ax.set_ylim(( -0.5, num_rows-0.5))
  # Invert y axis direction
  ax.invert_yaxis()
  # If title is provided, set it as the figure title
  if title is not None:
    fig.suptitle(title)
  # Hide header and footer, disable toolbar and resizing of the figure
  fig.canvas.header_visible = False
  fig.canvas.toolbar_visible = False
  fig.canvas.resizable = False
  fig.canvas.footer_visible = False
  # Redraw the figure with these settings
  fig.canvas.draw()
  # Return figure and axes handles for further customization
  return fig, ax


def plot_food(fig, ax, rc_food_loc, food=None):
  """
  Plots "food" on a grid implied by the given fig, ax arguments

  Args:
    fig, ax: matplotlib figure and axes objects
    rc_food_loc: ndarry(int) of shape (N:num_food x 2:row,col)
    food: a handle for the existing food matplotlib PatchCollection object
    if one exists
  Returns:
    a handle for matplotlib PathCollection object of food scatter plot, either
    new if no handle was passed or updated if it was
  """
  # if no PathCollection handle passed in:
  if food is None:
    food = ax.scatter([], [], s=150, marker='o', color='red', label='Food')
  rc_food_loc = np.array(rc_food_loc, dtype=int)
  #matrix indexing convention is is [row-vertical, col-horizontal]
  #plotting indexing convention is (x-horizontal,y-vertical), hence flip
  food.set_offsets(np.fliplr(rc_food_loc))
  return food


def plot_critters(fig, ax, critter_specs: List[Dict[str, object]]) -> List[Dict[str, object]]:
  """
  Plots multiple types of "critters" on a grid implied by the given
  fig, ax arguments.

  Args:
    fig, ax: matplotlib figure and axes objects.
    critter_specs: List of dictionaries with keys 'location', 'name', 'color',
    'marker', 'int_id', 'rc_critter_loc' and optionally 'handle' for each
    critter.

  Returns:
    Updated critter_specs with handles.
  """
  for spec in critter_specs:
    # Ensure required keys are present
    for key in ['marker', 'color', 'name', 'rc_loc']:
      if key not in spec:
        raise ValueError(f"Key '{key}' missing in critter spec.")
    handle_ = spec.get('handle')
    if handle_ is None:
      handle_ = ax.scatter([], [], s=250, marker=spec['marker'],
                           color=spec['color'], label=spec['name'])
    handle_.set_offsets(np.flip(spec['rc_loc']))
    spec.update({'handle': handle_})
  return critter_specs


def plot_critter(fig, ax, rc_critter_loc,
                 critter=None, critter_name='Critter'):
  """
  Plots "critter" on a grid implied by the given fig, ax arguments

  Args:
    fig, ax: matplotlib figure and axes objects
    rc_critter_loc: ndarry(int) of shape (N:num_critters x 2:row,col)
    critter: a handle for the existing food matplotlib PatchCollection object
    if one exists
  Returns:
    a handle for matplotlib PathCollection object of critter scatter plot,
    either new if no handle was passed in or updated if it was.
  """
  if critter is None:
    critter = ax.scatter([], [], s=250, marker='h',
                         color='blue', label=critter_name)
  # matrix indexing convention is is [row-vertical, col-horizontal]
  # plotting indexing convention is (x-horizontal,y-vertical), hence flip
  critter.set_offsets(np.flip(rc_critter_loc))
  return critter


def plot_fov(fig, ax, rc_critter, n_rows, n_cols, radius, has_fov,
             opaque=False, fov=None):
  """
  Plots a mask on a grid implied by the given fig, ax arguments

  Args:
    fig, ax: matplotlib figure and axes objects
    rc_critter: ndarry(int) (row,col) of the critter
    mask: a handle for the existing mask matplotlib Image object if one exists
  Returns:
    a handle for matplotlib Image object of mask, either new if no handle
    was passed in or updated if it was.
  """

  # Initialize mask as a semi-transparent overlay for the entire grid
  mask_array = np.ones((n_rows, n_cols, 4))
  mask_array[:, :, :3] = 0.5  # light grey color
  if has_fov == True:
    if opaque:
      mask_array[:, :, 3] = 1.0  # 50% opacity
    else:
      mask_array[:, :, 3] = 0.5  # 50% opacity
    # Create arrays representing the row and column indices
    rows = np.arange(n_rows)[:, np.newaxis]
    cols = np.arange(n_cols)[np.newaxis, :]
    # Iterate over each critter location
    dist = np.abs(rows - rc_critter[0]) + np.abs(cols - rc_critter[1])
    # Set the region within the specified radius around the critter to transparent
    mask_array[dist <= radius, 3] = 0
  else:
    mask_array[:, :, 3] = 0

  if fov is None:
    fov = ax.imshow(mask_array, origin='lower', zorder=2)
  else:
    fov.set_data(mask_array)

  return fov


def remove_ip_clutter(fig):
  fig.canvas.header_visible = False
  fig.canvas.toolbar_visible = False
  fig.canvas.resizable = False
  fig.canvas.footer_visible = False
  fig.canvas.draw()

In [None]:
# @title BaseGridworldBoard class
#######################################################################
# extend GridworldGame class locally before integrating in shared utils
#######################################################################




class BaseGridworldBoard():
  """
  A collection methods and parameters common to all the gridworld variants used
  in this book, core game logic will be different for different variants, but
  some functionality like checking legal moves, managing state, etc. is shared.
  Parameterization is overloaded, so it works with all variants, but
  core game logic is re-written almost entirely and expects very specific
  parameter values in order to work

  board state is represented by primarily by pieces, scores, moves_left and is_over

  pieces is a batch x n_rows x n_cols numpy array positive integers are critter
  locations 0's are empty space and negative integers are food. Each critter is
  unique and executing it's own policy so they are non-fungible, whereas food
  (of the same type) is always the same, so there can and typically will be
  duplicates of negative integers in the pieces array, but never of positive
  integers

  For pieces first dim is batch, second dim row , third is col,
  so pieces[0][1][7] is the square in row 2, in column 8 of the first board in
  the batch of boards.

  scores is a batchsize x num_critters numpy array giving the scores for each
  critter on each board in the batch (note off by one indexing)

  moves_left is how many moves are left in the game, for each critter.
  Usually this will be the same for every critter in every batch, but may not
  be for some variants.

  is_over just tracks whether each game in each batch has concluded, this allows
  for probabalistic end times, not just deterministic end times based on moves left

  other variant can add additional elements like active player, foraging histories
  for players etc.

  Note:
    In 2d np.array first dim is row (vertical), second dim is col (horizontal),
    i.e. top left corner is (0,0), so take care when visualizing/plotting
    as np.array visualization is aligned with typical tensor notation but at odds
    with conventional plotting where (0,0) is bottom left, first dim, x, is
    horizontal, second dim, y, is vertical, so we use invert y-axis when plotting
    with matplotlib
  """

  class CritterFoodType(Enum):
    FOOD = "food"
    PREY = "prey"
    PREDATOR = "predator"

  ARRAY_PAD_VALUE = -200

  def __init__(self, batch_size=1,
               n_rows=5, n_cols=5,
               num_foragars=1,
               num_predators=0,
               max_moves_taken=20,
               end_prob=0.05,
               moves_cost=True,
               foraging_costs=False,
               base_food_patch_prob = 0.4,
               food_patch_type_probs = [1.0],
               food_regen_probs = [1.0],
               forage_success_prob = [1.0],
               food_extinct_prob = [1.0],
               rng=None,
               state_elements = ['pieces', 'scores', 'is_over', 'moves_left'],
               init_board_state = None
               ):

    """Set the parameters of the game."""
    # size of the board/world
    self.n_rows = n_rows
    self.n_cols = n_cols
    self.batch_size = batch_size

    #number and type of critters on the board
    self.num_foragers = num_foragers
    self.num_predators = num_predators
    # foragers will be indicated by lower valued positive integers, predators
    # by higher valued intagers
    self.forager_predator_threshold = self.num_foragers
    self.num_critters = num_foragers + num_predators

    # end conditions can be deterministic or stochastic
    # one of moving, or eating or both might take time, e.g. eating might be
    # automatic and free after moving, conversely, moving might be free, but
    # eating count towards the session/episode ending, or both might
    self.max_moves_taken = max_moves_taken
    self.end_prob = end_prob
    self.moves_cost = moves_cost
    self.foraging_costs = foraging_costs

    # what proportion of the (non-critter occupied) patches contain food.
    self.food_patch_prob = food_patch_prob
    # what are the different types of food and what are their properties
    self.food_patch_type_probs = food_patch_type_probs
    self.food_regen_prob = food_regen_probs
    self.forage_success_probs = forage_success_prob
    self.food_extinct_prob = food_extinct_prob
    self.num_food_types = len(food_patch_type_probs)
    food_property_lists = [food_patch_type_probs, food_regen_probs, forage_success_prob, food_extinct_prob]
    if not all(len(lst) == len(lists_to_check[0]) for lst in food_property_lists):
      raise ValueError("All food type parameter lists must have the same length.")

    # reproducible stochasticity
    if rng is None:
      self.rng = np.random.default_rng(seed=SEED)
    else:
      self.rng = rng

    self.state_elements = state_elements

    # initialize the board
    if init_board_state is None:
      init_board_state = self.get_init_board_state()

    self.set_state(init_board_state)


  def init_loc(self, n_rows, n_cols, num, rng=None):
    """
    Samples random 2d grid locations without replacement, useful for placing
    critters and food on the board.

    Args:
      n_rows: int, number of rows in the grid
      n_cols: int, number of columns in the grid
      num:    int, number of samples to generate. Should throw an error if num > n_rows x n_cols
      rng:    instance of numpy.random's default rng. Used for reproducibility.

    Returns:
      int_loc: ndarray(int) of shape (num,), flat indices for a 2D grid flattened into 1D
      rc_index: tuple(ndarray(int), ndarray(int)), a pair of arrays with the first giving
        the row indices and the second giving the col indices. Useful for indexing into
        an n_rows by n_cols numpy array.
      rc_plotting: ndarray(int) of shape (num, 2), 2D coordinates suitable for matplotlib plotting
    """

    # Set up default random generator, use the boards default if none explicitly given
    if rng is None:
      rng = self.rng
    # Choose 'num' unique random indices from a flat 1D array of size n_rows*n_cols
    int_loc = rng.choice(n_rows * n_cols, num, replace=False)
    # Convert the flat indices to 2D indices based on the original shape (n_rows, n_cols)
    rc_index = np.unravel_index(int_loc, (n_rows, n_cols))
    # Transpose indices to get num x 2 array for easy plotting with matplotlib
    rc_plotting = np.array(rc_index).T
    # Return 1D flat indices, 2D indices for numpy array indexing and 2D indices for plotting
    return int_loc, rc_index, rc_plotting


  def get_init_board_state(self):
    """
    Set up starting board using game parameters
    Needs to be adapted for some variants
    """
    state = {}
    state['moves_left'] = (np.ones(self.batch_size) *
                           self.max_moves_taken)
    state['is_over'] = np.zeros(self.batch_size, dtype=bool)
    state['scores'] = np.zeros((self.batch_size, self.num_critters))

    # create an empty board array.
    pieces = np.zeros((self.batch_size, self.n_rows, self.n_cols),
                       dtype=int)
    # Place critter and initial food items on the board randomly
    init_food_nums = self.rng.binomial(self.n_rows * self.n_cols - self.num_critters,
                                       self.food_patch_prob, size=self.batch_size)
    # place food and critters randomly
    for ii in np.arange(self.batch_size):
      # num_food+num_critter because we want critter and food locations
      int_loc, rc_idx, rc_plot = self.init_loc(
        self.n_rows, self.n_cols, init_food_nums[ii]+self.num_critters)
      # critter random start locations
      for c_ in np.arange(self.num_critters):
        pieces[(ii, rc_idx[0][c_], rc_idx[1][c_])] = c_ + 1
      # food random start locations
      for f_ in np.arange(init_food_nums[ii]):
        pieces[(ii, rc_idx[0][self.num_critters + f_],
                    rc_idx[1][self.num_critters + f_])] = -f_ - 1
    state['pieces'] = pieces
    return state


  def set_state(self, board, check=False):
    """ board is dictionary giving game state """
    if check:
      if board['pieces'].shape != (self.batch_size, self.n_rows, self.n_cols):
        raise ValueError("Invalid shape for 'pieces'")
      if board['scores'].shape != (self.batch_size, self.num_crititters):
        raise ValueError("Invalid shape for 'scores'")
      if board['moves_left'].shape != (self.batch_size,):
        raise ValueError("Invalid shape for 'moves_left'")
      if board['is_over'].shape != (self.batch_size,):
        raise ValueError("Invalid shape for 'is_over'")
    for key in self.state_elements:
      if key in board:
        setattr(self, key, board[key].copy())
      else:
        raise ValueError(f"Key '{key}' not found in the provided board state.")


  def get_state(self):
    """ returns a board state, which is a triple of np arrays
    pieces,       - batch_size x n_rows x n_cols
    scores,       - batch_size
    moves_left   - batch_size
    """
    state = {'pieces': self.pieces.copy(),
             'scores': self.scores.copy(),
             'moves_left': self.moves_left.copy(),
             'is_over': self.is_over.copy()}
    return state


  def __getitem__(self, index):
    return self.pieces[index]




  def get_init_board_state(self):
    """Set up starting board using game parameters"""
    #set moves_left and score
    self.moves_left = (np.ones(self.batch_size) *
                        self.max_moves_taken * self.num_critters)
    # each players move counts down the clock so making this a multiple of the
    # number of critters ensures every player gets an equal number of turns
    self.scores = np.zeros((self.batch_size, self.num_critters))
    # create an empty board array.
    self.pieces = np.zeros((self.batch_size, self.n_rows, self.n_cols))
    # Place critter and initial food items on the board randomly
    for ii in np.arange(self.batch_size):
      # num_food+num_critter because we want critter and food locations
      int_loc, rc_idx, rc_plot = self.init_loc(
        self.n_rows, self.n_cols, self.num_food+self.num_critters)
      # critter random start locations
      for c_ in np.arange(self.num_critters):
        self.pieces[(ii, rc_idx[0][c_], rc_idx[1][c_])] = c_ + 1
      # food random start locations
      self.pieces[(ii, rc_idx[0][self.num_critters:],
                   rc_idx[1][self.num_critters:])] = -1
    self.active_player = 0
    state = {'pieces': self.pieces.copy(),
             'scores': self.scores.copy(),
             'moves_left': self.moves_left.copy(),
             'active_player': copy(self.active_player)}
    return state


  def set_state(self, board):
    """ board is dictionary giving game state a triple of np arrays
      pieces:        numpy array (batch_size x n_rows x n_cols),
      scores:        numpy array (batch_size x num_critters)
      moves_left:   numpy array (batch_size)
      active_player: int
    """
    self.pieces = board['pieces'].copy()
    self.scores = board['scores'].copy()
    self.moves_left = board['moves_left'].copy()
    self.active_player = copy(board['active_player'])


  def get_state(self):
    """ returns a board state, which is a triple of np arrays
    pieces,       - batch_size x n_rows x n_cols
    scores,       - batch_size
    moves_left   - batch_size
    """
    state = {'pieces': self.pieces.copy(),
             'scores': self.scores.copy(),
             'moves_left': self.moves_left.copy(),
             'active_player': copy(self.active_player)}
    return state


  def __getitem__(self, index):
    return self.pieces[index]


  def execute_moves(self, moves, critter):
    """
    Updates the state of the board given the moves made.

    Args:
      moves: a 3-tuple of 1-d arrays each of length batch_size,
        the first array gives the specific board within the batch,
        the second array in the tuple gives the new row coord for each critter
        on each board and the third gives the new col coord.

    Notes:
      Assumes that there is exactly one valid move for each board in the
      batch of boards for the critter type given. i.e. it does't check for
      bounce/reflection on edges or with other critters, or for multiple move
      made on the same board. It only checks for eating food and adds new food
      when appropriate. Invalid moves could lead to illegal teleporting
      behavior, critter duplication, or index out of range errors.

      Currently just prints a message if critter making the move is not the
      active player, could enforce this more strictly if needed.
    """
    if critter-1 != self.active_player:
      # note critter is [1 to num_critter] inclusive so that it can be used
      # directly in where statements on pieces but self.active_player is
      # [0 to numcritter-1] inclusive so that it can be used directly in
      # indexing player lists
      raise ValueError("Warning! The critter moving is not the expected active player")
    #critters leave their spots
    self.pieces[self.pieces==critter] = 0
    #which critters have food in their new spots
    eats_food = self.pieces[moves] == -1
    # some critters eat and their scores go up
    # note critter is +int so need to -1 for indexing
    self.scores[:,critter-1] = self.scores[:,critter-1] + eats_food

    num_empty_after_eat = (self.n_rows*self.n_cols - self.num_food -
                           self.num_critters + 1) # +1 for the food just eaten
    # which boards in the batch had eating happen
    g_eating = np.where(eats_food)[0]
    # put critters in new positions
    self.pieces[moves] = critter
    if np.any(eats_food):
      # add random food to replace what is eaten
      possible_new_locs = np.where(np.logical_and(
          self.pieces == 0, #the spot is empty
          eats_food.reshape(self.batch_size, 1, 1))) #food eaten on that board
      food_sample_ = self.rng.choice(num_empty_after_eat,
                                     size=np.sum(eats_food))
      food_sample = food_sample_ + np.arange(len(g_eating))*num_empty_after_eat
      assert np.all(self.pieces[(possible_new_locs[0][food_sample],
                                 possible_new_locs[1][food_sample],
                                 possible_new_locs[2][food_sample])] == 0)
      #put new food on the board
      self.pieces[(possible_new_locs[0][food_sample],
                   possible_new_locs[1][food_sample],
                   possible_new_locs[2][food_sample])] = -1
    self.moves_left = self.moves_left - 1
    if not np.all(self.pieces.sum(axis=(1,2)) ==
                  ((self.num_food * -1) + np.sum(np.arange(self.num_critters)+1))):
      print(self.pieces.sum(axis=(1,2)))
      print(((self.num_food * -1) + np.sum(np.arange(self.num_critters)+1)))
    assert np.all(self.pieces.sum(axis=(1,2)) ==
                  ((self.num_food * -1) + np.sum(np.arange(self.num_critters)+1)))
    # next player's turn
    self.active_player = (self.active_player + 1) % (self.num_critters)


  def execute_drift(self, offset_probs, wrapping=False):
    """
    Drift the food on the board based on the given offsets probabilities.
    Collisions handled by checking possible new locations in a random order and
    cancelling moves that result in a collision.

    Parameters:
    - offset_probs: Probabilities corresponding to each offset, note implicit
    order dependence here


    Returns:
    - nothing, just updates self.pieces
    """
    # Check the length of offset_probs
    #if len(offset_probs) != 5:
    #    raise ValueError("offset_probs should be of length 5.")
    # Check if values are non-negative
    #if any(p < 0 for p in offset_probs):
    #    raise ValueError("All probabilities in offset_probs should be non-negative.")
    # Normalize the probabilities
    #offset_probs = np.array(offset_probs) / np.sum(offset_probs)
    # Convert offsets to a 2D numpy array
    possible_offsets = np.array([[ 0, -1,  0], # up
                                 [ 0,  1,  0], # down
                                 [ 0,  0, -1], # left
                                 [ 0,  0,  1], # right
                                 [ 0,  0,  0]]) # still
    batch_size, n_rows, n_cols = self.pieces.shape
    # original food locations
    food_locations = np.argwhere(self.pieces == -1)
    # Sample offsets for each food location
    num_food = food_locations.shape[0]
    sampled_offsets = possible_offsets[self.rng.choice(
        np.arange(possible_offsets.shape[0]),
        size=num_food, replace=True, p=offset_probs)]
    # Possible new food locations
    possible_new_locations = food_locations + sampled_offsets
    possible_wrap_row_indexes = self.rng.choice(np.arange(n_rows),
                                                size=num_food)
    possible_wrap_col_indexes = self.rng.choice(np.arange(n_cols),
                                                size=num_food)

    # Randomly iterate through the possible new locations
    random_order = np.random.permutation(num_food)
    for idx in random_order:
      g, r, c = possible_new_locations[idx]
      # Check if the new location is inside the boundaries of the board
      if 0 <= r < self.pieces.shape[1] and 0 <= c < self.pieces.shape[2]:
        # Check if the new location is empty or contains a critter
        if self.pieces[g, r, c] == 0:
          # Update the board
          old_g, old_r, old_c = food_locations[idx]
          self.pieces[g, r, c] = -1
          self.pieces[old_g, old_r, old_c] = 0
      elif wrapping == True:
        # If wrapping is on then food can drift off the edge of the board and
        # 'new' food will appear in a random loc on the opposite side
        # Determine the opposite edge
        if r < 0:  # Top edge
          opposite_r = n_rows - 1
          opposite_c = possible_wrap_col_indexes[idx]
        elif r >= n_rows:  # Bottom edge
          opposite_r = 0
          opposite_c = possible_wrap_col_indexes[idx]
        elif c < 0:  # Left edge
          opposite_c = n_cols - 1
          opposite_r = possible_wrap_row_indexes[idx]
        elif c >= n_cols:  # Right edge
          opposite_c = 0
          opposite_r = possible_wrap_row_indexes[idx]

        # Check if the opposite location is unoccupied
        if self.pieces[g, opposite_r, opposite_c] == 0:
          old_g, old_r, old_c = food_locations[idx]
          self.pieces[g, opposite_r, opposite_c] = -1
          self.pieces[old_g, old_r, old_c] = 0


  def get_legal_moves(self, critter):
    """
    Identifies all legal moves for the critter, taking into account
    bouncing/reflection at edges,

    Returns:
      A numpy int array of size batch x 3(g,x,y) x 4(possible moves)

    Note:
      moves[0,1,3] is the x coordinate of the move corresponding to the
      fourth offset on the first board.
      moves[1,:,1] will give the g,x,y triple corresponding to the
      move on the second board and the second offset, actions are integers
    """
    # can only move one cell down, up, right, and left from current location
    critter_locs = np.array(np.where(self.pieces == critter))
    legal_offsets = np.stack([
      critter_locs + np.array([np.array([0,  1, 0])]*self.batch_size).T,
      critter_locs + np.array([np.array([0, -1, 0])]*self.batch_size).T,
      critter_locs + np.array([np.array([0, 0,  1])]*self.batch_size).T,
      critter_locs + np.array([np.array([0, 0, -1])]*self.batch_size).T])
    legal_offsets = np.vstack(np.transpose(legal_offsets, (0, 2, 1)))
    legal_offsets = set([tuple(m_) for m_ in legal_offsets])
    # must land on the board and not on another critter
    legal_destinations = np.where(self.pieces <= 0)
    legal_destinations = set([(g, r, c) for
                              g, r, c in zip(*legal_destinations)])
    # legal moves satisfy both these conditions
    legal_moves = legal_offsets.intersection(legal_destinations)
    return legal_moves


  def get_perceptions(self, radius, critter):
    """
    Generates a vector representation of the critter perceptions, oriented
    around the critter.

    Args:
      radius: int, how many grid squared the critter can see around it
        using L1  (Manhattan/cityblock) distance

    Returns:
      A batch_size x 2*radius*(radius+1) + 1, giving the values
      of the percept reading left to right, top to bottom over the board,
      for each board in the batch
    """
    # define the L1 ball mask
    diameter = radius*2+1
    mask = np.zeros((diameter, diameter), dtype=bool)
    mask_coords = np.array([(i-radius, j-radius)
      for i in range(diameter)
        for j in range(diameter)])
    mask_distances = cdist(mask_coords, [[0, 0]],
                           'cityblock').reshape(mask.shape)
    mask[mask_distances <= radius] = True
    mask[radius,radius] = False  # exclude the center

    # pad the array
    padded_arr = np.pad(self.pieces, ((0, 0), (radius, radius),
     (radius, radius)), constant_values=-2)

    # get locations of critters
    critter_locs = np.argwhere(padded_arr == critter)

    percepts = []
    for critter_loc in critter_locs:
      b, r, c = critter_loc
      surrounding = padded_arr[b, r-radius:r+radius+1, c-radius:c+radius+1]
      percept = surrounding[mask]
      percepts.append(percept)
    return(np.array(percepts))



class PatchyForageBoard():
  """
  A collection of methods and parameters of a patchy foraging game board that
  define the logic of the game, and allows for multiple critters on the same
  board

  game state is represented by primarily by food locations, forager locations,
  predator locations, scores, and rounds left
  food patch locations are stored on a batch x n_rows x n_cols numpy array,
  forager and predator(when we have them) locations are stored as dictionaries
  with integer keys corresponding to a forager/predatore 1, 2, 3 etc, and then
  np.argwhere style tuples of arrays of (batch_array, row_array, col_array)
  giving the locations

  scores is a batchsize x num_critters numpy array giving the scores for each
  critter on each board in the batch (note off by one indexing)

  moves_left is how many rounds are left in the game.

  Note:
    In 2d np.array first dim is row (vertical), second dim is col (horizontal),
    i.e. top left corner is (0,0), so take care when visualizing/plotting
    as np.array visualization inline with typical tensor notation but at odds
    with conventional plotting where (0,0) is bottom left, first dim, x, is
    horizontal, second dim, y, is vertical
  """

  ARRAY_PAD_VALUE = -200




  def init_loc(self, n_rows, n_cols, num, rng=None):
    """
    Samples random 2d grid locations without replacement

    Args:
      n_rows: int, number of rows in the grid
      n_cols: int, number of columns in the grid
      num:    int, number of samples to generate. Should throw an error if num > n_rows x n_cols
      rng:    instance of numpy.random's default rng. Used for reproducibility.

    Returns:
      int_loc: ndarray(int) of shape (num,), flat indices for a 2D grid flattened into 1D
      rc_index: tuple(ndarray(int), ndarray(int)), a pair of arrays with the first giving
        the row indices and the second giving the col indices. Useful for indexing into
        an n_rows by n_cols numpy array.
      rc_plotting: ndarray(int) of shape (num, 2), 2D coordinates suitable for matplotlib plotting
    """

    # Set up default random generator, use the boards default if none explicitly given
    if rng is None:
      rng = self.rng
    # Choose 'num' unique random indices from a flat 1D array of size n_rows*n_cols
    int_loc = rng.choice(n_rows * n_cols, num, replace=False)
    # Convert the flat indices to 2D indices based on the original shape (n_rows, n_cols)
    rc_index = np.unravel_index(int_loc, (n_rows, n_cols))
    # Transpose indices to get num x 2 array for easy plotting with matplotlib
    rc_plotting = np.array(rc_index).T
    # Return 1D flat indices, 2D indices for numpy array indexing and 2D indices for plotting
    return int_loc, rc_index, rc_plotting


  def get_init_board_state(self):
    """
    Set up starting board using game parameters

    Returns:
      state (dict):
      The state dictionary contains:
        - 'pieces': Current food patch locations as a batch x row x col numpy array.
        - 'scores': The current scores of the critters.
        - 'moves_taken': The number of foraging attempts each critter has made.
        - 'is_over': Flags indicating if the game is over for each board in the batch.
        - 'forager_locs': Dictionary of current locations of the foragers on the board.
        - 'misses_new_patch': List of counts for missed attempts at new patches for each critter.
        - 'misses_known_patch': List of counts for missed attempts at known patches for each critter.
        - 'at_new_patch': List of booleans indicating if each critter is at a new
    """
    # note that is_over applies at the batch level not the batch x forager level
    self.is_over = np.zeros(self.batch_size, dtype=bool)
    self.moves_taken = np.zeros((self.batch_size, self.num_foragers), dtype=int)
    self.scores = np.zeros((self.batch_size, self.num_foragers), dtype=int)
    # create an empty board array for food locs
    self.pieces = np.zeros((self.batch_size, self.n_rows, self.n_cols),
                           dtype=int)
    # Place critters in top left corner of the board
    self.forager_locs = {}
    for c in (np.arange(self.num_foragers)+1):
      self.forager_locs[c] = (np.arange(self.batch_size, dtype=int),
                              np.zeros(self.batch_size, dtype=int),
                              np.zeros(self.batch_size, dtype=int))
    # Initial food patches on the board randomly
    # each grid has an independent prob of being a pathc (to make the math
    # easier later) so total number of patches on a board is binomially
    # distributed
    num_foods = self.rng.binomial(n=self.n_rows * self.n_cols,
                                  p=self.food_patch_prob,
                                  size=self.batch_size)
    for ii in np.arange(self.batch_size):
      int_loc, rc_idx, rc_plot = self.init_loc(self.n_rows, self.n_cols,
                                               num_foods[ii])
      # food patch start locations (do each patch separate in case we
      # want to have different kinds of patches)
      for f_ in np.arange(num_foods[ii]):
        self.pieces[(ii, rc_idx[0][f_],
                         rc_idx[1][f_])] = - 1
    # keep track of which foragers have missed how many times
    # at what kind of patch
    self.misses_new_patch = np.zeros((self.batch_size, self.num_foragers), dtype=int)
    self.misses_known_patch = np.zeros((self.batch_size, self.num_foragers), dtype=int)
    self.at_new_patch = np.ones((self.batch_size, self.num_foragers), dtype=bool)
    state = {'pieces': self.pieces.copy(),
             'scores': self.scores.copy(),
             'moves_taken': self.moves_taken.copy(),
             'is_over': self.is_over.copy(),
             'forager_locs': copy.deepcopy(self.forager_locs),
             'misses_new_patch': self.misses_new_patch.copy(),
             'misses_known_patch': self.misses_known_patch.copy(),
             'at_new_patch': self.at_new_patch.copy()}
    return state


  def set_state(self, board):
    """
    Sets the state given a board dictionary.

    Args:
      board (dict):
      The board dictionary contains:
        - 'pieces': Current food patch locations as a batch x row x col numpy array.
        - 'scores': The current scores of the critters.
        - 'moves_taken': The number of foraging attempts each critter has made.
        - 'is_over': Flags indicating if the game is over for each board in the batch.
        - 'forager_locs': Dictionary of current locations of the foragers on the board.
        - 'misses_new_patch': List of counts for missed attempts at new patches for each critter.
        - 'misses_known_patch': List of counts for missed attempts at known patches for each critter.
        - 'at_new_patch': List of booleans indicating if each critter is at a new patch.
    """
    self.pieces = board['pieces'].copy()
    self.forager_locs = copy.deepcopy(board['forager_locs'])
    self.moves_taken = board['moves_taken'].copy()
    self.scores = board['scores'].copy()
    self.is_over = board['is_over'].copy()
    self.misses_new_patch = board['misses_new_patch'].copy()
    self.misses_known_patch = board['misses_known_patch'].copy()
    self.at_new_patch = board['at_new_patch'].copy()


  def get_state(self):
    """
    Returns the current board state.

    Returns:
      state (dict):
      The state dictionary contains:
        - 'pieces': Current food patch locations as a batch x row x col numpy array.
        - 'scores': The current scores of the critters.
        - 'moves_taken': The number of foraging attempts each critter has made.
        - 'is_over': Flags indicating if the game is over for each board in the batch.
        - 'forager_locs': Dictionary of current locations of the foragers on the board.
        - 'misses_new_patch': List of counts for missed attempts at new patches for each critter.
        - 'misses_known_patch': List of counts for missed attempts at known patches for each critter.
        - 'at_new_patch': List of booleans indicating if each critter is at a new patch.
    """
    state = {'pieces': self.pieces.copy(),
             'scores': self.scores.copy(),
             'moves_taken': self.moves_taken.copy(),
             'is_over': self.is_over.copy(),
             'forager_locs': copy.deepcopy(self.forager_locs),
             'misses_new_patch': self.misses_new_patch.copy(),
             'misses_known_patch': self.misses_known_patch.copy(),
             'at_new_patch': self.at_new_patch.copy()}
    return state


  ################# CORE GAME STATE UPDATE LOGIC ##############################
  ################# execute_moves is main, uses these helper functions ########


  def execute_moves(self, moves, which_critter):
    """
    Execute the moves on the board. A move to the current location implies
    foraging. If foraging, check if foraging is successful, update scores,
    and check if the food goes extinct. If moving to a new location, simply
    update the critter's location.

    Args:
      moves (tuple): A tuple of three arrays:
        - batch_array: Specifies which board in the batch the move corresponds to.
        - row_array: Specifies the target row for each move.
        - col_array: Specifies the target column for each move.
        Each array in the tuple has the same length. A move is represented by
        the combination of a batch index, row index, and column index at the
        same position in their respective arrays.
      which_critter (int): Index to identify the critter. Starts from 1.

    Returns: Nothing, just updates state related attributes of the board object

    """
    #expand moves tuple
    batch_moves, row_moves, col_moves = moves

    # Get current locations of the critter
    current_locs = self.forager_locs[which_critter]

    # Iterate over each board in the batch
    for ii in np.arange(self.batch_size):
      # If the game is over for this board, skip
      if self.is_over[ii]:
        continue

      # Get new location directly from the moves
      new_row = int(row_moves[ii])
      new_col = int(col_moves[ii])

      # Check if the critter has moved to a new patch
      if (new_row, new_col) != (current_locs[1][ii], current_locs[2][ii]):
        # Moved to a new patch
        self.misses_new_patch[ii, which_critter - 1] = 0
        self.misses_known_patch[ii, which_critter - 1] = 0
        self.at_new_patch[ii, which_critter - 1] = True
        if self.moves_cost:
          # in this variant moving also ticks down the clock
          self.moves_taken[ii, which_critter - 1] += 1
      # If the critter's position has not changed, it's trying to forage
      elif (new_row, new_col) == (current_locs[1][ii], current_locs[2][ii]):
        # always increment foraging attempt if foraging
        self.moves_taken[ii, which_critter - 1] += 1
        # Check if there's food at the location
        if self.pieces[ii, new_row, new_col] < 0:
          # Check if foraging is successful
          if self.rng.random() < self.forage_success_prob:
            # Successful foraging, increase critter's score
            self.scores[ii, which_critter - 1] += 1
            # misses are zeroed and no longer at new patch
            self.misses_new_patch[ii, which_critter - 1] = 0
            self.misses_known_patch[ii, which_critter - 1] = 0
            self.at_new_patch[ii, which_critter - 1] = False
            # Check if food goes extinct (only on success)
            if self.rng.random() < self.food_extinct_prob:
              self.pieces[ii, new_row, new_col] = 0  # Set it to empty
          else:
            #unsuccessful foraging at patch with food
            if self.at_new_patch[ii, which_critter - 1]:
              # at a new patch
              self.misses_new_patch[ii, which_critter - 1] += 1
            else:
              # at a known patch
              self.misses_known_patch[ii, which_critter - 1] += 1
        else:
          #unsuccessful foraging at patch without food
            if self.at_new_patch[ii, which_critter - 1]:
              # at a new patch
              self.misses_new_patch[ii, which_critter - 1] += 1
            else:
              # at a known patch
              self.misses_known_patch[ii, which_critter - 1] += 1

      # Always check if session is over, can end by hitting a fixed
      # horizon or by
      if self.moves_taken[ii] >= self.max_moves_taken:
        self.is_over[ii] = True
      elif self.rng.random() < self.end_prob:
        self.is_over[ii] = True

    # assume moves are legal and update locs for whole batch at once
    self.forager_locs[which_critter] = (batch_moves, row_moves, col_moves)

  ###### Getting Legal Moves and Perceptions #########################
  ####################################################################
  def get_neighbor_grc_indices(self, which_critter, radius, pad=False):
    """
    Returns all grid positions within a certain cityblock distance radius from
    the place corresponding to which_critter.

    Args:
        which_critter (int): The idex of the focal critter_food.
        radius (int): The cityblock distance.
        pad (bool): whether or not to pad the array, if padded all row, col
          indexes are valid for the padded array, useful for getting percept
          if not all indexes are correct for the original array, useful for
          figuring out legal moves.

    Returns:
        an array of indices, each row is a g, r, c index for the neighborhoods
        around the critters, can use the g value to know which board you are in.
        if pad=True also returns the padded array (the indices in that case) are
        for the padded array, so won't work on self.pieces, whereas if pad is
        False the indices will be for the offsets in reference to the original
        self.pieces, but note that some of these will be invalid, and will
        need to be filtered out (as we do in get_legal)
    """
    batch_size, n_rows, n_cols = self.pieces.shape
    batch, rows, cols = self.forager_locs[which_critter]
    # Create meshgrid for offsets
    if pad is True:
      padded_arr = np.pad(self.pieces, ((0, 0), (radius, radius),
        (radius, radius)), constant_values=self.ARRAY_PAD_VALUE)
      rows = rows + radius
      cols = cols + radius

    row_offsets, col_offsets = np.meshgrid(
        np.arange(-radius, radius + 1),
        np.arange(-radius, radius + 1),
        indexing='ij')

    # Filter for valid cityblock distances
    mask = np.abs(row_offsets) + np.abs(col_offsets) <= radius
    valid_row_offsets = row_offsets[mask]
    valid_col_offsets = col_offsets[mask]
    # Extend rows and cols dimensions for broadcasting
    extended_rows = rows[:, np.newaxis]
    extended_cols = cols[:, np.newaxis]
    # Compute all neighbors for each position in the batch
    neighbors_rows = extended_rows + valid_row_offsets
    neighbors_cols = extended_cols + valid_col_offsets

    indices = np.column_stack((np.repeat(np.arange(batch_size),
                                         neighbors_rows.shape[1]),
                               neighbors_rows.ravel(),
                               neighbors_cols.ravel()))
    if pad is False:
      return indices
    elif pad is True:
      return indices, padded_arr


  def get_legal_moves(self, which_critter, radius=1):
    """
    Identifies all legal moves for the critter.

    Returns:
      A numpy int array of size batch x 3(g,x,y) x 4(possible moves)

    Note:
      moves[0,1,3] is the x coordinate of the move corresponding to the
      fourth offset on the first board.
      moves[1,:,1] will give the g,x,y triple corresponding to the
      move on the second board and the second offset, actions are integers
    """

    critter_locs = np.array(self.forager_locs[which_critter])
    # turn those row, col offsets into a set of legal offsets
    legal_offsets = self.get_neighbor_grc_indices(which_critter, radius)
    legal_offsets = {tuple(m_) for m_ in legal_offsets}

    legal_destinations = np.where(np.ones(self.pieces.shape, dtype=bool))
    legal_destinations = {tuple(coords) for coords in zip(*legal_destinations)}
    # Add the current locations of the critters to legal_destinations
    current_locations = {tuple(loc) for loc in critter_locs.T}
    legal_destinations = legal_destinations.union(current_locations)

    # legal moves are both legal offsets and legal destinations
    legal_moves = legal_offsets.intersection(legal_destinations)
    return legal_moves


  def get_perceptions(self, critter_food, radius):
    idx, pad_pieces = self.get_neighbor_grc_indices(critter_food,
                                                    radius, pad=True)
    #percept_mask = np.zeros(pad_pieces.shape, dtype=bool)
    #percept_mask[idx[:,0], idx[:,1]], idx[:,2]] = True
    percept = pad_pieces[idx[:,0], idx[:,1], idx[:,2]]
    return(percept.reshape(self.batch_size, -1))

# 1.1.1.1: Initializing Gridworld

Before we introduce an organism with **behaviour** we're going to build an **environment** for them to behave in. To start, this world will consist of a 7 x 7 grid. Let's make a picture of that and see what it looks like.

In [None]:
############################################################################
## TODO for students: Replace ... with the correct arguments(inputs) in the
## make_grid function below to make our grid the right size and shape (7x7).
## The function definition is duplicated here for convenience and hackability,
## but in general you can use the tool tip by hovering over the word make_grid,
## when this is the active cell, to find out how to use the make_grid function.
## You can also use the tool tip to view the source code. How does it work?
## Comment out or remove these next two lines.
raise NotImplementedError(
  "Exercise: make a grid using the make_grid function")
############################################################################


def make_grid(num_rows, num_cols, figsize=(7,6), title=None):
  """Plots an n_rows by n_cols grid with cells centered on integer indices and
  returns fig and ax handles for further use
  Args:
    num_rows (int): number of rows in the grid (vertical dimension)
    num_cols (int): number of cols in the grid (horizontal dimension)

  Returns:
    fig (matplotlib.figure.Figure): figure handle for the grid
    ax: (matplotlib.axes._axes.Axes): axes handle for the grid
  """
  # Create a new figure and axes with given figsize
  fig, ax = plt.subplots(figsize=figsize, layout='constrained')
  # Set width and height padding, remove horizontal and vertical spacing
  fig.get_layout_engine().set(w_pad=4 / 72, h_pad=4 / 72, hspace=0, wspace=0)
  # Show right and top borders (spines) of the plot
  ax.spines[['right', 'top']].set_visible(True)
  # Set major ticks (where grid lines will be) on x and y axes
  ax.set_xticks(np.arange(0, num_cols, 1))
  ax.set_yticks(np.arange(0, num_rows, 1))
  # Set labels for major ticks with font size of 8
  ax.set_xticklabels(np.arange(0, num_cols, 1),fontsize=8)
  ax.set_yticklabels(np.arange(0, num_rows, 1),fontsize=8)
  # Set minor ticks (no grid lines here) to be between major ticks
  ax.set_xticks(np.arange(0.5, num_cols-0.5, 1), minor=True)
  ax.set_yticks(np.arange(0.5, num_rows-0.5, 1), minor=True)
  # Move x-axis ticks to the top of the plot
  ax.xaxis.tick_top()
  # Set grid lines based on minor ticks, make them grey, dashed, and half transparent
  ax.grid(which='minor', color='grey', linestyle='-', linewidth=2, alpha=0.5)
  # Remove minor ticks (not the grid lines)
  ax.tick_params(which='minor', bottom=False, left=False)
  # Set limits of x and y axes
  ax.set_xlim(( -0.5, num_cols-0.5))
  ax.set_ylim(( -0.5, num_rows-0.5))
  # Invert y axis direction
  ax.invert_yaxis()
  # If title is provided, set it as the figure title
  if title is not None:
    fig.suptitle(title)
  # Hide header and footer, disable toolbar and resizing of the figure
  fig.canvas.header_visible = False
  fig.canvas.toolbar_visible = False
  fig.canvas.resizable = False
  fig.canvas.footer_visible = False
  # Redraw the figure with these settings
  fig.canvas.draw()
  # Return figure and axes handles for further customization
  return fig, ax

fig, ax = make_grid(...)
plt.show()

In [None]:
# to_remove solution


def make_grid(num_rows, num_cols, figsize=(7,6), title=None):
  """Plots an n_rows by n_cols grid with cells centered on integer indices and
  returns fig and ax handles for further use
  Args:
    num_rows (int): number of rows in the grid (vertical dimension)
    num_cols (int): number of cols in the grid (horizontal dimension)

  Returns:
    fig (matplotlib.figure.Figure): figure handle for the grid
    ax: (matplotlib.axes._axes.Axes): axes handle for the grid
  """
  # Create a new figure and axes with given figsize
  fig, ax = plt.subplots(figsize=figsize, layout='constrained')
  # Set width and height padding, remove horizontal and vertical spacing
  fig.get_layout_engine().set(w_pad=4 / 72, h_pad=4 / 72, hspace=0, wspace=0)
  # Show right and top borders (spines) of the plot
  ax.spines[['right', 'top']].set_visible(True)
  # Set major ticks (where grid lines will be) on x and y axes
  ax.set_xticks(np.arange(0, num_cols, 1))
  ax.set_yticks(np.arange(0, num_rows, 1))
  # Set labels for major ticks with font size of 8
  ax.set_xticklabels(np.arange(0, num_cols, 1),fontsize=8)
  ax.set_yticklabels(np.arange(0, num_rows, 1),fontsize=8)
  # Set minor ticks (no grid lines here) to be between major ticks
  ax.set_xticks(np.arange(0.5, num_cols-0.5, 1), minor=True)
  ax.set_yticks(np.arange(0.5, num_rows-0.5, 1), minor=True)
  # Move x-axis ticks to the top of the plot
  ax.xaxis.tick_top()
  # Set grid lines based on minor ticks, make them grey, dashed, and half transparent
  ax.grid(which='minor', color='grey', linestyle='-', linewidth=2, alpha=0.5)
  # Remove minor ticks (not the grid lines)
  ax.tick_params(which='minor', bottom=False, left=False)
  # Set limits of x and y axes
  ax.set_xlim(( -0.5, num_cols-0.5))
  ax.set_ylim(( -0.5, num_rows-0.5))
  # Invert y axis direction
  ax.invert_yaxis()
  # If title is provided, set it as the figure title
  if title is not None:
    fig.suptitle(title)
  # Hide header and footer, disable toolbar and resizing of the figure
  fig.canvas.header_visible = False
  fig.canvas.toolbar_visible = False
  fig.canvas.resizable = False
  fig.canvas.footer_visible = False
  # Redraw the figure with these settings
  fig.canvas.draw()
  # Return figure and axes handles for further customization
  return fig, ax

with plt.xkcd():
  fig, ax = make_grid(7, 7)
  plt.show()

***Bonus: change the function definition:***

Tweak the make_grid function in the cell above to make the grid lines green.

Wow, what a boring environment. Let's add an organism and something for that organism to interact with. We'll start with 10 food items scattered randomly throughout the grid, never more than one food item per grid cell. To plot these food items we need their locations. We will set these by randomly sampling grid coordinates [without replacement](## "never picking the same (row,col) coordinate pair twice"). We'll place the organism in the same way and not on a food item to start. (We will use [blue, underlined text](## "example tool tip") to indicate tooltips, i.e. where more information will be provided when the mouse hovers over the text.)

In [None]:
################################################################################
# TODO for students: Replace ... in init_loc(...) to initialize the right
# number of food item locations and critter locations in coordinates that make
# sense for our grid environment. Then replace the ... in rc_plotting[...] to
# index the plotting coordinates for the food locations.
# Hint: The syntax for indexing elements of numpy arrays using [] can be
# confusing at first. If you're lost read the docs,
# https://numpy.org/doc/stable/user/basics.indexing.html and add some code
# cells below to play around with indexing and displaying different sub-arrays
# of the rc_plotting array.
# Comment out or remove this next line.
raise NotImplementedError("Exercise: initialize food and critter locations")
################################################################################


def init_loc(n_rows, n_cols, num, rng=None):
  """
  Samples random 2d grid locations without replacement

  Args:
    n_rows: int, number of rows in the grid
    n_cols: int, number of columns in the grid
    num:    int, number of samples to generate. Should throw an error if num > n_rows x n_cols
    rng:    instance of numpy.random's default rng. Used for reproducibility.

  Returns:
    int_loc: ndarray(int) of shape (num,), flat indices for a 2D grid flattened into 1D
    rc_index: tuple(ndarray(int), ndarray(int)), a pair of arrays with the first giving
      the row indices and the second giving the col indices. Useful for indexing into
      an n_rows by n_cols numpy array.
    rc_plotting: ndarray(int) of shape (num, 2), 2D coordinates suitable for matplotlib plotting
  """
  # If no random number generator given, make one using predefined global SEED
  if rng is None:
    rng = np.random.default_rng(seed=SEED)
  # Choose 'num' unique random indices from a flat
  # 1D array of size n_rows*n_cols
  int_loc = rng.choice(n_rows * n_cols, num, replace=False)
  # Convert flat indices to 2D indices based on shape (n_rows, n_cols)
  rc_index = np.unravel_index(int_loc, (n_rows, n_cols))
  # Transpose indices to get num x 2 array for easy plotting with matplotlib
  rc_plotting = np.array(rc_index).T
  # Return 1D flat indices, 2D indices for numpy array indexing
  # and 2D indices for plotting
  return int_loc, rc_index, rc_plotting

# Create a grid for the plot
fig, ax = make_grid(7, 7)
# Generate 11 unique locations on the grid
int_locs, rc_index, rc_plotting = init_loc(..., ..., ...)
# The first location is for the "critter"
rc_critter = rc_plotting[0]
plot_critter(fig, ax, rc_critter)
# Remaining locations are for "food"
rc_food = rc_plotting[...]
plot_food(fig, ax, rc_food)
# Add legend outside the upper right corner
fig.legend(loc='outside right upper')
plt.show()

In [None]:
#to_remove solution


def init_loc(n_rows, n_cols, num, rng=None):
  """
  Samples random 2d grid locations without replacement

  Args:
    n_rows: int, number of rows in the grid
    n_cols: int, number of columns in the grid
    num:    int, number of samples to generate. Should throw an error if num > n_rows x n_cols
    rng:    instance of numpy.random's default rng. Used for reproducibility.

  Returns:
    int_loc: ndarray(int) of shape (num,), flat indices for a 2D grid flattened into 1D
    rc_index: tuple(ndarray(int), ndarray(int)), a pair of arrays with the first giving
      the row indices and the second giving the col indices. Useful for indexing into
      an n_rows by n_cols numpy array.
    rc_plotting: ndarray(int) of shape (num, 2), 2D coordinates suitable for matplotlib plotting
  """
  # If no random number generator given, make one using predefined global SEED
  if rng is None:
    rng = np.random.default_rng(seed=SEED)
  # Choose 'num' unique random indices from a flat
  # 1D array of size n_rows*n_cols
  int_loc = rng.choice(n_rows * n_cols, num, replace=False)
  # Convert flat indices to 2D indices based on shape (n_rows, n_cols)
  rc_index = np.unravel_index(int_loc, (n_rows, n_cols))
  # Transpose indices to get num x 2 array for easy plotting with matplotlib
  rc_plotting = np.array(rc_index).T
  # Return 1D flat indices, 2D indices for numpy array indexing
  # and 2D indices for plotting
  return int_loc, rc_index, rc_plotting

# Set the drawing style to 'xkcd'
with plt.xkcd():
  # Create a grid for the plot
  fig, ax = make_grid(7, 7)
  # Generate 11 unique locations on the grid
  int_locs, rc_index, rc_plotting = init_loc(7, 7, 11)
  # The first location is for the "critter"
  rc_critter = rc_plotting[0]
  plot_critter(fig, ax, rc_critter)
  # Remaining locations are for "food"
  rc_food = rc_plotting[1:]
  plot_food(fig, ax, rc_food)
  # Add legend outside the upper right corner
  fig.legend(loc='outside right upper')
  plt.show()

In [None]:
# @title Submit your feedback
content_review(f"{feedback_prefix}_M1")

---
# 1.1.1.2: Random Eating

Now that we have an environment scattered with food and an organism, let's introduce some behaviour. The organism drifts around the environment randomly and eats the food it happens to stumble upon. (Can you think of any organisms that employ this strategy? [hint](## "think about the way very very small living things move around")). When food is eaten, the organism gets a **reward**, in this case a *Food Eaten* point, and a new food item appears randomly somewhere else in the environment (that doesn't already have food). Run the code cell below to see what this looks like.

In [None]:
# @title Random Movement
# @markdown Don't worry about how this code works – just **run this cell** then click the start button and watch what happens.

rng = np.random.default_rng(seed=420)
gwg = GridworldGame(batch_size=1, n_rows=7, n_cols=7, num_food=10,
                    max_moves_taken=30, rng=rng)
random_igwg = InteractiveGridworld(gwg, player=None, figsize=(5,4))
display(random_igwg.b_fig.canvas)
clear_output()
display(random_igwg.final_display)

*Question:* When the organism is just drifting around randomly how good is it at eating lots of food, what is its efficiency in terms of food per movement? Now click the start button again and run the simulation a few more times. Does the organism always eat the same amount of food or does it change between simulation runs? [explanation](## "The amount of food eaten varies from simulation run to simulation run,usually the organism manages to eat one or two or three pieces of food, sometimes more
sometimes less.")

*Bonus: see how the effectiveness of a strategy depends on the environment:*

Before we move on it's important to test that our simulation is running as we expect. Randomness can make testing hard, but can be overcome in part by setting up the environment in such a way that the outcome becomes deterministic. In the code cells bellow change how the Gridworld is initialized. By altering the size, shape and number of food items available create a scenario where the organism will always achieve perfect efficiency and a scenario where the organism will fail completely.

We will do this here by either providing food everywhere or nowhere.

In [None]:
###############################################################################
# TODO for students: Replace the ...'s in GridworldGame(...) to initialize a
# Gridworld where the organism is always 100% efficient. Food. Everywhere.
raise NotImplementedError("Exercise: make random movement 100% efficient")
################################################################################

gwg100 = GridworldGame(batch_size=1, n_cols=..., n_rows=..., num_food=...,
                       max_moves_taken=30)
random_igwg_100 = InteractiveGridworld(gwg100, player=None, figsize=(5,4))
display(random_igwg_100.b_fig.canvas)
clear_output()
display(random_igwg_100.final_display)

In [None]:
#to_remove solution
gwg100 = GridworldGame(batch_size=1, n_cols=2, n_rows=2, num_food=3,
                       max_moves_taken=30)
random_igwg_100 = InteractiveGridworld(gwg100, player=None, figsize=(5,4))
display(random_igwg_100.b_fig.canvas)
clear_output()
display(random_igwg_100.final_display)

Ok. We have just seen a super successful (albeit completely dumb) organism. Lets see if we can have an environment where any organism would fail (maybe surprisingly intelligence can not make food out of nothing).

In [None]:
###############################################################################
# TODO for students: Replace the ...'s in GridworldGame(...) to initialize a
# Gridworld where the organism is always 0% efficient.
raise NotImplementedError("Exercise: make random movement 0% efficient")
################################################################################

gwg0 = GridworldGame(batch_size=1, n_cols=..., n_rows=..., num_food=...,
                     max_moves_taken=30)
random_igwg_0 = InteractiveGridworld(gwg0, player=None, figsize=(5,4))
display(random_igwg_0.b_fig.canvas)
clear_output()
display(random_igwg_0.final_display)

In [None]:
#to_remove solution
gwg0 = GridworldGame(batch_size=1, n_cols=2, n_rows=2, num_food=0,
                     max_moves_taken=30)
random_igwg_0 = InteractiveGridworld(gwg0, player=None, figsize=(5,4))
display(random_igwg_0.b_fig.canvas)
clear_output()
display(random_igwg_0.final_display)

In [None]:
# @title Submit your feedback
content_review(f"{feedback_prefix}_M2")

---
# 1.1.1.3: Better Than Random Eating
Now it's your turn to actually control the organism with some level of intelligence (give it your all). Run the next cell and see how much more efficient than random drifting your control of the organism is in terms of food per movement. Does intelligence help?

In [None]:
# @title Controlled Movement
# @markdown Don't worry about how this code works – just **run the cell** and then use the buttons to guide the organism

# user in control
gwg_c = GridworldGame(2, 7, 7, 10, 30,
                    rng=np.random.default_rng(seed=9))
h2h_igwg = Head2HeadGridworld(gwg_c, player0='human',
                              player1=None, figsize=(3,3),
                              )
display(h2h_igwg.b_fig0.canvas)
display(h2h_igwg.b_fig1.canvas)
display(h2h_igwg.b_fig_legend.canvas)
clear_output()
display(h2h_igwg.final_display)

Hopefully your performance was more successful than random flailing (if not, reset to the safe point). Even in this relatively simple and contrived foraging scenario intelligence can help a lot. What kinds of strategies and heuristics did you use to guide your choice of direction? A fundamental purpose of nervous systems and brains is to solve problems of this kind — choosing which actions to take based on environmental inputs to maximize rewards.

In [None]:
# @title Submit your feedback
content_review(f"{feedback_prefix}_M3")

---
# 1.1.1.4: Optimized Eating


Let's welcome a special guest, GW7x7-10-30, from the final chapter of this book. Utilizing a blend of deep reinforcement learning and Monte-Carlo search based on the AlphaZero optimization algorithm, GW7x7-10-30 has achieved mastery of the 7x7 Gridworld environment, with 10 food items and a game duration of 30 rounds.

The AlphaZero optimization algorithm finds inspiration from our understanding of the brain, and also draws upon various concepts from machine learning. As such, our specific computer implementation of the algorithm is unlikely to mirror the learning algorithms used by the brain in any particular detail.

Despite this lack of immediate correspondence, we can still gain significant insight by identifying the generalized form of learning problems that the brain solves together with the classes of optimization algorithms capable of feasibly solving these learning problems subject to biological constraints. These constraints derive from a multitude of factors based on the evolution, ecology, physiology, development, etc, of the organism. These insights will enable us to deduce the most probable types of learning algorithms employed by brains.

Subsequently, these deductions can guide us in seeking out the specific mechanisms and intricate details of the learning algorithms found in brains. Throughout this book, we will focus on introducing the general learning problems encountered by living organisms. We will identify different machine learning techniques that can presently solve these problems under various conditions. Furthermore, we will link the feasible machine learning solutions of these broader learning problems to our current empirical understanding of how a brain might implement similar solutions.

Our aim with this approach is to foster a principled, systematic, and integrative groundwork for neuroscience research. Now, let's run the next code cell to see who is more efficient – you or GW7x7-10-30. Reading this book will empower you to design the next generation of GW7x7-10-30.


In [None]:
# @title Optimized Movement
# @markdown Don't worry about how this code works – **run this cell** to set up the superorganism and an environment for it and you. Note, the superorganism will be **slow** to compute its moves if GPU acceleration is not enabled for this runtime. If possible, in the menu under `Runtime` -> "`Change runtime type`  select `GPU`.

# initialize the game, network, and MonteCarlo player
gwg = GridworldGame(batch_size=1, n_rows=7, n_cols=7, num_food=10,
                    max_moves_taken=30)
pvnetMC = PolicyValueNetwork(gwg)
mcp = MonteCarloBasedPlayer(gwg, pvnetMC, default_depth=3,
                            default_rollouts=80, default_temp=0.02)


#grab the saved model from the repo or where it ends up being hosted
url = "https://raw.githubusercontent.com/dcownden/PerennialProblemsOfLifeWithABrain/main/sequences/P1C1_BehaviourAsPolicy/data/pvnetMC.pth.tar"
r = requests.get(url)

if r.status_code == 200:
  filename = os.path.basename(url)
  # Write the contents to a file in the current working directory
  with open(filename, 'wb') as file:
    file.write(r.content)
    #print(f'{filename} downloaded successfully.')
else:
  print('Error occurred while downloading the file.')

# load the saved model
pvnetMC.load_checkpoint(folder=os.getcwd(), filename='pvnetMC.pth.tar')

# user in control versus mc player
gwg = GridworldGame(2, 7, 7, 10, 30,
                    rng=np.random.default_rng(seed=2000))
h2h_igwg = Head2HeadGridworld(gwg, player0='human', player1=mcp, figsize=(4,4),
                              p0_long_name='The Human',
                              p1_long_name='gw7x7-10-30')
display(h2h_igwg.b_fig0.canvas)
display(h2h_igwg.b_fig1.canvas)
display(h2h_igwg.b_fig_legend.canvas)
clear_output()
display(h2h_igwg.final_display)

Who was more efficient in this environment you or gw7x7-10-30? If gw7x7-10-30 was better, you really have read this book 😉 (If you can't beat the AIs, at least learn how to program them.) Even if you were about as good as gw7x7-10-30 you still might want to read this book. A deep understanding of the optimization processes that shape behaviour in simple organism-environment systems like this one will allow for generalization to more intricate systems, specifically, a rich understanding of how brains generate adaptive behaviour as a result of optimization processes.

In [None]:
# @title Submit your feedback
content_review(f"{feedback_prefix}_M4")

---
# Comprehension Quiz

In [None]:
# @title Quiz
# @markdown **Run this cell** to take the quiz
comprehension_quiz = [
{
  "question": "What does a policy represent in the context of behaviour?",
  "type": "multiple_choice",
  "answers": [
  {
    "answer": "The evolutionary history of an organism",
    "correct": False,
    "feedback": "This is true in a broad and abstract sense, but there is a more precise answer here."
  },
  {
    "answer": "The environment in which an organism lives",
    "correct": False,
    "feedback": "There is a sense in which a policy shaped by evolution can reflect aspects of an organism's environment, but there is a more precise answer here."
  },
  {
    "answer": "The formal description of behaviour as a function that maps experiences to actions",
    "correct": True,
    "feedback": "Correct."
  },
  {
    "answer": "The randomness present in an organism's behavior",
    "correct": False,
    "feedback": "The policy might have randomness in it, but that's not what it is."
  }]
},
{
  "question": "How is a policy evaluated in terms of its goodness?",
  "type": "multiple_choice",
  "answers": [
  {
    "answer": "By integrating rewards and environmental signals into a loss/objective function",
    "correct": True,
    "feedback": "Correct, 'goodness' needs to be formalized in a loss/objective function"
  },
  {
    "answer": "By measuring the organism's fitness in the environment",
    "correct": True,
    "feedback": "This is one important way of evaluating a policy, but there is a more generally correct answer here."
  },
  {
    "answer": "By determining the amount of randomness present in the policy",
    "correct": False,
    "feedback": "Incorrect."
  },
  {
    "answer": "By analyzing the organism's evolutionary adaptations",
    "correct": False,
    "feedback": "Incorrect."
  }]
},
{
  "question": "What is stochasticity in the context of behaviour?",
  "type": "multiple_choice",
  "answers": [
    {
      "answer": "The specific niche an organism occupies within its environment",
      "correct": False,
      "feedback": "Incorrect."
    },
    {
      "answer": "The ability of an organism to adapt to changing environmental conditions",
      "correct": False,
      "feedback": "Incorrect."
    },
    {
      "answer": "The random elements present in both the environment and an organism's behavior",
      "correct": True,
      "feedback": "Correct."
    },
    {
      "answer": "The process of optimizing a policy to achieve better outcomes",
      "correct": False,
      "feedback": "Incorrect."
    }]
  },
  {
    "question": "What is the main difference between random eating and controlled movement in the environment?",
    "type": "multiple_choice",
    "answers": [
    {
      "answer": "Random eating involves unpredictable movements, while controlled movement can be planned and strategic.",
      "correct": True,
      "feedback": "Correct."
    },
    {
      "answer": "Random eating leads to higher efficiency, while controlled movement leads to lower efficiency.",
      "correct": False,
      "feedback": "Incorrect."
    },
    {
      "answer": "Random eating relies on external cues, while controlled movement relies on internal motivations.",
      "correct": False,
      "feedback": "Incorrect."
    },
    {
      "answer": "Random eating results in adaptive behavior, while controlled movement leads to stagnation.",
      "correct": False,
      "feedback": "Incorrect."
    }]
  },
  {
    "question": "What is the significance of GW7x7-10-30 in the context of optimized eating?",
    "type": "multiple_choice",
    "answers": [
    {
      "answer": "It represents a time-traveling superorganism with advanced cognitive abilities.",
      "correct": False,
      "feedback": "Incorrect."
    },
    {
      "answer": "It demonstrates the limitations of optimized behavior in a simple environment.",
      "correct": True,
      "feedback": "It could serve this purpose, though that wasn't the main reason we introduced it here."
    },
    {
      "answer": "It showcases the potential efficiency achievable through optimized behavior.",
      "correct": True,
      "feedback": "Correct."
    },
    {
      "answer": "It serves as a benchmark for comparing different organisms' performance.",
      "correct": True,
      "feedback": "It could serve this purpose, though that wasn't the main reason we introduced it here."
    }]
  }
]


display_quiz(comprehension_quiz)