<a href="https://colab.research.google.com/github/dcownden/PerennialProblemsOfLifeWithABrain/blob/P1C2-Sequence1/sequences/P1C2_OptimizationAndEnvironment/P1C2_Sequence1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The following is a fourth test for an upcoming text book on computational neuroscience from an optimization/learning perspective. The book will start with evolution because ultimately, all aspects of the brain are shaped by evolution and, as we will see, evolution can also be seen as a learning algorithm. We are sharing it now to get feedback on what works and what does not and the developments we should do.

# **Part 1 Behaviour, Environments and Optimization: Evolution and Learning**

### **Animals are adapted to their specific environments; their behaviour is best understood within the context of their evolutionary environment.**

### Objective: Part 1 of the book aims to introduce the fundamental concepts of
* **environment**, the (statistical) properties of where an organism lives
* **behaviour**, the statistics of what the organism does
* **optimization**, how learning and evolution shape an organism's behaviour to make it better suited to its environment

This very much is the core of why we are writing this book: we can view pretty much anything happening in the brain (and biology) as being part of a process that brings about improvement in this sense.

___
# Chapter 1.2 Optimizing Behaviour - Environment is Central

### Objective: In the previous chapter we learned to describe and think about behaviour as a policy and to evaluate an organism's behaviour relative to the goals implied by the organism's environmental niche. In this next chapter we are going to really focus on ***Optimization***, different ways of making a policy better, different ways an organism might become better suited to its niche.

You will learn:
*   How to formally define an objective function. In the last chapter our policy goals were loosely stated as 'eat lots of food'. In this chapter we will better formalize goals, eventually using mathematic and the Markov Decision Process framework, so that these goals can directly guide policy improvement.
*   Different optimization methods. There are many different methods for modifying a policy to improve its performance, and we will begin to explore several of these (guess and check, grid search,  propose and test, evolutionary algorithms).
*   How evolutionary processes - selection applied to variation in population - are a form of optimization.
*   How competition and interaction complicate optimization. Optimization in a strict sense does not directly apply to situations with multiple interacting organisms, each with their own goals that they are simultaneously optimizing for. Ideas from Game Theory and dynamical systems are required to do multi-organism 'optimization' the right way.
*   The conditions where competition and interaction 'average out' and a naive optimization perspective is appropriate. As we will see, for neural systems, a naive optimization perspective, without the complications of Game Theory, is often sufficient.

___
# **Sequence 1.2.1: Optimization makes you better**

### Objective: In this sequence, we will use our 'Parameterized Weights' policy and the solo Gridworld environment to learn about policy optimization. In this sequence you will:
* Mathematically define an objective function.
* Hand tune parameters to optimize that objective function
* Implement and examine a grid search optimization procedure
* Implement and examine a propose and reject optimization procedure.

This is our first foray into real optimization, the foundations of almost everything else in this book. There isn't much explcitly about brains here at the moment, but once we known a bit about the different kinds of optimization that are possible mathematically, we can start to think about which of those are biologically feasible from an evolutionary and neuro-physiological prespetive.

Picture showing the "wiring diagram" of our parameterized weights policy organism, the twelve input receptors, the brain compute is "just" the connective weights, the four direction outputs. We could embedded this in a more 'realistic' neural system, i.e. real-ish nerves distill into some abstract spatial input, and those abstract spatial outputs hook up to some real-ish nerves and muscles) just a thought open to other suggestions. This is just a place holder for our actual image, same as last time)
![TBD_Parameterized_Weights.png](https://github.com/dcownden/PerennialProblemsOfLifeWithABrain/blob/P1C2-Sequence1/sequences/P1C2_OptimizationAndEnvironment/?raw=1##)

# Setup

Run the following cell to setup and install the various dependencies and helper functions for this sequence.

In [1]:
# @title Dependencies, Imports and Setup
# @markdown You don't need to worry about how this code works – but you do need to **run the cell**

!pip install ipympl vibecheck datatops jupyterquiz > /dev/null 2> /dev/null #google.colab

import requests
import numpy as np
import itertools
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import pandas as pd
import ipywidgets as widgets
import time
import logging
import random
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from copy import copy
from scipy.spatial.distance import cdist
from scipy.stats import norm
from scipy.optimize import minimize
from tabulate import tabulate
from IPython.display import display, clear_output, Markdown
from jupyterquiz import display_quiz
from vibecheck import DatatopsContentReviewContainer
from pathlib import Path
from typing import List, Dict
from tqdm.notebook import tqdm


# random seed settings and
# getting torch to use gpu if it's there


def set_seed(seed=None, seed_torch=True):
  """
  Function that controls randomness. NumPy and random modules must be imported.

  Args:
    seed : Integer
      A non-negative integer that defines the random state. Default is `None`.
    seed_torch : Boolean
      If `True` sets the random seed for pytorch tensors, so pytorch module
      must be imported. Default is `True`.

  Returns:
    Nothing.
  """
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)
  if seed_torch:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

  print(f'Random seed {seed} has been set.')


def seed_worker(worker_id):
  """
  DataLoader will reseed workers following randomness in
  multi-process data loading algorithm.

  Args:
    worker_id: integer
      ID of subprocess to seed. 0 means that
      the data will be loaded in the main process
      Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details

  Returns:
    Nothing
  """
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)


def set_device():
  """
  Set the device. CUDA if available, CPU otherwise

  Args:
    None

  Returns:
    Nothing
  """
  device = "cuda" if torch.cuda.is_available() else "cpu"
  if device != "cuda":
    print("This notebook isn't using and doesn't need a GPU. Good.")
  else:
    print("GPU is enabled in this notebook but not needed.")
    print("If possible, in the menu under `Runtime` -> ")
    print("`Change runtime type.`  select `CPU`")

  return device


SEED = 2021
set_seed(seed=SEED)
DEVICE = set_device()


def printmd(string):
  display(Markdown(string))


# the different utility .py files used in this notebook
filenames = ['gw_plotting.py', 'gw_board.py', 'gw_game.py',
             'gw_widgets.py', 'gw_NN_RL.py']
#filenames = []
# just run the code straight out of the response, no local copies needed!
for filename in filenames:
  url = f'https://raw.githubusercontent.com/dcownden/PerennialProblemsOfLifeWithABrain/main/utils/{filename}'
  response = requests.get(url)
  # Check that we got a valid response
  if response.status_code == 200:
    code = response.content.decode()
    exec(code)
  else:
    print(f'Failed to download {url}')

# environment contingent imports
try:
  print('Running in colab')
  from google.colab import output
  output.enable_custom_widget_manager()
  #from google.colab import output as colab_output
  #colab_output.enable_custom_widget_manager()
  IN_COLAB = True
except:
  IN_COLAB = False
  print('Not running in colab')

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%matplotlib widget
plt.style.use("https://raw.githubusercontent.com/dcownden/PerennialProblemsOfLifeWithABrain/main/pplb.mplstyle")
plt.ioff() #need to use plt.show() or display explicitly
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)


def content_review(notebook_section: str):
  return DatatopsContentReviewContainer(
    "",  # No text prompt
    notebook_section,
    {
      "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
      "name": "neuro_book",
      "user_key": "xuk960xj",
    },
  ).render()
feedback_prefix = "P1C2_S1"




################################################################
# refactor Monte Carlo for boards that support multiple critters
################################################################




class MonteCarlo():
  """
  Implementation of Monte Carlo Algorithm
  """


  def __init__(self, game, nnet, default_depth=5, random_seed=None):
    """
    Initialize Monte Carlo Parameters

    Args:
      game: Gridworld Game instance
        Instance of the gridworldGame class above;
      nnet: gridworldNet instance
        Instance of the gridworldNNet class above;
      args: dictionary
        Instantiates number of iterations and episodes, controls temperature threshold, queue length,
        arena, checkpointing, and neural network parameters:
        learning-rate: 0.001, dropout: 0.3, epochs: 10, batch_size: 64,
        num_channels: 512

    Returns:
      Nothing
    """
    self.game = game
    self.nnet = nnet
    self.default_depth = default_depth
    self.rng = np.random.default_rng(seed=random_seed)


  def pis_vs_from_board(self, board, critter):
    #helper function, to put board in canonical form that nn was trained on
    batch_size, n_rows, n_cols = board['pieces'].shape
    co_pieces = board['pieces'].copy()
    this_critter_locs = np.where(co_pieces == critter)
    all_critter_locs = np.where(co_pieces >= 1)
    # other critters are invisible to this player
    co_pieces[all_critter_locs] = 0
    # nnet trained to see self as 1
    co_pieces[this_critter_locs] = 1
    scalar_rounds_left = board['rounds_left'][0]
    co_rounds_left = scalar_rounds_left // self.game.num_critters
    if critter-1 < scalar_rounds_left % self.game.num_critters:
       # add an extra if we haven't had this players turn yet in the round cycle
       co_rounds_left = co_rounds_left + 1
    co_rounds_left = np.array([co_rounds_left]*batch_size)
    pis, vs = self.nnet.predict(co_pieces,
                                board['scores'][:,critter-1],
                                co_rounds_left)
    return pis, vs


  def simulate(self, board, actions, action_indexes, critter=1, depth=None):
    """
    Helper function to simulate one Monte Carlo rollout

    Args:
      board: triple (batch_size x x_size x y_size np.array of board position,
                     scalar of current score,
                     scalar of rounds left
      actions: batch size list/array of integer indexes for moves on each board
      these are assumed to be legal, no check for validity of moves
    Returns:
      temp_v:
        Terminal State
    """
    batch_size, x_size, y_size = board['pieces'].shape
    next_board = self.game.get_next_state(board, critter,
                                          actions, action_indexes)
    # in this version of the mc player, the existence of other players is
    # ignored, in another version of mc other players moves might be simulated
    next_board['active_player'] = critter-1

    if depth is None:
      depth = self.default_depth
    # potentially expand the game tree here,
    # but just do straight rollouts after this
    # doesn't expand to deal with all random food generation possibilities
    # just expands based on the actions given
    expand_bs, _, _ = next_board['pieces'].shape

    for i in range(depth):  # maxDepth
      if next_board['rounds_left'][0] <= 0:
        # check that game isn't over
        # assumes all boards have the same rounds left
        # no rounds left return scores as true values
        terminal_vs = next_board['scores'][:,critter-1].copy()
        return terminal_vs
      else:
        #pis, vs = self.nnet.predict(next_board['pieces'], next_board['scores'], next_board['rounds_left'])
        pis, vs = self.pis_vs_from_board(next_board, critter)
        valids = self.game.get_valid_actions(next_board, critter)
        masked_pis = pis * valids
        sum_pis = np.sum(masked_pis, axis=1)
        probs = np.array(
            [masked_pi / masked_pi.sum() if masked_pi.sum() > 0
             else valid / valid.sum()
             for valid, masked_pi in zip(valids, masked_pis)])
        samp = self.rng.uniform(size = expand_bs).reshape((expand_bs,1))
        sampled_actions = np.argmax(probs.cumsum(axis=1) > samp, axis=1)
      next_board = self.game.get_next_state(next_board, critter,
                                            sampled_actions)
      # in this version of the mc player, existence of other players is ignored
      # in another better version other players moves might be simulated, either
      # as copies of self, or as distinct environmental dynamics
      next_board['active_player'] = critter-1


    pis, vs = self.pis_vs_from_board(next_board, critter)
    #pis, vs = self.nnet.predict(next_board['pieces'], next_board['scores'],
    #                            next_board['rounds_left'])
    #print(vs.shape)
    return vs



# @title plotting functions
#################################################
# More plotting functions
#################################################


def plot_directions(fig, ax, loc_prob_dict, critter, deterministic=False,
                    name=None):
  """
  Plot vector field indicating critter direction probabilities.

  Args:
    fig, ax (matplotlib objects): Figure and axes objects for plotting.
    loc_prob_dict (dict): Dictionary with keys as (row, col) location tuples
      and values as lists of direction probabilities corresponding to the
      directions ['right', 'down', 'left', 'up'].
    critter (int): Identifier for which critter directions are associated with.
    deterministic (bool, optional): If True, the probabilities array is
      converted to 1-hot, and the arrows are plotted at the center of the cell
      and are larger. Defaults to False.
  """

  #looks like direction ignores inverted axis
  direction_vectors = {'right': (1, 0), 'down': (0, -1),
                       'left': (-1, 0), 'up': (0, 1)}
  # but offsets need to be aware of inverted
  direction_offsets = {'right': (0.1, 0), 'down': (0, 0.1),
                       'left': (-0.1, 0), 'up': (0, -0.1)}
  # Offsets for each critter type 1 and 2 to be used together, 0 by itself
  critter_offsets = {0: (0, 0), 1: (-0.05, -0.05), 2: (0.05, 0.05)}
  # same logic for colors
  critter_colors = {0: 'black', 1: 'red', 2: 'blue'}
  # Get the offset and color for this critter
  critter_offset = critter_offsets[critter]
  critter_color = critter_colors[critter]

  # Add legend only if critter is not 0
  custom_leg_handles = []
  if critter != 0:
    if name is None:
      name = f'Critter {critter}'
    legend_patch = mpatches.Patch(color=critter_color, label=name)
    # Add the legend for this critter
    custom_leg_handles.append(legend_patch)

  C, R, U, V, A = [], [], [], [], []

  for loc in loc_prob_dict.keys():
    row, col = loc
    probs = loc_prob_dict[loc]
    for dir_key, prob in probs.items():
      C.append(col + critter_offset[0] + direction_offsets[dir_key][0])
      R.append(row + critter_offset[1] + direction_offsets[dir_key][1])
      U.append(direction_vectors[dir_key][0])
      V.append(direction_vectors[dir_key][1])

      if deterministic:
        A.append(1 if prob == max(probs.values()) else 0)
      else:
        A.append(prob)

  linewidth = 1.5 if deterministic else 0.5
  scale = 15 if deterministic else 30

  ax.quiver(C, R, U, V, alpha=A, color=critter_color,
            scale=scale, linewidth=linewidth)
  return fig, ax, custom_leg_handles

def make_grid(num_rows, num_cols, figsize=(7,6), title=None):
  """Plots an n_rows by n_cols grid with cells centered on integer indices and
  returns fig and ax handles for futher use
  Args:
    num_rows (int): number of rows in the grid (vertical dimension)
    num_cols (int): number of cols in the grid (horizontal dimension)

  Returns:
    fig (matplotlib.figure.Figure): figure handle for the grid
    ax: (matplotlib.axes._axes.Axes): axes handle for the grid
  """
  # Create a new figure and axes with given figsize
  fig, ax = plt.subplots(figsize=figsize, layout='constrained')
  # Set width and height padding, remove horizontal and vertical spacing
  fig.get_layout_engine().set(w_pad=4 / 72, h_pad=4 / 72, hspace=0, wspace=0)
  # Show right and top borders (spines) of the plot
  ax.spines[['right', 'top']].set_visible(True)
  # Set major ticks (where grid lines will be) on x and y axes
  ax.set_xticks(np.arange(0, num_cols, 1))
  ax.set_yticks(np.arange(0, num_rows, 1))
  # Set labels for major ticks with font size of 8
  ax.set_xticklabels(np.arange(0, num_cols, 1),fontsize=8)
  ax.set_yticklabels(np.arange(0, num_rows, 1),fontsize=8)
  # Set minor ticks (no grid lines here) to be between major ticks
  ax.set_xticks(np.arange(0.5, num_cols-0.5, 1), minor=True)
  ax.set_yticks(np.arange(0.5, num_rows-0.5, 1), minor=True)
  # Move x-axis ticks to the top of the plot
  ax.xaxis.tick_top()
  # Set grid lines based on minor ticks, make them grey, dashed, and half transparent
  ax.grid(which='minor', color='grey', linestyle='-', linewidth=2, alpha=0.5)
  # Remove minor ticks (not the grid lines)
  ax.tick_params(which='minor', bottom=False, left=False)
  # Set limits of x and y axes
  ax.set_xlim(( -0.5, num_cols-0.5))
  ax.set_ylim(( -0.5, num_rows-0.5))
  # Invert y axis direction
  ax.invert_yaxis()
  # If title is provided, set it as the figure title
  if title is not None:
    fig.suptitle(title)
  # Hide header and footer, disable toolbar and resizing of the figure
  fig.canvas.header_visible = False
  fig.canvas.toolbar_visible = False
  fig.canvas.resizable = False
  fig.canvas.footer_visible = False
  # Redraw the figure with these settings
  fig.canvas.draw()
  # Return figure and axes handles for further customization
  return fig, ax


def plot_food(fig, ax, rc_food_loc, food=None):
  """
  Plots "food" on a grid implied by the given fig, ax arguments

  Args:
    fig, ax: matplotlib figure and axes objects
    rc_food_loc: ndarry(int) of shape (N:num_food x 2:row,col)
    food: a handle for the existing food matplotlib PatchCollenction object
    if one exists
  Returns:
    a handle for matplotlib PathCollection object of food scatter plot, either
    new if no handle was passed or updated if it was
  """
  # if no PathCollection handle passed in:
  if food is None:
    food = ax.scatter([], [], s=150, marker='o', color='red', label='Food')
  rc_food_loc = np.array(rc_food_loc, dtype=int)
  #matrix indexing convention is is [row-vertical, col-horizontal]
  #plotting indexing convention is (x-horizontal,y-vertical), hence flip
  food.set_offsets(np.fliplr(rc_food_loc))
  return food


def plot_critters(fig, ax, critter_specs: List[Dict[str, object]]) -> List[Dict[str, object]]:
  """
  Plots multiple types of "critters" on a grid implied by the given
  fig, ax arguments.

  Args:
    fig, ax: matplotlib figure and axes objects.
    critter_specs: List of dictionaries with keys 'location', 'name', 'color',
    'marker', 'int_id', 'rc_critter_loc' and optionally 'handle' for each
    critter.

  Returns:
    Updated critter_specs with handles.
  """
  for spec in critter_specs:
    # Ensure required keys are present
    for key in ['marker', 'color', 'name', 'rc_loc']:
      if key not in spec:
        raise ValueError(f"Key '{key}' missing in critter spec.")
    handle_ = spec.get('handle')
    if handle_ is None:
      handle_ = ax.scatter([], [], s=250, marker=spec['marker'],
                           color=spec['color'], label=spec['name'])
    handle_.set_offsets(np.flip(spec['rc_loc']))
    spec.update({'handle': handle_})
  return critter_specs


def plot_critter(fig, ax, rc_critter_loc,
                 critter=None, critter_name='Critter'):
  """
  Plots "critter" on a grid implied by the given fig, ax arguments

  Args:
    fig, ax: matplotlib figure and axes objects
    rc_critter_loc: ndarry(int) of shape (N:num_critters x 2:row,col)
    critter: a handle for the existing food matplotlib PatchCollenction object
    if one exists
  Returns:
    a handle for matplotlib PathCollection object of critter scatter plot,
    either new if no handle was passed in or updated if it was.
  """
  if critter is None:
    critter = ax.scatter([], [], s=250, marker='h',
                         color='blue', label=critter_name)
  # matrix indexing convention is is [row-vertical, col-horizontal]
  # plotting indexing convention is (x-horizontal,y-vertical), hence flip
  critter.set_offsets(np.flip(rc_critter_loc))
  return critter


def plot_fov(fig, ax, rc_critter, n_rows, n_cols, radius, has_fov, fov=None):
  """
  Plots a mask on a grid implied by the given fig, ax arguments

  Args:
    fig, ax: matplotlib figure and axes objects
    rc_critter: ndarry(int) (row,col) of the critter
    mask: a handle for the existing mask matplotlib Image object if one exists
  Returns:
    a handle for matplotlib Image object of mask, either new if no handle
    was passed in or updated if it was.
  """

  # Initialize mask as a semi-transparent overlay for the entire grid
  mask_array = np.ones((n_rows, n_cols, 4))
  mask_array[:, :, :3] = 0.5  # light grey color
  if has_fov == True:
    mask_array[:, :, 3] = 0.5  # 50% opacity
    # Create arrays representing the row and column indices
    rows = np.arange(n_rows)[:, np.newaxis]
    cols = np.arange(n_cols)[np.newaxis, :]
    # Iterate over each critter location
    dist = np.abs(rows - rc_critter[0]) + np.abs(cols - rc_critter[1])
    # Set the region within the specified radius around the critter to transparent
    mask_array[dist <= radius, 3] = 0
  else:
    mask_array[:, :, 3] = 0

  if fov is None:
    fov = ax.imshow(mask_array, origin='lower', zorder=2)
  else:
    fov.set_data(mask_array)

  return fov


def remove_ip_clutter(fig):
  fig.canvas.header_visible = False
  fig.canvas.toolbar_visible = False
  fig.canvas.resizable = False
  fig.canvas.footer_visible = False
  fig.canvas.draw()



# @title GridworldBoard class
#######################################################################
# extend GridworldGame class locally before integrating in shared utils
#######################################################################




class GridworldBoard():
  """
  A collection methods and parameters of a gridworld game board that
  define the logic of the game, and allows for multiple critters on the same
  board

  board state is represented by primarily by pieces, score, and rounds left
  pieces is a batch x n_rows x n_cols numpy array positive integers are critter
  locations 0's are empty space and -1's are food.

  For pieces first dim is batch, second dim row , third is col,
  so pieces[0][1][7] is the square in row 2, in column 8 of the first board in
  the batch of boards.

  scores is a batchsize x num_critters numpy array giving the scores for each
  critter on each board in the batch (note off by one indexing)

  rounds_left is how many rounds are left in the game.

  active_player keeps track of which players turn it is

  Note:
    In 2d np.array first dim is row (vertical), second dim is col (horizontal),
    i.e. top left corner is (0,0), so take care when visualizing/plotting
    as np.array visualization inline with typical tensor notation but at odds
    with conventional plotting where (0,0) is bottom left, first dim, x, is
    horizontal, second dim, y, is vertical
  """


  def __init__(self, batch_size=1,
               n_rows=7, n_cols=7,
               num_critters=2, num_food=10,
               lifetime=30, rng = None):
    """Set the parameters of the game."""
    self.n_rows = n_rows
    self.n_cols = n_cols
    self.batch_size = batch_size
    self.num_critters = num_critters
    self.num_food = num_food
    self.lifetime = lifetime
    if rng is None:
      self.rng = np.random.default_rng(seed=SEED)
    else:
      self.rng = rng


  def init_loc(self, n_rows, n_cols, num, rng=None):
    """
    Samples random 2d grid locations without replacement

    Args:
      n_rows: int, number of rows in the grid
      n_cols: int, number of columns in the grid
      num:    int, number of samples to generate. Should throw an error if num > n_rows x n_cols
      rng:    instance of numpy.random's default rng. Used for reproducibility.

    Returns:
      int_loc: ndarray(int) of shape (num,), flat indices for a 2D grid flattened into 1D
      rc_index: tuple(ndarray(int), ndarray(int)), a pair of arrays with the first giving
        the row indices and the second giving the col indices. Useful for indexing into
        an n_rows by n_cols numpy array.
      rc_plotting: ndarray(int) of shape (num, 2), 2D coordinates suitable for matplotlib plotting
    """

    # Set up default random generator, use the boards default if none explicitly given
    if rng is None:
      rng = self.rng
    # Choose 'num' unique random indices from a flat 1D array of size n_rows*n_cols
    int_loc = rng.choice(n_rows * n_cols, num, replace=False)
    # Convert the flat indices to 2D indices based on the original shape (n_rows, n_cols)
    rc_index = np.unravel_index(int_loc, (n_rows, n_cols))
    # Transpose indices to get num x 2 array for easy plotting with matplotlib
    rc_plotting = np.array(rc_index).T
    # Return 1D flat indices, 2D indices for numpy array indexing and 2D indices for plotting
    return int_loc, rc_index, rc_plotting


  def get_init_board_state(self):
    """Set up starting board using game parameters"""
    #set rounds_left and score
    self.rounds_left = (np.ones(self.batch_size) *
                        self.lifetime * self.num_critters)
    # each players move counts down the clock so making this a multiple of the
    # number of critters ensures every player gets an equal number of turns
    self.scores = np.zeros((self.batch_size, self.num_critters))
    # create an empty board array.
    self.pieces = np.zeros((self.batch_size, self.n_rows, self.n_cols))
    # Place critter and initial food items on the board randomly
    for ii in np.arange(self.batch_size):
      # num_food+num_critter because we want critter and food locations
      int_loc, rc_idx, rc_plot = self.init_loc(
        self.n_rows, self.n_cols, self.num_food+self.num_critters)
      # critter random start locations
      for c_ in np.arange(self.num_critters):
        self.pieces[(ii, rc_idx[0][c_], rc_idx[1][c_])] = c_ + 1
      # food random start locations
      self.pieces[(ii, rc_idx[0][self.num_critters:],
                   rc_idx[1][self.num_critters:])] = -1
    self.active_player = 0
    state = {'pieces': self.pieces.copy(),
             'scores': self.scores.copy(),
             'rounds_left': self.rounds_left.copy(),
             'active_player': copy(self.active_player)}
    return state


  def set_state(self, board):
    """ board is dictionary giving game state a triple of np arrays
      pieces:        numpy array (batch_size x n_rows x n_cols),
      scores:        numpy array (batch_size x num_critters)
      rounds_left:   numpy array (batch_size)
      active_player: int
    """
    self.pieces = board['pieces'].copy()
    self.scores = board['scores'].copy()
    self.rounds_left = board['rounds_left'].copy()
    self.active_player = copy(board['active_player'])


  def get_state(self):
    """ returns a board state, which is a triple of np arrays
    pieces,       - batch_size x n_rows x n_cols
    scores,       - batch_size
    rounds_left   - batch_size
    """
    state = {'pieces': self.pieces.copy(),
             'scores': self.scores.copy(),
             'rounds_left': self.rounds_left.copy(),
             'active_player': copy(self.active_player)}
    return state


  def __getitem__(self, index):
    return self.pieces[index]


  def execute_moves(self, moves, critter):
    """
    Updates the state of the board given the moves made.

    Args:
      moves: a 3-tuple of 1-d arrays each of length batch_size,
        the first array gives the specific board within the batch,
        the second array in the tuple gives the new row coord for each critter
        on each board and the third gives the new col coord.

    Notes:
      Assumes that there is exactly one valid move for each board in the
      batch of boards for the critter type given. i.e. it does't check for
      bounce/reflection on edges or with other critters, or for multiple move
      made on the same board. It only checks for eating food and adds new food
      when appropriate. Invalid moves could lead to illegal teleporting
      behavior, critter duplication, or index out of range errors.

      Currently just prints a message if critter making the move is not the
      active player, could enforce this more strictly if needed.
    """
    if critter-1 != self.active_player:
      # note critter is [1 to num_critter] inclusive so that it can be used
      # directly in where statements on pieces but self.active_player is
      # [0 to numcritter-1] inclusive so that it can be used directly in
      # indexing player lists
      raise ValueError("Warning! The critter moving is not the expected active player")
    #critters leave their spots
    self.pieces[self.pieces==critter] = 0
    #which critters have food in their new spots
    eats_food = self.pieces[moves] == -1
    # some critters eat and their scores go up
    # note critter is +int so need to -1 for indexing
    self.scores[:,critter-1] = self.scores[:,critter-1] + eats_food

    num_empty_after_eat = (self.n_rows*self.n_cols - self.num_food -
                           self.num_critters + 1) # +1 for the food just eaten
    # which boards in the batch had eating happen
    g_eating = np.where(eats_food)[0]
    # put critters in new positions
    self.pieces[moves] = critter
    if np.any(eats_food):
      # add random food to replace what is eaten
      possible_new_locs = np.where(np.logical_and(
          self.pieces == 0, #the spot is empty
          eats_food.reshape(self.batch_size, 1, 1))) #food eaten on that board
      food_sample_ = self.rng.choice(num_empty_after_eat,
                                     size=np.sum(eats_food))
      food_sample = food_sample_ + np.arange(len(g_eating))*num_empty_after_eat
      assert np.all(self.pieces[(possible_new_locs[0][food_sample],
                                 possible_new_locs[1][food_sample],
                                 possible_new_locs[2][food_sample])] == 0)
      #put new food on the board
      self.pieces[(possible_new_locs[0][food_sample],
                   possible_new_locs[1][food_sample],
                   possible_new_locs[2][food_sample])] = -1
    self.rounds_left = self.rounds_left - 1
    if not np.all(self.pieces.sum(axis=(1,2)) ==
                  ((self.num_food * -1) + np.sum(np.arange(self.num_critters)+1))):
      print(self.pieces.sum(axis=(1,2)))
      print(((self.num_food * -1) + np.sum(np.arange(self.num_critters)+1)))
    assert np.all(self.pieces.sum(axis=(1,2)) ==
                  ((self.num_food * -1) + np.sum(np.arange(self.num_critters)+1)))
    # next player's turn
    self.active_player = (self.active_player + 1) % (self.num_critters)


  def get_legal_moves(self, critter):
    """
    Identifies all legal moves for the critter, taking into acount
    bouncing/reflection at edges,

    Returns:
      A numpy int array of size batch x 3(g,x,y) x 4(possible moves)

    Note:
      moves[0,1,3] is the x coordinate of the move corresponding to the
      fourth offstet on the first board.
      moves[1,:,1] will give the g,x,y triple corresponding to the
      move on the second board and the second offset, actions are integers
    """
    # can only move one cell down, up, right, and left from current location
    critter_locs = np.array(np.where(self.pieces == critter))
    legal_offsets = np.stack([
      critter_locs + np.array([np.array([0,  1, 0])]*self.batch_size).T,
      critter_locs + np.array([np.array([0, -1, 0])]*self.batch_size).T,
      critter_locs + np.array([np.array([0, 0,  1])]*self.batch_size).T,
      critter_locs + np.array([np.array([0, 0, -1])]*self.batch_size).T])
    legal_offsets = np.vstack(np.transpose(legal_offsets, (0, 2, 1)))
    legal_offsets = set([tuple(m_) for m_ in legal_offsets])
    # must land on the board and not on another critter
    legal_destinations = np.where(self.pieces <= 0)
    legal_destinations = set([(g, r, c) for
                              g, r, c in zip(*legal_destinations)])
    # legal moves satisfy both these conditions
    legal_moves = legal_offsets.intersection(legal_destinations)
    return legal_moves


  def get_perceptions(self, radius, critter):
    """
    Generates a vector representation of the critter perceptions, oriented
    around the critter.

    Args:
      radius: int, how many grid squared the critter can see around it
        using L1  (Manhattan/cityblock) distance

    Returns:
      A batch_size x 2*radius*(radius+1) + 1, giving the values
      of the percept reading left to right, top to bottom over the board,
      for each board in the batch
    """
    # define the L1 ball mask
    diameter = radius*2+1
    mask = np.zeros((diameter, diameter), dtype=bool)
    mask_coords = np.array([(i-radius, j-radius)
      for i in range(diameter)
        for j in range(diameter)])
    mask_distances = cdist(mask_coords, [[0, 0]],
                           'cityblock').reshape(mask.shape)
    mask[mask_distances <= radius] = True
    mask[radius,radius] = False  # exclude the center

    # pad the array
    padded_arr = np.pad(self.pieces, ((0, 0), (radius, radius),
     (radius, radius)), constant_values=-2)

    # get locations of critters
    critter_locs = np.argwhere(padded_arr == critter)

    percepts = []
    for critter_loc in critter_locs:
      b, r, c = critter_loc
      surrounding = padded_arr[b, r-radius:r+radius+1, c-radius:c+radius+1]
      percept = surrounding[mask]
      percepts.append(percept)
    return(np.array(percepts))




# @title GridworldGame class
#######################################################################
# extend GridworldGame class locally before integrating in shared utils
#######################################################################



class GridworldGame():
  """
  A collection methods and parameters of a gridworld game that allow
  for interaction with and display of GridwordlBoard objects.
  """


  def __init__(self, batch_size=1, n_rows=7, n_cols=7,
               num_critters=2, num_food=10,
               lifetime=30, rng=None):
    """
    Initializes an instance of the class with the specified parameters.

    Args:
      batch_size (int, optional): Number of instances in a batch. Default is 1.
      n_rows (int, optional): Number of rows in the grid. Default is 7.
      n_cols (int, optional): Number of columns in the grid. Default is 7.
      num_critters (int, optional): Number of different agents running around
        on each board in the batch. Default is 2.
      num_food (int, optional): Number of food items. Default is 10.
      lifetime (int, optional): Time before critter's life ends, in terms of
        time steps. Default is 30.
      rng (numpy random number generator, optional): Random number generator
        for reproducibility. If None, uses default RNG with a preset seed.
    """

    # Check for positive integer inputs
    assert all(isinstance(i, int) and i >= 0
               for i in [batch_size, n_rows, n_cols, num_critters, num_food,
                         lifetime]), "All inputs must be non-negative integers."
    self.batch_size = batch_size
    self.n_rows = n_rows
    self.n_cols = n_cols
    self.num_critters = num_critters
    # Check for num_food exceeding maximum possible value
    max_food = n_rows * n_cols - num_critters
    if num_food > max_food:
      print(f'num_food is too large, setting it to maximum possible value: {max_food}')
      num_food = max_food
    self.num_food = num_food
    self.lifetime = lifetime
    # Set up random number generator
    if rng is None:
      self.rng = np.random.default_rng(seed=SEED)
    else:
      self.rng = rng


  def get_init_board(self):
    """
    Generates a starting board given the parameters of the game.
    Returns a tuple giving current state of the game
    """
    # current score, and rounds left in the episode
    b = GridworldBoard(batch_size=self.batch_size, n_rows=self.n_rows,
                       n_cols=self.n_cols, num_critters=self.num_critters,
                       num_food=self.num_food, lifetime=self.lifetime,
                       rng=self.rng)
    return b.get_init_board_state()


  def get_board_size(self):
    """Shape of a single board, doesn't give batch size"""
    return (self.n_rows, self.n_cols)


  def get_action_size(self):
    """
    Returns the number of all possible actions, even though only  2-4 of
    these will ever be valid on a given turn.
    Actions correspond to integer indexes of board locations,
    moves to g,r,c coordinate indexes of board locations
    """
    return self.n_rows * self.n_cols


  def get_batch_size(self):
    """
    Returns the number of actions, only 0-4 of these will ever be valid.
    Actions correspond to integer indexes of board locations,
    moves to r,c indexes of board locations
    """
    return self.batch_size


  def string_rep(self, board, g=0):
    """ A bytestring representation board g's state in the batch of boards"""
    return (board['pieces'][g].tobytes() + board['scores'][g].tobytes() +
            board['rounds_left'][g].tobytes())


  def get_square_symbol(self, piece):
    """ Translate integer piece value to symbol for display"""
    if piece == -1:
      return "X"
    elif piece == 0:
      return "-"
    elif piece >= 1:
      return "0"
    else:
      return "???????????????????????????"


  def string_rep_readable(self, board, g=0):
    """ A human readable representation of g-th board's state in the batch"""
    board_s = "".join([self.get_square_symbol(square)
                        for row in board['pieces'][g]
                          for square in row])
    board_s = board_s + '_' + str(board['scores'][g])
    board_s = board_s + '_' + str(board['rounds_left'][g])
    return board_s


  def get_scores(self, board):
    return board['scores'].copy()


  def get_rounds_left(self, board):
    return board['rounds_left'].copy()


  def display(self, board, g=0):
    """Displays the g-th games in the batch of boards"""
    print("   ", end="")
    for c_ in range(self.n_cols):
      print(c_, end=" ")
    print("")
    print("-----------------------")
    for c_ in range(self.n_cols):
      print(c_, "|", end="")    # Print the row
      for r_ in range(self.n_rows):
        piece = board['pieces'][g,c_,r_]    # Get the piece to print
        #print(piece)
        print(self.get_square_symbol(piece), end=" ")
      print("|")
    print("-----------------------")
    print("Rounds Left: " + str(board['rounds_left'][g]))
    print("Score: " + str(board['scores'][g]))


  def get_critter_rc(self, board, g, critter_index):
    return np.squeeze(np.array(np.where(board['pieces'][g]==critter_index)))


  def plot_moves(self, board, player0, g=0, player1=None,
                 fig=None, ax=None, p0_name='Player 0', p1_name='Player 1',
                 figsize=(6,5), critter_name='Critter', title=None,
                 deterministic=False):
    """
    Uses plotting functions to make picture of the current board state, and what
    a critter would do at each non-food location in the current board state
    """
    def make_prob_dict(critter_locs, play):
      offset_dict = {(0, 1): 'right',
                     (0,-1): 'left',
                     ( 1, 0): 'down',
                     (-1, 0): 'up'}
      index_probs = play[2].copy()
      loc_prob_dict = {}
      # for each non food locations
      for g, loc_ in enumerate(critter_locs):
        # this is the location as an r, c tuple
        rc_tup = tuple((loc_[1], loc_[2]))
        # the relevant probabilities
        raw_probs = index_probs[g]
        probs = raw_probs[raw_probs > 0]
        indexes = np.argwhere(raw_probs > 0)
        # turn the probability indexes into r, c coords
        rows = np.floor_divide(indexes, gwg.n_cols)
        cols = np.remainder(indexes, gwg.n_cols)
        moves = np.squeeze(np.array([z for z in zip(rows, cols)]), axis=2)
        #compute the offsets and turn them to strings
        offsets = moves - loc_[1:]
        str_offsets = np.array(list(map(offset_dict.get, map(tuple, offsets))))
        # update the loc_prob_dict for plotting
        prob_dict = dict(zip(str_offsets, probs))
        loc_prob_dict.update({rc_tup: prob_dict})
      return loc_prob_dict

    batch_size, n_rows, n_cols = board['pieces'].shape
    plt.ioff()
    if fig is None and ax is None:
      fig, ax = make_grid(n_rows, n_cols, figsize=figsize, title=title)

    rc_food_index = np.array(np.where(board['pieces'][g] == -1))
    rc_food_plotting = np.array(rc_food_index).T
    food = plot_food(fig, ax, rc_food_plotting)

    expanded_board = self.critter_everywhere_state_expansion(
      board, player0.critter_index, to_expand=g)
    critter_locs = np.argwhere(expanded_board['pieces']==player0.critter_index)
    #play the expanded state
    p0_play = player0.play(expanded_board)
    #get the prob dict
    p0_loc_prob_dict = make_prob_dict(critter_locs, p0_play)
    # same for player1 if there is one
    if player1 is not None:
      p1_play = player1.play(expanded_board)
      p1_loc_prob_dict = make_prob_dict(critter_locs, p1_play)

    existing_handels, _ = ax.get_legend_handles_labels()
    if player1 is None:
      fig, ax, leg_handles_0 = plot_directions(fig, ax, p0_loc_prob_dict,
        critter=0, deterministic=deterministic)
      leg_handles = existing_handels
    else:
      fig, ax, leg_handles_0 = plot_directions(fig, ax, p0_loc_prob_dict,
        critter=1, deterministic=deterministic, name=p0_name)
      fig, ax, leg_handles_1 = plot_directions(fig, ax, p1_loc_prob_dict,
        critter=2, deterministic=deterministic, name=p1_name)
      leg_handles = existing_handels + leg_handles_0 + leg_handles_1

    fig.legend(handles=leg_handles, loc="outside right upper")
    fig.canvas.draw()
    return fig, ax


  def plot_board(self, board, g=0,
                 fig=None, ax=None, critter_specs=None, food=None, fov=None,
                 legend_type='included',
                 has_fov=False, #fog_of_war feild_of_view
                 radius=2, figsize=(6,5), title=None,
                 name='Critter'):
    """Uses plotting functions to make picture of the current board state"""
    batch_size, n_rows, n_cols = board['pieces'].shape
    plt.ioff()
    if fig is None and ax is None:
      fig, ax = make_grid(n_rows, n_cols, figsize=figsize, title=title)

    # generate critter plotting specs if we don't already have them
    if critter_specs is None:
      critter_specs = []
      markers = ['h', 'd']  # hexagon and diamond
      colors = sns.color_palette("colorblind")
      for i in range(self.num_critters):
        critter_name = name if self.num_critters == 1 else f'{name} {i+1}'
        spec = {'marker': markers[i % len(markers)],
                'color': colors[i // len(markers) % len(colors)],
                'name': critter_name,
                'int_id': i+1}
        critter_specs.append(spec)
    # get critter locs and plot them
    assert len(critter_specs) == self.num_critters, "More/fewer specs than critters"
    for spec in critter_specs:
      rc_loc = np.array(np.where(board['pieces'][g] == spec['int_id'])).T
      spec.update({'rc_loc': rc_loc})
    critter_specs = plot_critters(fig, ax, critter_specs)

    # get food locs and plot them
    rc_food_index = np.array(np.where(board['pieces'][g] == -1))
    rc_food_plotting = np.array(rc_food_index).T
    if food is None:
      food = plot_food(fig, ax, rc_food_plotting)
    else:
      food = plot_food(fig, ax, rc_food_plotting, food)

    #plot field of view if doing that
    if has_fov:
      # will need to think about how to do this for multiple
      # critters, currently just use rc of first critter in the spec list
      if fov is None:
        fov = plot_fov(fig, ax, critter_specs[0]['rc_loc'], n_rows, n_cols,
                       radius, has_fov)
      else:
        fov = plot_fov(fig, ax, critter_specs[0]['rc_loc'], n_rows, n_cols,
                       radius, has_fov, fov)
    # make legend and draw and return figure
    if legend_type == 'included':
      fig.legend(loc = "outside right upper", markerscale=0.8)
      fig.canvas.draw()
      return fig, ax, critter_specs, food, fov
    elif legend_type == 'separate':
      fig_legend, ax_legend = plt.subplots(figsize=(1.5,1.5), layout='constrained')
      fig_legend.get_layout_engine().set(w_pad=0, h_pad=0, hspace=0, wspace=0)
      handles, labels = ax.get_legend_handles_labels()
      ax_legend.legend(handles, labels, loc='center', markerscale=0.8)
      ax_legend.axis('off')
      fig_legend.canvas.header_visible = False
      fig_legend.canvas.toolbar_visible = False
      fig_legend.canvas.resizable = False
      fig_legend.canvas.footer_visible = False
      fig_legend.canvas.draw()
      return fig, ax, critter_specs, food, fov, fig_legend, ax_legend
    else: #no legend
      fig.canvas.draw()
      return fig, ax, critter_specs, food, fov


  def get_valid_actions(self, board, critter):
    """
    A Helper function to translate the g,x,y, tuples provided the
    GridworldBoard.get_legal_moves method into valid actions, represented
    as binary vectors of len num_actions.

    Args:
      board: a triple of np arrays representing board state
        pieces,       - batch_size x n_rows x n_cols
        scores,       - batch_size
        rounds_left   - batch_size

    Returns:
      valids: np.ndarray(binary) batch_size x num_actions, 1's represent
              valid moves
    """
    batch_size, n_rows, n_cols = board['pieces'].shape
    b = GridworldBoard(batch_size=batch_size, n_rows=n_rows,
                       n_cols=n_cols, num_critters=self.num_critters,
                       num_food=self.num_food, lifetime=self.lifetime,
                       rng=self.rng)
    b.set_state(board)
    legal_moves =  b.get_legal_moves(critter)
    valids = np.zeros((batch_size, n_rows * n_cols))
    for g, r, c in legal_moves:
      valids[g, r * n_cols + c] = 1
    return valids


  def display_moves(self, board, critter=1, g=0):
    """Displays possible moves for the g-th games in the batch of boards"""
    batch_size, n_rows, n_cols = board['pieces'].shape
    A=np.reshape(self.get_valid_actions(board, critter)[g],
                 (n_rows, n_cols))
    print("  ")
    print("possible moves")
    print("   ", end="")
    for col in range(self.n_cols):
      print(col, end=" ")
    print("")
    print("-----------------------")
    for col in range(self.n_cols):
      print(col, "|", end="")    # Print the row
      for row in range(self.n_rows):
        piece = A[col][row]    # Get the piece to print
        print(self.get_square_symbol(piece), end=" ")
      print("|")
    print("-----------------------")


  def get_perceptions(self, board, radius, critter):
    batch_size, n_rows, n_cols = board['pieces'].shape
    b = GridworldBoard(batch_size=batch_size, n_rows=n_rows,
                       n_cols=n_cols, num_critters=self.num_critters,
                       num_food=self.num_food, lifetime=self.lifetime,
                       rng=self.rng)
    b.set_state(board)
    return(b.get_perceptions(radius, critter))


  def get_next_state(self, board, critter, actions, a_indx=None):
    """
    Helper function using GridworldBoard.execute_moves to update board state
    given actions on a batch of boards, for a given critter

    Args:
      board: a triple of np arrays representing board state
        pieces,       - batch_size x n_rows x n_cols
        scores,       - batch_size
        rounds_left   - batch_size
      critter: integer index of the critter type
      actions: list of flat integer indexes of critter's new board positions
      a_indx: list of integer indexes indicating which actions are being taken
        on which boards in the batch

    Returns:
      a board triple signifying next state

    Note:
      if len(actions) > batch_size of board the returned board state will have
      an expanded batch size, allowing multiple paths in the game tree to be
      explored in parallel

    """
    batch_size, n_rows, n_cols = board['pieces'].shape
    if board['rounds_left'][0] <= 0:
      # assumes all boards in the batch have the same rounds left
      # no rounds left return the board unchanged
      return board
    else:
      moves = self.actions_to_moves(actions)
      b = GridworldBoard(batch_size=len(actions), n_rows=n_rows,
                         n_cols=n_cols, num_critters=self.num_critters,
                         num_food=self.num_food, lifetime=self.lifetime,
                         rng=self.rng)
      if a_indx is None:
        # just one move on each board in the batch
        assert batch_size == len(actions)
        b.set_state(board)
      else:
        # potentially multiple moves on each board, expand the batch
        assert len(actions) == len(a_indx)
        new_pieces = np.array([board['pieces'][ai].copy() for ai in a_indx])
        new_scores = np.array([board['scores'][ai].copy() for ai in a_indx])
        new_rounds_left = np.array([board['rounds_left'][ai].copy() for ai in a_indx])
        new_active_player = copy(board['active_player'])
        new_state = {'pieces': new_pieces,
                     'scores': new_scores,
                     'rounds_left': new_rounds_left,
                     'active_player': new_active_player}
        b.set_state(new_state)
      b.execute_moves(moves, critter)
      return b.get_state()


  def actions_to_moves(self, actions):
    """
    Args:
      actions: a batch length list of integer indexes for the flattened boards,
      i.e. in the range(n_cols * n_rows) actions are often much easier to use
      as training targets for NN based RL agents.
    Returns
      moves: a 3-tuple of 1-d arrays each of length batch_size,
        the first array gives the specific board within the batch,
        the second array in the tuple gives the new row coord for each critter
        on each board and the third gives the new col coord. Note this is
        exactly the format expected by GridworldBoard.execute_moves, and
        is a canonical way of indexing arrays for quick numpy operations.
    """
    moves = (np.arange(len(actions)),
             np.floor_divide(actions, self.n_cols),
             np.remainder(actions, self.n_cols))
    return moves


  def moves_to_actions(self, moves):
    """
    Args:
      moves: a 3-tuple of 1-d arrays each of length batch_size,
        the first array gives the specific board within the batch,
        the second array in the tuple gives the new row coord for each critter
        on each board and the third gives the new col coord. Note this is
        exactly the format expected by GridworldBoard.execute_moves, and
        is a canonical way of indexing arrays for quick numpy operations.
    Returns:
      actions: a batch length list of integer indexes for the flattened boards,
      i.e. in the range(n_cols * n_rows) actions are often much easier to use
      as training targets for NN based RL agents.
    """
    _, rows, cols = moves
    actions = rows * self.n_cols + cols
    return actions


  def critter_oriented_get_next_state(self, board, critter, offsets):
    """
    Translates directions in reference to the critter's location into
    moves on the board in absolute terms, while checking for
    bouncing/reflecting then get's the next state.

    Args:
      board: a triple of np arrays representing board state
        pieces,       - batch_size x n_rows x n_cols
        scores,       - batch_size
        rounds_left   - batch_size
      offsets: batch length list of strings one 'up', 'down', 'left', 'right'

    Returns:
      a board triple signifying next state

    Note:
      Unlike get_next_state, this method does not allow for expansion
      of the game tree, i.e. len(offsets)==batch_size required
    """
    assert len(offsets) == board['pieces'].shape[0]
    batch_size, n_rows, n_cols = board['pieces'].shape
    b = GridworldBoard(batch_size=batch_size, n_rows=n_rows,
                       n_cols=n_cols, num_critters=self.num_critters,
                       num_food=self.num_food, lifetime=self.lifetime,
                       rng=self.rng)
    b.set_state(board)
    moves = self.critter_direction_to_move(board, offsets, critter)
    b.execute_moves(moves, critter)
    return(b.get_state())


  def critter_direction_to_move(self, board, offsets, critter):
    """
    Translates directions in reference to the critter's location into
    moves on the board in absolute terms, while checking for
    bouncing/reflecting then returns moves. Doesn't check for collisions with
    other critters though. In general player's move methods should be checking
    valid moves and only making legal ones.

    Args:
      board: dict of np arrays representing board state
        'pieces':       batch_size x n_rows x n_cols
        'scores':       batch_size
        'rounds_left':  batch_size
      offsets: batch length list of strings,
        one of 'up', 'down', 'left', 'right'
      critter: integer index for the critter we want moves for

    Returns:
      moves: a 3-tuple of 1-d arrays each of length batch_size,
        the first array gives the specific board within the batch,
        the second array in the tuple gives the new row coord for each critter
        on each board and the third gives the new col coord. Note this is
        exactly the format expected by GridworldBoard.execute_moves, and
        is a canonical way of indexing arrays for numpy.

    Note:
      Unlike get_next_state, this method does not allow for expansion
      of the game tree, i.e. len(offsets)==batch_size required
    """
    assert len(offsets) == board['pieces'].shape[0]
    batch_size, n_rows, n_cols = board['pieces'].shape
    offset_dict = {'up': (0, -1, 0),
                   'down': (0, 1, 0),
                   'left': (0, 0, -1),
                   'right': (0, 0, 1),
                   'still': (0, 0, 0)}
    this_critter_locs = np.where(board['pieces'] == critter)
    all_critter_locs = np.where(board['pieces'] >= 1)
    offsets_array = np.hstack([np.array(offset_dict[ost_]).reshape((3,1))
                           for ost_ in offsets])
    new_locs = np.array(this_critter_locs) + offsets_array
    #check bounces at boundaries
    new_locs[1,:] = np.where(new_locs[1] >=
                               n_rows, n_rows-2, new_locs[1])
    new_locs[2,:] = np.where(new_locs[2,:] >=
                               n_cols, n_cols-2, new_locs[2,:])
    new_locs[1,:] = np.where(new_locs[1,:] < 0, 1, new_locs[1,:])
    new_locs[2,:] = np.where(new_locs[2,:] < 0, 1, new_locs[2,:])
    moves = tuple(new_locs)
    return moves


  def direction_probs_to_flat_probs(self, board, direction_probs, critter):
    """
    Converts direction probabilities in reference to the critter's location into
    probability arrays on the flattened board.

    Args:
      board: dict of np arrays representing board state
        'pieces':       batch_size x n_rows x n_cols
        'scores':       batch_size
        'rounds_left':  batch_size
      direction_probs: batch length list of dictionaries with keys
        ['up', 'down', 'left', 'right'] and corresponding probabilities.

    Returns:
      probs_arrays: list of arrays, where each array is of length n_rows*n_cols
                    and represents the flattened probability distribution for
                    board in the batch.
    """
    batch_size, n_rows, n_cols = board['pieces'].shape
    offset_dict = {
        'up': np.array((0, -1, 0)),
        'down': np.array((0, 1, 0)),
        'left': np.array((0, 0, -1)),
        'right': np.array((0, 0, 1))}
    critter_locs = np.where(board['pieces'] == critter)
    probs_arrays = np.zeros((batch_size, n_rows * n_cols))
    for batch_index in range(batch_size):
      prob_array = np.zeros(n_rows * n_cols)
      for direction, prob in direction_probs[batch_index].items():
          offset = offset_dict[direction]
          new_loc = np.array(critter_locs)[:, batch_index] + offset
          # Check bounces at boundaries
          new_loc[1] = np.where(new_loc[1] >= n_rows, n_rows-2, new_loc[1])
          new_loc[2] = np.where(new_loc[2] >= n_cols, n_cols-2, new_loc[2])
          new_loc[1] = np.where(new_loc[1] < 0, 1, new_loc[1])
          new_loc[2] = np.where(new_loc[2] < 0, 1, new_loc[2])
          # Convert 2D location to flattened index
          flattened_index = new_loc[1] * n_cols + new_loc[2]
          prob_array[flattened_index] += prob
      probs_arrays[batch_index, :] = prob_array
    return list(probs_arrays)


  def action_to_critter_direction(self, board, critter, actions):
    """
    Translates an integer index action into up/down/left/right

    Args:
      board: a triple of np arrays representing board state
        pieces,       - batch_size x n_rows x n_cols
        scores,       - batch_size
        rounds_left   - batch_size
      actions: a batch size ndarry of integer indexes for actions on each board

    Returns:
      offsets: a batch length list of strings 'up', 'down', 'left', 'right'
    """
    offset_dict = {(0, 0, 1): 'right',
                   (0, 0,-1): 'left',
                   (0, 1, 0): 'down',
                   (0,-1, 0): 'up'}
    batch_size, n_rows, n_cols = board['pieces'].shape
    critter_locs = np.where(board['pieces'] == critter)
    moves = (np.arange(len(actions)),
               np.floor_divide(actions, n_cols),
               np.remainder(actions, n_cols))
    # need to reverse this from above, moves is equiv to new_locs
    # new_locs = np.array(critter_locs) + offsets_array
    offsets_array = np.array(moves) - np.array(critter_locs)
    offsets = [offset_dict[tuple(o_)] for o_ in offsets_array.T]
    return offsets


  def get_valid_directions(self, board, critter):
    """
    Transforms output of get_valid_actions to a list of the valid directions
    for each board in the batch for a given critter.
    """
    offset_dict = {( 0, 1): 'right',
                   ( 0,-1): 'left',
                   ( 1, 0): 'down',
                   (-1, 0): 'up'}
    batch_size, n_rows, n_cols = board['pieces'].shape
    valid_actions = self.get_valid_actions(board, critter)
    if batch_size != len(valid_actions):
      raise ValueError("Need Exactly one set of valid actions per board in batch")
    critter_locs = np.column_stack(np.where(board['pieces'] == critter))
    valid_directions = []
    for g, batch_valid in enumerate(valid_actions):
      valid_int_indices = np.where(batch_valid==1)[0]
      critter_loc = critter_locs[critter_locs[:, 0] == g, 1:]
      # critter_loc shape is (1, 2)
      critter_loc = np.squeeze(critter_loc)
      moves = np.column_stack([valid_int_indices // n_cols, valid_int_indices % n_cols])
      offsets = moves - critter_loc
      batch_valid_directions = [offset_dict[tuple(offset)] for offset in offsets]
      valid_directions.append(batch_valid_directions)
    return valid_directions


  def get_game_ended(self, board):
    """
    Helper function to signify if game has ended
    Returns a batch size np.array of -1 if not ended, and scores for each game
    in the batch if it is ended, note only returns scores if all games in the
    batch have ended
    """
    rounds_left = board['rounds_left']
    scores = board['scores']
    if np.any(rounds_left >= 1):
      return np.ones(self.batch_size) * -1.0
    else:
      return scores


  def critter_everywhere_state_expansion(self, board_state,
                                         critter=1, to_expand=0):
    """
    Expand a given board state by placing a critter at each non-food location.

    The function takes a game state and returns an expanded version of it. For
    each board in the state, it creates a new version of the board for every
    non-food location, placing a critter at that location. The scores and
    remaining rounds are copied for each new board. The result is a new game state
    with a larger number of boards, each representing a possible configuration
    with a critter at a different location.

    Args:
      board_state (dict): A dictionary containing the current game state.
      It should have the following keys:
        - 'pieces': a 3D numpy array (batch x n_col x n_row) representing the game
          board. -1 -> food, 0 -> empty cell, and 1 -> critter.
        - 'scores': 1D numpyp array of the score for each board in the batch.
        - 'rounds_left': a 1D numpy array of the rounds left for
          each board in the batch.
      critter: integer index to place on the expanded board state
      to_expand (list (int)): list of batch indices to have state expanded

    Returns:
      dict: A dictionary containing the expanded game state with the same keys
        as the input. The number of boards will be larger than the input state.
    """
    pieces = board_state['pieces'].copy()
    scores = board_state['scores'].copy()
    rounds_left = board_state['rounds_left'].copy()
    active_player = copy(board_state['active_player'])
    # Determine non-food locations
    non_food_locs = np.argwhere(pieces[to_expand] != -1)
    #scrub all existing critter locations,
    # maybe later only scrub specific critter type
    pieces[pieces >= 1] = 0
    # lists to store expanded states
    expanded_pieces = []
    expanded_scores = []
    expanded_rounds_left = []
    # Iterate over each non-food location
    for i in range(non_food_locs.shape[0]):
      # Create a copy of the board
      expanded_board = np.copy(pieces[to_expand])
      # Place the critter at the non-food location
      # later consider only placing at non-food,
      # non-other critter locs
      expanded_board[tuple(non_food_locs[i])] = critter
      # Add the expanded board to the list along score and rounds_left
      expanded_pieces.append(expanded_board)
      expanded_scores.append(scores[to_expand])
      expanded_rounds_left.append(rounds_left[to_expand])
    # Convert to arrays and create expanded board state
    expanded_state = {'pieces': np.stack(expanded_pieces),
                      'scores': np.array(expanded_scores),
                      'rounds_left': np.array(expanded_rounds_left),
                      'active_player': active_player}
    return expanded_state


  def play_game(self, players=[], collect_fov_data=False, fov_radius=2,
                visualize = False):
    """This method takes a list of players the same length as num_critters,
        and then plays a batch of games with them and returns the final board
        state"""
    if len(players) != self.num_critters:
      raise ValueError("number of players different than expected")

    board = self.get_init_board()
    if visualize == True:
      self.display(board, 0)

    if collect_fov_data is True:
      batch_size, n_rows, n_cols = board['pieces'].shape
      b = GridworldBoard(batch_size=batch_size,
                         n_rows=n_rows, n_cols=n_cols,
                         num_critters=self.num_critters,
                         num_food=self.gwg.num_food,
                         lifetime=self.gwg.lifetime,
                         rng=self.gwg.rng)


    for ii in range(self.lifetime):
      for jj, player in enumerate(players):
        active_player_index = board['active_player']
        old_scores = board['scores']
        if collect_fov_data is True:
          b.set_state(board)
          percepts = b.get_perceptions(fov_radius)

        a_player, _, _ = players[active_player_index].play(board)
        board = self.get_next_state(board, active_player_index+1, a_player)
        if visualize == True:
          self.display(board, 0)
    return board




# @title player zoo
###########################################################################
# make a separate player zoo
###########################################################################




class RandomValidPlayer():
  """
  Instantiate random player for GridWorld
  """


  def __init__(self, game, critter_index=1):
    self.game = game
    self.critter_index = critter_index
    assert (isinstance(critter_index, int) and
        0 < critter_index <= game.num_critters), "Value is not a positive integer or exceeds the upper limit."


  def play(self, board):
    """
    Simulates a batch of random game plays based on the given board state.

    This function computes the probability of each valid move being played
    (uniform for valid moves, 0 for others), then selects a move randomly for
    each game in the batch based on these probabilities.

    Args:
      board (dict): A dictionary representing the state of the game. It
          contains:
          - 'pieces': A (batch_size, x_size, y_size) numpy array indicating
                      the pieces on the board.
          - 'scores' (not used directly in this function, but expected in dict)
          - 'rounds_left' (not used directly in this function, but expected in dict)

    Returns:
      tuple:
      - a (numpy array): An array of shape (batch_size,) containing randomly
                         chosen actions for each game in the batch.
      - a_1hots (numpy array): An array of shape (batch_size, action_size)
                               with one-hot encoded actions.
      - probs (numpy array): An array of the same shape as 'valids' containing
                             the probability of each move being played.
    """
    batch_size, x_size, y_size = board['pieces'].shape
    valids = self.game.get_valid_actions(board, self.critter_index)
    action_size = self.game.get_action_size()

    probs = valids / np.sum(valids, axis=1).reshape(batch_size,1)

    a = [self.game.rng.choice(action_size, p=probs[ii])
                                for ii in range(batch_size)]
    a_1hots = np.zeros((batch_size, action_size))
    a_1hots[(range(batch_size), a)] = 1.0
    return np.array(a), a_1hots, probs




class RandomDirectionPlayer():
  """
  Instantiate random player for GridWorld
  """

  def __init__(self, game, critter_index=1):
    self.game = game
    self.critter_index = critter_index
    assert (isinstance(critter_index, int) and
        0 < critter_index <= game.num_critters), "Value is not a positive integer or exceeds the upper limit."

  def play(self, board):
    """
    Simulates a batch of random game plays based on the given board state.

    This function assigns a uniform probability to going up down left or right
    independent of whether it is at an edge or cornor or not. Then because of
    bouncing off edges it will have a higher probability of moving away from
    edges as opposed to along them than the random valid move player.

    Args:
      board (dict): A dictionary representing the state of the game. It
          contains:
          - 'pieces': A (batch_size, x_size, y_size) numpy array indicating
                      the pieces on the board.
          - 'scores' (not used directly in this function, but expected in dict)
          - 'rounds_left' (not used directly in this function, but expected in dict)

    Returns:
      tuple:
      - a (numpy array): An array of shape (batch_size,) containing randomly
                         chosen actions for each game in the batch.
      - a_1hots (numpy array): An array of shape (batch_size, action_size)
                               with one-hot encoded actions.
      - probs (numpy array): An array of the same shape as 'valids' containing
                             the probability of each move being played.
    """
    batch_size, n_rows, n_cols = board['pieces'].shape
    action_probs = {'up': 0.25, 'down': 0.25, 'left': 0.25, 'right': 0.25}

    critter_oriented_moves = self.game.rng.choice(list(action_probs.keys()),
                                                  size=(batch_size))
    direction_probs = [action_probs] * batch_size
    moves = self.game.critter_direction_to_move(board, critter_oriented_moves,
                                                self.critter_index)
    probs = self.game.direction_probs_to_flat_probs(board, direction_probs,
                                                    self.critter_index)
    sampled_actions = self.game.moves_to_actions(moves)
    a_1hots = np.zeros((batch_size, n_rows*n_cols))
    a_1hots[(range(batch_size), sampled_actions)] = 1.0

    return sampled_actions, a_1hots, probs




class MonteCarloBasedPlayer():
  """
  Simulate Player based on Monte Carlo Algorithm

  Note: Has dependencies in the gw_NN_RL.py util, namely a policy/value
  network and the Monte Carlo class.
  """

  def __init__(self, game, nnet,
               critter_index=1,
               default_depth=1,
               default_rollouts=1,
               default_K=4,
               default_temp=1.0,
               random_seed=None):
    """
    Initialize Monte Carlo Parameters

    Args:
      game: Gridworld Game instance
        Instance of the gridworldGame class above;
      nnet: gridworldNet instance
        Instance of the gridworldNNet class above;
      args: dictionary
        Instantiates number of iterations and episodes, controls temperature threshold, queue length,
        arena, checkpointing, and neural network parameters:
        learning-rate: 0.001, dropout: 0.3, epochs: 10, batch_size: 64,
        num_channels: 512

    Returns:
      Nothing
    """
    self.game = game
    self.critter_index = critter_index
    assert (isinstance(critter_index, int) and
        0 < critter_index <= game.num_critters), "Value is not a positive integer or exceeds the upper limit."
    self.nnet = nnet
    self.default_depth = default_depth
    self.default_rollouts = default_rollouts
    self.mc = MonteCarlo(self.game, self.nnet, self.default_depth)
    self.default_K = default_K
    self.default_temp = default_temp
    self.rng = np.random.default_rng(seed=random_seed)


  def play(self, board,
           num_rollouts=None,
           rollout_depth=None,
           K=None,
           softmax_temp=None):
    """
    Simulates a batch Monte Carlo based plays on the given board state.

    Computes the probability of each valid move being played using a softmax
    activation on the Monte Carlo based value (Q) of each action then selects a
    move randomly for each game in the batch based on those probabilities.

    Args:
      board (dict): A dictionary representing the state of the game. It
          contains:
          - 'pieces': A (batch_size, x_size, y_size) numpy array indicating
                      the pieces on the board.
          - 'scores' (not used directly in this function, but expected in dict)
          - 'rounds_left' (not used directly in this function, but expected in dict)

    Returns:
      tuple:
      - a (numpy array): An array of shape (batch_size,) containing randomly
                         chosen actions for each game in the batch.
      - a_1hots (numpy array): An array of shape (batch_size, action_size)
                               with one-hot encoded actions.
      - probs (numpy array): An array of the same shape as 'valids' containing
                             the probability of each move being played.
    """
    batch_size, n_rows, n_cols = board['pieces'].shape
    if num_rollouts is None:
      num_rollouts = self.default_rollouts
    if rollout_depth is None:
      rollout_depth = self.default_depth
    if K is None:
      K = self.default_K
    if softmax_temp is None:
      softmax_temp = self.default_temp

    # figure out top k actions according to normalize action probability
    # given by our policy network prediction
    #co_pieces = board['pieces'].copy()
    #this_critter_locs = np.where(co_pieces == self.critter_index+1)
    #all_critter_locs = np.where(co_pieces >= 1)
    # other critters are invisible to this player
    #co_pieces[all_critter_locs] = 0
    # nnet trained to see self as 1
    #co_pieces[this_critter_locs] = 1
    #scalar_rounds_left = board['rounds_left'][0]
    #co_rounds_left = scalar_rounds_left // self.game.num_critters
    #if self.critter_index-1 < scalar_rounds_left % self.game.num_critters:
       # add an extra if we haven't had this players turn yet in the round cycle
    #   co_rounds_left = co_rounds_left + 1
    #co_rounds_left = np.array([co_rounds_left]*batch_size)
    #pis, vs = self.nnet.predict(co_pieces,
    #                            board['scores'][:,self.critter_index-1],
    #                            co_rounds_left)
    pis, vs = self.mc.pis_vs_from_board(board, self.critter_index)
    valids = self.game.get_valid_actions(board, self.critter_index)
    masked_pis = pis * valids  # Masking invalid moves
    sum_pis = np.sum(masked_pis, axis=1)
    num_valid_actions = np.sum(valids, axis=1)
    effective_topk = np.array(np.minimum(num_valid_actions, K), dtype= int)
    probs = np.array([masked_pi / masked_pi.sum() if masked_pi.sum() > 0
                      else valid / valid.sum()
                      for valid, masked_pi in zip(valids, masked_pis)])
    partioned = np.argpartition(probs,-effective_topk)
    topk_actions = [partioned[g,-(ii+1)]
                      for g in range(batch_size)
                        for ii in range(effective_topk[g])]
    topk_actions_index = [ii
                            for ii, etk in enumerate(effective_topk)
                              for _ in range(etk)]
    values = np.zeros(len(topk_actions))
    # Do some rollouts
    for _ in range(num_rollouts):
      values = values + self.mc.simulate(board, topk_actions,
                                         topk_actions_index,
                                         critter=self.critter_index,
                                         depth=rollout_depth)
    values = values / num_rollouts

    value_expand = np.zeros((batch_size, n_rows*n_cols))
    value_expand[(topk_actions_index, topk_actions)] = values
    value_expand_shift = value_expand - np.max(value_expand, axis=1, keepdims=True)
    value_expand_scale = value_expand_shift/softmax_temp
    v_probs = np.exp(value_expand_scale) / np.sum(
        np.exp(value_expand_scale), axis=1, keepdims=True)
    v_probs = v_probs * valids
    v_probs = v_probs / np.sum(v_probs, axis=1, keepdims=True)
    samp = self.rng.uniform(size = batch_size).reshape((batch_size,1))
    sampled_actions = np.argmax(v_probs.cumsum(axis=1) > samp, axis=1)
    a_1Hots = np.zeros((batch_size, n_rows*n_cols))
    a_1Hots[(range(batch_size), sampled_actions)] = 1.0
    return sampled_actions, a_1Hots, v_probs




class SimpleRulePlayer():
  """
  A Player based on the following simple policy:
  If there is any food immediately nearby move towards it,
  otherwise it move randomly.
  """

  def __init__(self, game, fov_radius=2, critter_index=1):
    """

    """
    self.game = game
    self.critter_index = critter_index
    assert (isinstance(critter_index, int) and
        0 < critter_index <= game.num_critters), "Value is not a positive integer or exceeds the upper limit."
    self.fov_radius = fov_radius


  def simple_action_from_percept(self, percept):
    """
    Determine an action based on perception.

    Args:
      percept: A 1D array (len 12 if fov_radius = 2)representing the perception
        of the organism. Indices correspond to spaces around the organism. The
        values in the array can be -2 (out-of-bounds), 0 (empty space), or
        -1 (food).

    Returns:
      action: a str, one of 'up', 'down', 'left', 'right'. If food in one or
        more of the spaces immediately beside the organism, the function will
        return a random choice among these directions. If there is no food
        nearby, the function will return a random direction.
    """
    # a human interpretable overview of the percept structure
    percept_struct = [
      'far up', 'left up', 'near up', 'right up',
      'far left', 'near left', 'near right', 'far right',
      'left down', 'near down', 'right down', 'far down']
    # Defines directions corresponding to different perception indices
    direction_struct = [
      'None', 'None', 'up', 'None',
      'None', 'left', 'right', 'None',
      'None', 'down', 'None', 'None']
    # these are what count as nearby in the percept
    nearby_directions = ['near up', 'near left', 'near right', 'near down']
    # Get the corresponding indices in the percept array
    nearby_indices = [percept_struct.index(dir_) for dir_ in nearby_directions]
    # Identify the directions where food is located
    food_indices = [index for index in nearby_indices if percept[index] == -1]
    food_directions = [direction_struct[index] for index in food_indices]

    action_probs = {'up': 0.0, 'down': 0.0, 'left': 0.0, 'right': 0.0}
    if len(food_directions) > 0:  # If there is any food nearby
      # If there is any food nearby randomly choose a direction with food
      action = self.game.rng.choice(food_directions)  # Move towards a random one
      for direction in food_directions:
        action_probs[direction] = 1.0 /len(food_directions)
    else:
      # If there is no food nearby, move randomly
      action = self.game.rng.choice(['up', 'down', 'left', 'right'])
      for direction in ['up', 'down', 'left', 'right']:
        action_probs[direction] = 0.25

    return action, action_probs


  def play(self, board):
    """
    Simulate Play on a Board

    Args:
      board: dict {'pieces':
      (batch x num_rows x num_cols) np.ndarray of board position,
                  'scores': batch len array of current scores,
                  'rounds_left': batch len array of rounds left

    Returns:
      sampled_actions: a batch, row, col index of the move taken
      by each player on each board
      a_1hots: a batch nrow*ncol array of 1hot indices of those same moves
      probs: sampling probabilities for those 1hots (If the policy
      is deterministic a_1hots is returned here as well... or if getting the
      probs is an un-needed fuss to compute)

    """
    batch_size, n_rows, n_cols = board['pieces'].shape
    perceptions = self.game.get_perceptions(board, self.fov_radius,
                                            self.critter_index)

    critter_oriented_moves = []
    direction_probs = []
    for g in range(batch_size):
      action, action_probs = self.simple_action_from_percept(perceptions[g])
      critter_oriented_moves.append(action)
      direction_probs.append(action_probs)
    moves = self.game.critter_direction_to_move(board, critter_oriented_moves,
                                                direction_probs,
                                                self.critter_index)
    probs = self.game.direction_probs_to_flat_probs(board, direction_probs)
    sampled_actions = self.game.moves_to_actions(moves)
    a_1hots = np.zeros((batch_size, n_rows*n_cols))
    a_1hots[(range(batch_size), sampled_actions)] = 1.0

    return sampled_actions, a_1hots, probs




class PerceptParamPlayer():
  """
  A Player playing a parameterized policy defined by the given weights
  """


  def __init__(self, game, weights=None, fov_radius=2, critter_index=1):
    """
    Args:
      game: Gridworld Game instance
        Instance of the gridworldGame class above;
      weights: 4 x 12 numpy array (assumes fov_radius = 2), that gives the
        connection strengths between the 'perception' neurons and the direction
        'neurons'
      fov_radius: int how far around itself the critter perceives, weights is
        expecting fov_radius = 2
    Returns:
      Nothing
    """
    self.game = game
    self.critter_index = critter_index
    assert (isinstance(critter_index, int) and
        0 < critter_index <= game.num_critters), "Value is not a positive integer or exceeds the upper limit."
    if weights is None:
      self.W = np.array(
      [[1., 1., 4., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 4., 1., 1.],
       [0., 1., 0., 0., 1., 4., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 4., 1., 0., 0., 1., 0.]])
    else:
      self.W = weights
    self.fov_radius = fov_radius
    self.default_softmax_temp = 0.05


  def param_action_from_percept(self, percept, valid_directions, W,
                                softmax_temp=None):
    """
    Determine an action based on perception.

    Args:
      percept: A 1D len 12 array representing the perception of the organism.
        Indices correspond to spaces around the organism. The values in the
        array can be -2 (out-of-bounds), 0 (empty space), or -1 (food).
      W: a 4 x 12 weight matrix parameter representing the connection strengths
        between the 12 perceptions inputs and the 4 possible output actions.

    Returns:
      direction: a str, one of 'up', 'down', 'left', 'right'. If food in one or
        more of the spaces immediately beside the organism, the function will
        return a random choice among these directions. If there is no food
        nearby, the function will return a random direction.
      direction_probs: dictionary with probabilities of taking each action.
    """
    if len(valid_directions) == 0:
      # if there is no where legit to move, stay put
      return 'still', {direction: 0 for direction in output_struct}

    if softmax_temp is None:
      # very low temp, basically deterministic for this range of values
      softmax_temp = self.default_softmax_temp
    # a human interpretable overview of the percept structure
    percept_struct = [
      'far up', 'left up', 'near up', 'right up',
      'far left', 'near left', 'near right', 'far right',
      'left down', 'near down', 'right down', 'far down']
    # a human interpretable overview of the out structure
    output_struct = ['up', 'down', 'left', 'right']
    # boolean representation of percept, no edges, just 1's where food is,
    # zero otherwise, also means other organisms are invisible
    x = np.asarray(percept == -1, int)
    output_activations = W @ x

    # softmax shift by max, scale by temp
    shift_scale_ex = np.exp((output_activations -
                             np.max(output_activations))/softmax_temp)
    # set invalid direction activations to zero
    invalid_directions = [direction for direction in output_struct
                           if direction not in valid_directions]
    invalid_indices = [output_struct.index(direction)
                        for direction in valid_directions]
    sm = shift_scale_ex / shift_scale_ex.sum() #normalized
    # set invalid direction probabilities to zero
    invalid_directions = [direction for direction in output_struct
                           if direction not in valid_directions]
    invalid_indices = [output_struct.index(direction)
                        for direction in invalid_directions]
    sm[invalid_indices] = 0
    probs_sm = sm / sm.sum(axis=0) #re-normalized again for fp issues
    direction = self.game.rng.choice(output_struct, p=probs_sm)
    direction_probs = {direction: prob
                        for direction, prob in zip(output_struct, probs_sm)}
    return direction, direction_probs


  def play(self, board, temp=None):
    """
    Simulate Play on a Board

    Args:
      board: dict {'pieces':
      (batch x num_rows x num_cols) np.ndarray of board position,
                  'scores': batch len array of current scores,
                  'rounds_left': batch len array of rounds left

    Returns:
      sampled_actions: a batch, row, col index of the move taken
      by each player on each board
      a_1hots: a batch nrow*ncol array of 1hot indices of those same moves
      v_probs: sampling probabilities for those 1hots (If the policy
      is deterministic a_1hots is returned here as well... or if getting the
      probs is an un-needed fuss to compute)
    """
    if temp is None:
      temp = self.default_softmax_temp
    batch_size, n_rows, n_cols = board['pieces'].shape
    perceptions = self.game.get_perceptions(board, self.fov_radius,
                                            self.critter_index)
    critter_oriented_moves = []
    direction_probs = []

    # Get valid actions for each game in the batch
    valid_directions = self.game.get_valid_directions(board, self.critter_index)
    for g in range(batch_size):
      direction, batch_direction_probs = self.param_action_from_percept(
        perceptions[g], valid_directions[g], self.W, softmax_temp=temp)
      critter_oriented_moves.append(direction)
      direction_probs.append(batch_direction_probs)
    moves = self.game.critter_direction_to_move(board, critter_oriented_moves,
                                                self.critter_index)
    probs = self.game.direction_probs_to_flat_probs(board, direction_probs, self.critter_index)
    sampled_actions = self.game.moves_to_actions(moves)
    a_1hots = np.zeros((batch_size, n_rows*n_cols))
    a_1hots[(range(batch_size), sampled_actions)] = 1.0

    return sampled_actions, a_1hots, probs




class BatchOptPerceptParamPlayer():
  """
  A Player playing a parameterized policy defined by the given weights
  """


  def __init__(self, game, weights=None, fov_radius=2, critter_index=1,
               get_probs=False, deterministic=False):
    """
    Args:
      game: Gridworld Game instance
        Instance of the gridworldGame class above;
      weights: 4 x 12 numpy array (assumes fov_radius = 2), that gives the
        connection strengths between the 'perception' neurons and the direction
        'neurons'
      fov_radius: int how far around itself the critter percieves, weights is
        expecting fov_radius = 2
    Returns:
      Nothing
    """
    # all critters need these things
    self.game = game
    self.critter_index = critter_index
    assert (isinstance(critter_index, int) and
        0 < critter_index <= game.num_critters), "Value is not a positive integer or exceeds the upper limit."
    self.get_probs = get_probs
    # these things are specfic to this kind of critter
    self.deterministic = deterministic
    if weights is None:
      self.W = np.array(
      [[1., 1., 4., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 4., 1., 1.],
       [0., 1., 0., 0., 1., 4., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 4., 1., 0., 0., 1., 0.]])
    else:
      self.W = weights
    self.fov_radius = fov_radius
    self.default_softmax_temp = 0.05


  def direction_value_from_percept(self, percepts, W):
    """
    Determine an action based on perception.

    Args:
      percept: A batch by 1D len 12 array representing the perceptions of the
      organism. Indices correspond to spaces around the organism. The values in
      the array can be -2 (out-of-bounds), 0 (empty space), or -1 (food).
      W: a 4 x 12 weight matrix parameter representing the connection strengths
        between the 12 perceptions inputs and the 4 possible output actions.

    Returns:
      direction_probs: array of probabilities of taking each action.
    """
    # a human interpretable overview of the percept structure
    #percept_struct = [
    #  'far up', 'left up', 'near up', 'right up',
    #  'far left', 'near left', 'near right', 'far right',
    #  'left down', 'near down', 'right down', 'far down']
    # a human interpretable overview of the out structure
    #output_struct = ['up', 'down', 'left', 'right']
    # boolean representation of percept, no edges, just 1's where food is,
    # zero otherwise, also means other organisms are invisible
    # x is batch x 12
    x = np.asarray(percepts == -1, int)
    # W is 4 x 12
    # this does the broadcasting we want
    output_activations = (W @ x.T).T
    # output activations is batch by 4
    return output_activations


  def play(self, board, temp=None, W=None):
    """
    Simulate Play on a Board

    Args:
      board: dict {'pieces':
      (batch x num_rows x num_cols) np.ndarray of board position,
                  'scores': batch len array of current scores,
                  'rounds_left': batch len array of rounds left

    Returns:
      sampled_actions: a batch, row, col index of the move taken
      by each player on each board
      a_1hots: a batch nrow*ncol array of 1hot indices of those same moves
      v_probs: sampling probabilities for those 1hots (If the policy
      is deterministic a_1hots is returned here as well... or if getting the
      probs is an un-needed fuss to compute)
    """
    if temp is None:
      temp = self.default_softmax_temp
    if W is None:
      W = self.W
    batch_size, n_rows, n_cols = board['pieces'].shape
    perceptions = self.game.get_perceptions(board, self.fov_radius,
                                            self.critter_index)
    # note the fragile order based dependency on how legal offsets is written,
    # and how output activations are interpreted
    direction_v = self.direction_value_from_percept(perceptions, W)
    flat_ds = direction_v.T.ravel()

    critter_locs = np.array(np.where(board['pieces'] == self.critter_index))
    legal_offsets = np.stack([
    critter_locs + np.array([np.array([0, -1,  0])]*batch_size).T, # up
    critter_locs + np.array([np.array([0,  1,  0])]*batch_size).T, # down
    critter_locs + np.array([np.array([0,  0, -1])]*batch_size).T, # left
    critter_locs + np.array([np.array([0,  0,  1])]*batch_size).T]) #right
    legal_offsets = np.vstack(np.transpose(legal_offsets, (0, 2, 1)))

    # conditions for offsets on the board
    c1 = legal_offsets[:,1] >= 0
    c2 = legal_offsets[:,1] <= n_rows-1
    c3 = legal_offsets[:,2] >= 0
    c4 = legal_offsets[:,2] <= n_cols-1
    all_c = np.logical_and.reduce([c1, c2, c3, c4])

    batch_indexes = legal_offsets[:,0][all_c]
    action_indexes = legal_offsets[:,1][all_c] * n_cols + legal_offsets[:,2][all_c]
    direction_values = flat_ds[all_c]

    valids = gwg.get_valid_actions(board, self.critter_index)
    value_expand = np.zeros((batch_size, n_rows*n_cols))
    value_expand[(batch_indexes, action_indexes)] = direction_values
    valid_value_expand = value_expand * valids
    row_sums = valid_value_expand.sum(axis=1)
    zero_rows = (row_sums == 0)
    valid_row_sums = valids.sum(axis=1, keepdims=True)
    # Handle the case where a row in 'valids' is all zeros to avoid division by zero
    valid_row_sums[valid_row_sums == 0] = 1
    normalized_valids = valids / valid_row_sums
    # Update only those rows where the sum was zero
    valid_value_expand[zero_rows] = normalized_valids[zero_rows]
    if self.deterministic:
      sampled_actions = np.argmax(valid_value_expand, axis=1)
      a_1Hots = np.zeros((batch_size, n_rows * n_cols))
      a_1Hots[np.arange(batch_size), sampled_actions] = 1.0
      v_probs = a_1Hots
    else:
      value_expand_shift = value_expand - np.max(value_expand, axis=1, keepdims=True)
      value_expand_scale = value_expand_shift/temp
      v_probs = np.exp(value_expand_scale) / np.sum(
        np.exp(value_expand_scale), axis=1, keepdims=True)
      v_probs = v_probs / np.sum(v_probs, axis=1, keepdims=True)
      samp = self.game.rng.uniform(size = batch_size).reshape((batch_size,1))
      sampled_actions = np.argmax(v_probs.cumsum(axis=1) > samp, axis=1)
      a_1Hots = np.zeros((batch_size, n_rows*n_cols))
      a_1Hots[(range(batch_size), sampled_actions)] = 1.0
    return sampled_actions, a_1Hots, v_probs



# @title interactive gridworld widgets

########################################
# widgets refactor for multi-critter
#########################################
# Interactive Gridworld Game Widgets




class InteractiveGridworld():
  """
  A widget based object for interacting with a gridworld game
  """

  def __init__(self, gridworld_game, init_board=None, has_fov=False,
               radius=2, collect_fov_data=False,
               figsize=(6,5), critter_names=['Critter'], players=['human']):
    """
    Initializes a widget based object for interacting with a gridworld game

    Args:
      gridworld_game: an instance of GridworldGame object
        expects this to have batchsize 1
      init_board: (optional) a triple of np arrays representing board state
        pieces,       - batch_size x n_rows x n_cols
        scores,       - batch_size
        rounds_left   - batch_size
        if left out will initialize with a random board state
      has_fov: bool, whether or not to display fog of war around the critter
      radius: int, number of squares the critter can "see" around it
      figsize: tuple (int, int), size of the figure
      critter_names: a list of strings that determines what the critter is called
        in the plot legend, order should align with players
      player: a list of either 'human', None, or a player object with a play
        method and a critter_index attribute. If 'human' use buttons,  if None
        default to making a RandomValidPlayer object, otherwise use the
        player class provided to make the player objects and use a start button.
        The list needs to be as long as the gridworld_game.num_critters
        attribute. Order should align with critter_name.

      Note: fov is going to look pretty janky with more than one player, maybe
      we get fov to only turn on for the 'active' player?
    """

    # Set GridworldGame object and initialize the board state
    self.gwg = gridworld_game
    self.has_fov = has_fov
    self.radius = radius
    self.percept_len = 2*self.radius*(self.radius+1)
    self.collect_fov_data = collect_fov_data
    self.figsize = figsize
    # initialize players and plotting specs together to ensure alignment
    self.players = []
    self.any_human_players = False
    self.crit_specs = []
    markers = ['h', 'd']  # hexagon and diamond
    colors = sns.color_palette("colorblind")
    for i in range(self.gwg.num_critters):
      spec = {'marker': markers[i % len(markers)],
              'color': colors[i // len(markers) % len(colors)],
              'name': critter_names[i],
              'int_id': i+1}
      self.crit_specs.append(spec)
      player = players[i] #implict check that players is at least long enough
      if player is None:
        self.players.append(RandomValidPlayer(self.gwg, critter_index=i+1))
      elif player == 'human':
        self.players.append('human')
        # right now only ever have on human player with index 1
        self.any_human_players = True
      else:
        player.critter_index = i+1
        self.players.append(player)
    self.final_scores = []
    if init_board is None:
      self.board_state = self.gwg.get_init_board()
    else:
      self.board_state = init_board
    if self.collect_fov_data is True:
      # keep raw records of percept and eating for manipulation later
      self.percept_eat_records = []
      # keep data in contingency table of how many food items were in
      # the percept, and whether or not food was eaten
      self.fov_eat_table_data = np.zeros((2, self.percept_len+1))
    # Initialize widgets and buttons
    self.output = widgets.Output(layout=widgets.Layout(
      width = '20.0em', min_width='20.0em', max_width='21.0em',
      min_height='10.0em', overflow='auto'))
    self.scoreboard = widgets.Output(layout=widgets.Layout(
      min_width='12.5em', max_width='18.8em',
      min_height='6.3em', overflow='auto'))
    self.fov_eat_table_display = widgets.Output(layout=widgets.Layout(
      min_width='25.0em', min_height='18.8em', overflow='auto'))
    self.up_button = widgets.Button(description="Up",
      layout=widgets.Layout(width='6.3em'))
    self.down_button = widgets.Button(description="Down",
      layout=widgets.Layout(width='6.3em'))
    self.left_button = widgets.Button(description="Left",
      layout=widgets.Layout(width='6.3em'))
    self.right_button = widgets.Button(description="Right",
      layout=widgets.Layout(width='6.3em'))
    self.start_button = widgets.Button(description="Start",
      layout=widgets.Layout(width='6.3em'))

    # get plot canvas widgets and other plotting objects
    plt.ioff()
    if self.collect_fov_data and self.any_human_players:
      self.legend_type = None # don't keep regenerating the legend
      # do legend separately if showing observations and no human player
      (self.b_fig, self.b_ax, self.b_crit_specs, self.b_food, self.b_fov,
       self.b_fig_legend, self.b_ax_legend) = self.gwg.plot_board(
          self.board_state, g=0, critter_specs=self.crit_specs,
          legend_type='separate', figsize=self.figsize, has_fov=self.has_fov,
          radius=self.radius)
    elif len(self.players) > 1:
      self.legend_type=None # don't keep regenerating the legend
      (self.b_fig, self.b_ax, self.b_crit_specs, self.b_food, self.b_fov,
       self.b_fig_legend, self.b_ax_legend) = self.gwg.plot_board(
          self.board_state, g=0, critter_specs=self.crit_specs,
          has_fov=self.has_fov, legend_type='separate',
          radius=self.radius, figsize=self.figsize)
    else:
      self.legend_type = 'included'
      (self.b_fig, self.b_ax, self.b_crit_specs, self.b_food, self.b_fov
        ) = self.gwg.plot_board(self.board_state, g=0,
                                critter_specs=self.crit_specs,
                                has_fov=self.has_fov,
                                radius=self.radius, figsize=self.figsize)
    # lump buttons together
    self.buttons = widgets.HBox([self.left_button,
                               widgets.VBox([self.up_button, self.down_button]),
                               self.right_button])
    # automatically pick different layouts for different situations
    if self.any_human_players:
      self.board_and_buttons = widgets.VBox([self.b_fig.canvas,
                                             self.buttons])

      if len(self.players) == 1:
        self.output_and_score = widgets.HBox([self.scoreboard, self.output])
        self.no_table_final_display = widgets.VBox([self.board_and_buttons,
                                                  self.output_and_score])
        if self.collect_fov_data == True:
          # a single human player collecting data
          self.final_display = widgets.HBox([self.no_table_final_display,
                                           self.fov_eat_table_display])
        else: # self.collect_fov_data == False:
          # a single human player not collecting data
          self.final_display = self.no_table_final_display
      else:
        # more than one player, one of them human
        self.V_board_outbput = widgets.VBox([self.board_and_buttons,
                                             self.output])
        self.V_scoreboard_start_legend = widgets.VBox([
        self.scoreboard, self.start_button, self.b_fig_legend.canvas])
        self.final_display = widgets.HBox([self.V_board_outbput,
                                             self.V_scoreboard_start_legend])
    else: # player is some kind of ai
      if self.collect_fov_data == True:
        # an ai player with recording
        # in this case legend is separate
        self.V_score_start_output_legend = widgets.VBox([self.scoreboard,
          self.start_button,  self.output, self.b_fig_legend.canvas])
        self.V_board_table = widgets.VBox([self.b_fig.canvas,
                                           self.fov_eat_table_display])
        self.final_display = widgets.HBox([self.V_board_table,
                                           self.V_score_start_output_legend])
      else:
        if len(self.players) == 1:
          # an ai player without recording
          self.H_score_output_start = widgets.HBox([
            self.scoreboard, self.output, self.start_button])
          self.final_display = widgets.VBox([
            self.b_fig.canvas, self.H_score_output_start])
        else:
          # more than one ai player
          self.V_board_outbput = widgets.VBox([self.b_fig.canvas, self.output])
          self.V_scoreboard_start_legend = widgets.VBox([
              self.scoreboard, self.start_button, self.b_fig_legend.canvas])
          self.final_display = widgets.HBox([self.V_board_outbput,
                                             self.V_scoreboard_start_legend])

    # initialize text outputs
    with self.scoreboard:
      table = [['High Score:'] + ['--'] * self.gwg.num_critters,
               ['Last Score:'] + ['--'] * self.gwg.num_critters,
               ['Average Score:'] + ['--'] * self.gwg.num_critters,]
      if len(self.players) > 1:
        headers = [''] + [f'P{i+1}' for i in range(self.gwg.num_critters)]
        print(tabulate(table, headers=headers))
      else: # len(self.players) == 1
        print(tabulate(table))
    with self.output:
      if self.any_human_players:
        print('Click a button to start playing')
      else:
        print('Click the start button to run the simulation')
    with self.fov_eat_table_display:
      printmd("**Observations**")
      table_data = [[str(ii),
                     str(self.fov_eat_table_data[0,ii]),
                     str(self.fov_eat_table_data[1,ii])] for ii in range(11)]
      table = ([['Food in Percept', 'Food Not Eaten', 'Food Eaten']] +
               table_data)
      print(tabulate(table))

    # Connect the buttons to functions that do something
    self.up_button.on_click(self.on_up_button_clicked)
    self.down_button.on_click(self.on_down_button_clicked)
    self.left_button.on_click(self.on_left_button_clicked)
    self.right_button.on_click(self.on_right_button_clicked)
    self.start_button.on_click(self.on_start_button_clicked)


  def button_output_update(self, which_button):
    old_board = self.board_state.copy()
    # index is 0 through num_critter-1, critter value in pieces is
    # index + 1
    active_player_index = old_board['active_player']
    old_scores = old_board['scores'][0]
    if self.collect_fov_data is True:
      batch_size, n_rows, n_cols = old_board['pieces'].shape
      b = GridworldBoard(batch_size, n_rows, n_cols,
                         self.gwg.num_food, self.gwg.lifetime,
                         rng=self.gwg.rng)
      b.set_state(old_board)
      percept = b.get_perceptions(self.radius)[0]

    #print(self.players[active_player_index])
    #print(active_player_index)
    #if not self.players[active_player_index] == 'human':
    #  print(self.players[active_player_index].critter_index)
    #else:
    #  print('human has no critter_index attribute')

    if (isinstance(self.players[active_player_index], str) and
        'human' in self.players[active_player_index]):
      direction = which_button
    else:
      a_player, _, _ = self.players[active_player_index].play(old_board)
      a_player = self.gwg.action_to_critter_direction(old_board,
                                                      active_player_index+1,
                                                      a_player)
      # but we only want to apply their move to the appropriate board
      direction = a_player[0]

    self.board_state = self.gwg.critter_oriented_get_next_state(
          self.board_state, active_player_index+1, [direction])
    new_scores = self.board_state['scores'][0] #first batch first critter type
    rounds_left = self.board_state['rounds_left'][0]
    num_moves = np.floor(self.gwg.lifetime -
                         rounds_left / self.gwg.num_critters)
    if new_scores[active_player_index] > old_scores[active_player_index]:
      #eating happened
      eating_string = "They ate the food there!"
      did_eat = 1
    else: #eating didn't happen
      eating_string = "There's no food there."
      did_eat = 0
    row, col = self.gwg.get_critter_rc(self.board_state, 0,
                                       active_player_index+1)
    (self.b_fig, self.b_ax, self.b_crit_specs, self.b_food, self.b_fov
     ) = self.gwg.plot_board(self.board_state, g=0,
                             fig=self.b_fig, ax=self.b_ax,
                             critter_specs=self.b_crit_specs, food=self.b_food,
                             fov=self.b_fov, has_fov=self.has_fov,
                             radius=self.radius, legend_type=self.legend_type)
    if self.collect_fov_data is True:
      p_e_data = {'perception': percept.copy(),
                  'state': old_board,
                  'did_eat': bool(did_eat)}
      self.percept_eat_records.append(p_e_data)
      percept_int = np.sum(percept==-1) # number of food items in FoV
      self.fov_eat_table_data[did_eat, percept_int] += 1

    with self.output:
      clear_output()
      if len(self.players) == 1:
        print("The critter (tried) to move " + direction +
              " and is now at ({}, {}).".format(row,col))
        print(eating_string)
        print("Rounds Left: {}\nFood Eaten: {}\nFood Per Move: {:.2f}".format(
            rounds_left, new_scores[active_player_index],
            new_scores[active_player_index] / num_moves))
      else: # more than one players
        print("Critter {} (tried) to move ".format(active_player_index+1) +
              direction +
              " and is now at ({}, {}).".format(row, col))
        print(eating_string)
        print("Rounds Left: {}\nFood Eaten: {}".format(
            rounds_left, new_scores))
    if rounds_left == 0:
      self.final_scores.append(new_scores)
      with self.output:
        clear_output
        print('Game Over. Final Score {}'.format(new_scores))
        print('Resetting the board for another game')
        self.board_state = self.gwg.get_init_board()
      (self.b_fig, self.b_ax, self.b_crit_specs, self.b_food, self.b_fov
       ) = self.gwg.plot_board(self.board_state, 0, self.b_fig, self.b_ax,
                               self.b_crit_specs, self.b_food, self.b_fov,
                               has_fov=self.has_fov, radius=self.radius,
                               legend_type=self.legend_type)
    with self.scoreboard:
        clear_output()
        print('Games Played: ' + str(len(self.final_scores)))
        if len(self.players) == 1:
          if len(self.final_scores) > 0:
            table = [
              ['High Score:', str(np.max(np.array(self.final_scores)))],
              ['Last Score:', str(self.final_scores[-1])],
              ['Average Score',
              '{:.2f}'.format(np.mean(np.array(self.final_scores)))]]
          else:
            table = [['High Score:', '--'],
                     ['Last Score:', '--'],
                     ['Average Score:', '--']]
          print(tabulate(table))
        else: # len(self.players) > 1
          headers = [''] + [f'P{i+1}' for i in range(self.gwg.num_critters)]
          if len(self.final_scores) > 0:
            table = []
            # Assuming the batch size is 1 for now
            current_scores = self.final_scores[-1]
            max_scores = np.max(np.array(self.final_scores), axis=0)
            average_scores = np.mean(np.array(self.final_scores), axis=0)
            table.append(['High Scores:'] +
              [str(score) for score in max_scores])
            table.append(['Last Scores:'] +
              [str(score) for score in current_scores])
            table.append(['Average Scores:'] +
              ['{:.2f}'.format(score) for score in average_scores])
          else:
            table = [
              ['High Score:'] + ['--'] * self.gwg.num_critters,
              ['Last Score:'] + ['--'] * self.gwg.num_critters,
              ['Average Score:'] + ['--'] * self.gwg.num_critters,]
          print(tabulate(table, headers=headers))

    with self.fov_eat_table_display:
      clear_output()
      printmd("**Observations**")
      table_data = [[str(ii),
                     str(self.fov_eat_table_data[0,ii]),
                     str(self.fov_eat_table_data[1,ii])] for ii in range(11)]
      table = ([['Food in Percept', 'Food Not Eaten', 'Food Eaten']] +
               table_data)
      print(tabulate(table))

  def disable_direction_buttons(self):
    self.up_button.disabled = True
    self.down_button.disabled = True
    self.left_button.disabled = True
    self.right_button.disabled = True

  def enable_direction_buttons(self):
    self.up_button.disabled = False
    self.down_button.disabled = False
    self.left_button.disabled = False
    self.right_button.disabled = False

  def human_ai_player_loop(self, direction):
    self.disable_direction_buttons()
    for player in self.players:
      if player == 'human':
        self.button_output_update(direction)
      else:
        self.button_output_update('tbd')
    self.enable_direction_buttons()

  def on_up_button_clicked(self, *args):
    self.human_ai_player_loop('up')

  def on_down_button_clicked(self, *args):
    self.human_ai_player_loop('down')

  def on_left_button_clicked(self, *args):
    self.human_ai_player_loop('left')

  def on_right_button_clicked(self, *args):
    self.human_ai_player_loop('right')

  def on_start_button_clicked(self, *args):
    self.start_button.disabled = True
    for ii in range(self.gwg.lifetime*self.gwg.num_critters):
      self.button_output_update('tbd')
      time.sleep(0.2)
    self.start_button.disabled = False




class Head2HeadGridworld():
  """
  A widget for interacting with a gridworld game while an artificial player
  plays on an identical board or watching two artificial players play, again
  with identical starting positions (though RNG not synched between the two
  boards, so not like duplicate bridge). We are not going to worry about having
  more than 1 critter type playing in head to head, (maybe we will to talk
  about cooperation... maybe).
  """

  def __init__(self, gridworld_game, init_board=None,
               player0='human', p0_short_name='YOU', p0_long_name=None,
               player1=None, p1_short_name='THEM', p1_long_name=None,
               has_fov=False, radius=2, collect_fov_data=False,
               critter_name='Critter', figsize=(5,4.5),
               has_temp_slider=False):
    """
    Initializes a widget based object for interacting with a gridworld game

    Args:
      gridworld_game: an instance of GridworldGame object
        expects this to have batch_size of 2
      init_board: (optional) a triple of np arrays representing board state
        pieces,       - batch_size x n_rows x n_cols
        scores,       - batch_size
        rounds_left   - batch_size
        if left out will initialize with a random board state
      player0: object with a play method that takes a board state
        as an argument and returns a move. If none will use a random player
        if the special string 'human' is passed make arrow keys for that player
      player1: same deal as player0, never more than 1 human player
      has_fov: bool, whether or not to display field of view around the critter
      radius: int, number of squares the critter can "see" around it
    """
    # Set GridworldGame object and initialize the board state
    self.gwg = gridworld_game
    self.final_scores = []
    self.player0 = player0
    self.p0_short_name = p0_short_name
    self.p0_long_name = p0_long_name
    self.player1 = player1
    self.p1_short_name = p1_short_name
    self.p1_long_name = p1_long_name
    self.no_human = True
    if self.player0 == 'human':
      assert self.player1 != 'human'
      self.no_human = False
    if self.player1 == 'human':
      assert self.player0 != 'human'
      self.no_human = False
    self.p0_next_move = None
    self.p1_next_move = None
    self.has_fov = has_fov
    self.radius = radius
    self.percept_len = 2*self.radius*(self.radius+1)
    self.collect_fov_data = collect_fov_data
    self.critter_name = critter_name
    self.figsize = figsize
    if player0 is None:
      self.player0 = RandomValidPlayer(self.gwg)
    else:
      self.player0 = player0
    if player1 is None:
      self.player1 = RandomValidPlayer(self.gwg)
    else:
      self.player1 = player1
    self.has_temp_slider = has_temp_slider

    if self.collect_fov_data is True:
      self.percept_eat_records = []
      self.fov_eat_table_data = np.zeros((2, self.percept_len+1))
    if init_board is None:
      self.board_state = self.gwg.get_init_board()
    else:
      self.board_state = init_board
    #print(self.board_state)

    # both players have same starting board
    self.board_state['pieces'][1] = self.board_state['pieces'][0].copy()

    # Initialize widgets and buttons
    if self.has_temp_slider:
      self.sft_slider_label = widgets.Label(value='Softmax Temperature')
      self.sft_slider = widgets.FloatSlider(value=1.0, min=0.05,
                                            max=5.0, step=0.05)
      self.softmax_temp_slider = widgets.VBox([self.sft_slider_label,
                                               self.sft_slider])
    self.output0 = widgets.Output(layout=widgets.Layout(
      width = '20.0em', min_width='20.0em', max_width='21.0em',
      min_height='10.0em', overflow='auto'))
    self.output1 = widgets.Output(layout=widgets.Layout(
      width = '20.0em', min_width='20.0em', max_width='21.0em',
      min_height='10.0em', overflow='auto'))
    self.scoreboard = widgets.Output(layout=widgets.Layout(
      min_width='20em', max_width='21em', min_height='6.3em', overflow='auto'))
    self.up_button = widgets.Button(description="Up",
                                    layout=widgets.Layout(width='6.3em'))
    self.down_button = widgets.Button(description="Down",
                                      layout=widgets.Layout(width='6.3em'))
    self.left_button = widgets.Button(description="Left",
                                      layout=widgets.Layout(width='6.3em'))
    self.right_button = widgets.Button(description="Right",
                                       layout=widgets.Layout(width='6.3em'))
    self.start_button = widgets.Button(description="Start",
      layout=widgets.Layout(width='6.3em', margin='0.6em 0 0 0'))  # 0.6em top margin

    #set up buttons and outputs and layouts
    self.buttons = widgets.HBox([self.left_button,
                               widgets.VBox([self.up_button, self.down_button]),
                               self.right_button])
    plt.ioff()
    (self.b_fig0, self.b_ax0, self.b_crit_specs0, self.b_food0, self.b_fov0,
     self.b_fig_legend, self.b_ax_legend) = self.gwg.plot_board(
        self.board_state, g=0, legend_type='separate', figsize=self.figsize,
        has_fov=self.has_fov, radius=self.radius,
        name=self.critter_name, title=self.p0_long_name)
    (self.b_fig1, self.b_ax1, self.b_crit_specs1, self.b_food1, self.b_fov1
     ) = self.gwg.plot_board(self.board_state, g=1, legend_type=None,
                             figsize=self.figsize, has_fov=self.has_fov,
                             radius=self.radius, title=self.p1_long_name)
    # player 0 is human
    self.board_buttons_and_output0 = widgets.VBox(
      [self.b_fig0.canvas, self.buttons, self.output0])
    # player 1 is human
    self.board_buttons_and_output1 = widgets.VBox(
      [self.b_fig1.canvas, self.buttons, self.output1])
    # non human players
    self.board_and_output0 = widgets.VBox([self.b_fig0.canvas, self.output0])
    self.board_and_output1 = widgets.VBox([self.b_fig1.canvas, self.output1])

    self.legend_and_scores = widgets.VBox([self.b_fig_legend.canvas,
                                           self.scoreboard])
    if self.has_temp_slider:
      self.legend_scores_start = widgets.VBox([self.b_fig_legend.canvas,
                                               self.scoreboard,
                                               self.softmax_temp_slider,
                                               self.start_button])
    else:
      self.legend_scores_start = widgets.VBox([self.b_fig_legend.canvas,
                                               self.scoreboard,
                                               self.start_button])
    if self.player0 == 'human':
      self.final_display = widgets.HBox([self.board_buttons_and_output0,
                                         self.legend_and_scores,
                                         self.board_and_output1])
    elif self.player1 == 'human':
      self.final_display = widgets.HBox([self.board_and_output0,
                                         self.legend_and_scores,
                                         self.board_buttons_and_output1])
    else: # no human player
      self.final_display = widgets.HBox([self.board_and_output0,
                                          self.legend_scores_start,
                                          self.board_and_output1])
    # initial text outputs
    # if there's a temp slider check who, if anyone uses it
    self.p0_uses_temp = False
    self.p1_uses_temp = False
    if self.has_temp_slider:
      if self.player0=='human':
        pass
      else:
        try:
          _ = self.player0.play(self.board_state, temp=1.0)
          self.p0_uses_temp = True
        except TypeError: pass
      if self.player1 == 'human':
        pass
      else:
        try:
          _ = self.player1.play(self.board_state, temp=1.0)
          self.p1_uses_temp = True
        except TypeError: pass
      if not self.p0_uses_temp and not self.p1_uses_temp:
        with self.output0:
          print("Warning: neither player supports temperature adjustment. "
                "The slider will have no effect.")
    with self.output0:
      if self.no_human == False:
        print('Click a button to start.')
      else:
        print('Click the start button to run the simulation')
    with self.scoreboard:
      print('Games Played: ' + str(len(self.final_scores)))
      table = [['', self.p0_short_name, self.p1_short_name],
          ['High Score:', '--', '--'],
          ['Last Score:', '--', '--'],
          ['Avg. Score:', '--', '--']]
      print(tabulate(table))

    # Connect the buttons to functions that do something
    self.up_button.on_click(self.on_up_button_clicked)
    self.down_button.on_click(self.on_down_button_clicked)
    self.left_button.on_click(self.on_left_button_clicked)
    self.right_button.on_click(self.on_right_button_clicked)
    self.start_button.on_click(self.on_start_button_clicked)


  def button_output_update(self, which_button):
    old_board = self.board_state.copy()
    old_scores = old_board['scores'][:,0] #both batches only one critter type
    active_player_index = old_board['active_player']
    self.disable_buttons()
    if self.player0 == 'human':
      a_player0 = which_button
    else:
      if self.p0_next_move is not None:
        a_player0_ = self.p0_next_move
        self.p0_next_move = None
      else:
        with self.output0:
          print("AI is thinking...")
        if self.p0_uses_temp:
          a_player0_, _, _ = self.player0.play(old_board,
                                               temp=self.sft_slider.value)
        else:
          a_player0_, _, _ = self.player0.play(old_board)
      a_player0_ = self.gwg.action_to_critter_direction(old_board,
                                                        active_player_index+1,
                                                        a_player0_)
      a_player0 = a_player0_[0]
    if self.player1 == 'human':
      a_player1 = which_button
    else:
      if self.p1_next_move is not None:
        a_player1_ = self.p1_next_move
        self.p1_next_move = None
      else:
        with self.output1:
          print("AI is thinking...")
        if self.p1_uses_temp:
          a_player1_, _, _ = self.player1.play(old_board,
                                               temp=self.sft_slider.value)
        else:
          a_player1_, _, _ = self.player1.play(old_board)
      a_player1_ = self.gwg.action_to_critter_direction(old_board,
                                                        active_player_index+1,
                                                        a_player1_)
      a_player1 = a_player1_[1]
    self.enable_buttons()

    self.board_state = self.gwg.critter_oriented_get_next_state(
        self.board_state, active_player_index+1, [a_player0, a_player1])

    # Try to precompute next AI player move(s) if there are any rounds left
    if self.board_state['rounds_left'][0] > 0:
      if self.player0 != 'human':
        if self.p0_uses_temp:
          self.p0_next_move, _, _ = self.player0.play(
            self.board_state, temp=self.sft_slider.value)
        else:
          self.p0_next_move, _, _ = self.player0.play(self.board_state)
      if self.player1 != 'human':
        if self.p1_uses_temp:
          self.p1_next_move, _, _ = self.player1.play(
            self.board_state, temp=self.sft_slider.value)
        else:
          self.p1_next_move, _, _ = self.player1.play(self.board_state)

    if self.collect_fov_data is True:
      batch_size, n_rows, n_cols = old_board['pieces'].shape
      b = GridworldBoard(batch_size, n_rows, n_cols,
                         self.gwg.num_food, self.gwg.lifetime,
                         rng=self.gwg.rng)
      b.set_state(old_board)
      percept = b.get_perceptions(self.radius)

    new_scores = self.board_state['scores'][:,0] #both batches one critter type
    rounds_left = self.board_state['rounds_left'][0]
    num_moves = self.gwg.lifetime - rounds_left

    if new_scores[0] > old_scores[0]:
      eating_string0 = "They ate the food there!"
    else:
      eating_string0 = "There's no food there."
    if new_scores[1] > old_scores[1]:
      eating_string1 = "They ate the food there!"
    else:
      eating_string1 = "There's no food there."
    did_eat = int(new_scores[0] > old_scores[0])

    row0, col0 = self.gwg.get_critter_rc(self.board_state, 0, 1)
    (self.b_fig0, self.b_ax0, self.b_crit_specs0, self.b_food0, self.b_fov0
     ) = self.gwg.plot_board(self.board_state, 0, self.b_fig0, self.b_ax0,
                             self.b_crit_specs0, self.b_food0, self.b_fov0,
                             has_fov=self.has_fov, radius=self.radius,
                             legend_type=None)
    row1, col1 = self.gwg.get_critter_rc(self.board_state, 1, 1)
    (self.b_fig1, self.b_ax1, self.b_crit_specs1, self.b_food1, self.b_fov1
     ) = self.gwg.plot_board(self.board_state, 1, self.b_fig1, self.b_ax1,
                             self.b_crit_specs1, self.b_food1, self.b_fov1,
                             has_fov=self.has_fov, radius=self.radius,
                             legend_type=None)

    with self.output0:
      clear_output()
      if self.player0 == 'human':
        print("You clicked the " + which_button +
              " button and your critter is now at ({}, {}).".format(row0,col0))
      else:
        print("This player (tried) to move " + a_player0 +
              " and is now at ({}, {}).".format(row0,col0))
      print(eating_string0)
      print("Rounds Left: {} \nFood Eaten: {} \nFood Per Move: {:.2f}".format(
          rounds_left, new_scores[0], new_scores[0] / num_moves))
    with self.output1:
      clear_output()
      if self.player1 == 'human':
        print("You clicked the " + which_button +
              " button and your critter is now at ({}, {}).".format(row1,col1))
      else:
        print("This player (tried) to move " + a_player1 +
              " and is now at ({}, {}).".format(row1,col1))
      print(eating_string1)
      print("Rounds Left: {} \nFood Eaten: {} \nFood Per Move: {:.2f}".format(
        rounds_left, new_scores[1], new_scores[1] / num_moves))

    if self.collect_fov_data is True:
      p_e_data = (percept.copy(), did_eat, old_board)
      self.percept_eat_records.append(p_e_data)
      percept_int = np.sum(percept==-1, axis=1)
      self.fov_eat_table_data[did_eat, percept_int] += 1

    if rounds_left == 0:
      self.final_scores.append(new_scores)
      self.board_state = self.gwg.get_init_board()
      self.board_state['pieces'][1] = self.board_state['pieces'][0].copy()
      (self.b_fig0, self.b_ax0, self.b_crit_specs0, self.b_food0, self.b_fov0
       ) = self.gwg.plot_board(self.board_state, 0, self.b_fig0, self.b_ax0,
                             self.b_crit_specs0, self.b_food0, self.b_fov0,
                             has_fov=self.has_fov, radius=self.radius,
                             legend_type=None)
      (self.b_fig1, self.b_ax1, self.b_crit_specs1, self.b_food1, self.b_fov1
       ) = self.gwg.plot_board(self.board_state, 1, self.b_fig1, self.b_ax1,
                             self.b_crit_specs1, self.b_food1, self.b_fov1,
                             has_fov=self.has_fov, radius=self.radius,
                             legend_type=None)
      with self.output0:
        clear_output
        print('Game Over. Final Score {}'.format(new_scores[0]))
        print('Resetting the board for another game')
      with self.output1:
        clear_output
        print('Game Over. Final Score {}'.format(new_scores[1]))
        print('Resetting the board for another game')
    with self.scoreboard:
      clear_output()
      self.b_fig_legend.canvas.draw()
      print('Games Played: ' + str(len(self.final_scores)))
      if len(self.final_scores) > 0:
        table = [['', self.p0_short_name, self.p1_short_name],
          ['High Score:', str(np.max(np.array(self.final_scores)[:,0])),
                          str(np.max(np.array(self.final_scores)[:,1]))],
          ['Last Score:', str(self.final_scores[-1][0]),
                          str(self.final_scores[-1][1])],
          ['Average Score',
            '{:.2f}'.format(np.mean(np.array(self.final_scores)[:,0])),
            '{:.2f}'.format(np.mean(np.array(self.final_scores)[:,1]))]]
      else:
        table = [['', self.p0_short_name, self.p1_short_name],
          ['High Score:', '--', '--'],
          ['Last Score:', '--', '--'],
          ['Average Score:', '--', '--']]
      print(tabulate(table))


  def on_up_button_clicked(self, *args):
    self.button_output_update('up')

  def on_down_button_clicked(self, *args):
    self.button_output_update('down')

  def on_left_button_clicked(self, *args):
    self.button_output_update('left')

  def on_right_button_clicked(self, *args):
    self.button_output_update('right')

  def on_start_button_clicked(self, *args):
    self.start_button.disabled = True
    if self.has_temp_slider:
      self.softmax_temp_slider.disabled = True
    for ii in range(self.gwg.lifetime):
      self.button_output_update('tbd')
      time.sleep(0.2)
    self.start_button.disabled = False
    if self.has_temp_slider:
      self.softmax_temp_slider.disabled = False

  def disable_buttons(self):
    self.up_button.disabled = True
    self.down_button.disabled = True
    self.left_button.disabled = True
    self.right_button.disabled = True

  def enable_buttons(self):
    self.up_button.disabled = False
    self.down_button.disabled = False
    self.left_button.disabled = False
    self.right_button.disabled = False




class ExploreWeightsWidget:
  def __init__(self, game):
    self.game = game
    self.n_rows, self.n_cols = 4, 12  # four directions, twelve fov cells
    self.row_labels = ['Up', 'Down', 'Left', 'Right']
    self.col_labels = [
      'Far<br>Up', 'Left<br>Up', 'Near<br>Up', 'Right<br>Up',
      'Far<br>Left', 'Near<br>Left', 'Near<br>Right', 'Far<br>Right',
      'Left<br>Down', 'Near<br>Down', 'Right<br>Down', 'Far<br>Down'
      ]

    # Create column headers
    col_header = [widgets.Label(value='', layout=widgets.Layout(width='50px'))] + \
                 [widgets.HTML(value=label, layout=widgets.Layout(width='60px', min_height='60px')) for label in self.col_labels]

    self.text_fields = [widgets.HBox(col_header)]

    for label in self.row_labels:
      row_fields = [widgets.FloatText(value=0.0, step=0.1, layout=widgets.Layout(width='60px'))
                    for _ in range(self.n_cols)]
      row_label = widgets.Label(value=f"{label}:", layout=widgets.Layout(width='50px'))
      self.text_fields.append(widgets.HBox([row_label] + row_fields))

    # Create a button to start the game
    self.run_button = widgets.Button(description="Run Game")
    self.run_button.on_click(self.run_game)

    # set up fig and create placeholders for vertical lines and histograms
    colors = sns.color_palette("colorblind")
    self.current_color = colors[0]
    self.best_color = colors[1]
    self.prev_color = colors[2]
    self.fig, self.ax = plt.subplots(figsize=(6,4))
    self.ax.set_xlim([0,25])
    self.ax.set_ylim([0,1])
    remove_ip_clutter(self.fig)
    self.current_avg_line = self.ax.axvline(-1, color=self.current_color,
                                            linestyle='dashed', linewidth=3,
                                            label='Current')
    self.prev_avg_line = self.ax.axvline(-1, color=self.prev_color,
                                         linestyle='dashed', linewidth=3,
                                         label='Previous')
    self.best_avg_line = self.ax.axvline(-1, color=self.best_color,
                                         linestyle='dashed', linewidth=3,
                                         label='Best')
    if self.game.batch_size > 1:
      #only do hist for batches
      self.current_hist_bars = self.ax.bar([0]*10, [0]*10,
                                           color=self.current_color,
                                           label='Current Run')
      self.prev_hist_bars = self.ax.bar([0]*10, [0]*10, color=self.prev_color,
                                        label='Previous Run', alpha=0.5)
    self.fig.legend(loc='outside right upper')
    self.fig.canvas.draw()

    # Output widget to display game output and any other information
    self.out = widgets.Output()

    # keep track of important values
    self.best_avg_score = float('-inf')
    self.best_params = None
    self.prev_scores = []
    self.scores = np.array([])

    # Button to set the best weights
    self.best_button = widgets.Button(description="Set Best Weights")
    self.best_button.on_click(self.set_best_weights)

    # Add a ToggleButton for symmetry
    self.symmetry_toggle = widgets.ToggleButton(value=False,
                                                description='Enforce Symmetry',
                                                disabled=False,
                                                button_style='',
                                                tooltip='Toggle symmetry enforcement',
                                                icon='check')
    self.symmetry_toggle.observe(self.toggle_symmetry, 'value')

    self.final_display = widgets.VBox([
      *self.text_fields,
      widgets.HBox([self.run_button, self.best_button, self.symmetry_toggle]),
      self.fig.canvas,
      self.out])

    self.links = []  # To keep track of the jslink objects


  def run_game(self, *args):
    weights = []
    for hbox in self.text_fields[1:]:  # Skip the header row
      row_weights = [field.value for field in hbox.children[1:]]  # Skip the label at the first position
      weights.append(row_weights)
    weights = np.array(weights)
    ppp = PerceptParamPlayer(self.game, weights=weights)
    # Run the game
    final_board = self.game.play_game(players=[ppp], visualize=False)

    self.scores = final_board['scores'].flatten()
    avg_score = np.mean(self.scores)

    if avg_score > self.best_avg_score:
      self.best_avg_score = avg_score
      self.best_params = weights

    if self.game.batch_size > 1:
      # Compute and update histogram data
      counts, edges = np.histogram(self.scores, bins=10)
      counts = counts/np.sum(counts)
      prev_counts, prev_edges = np.histogram(self.prev_scores[-1], bins=10) if len(self.prev_scores) > 0 else ([0]*10, edges)
      prev_sum = np.sum(prev_counts)
      if prev_sum > 0:
        prev_counts = prev_counts / np.sum(prev_counts)
      # Update the height of bars for the current scores
      for bar, height, left in zip(self.current_hist_bars, counts, edges[:-1]):
          bar.set_height(height)
          bar.set_x(left)
          bar.set_width(edges[1] - edges[0])
      # Update the height of bars for the previous scores
      for bar, height, left in zip(self.prev_hist_bars, prev_counts, prev_edges[:-1]):
          bar.set_height(height)
          bar.set_x(left)
          bar.set_width(prev_edges[1] - prev_edges[0])
    # set vline data
    self.current_avg_line.set_xdata([avg_score])
    if len(self.prev_scores) > 0:
        prev_avg_score = np.mean(self.prev_scores[-1])
        self.prev_avg_line.set_xdata([prev_avg_score])
    self.best_avg_line.set_xdata([self.best_avg_score])

    #update the fig
    self.fig.legend(loc='outside right upper')
    self.fig.canvas.draw()
    # Display the output
    with self.out:
      clear_output()

      if self.game.batch_size > 1:
        print(f"Average Score This Time: {avg_score}")
        if len(self.prev_scores) > 0:
          prev_avg_score = np.mean(self.prev_scores[-1])
          print(f"Average Score Last Time: {prev_avg_score}")
        print(f"Best Ever Average Score: {self.best_avg_score}")
      else:
        print(f"Score This Run: {avg_score}")
        if len(self.prev_scores) > 0:
          print(f"Score Last Run: {prev_avg_score}")
        print(f"Best Score: {self.best_avg_score}")

    self.prev_scores.append(self.scores.copy())


  def link_symmetric_widgets(self, change):
    # Each row's symmetry permutation indices
    symmetry_indices = {
      'Up':    [ 0,  1,  2,  1,  3,  4,  4,  3,  5,  6,  5,  7],
      'Down':  [ 7,  5,  6,  5,  3,  4,  4,  3,  1,  2,  1,  0],
      'Left':  [ 3,  1,  4,  5,  0,  2,  6,  7,  1,  4,  5,  3],
      'Right': [ 3,  5,  4,  1,  7,  6,  2,  0,  5,  4,  1,  3]}

    if change['new']:  # If the toggle button is activated
      base_row = self.text_fields[1].children[1:]  # The 'Up' row
      base_perm = np.array(symmetry_indices['Up'])  # Convert to a NumPy array for easier manipulation

      for row_label, hbox in zip(self.row_labels, self.text_fields[1:]):  # Include all rows
        link_row = hbox.children[1:]  # Skip the label
        perm = np.array(symmetry_indices[row_label])  # Convert to a NumPy array for easier manipulation
        for i, j in enumerate(perm):
          base_index = np.flatnonzero(base_perm == j)[0]
          if row_label != 'Up' or base_index != i:  # Skip self-links
            link = widgets.jslink((base_row[base_index], 'value'), (link_row[i], 'value'))
            self.links.append(link)


  def unlink_symmetric_widgets(self):
    for link in self.links:
      link.unlink()
    self.links.clear()


  def toggle_symmetry(self, change):
    if change.new:
      self.link_symmetric_widgets(change)
    else:
      self.unlink_symmetric_widgets()


  def set_best_weights(self, *args):
    if self.best_params is not None:
      for i, hbox in enumerate(self.text_fields[1:]):
        for j, field in enumerate(hbox.children[1:]):
          field.value = self.best_params[i][j]



Random seed 2021 has been set.
This notebook isn't using and doesn't need a GPU. Good.
Running in colab


# 1.2.1.1: Artisanal Hand-Tuned Optimization

To start our optimization journey use the widget below to hand-tune the parameters of the policy. When you are happy with your parameter selection click the 'Go!' button to run a simulation and get a result. Try it twice or thrice. Is the stochasticity of the Gridworld environment making your optimization task harder or easier? Is the enforce symmetry button helpful? What does it do?

In [2]:
# @title Twiddle the Weights
# @markdown Run this cell to play with the weights for our 'Parameterized Weights' policy.
gwg = GridworldGame(batch_size=1, n_rows=7, n_cols=7,
                    num_critters=1, num_food=10, lifetime=30)
eww = ExploreWeightsWidget(gwg)
display(eww.fig.canvas)
clear_output()
display(eww.final_display)

VBox(children=(HBox(children=(Label(value='', layout=Layout(width='50px')), HTML(value='Far<br>Up', layout=Lay…

Okay, so it will probably take a long time to optimize if we can only get feedback on one stochastic simulation at a time. Also what are we even trying to optimize for? Our best score out of 10 simulation runs? The highest low score out of 100 simulation runs? The highest average score over 10, 100, 1000 simulation runs? The policy that achieves one of these goals might not be the same as the policy that achieves another of these goals, so we need to choose. A well defined evolutionary process often implies a specific goal, but we don't have a well defined evolutionary process just yet so instead we're going to have to just pick something. For now we're going to go with maximizing the ***Expected*** score, i.e. the goal is to have a policy with a [high average score over many simulations](## "The expectation of a random variable is connected to the mean of a set of samples of that variable by an important theorem in probability theory called the Law of Large Numbers. Specifically, the sample mean can be made arbitrarily close to the expected value of the random variable by making the sample size larger and larger. Imagine flipping a fair coin. Even though the chance of getting heads is 50%, in just 10 flips, there could easily be 7 heads and 3 tails just due to randomness, and 70% heads is very different from the true expected value of 50%. But flip that same coin a million times and the proportion of heads will be much closer to 50%. The Law of Large Numbers essentially says: the more flips, the closer to that true 50% mark.").  We can get better feedback on how well we are doing on this goal, despite the stochaticity of the environment and the policy, by averaging the score over many runs. Run the cell below to continue exploring the impact of weight changes on score, but now getting feedback from 400 simulations run in parallel.

In [3]:
# @title Twiddle the Weights Again
# @markdown Run this cell to play with the weights for our 'Parameterized Weights' policy. This time getting feedback on 400 simulation runs instead of just one. It may take a few seconds to run the simulations.
gwg = GridworldGame(batch_size=400, n_rows=7, n_cols=7,
                    num_critters=1, num_food=10, lifetime=30)
eww = ExploreWeightsWidget(gwg)
display(eww.fig.canvas)
clear_output()
display(eww.final_display)

VBox(children=(HBox(children=(Label(value='', layout=Layout(width='50px')), HTML(value='Far<br>Up', layout=Lay…

How well did you do? What differences did you notice between evaluating weights based on 400 simulations versus one?
[Answer](## "The simulations take longer to run, but the feedback is more consistent. This is due to an important theorem in statistics called the Central Limit Theorem. It states that, for large sample sizes, the average (or mean) of the sample values will follow a normal distribution, centered around the true average of the entire population. Furthermore the variability (or spread) of these averages decreases as you increase the sample size. Specifically, the standard deviation of the sample mean distribution is close to the standard deviation of the population divided by the square root of the sample size. These means that with a sample size of 400 we expect the standard deviation in the feedback to be 20 times smaller than in the single simulation case.")

You may now be wondering if there are ways of choosing parameter weights that are less tedious than dialing them in by hand and looking at the feedback. You may even be imagining what they might be, and how to implement them. There are ***many*** ways of setting parameter values to improve an objective function and we are going to start implementing and thinking about some of the simplest and most foundational ones. We start with grid search.

In [4]:
# @markdown Submit your feedback
content_review(f"{feedback_prefix}_M1")

VBox(children=(Label(value=''), HBox(children=(Button(description='🙂', layout=Layout(height='auto', width='aut…

# 1.2.1.2 Grid Search Parameter Tuning
Grid search is both simple and computationally intensive, and so is often reffered to as a 'Brute Force' method. The basic idea, in our case, is to think of the parameters of the policy as sitting in a 48 dimensional, real valued, vector space (one for each element of the weights matrix). In math symbols this is $\theta \in \mathbb{R}^{48}$. We can make a grid over a plausible set of parameters in this vector space, and evaluate the performance of a policy with weights corresponding to each point on the grid. Grids of evenly spaced numbers in vector spaces of arbitrary dimension come up frequently in many applications, so there are nice libraries and functions to help us make such grids. Let's start by only allowing parameters to be one of three values, $0$, $0.5$, or $1$.

In [5]:
# numpy's linspace makes an array of evenly spaced points on a line,
# we need 48 of those
dimensions = [np.linspace(0, 1, 3) for _ in range(48)]
print("'dimensions' is a: " + type(dimensions).__name__)
print('It has a length of ' + str(len(dimensions)))
print('Each element of dimensions looks like:')
display(dimensions[0])

'dimensions' is a: list
It has a length of 48
Each element of dimensions looks like:


array([0. , 0.5, 1. ])

Now let's make a grid of all those dimensions to evaluate our policy at.

In [6]:
# numpy.meshgrid or itertools.product can be used to enumerate every possible
# combination of values in each of those parameter dimension.
try:
  grid = np.meshgrid(*dimensions, indexing='ij')
except ValueError as e:
  display(f"Caught an error: {e}")

'Caught an error: maximum supported dimension for an ndarray is 32, found 48'

This should throw an error. The reason that the meshgrid function is throwing an error is because we are asking it to make a grid with $3^{48} \approx 7.9766 \times 10^{22}$ points in it. For context this is many orders of magnitude greater than the number of synapses in a human brain or the number of grains of sand on all the earth's beaches. More directly relevant, this is also much larger than the maximum number of float64's ($6.85 \times 10^9$) that could theoretically be stored in 51 GB of memory (The memory provision of a 'High-Ram' Colab instance circa September 2023). It is simply not feasible for us to iterate through all of those grid-points. We can reduce our problem to something more tractable by enforcing symmetry (though maybe the optimal strategy isn't symmetric). Because Gridworld is symmetric and the organism is symmetric, it is plausible that a good weight strength between the 'near right' perceptive cell and the 'right direction' output neuron should be the same as the weight strength between the 'near left' perceptive cell and the 'left direction' output neuron, which should also be the same as between 'near up' and 'up direction' and the same as between 'near down' and 'down direction'. Generalizing this thinking we can apply rotational symmetry to reduce the parameters by a factor of 4, down to 12 and then apply flip symmetries (e.g. down and up perceptions should be treated the same for purposes of deciding to move left or right) to further reduce from 12 parameters to 8. In contrast $3^8=6561$ is not [such a big number](## "From a practical perspective thinking about data set sizes, model sizes, and algorithm memory use three distinct order of magnitude categories should be considered. 1) Fits in memory life is great. 2) Fits on disk but not on memory, need to use a database to manage io between disk and memory but life is okay. 3) Does not fit on disk, need to set up a cluster and coordinate computation between different hardware, life is complicated"). This scale of grid easily fits in a (modern) computer's memory, we can readily work with it.

In [7]:
sym_dimensions = [np.linspace(0, 1, 3) for _ in range(8)]
try:
  # Generate all possible combinations of symmetry parameters
  param_combinations = list(itertools.product(*sym_dimensions))
except ValueError as e:
  display(f"Caught an error: {e}")

Okay, now that we have a tractable set of parameters to iterate through, let's implement grid search.

In [8]:
################################################################################
# TODO for students: Complete the lines with ...  to correctly map the symmetric
# parameters to the expanded weight matrix, and to update the best params and
# average score based on each evaluation of the params. Then comment out the
# line below.
raise NotImplementedError("Exercise: implement symmetry, and param updates")
################################################################################

game = GridworldGame(batch_size=9, n_rows=7, n_cols=7,
                     num_critters=1, num_food=10, lifetime=30,
                     rng=np.random.default_rng(48))
best_avg_score = float('-inf')
best_params = None

def convert_symmetry_to_weights(symmetry_params):
  # Initialize the weight matrix with zeros
  weights = np.zeros((4,12))
  symmetry_indices = {
    'Up':    [0,  1,  2,  1,  3,  4,  4,  3,  5,  6,  5,  7],
    'Down':  [7,  5,  6,  5,  3,  4,  4,  3,  1,  2,  1,  0],
    'Left':  [3,  1,  4,  5,  0,  2,  6,  7,  1,  4,  5,  3],
    'Right': [3,  5,  4,  1,  7,  6,  2,  0,  5,  4,  1,  3]}
  # Use the symmetry indices to populate the 48-dimensional weight vector
  for i, direction in enumerate(['Up', 'Down', 'Left', 'Right']):
    for j, idx in enumerate(symmetry_indices[direction]):
      weights[..., ...] = symmetry_params[...]
    return weights

# Loop through each combination
for params in tqdm(param_combinations):
  # Convert symmetry parameters to the actual weights
  weights = convert_symmetry_to_weights(params)

  # Run the game with the weights
  boppp = BatchOptPerceptParamPlayer(game, weights=weights, deterministic=True)
  final_board = game.play_game(players=[boppp], visualize=False)

  # Evaluate the score
  scores = final_board['scores'].flatten()
  avg_score = np.mean(scores)

  # Update best parameters if needed
  if avg_score > ...:
    best_avg_score = ...
    best_params = ...

print(best_params)
print(best_avg_score)

NotImplementedError: ignored

In [9]:
# to_remove solution
game = GridworldGame(batch_size=9, n_rows=7, n_cols=7,
                     num_critters=1, num_food=10, lifetime=30,
                     rng=np.random.default_rng(48))
best_avg_score = float('-inf')
best_params = None

def convert_symmetry_to_weights(symmetry_params):
  # Initialize the weight matrix with zeros
  weights = np.zeros((4,12))
  symmetry_indices = {
    'Up':    [0,  1,  2,  1,  3,  4,  4,  3,  5,  6,  5,  7],
    'Down':  [7,  5,  6,  5,  3,  4,  4,  3,  1,  2,  1,  0],
    'Left':  [3,  1,  4,  5,  0,  2,  6,  7,  1,  4,  5,  3],
    'Right': [3,  5,  4,  1,  7,  6,  2,  0,  5,  4,  1,  3]}
  # Use the symmetry indices to populate the 48-dimensional weight vector
  for i, direction in enumerate(['Up', 'Down', 'Left', 'Right']):
    for j, idx in enumerate(symmetry_indices[direction]):
      weights[i, j] = symmetry_params[idx]
    return weights

# Loop through each combination
for params in tqdm(param_combinations):
  # Convert symmetry parameters to the actual weights
  weights = convert_symmetry_to_weights(params)

  # Run the game with the weights
  boppp = BatchOptPerceptParamPlayer(game, weights=weights, deterministic=True)
  final_board = game.play_game(players=[boppp], visualize=False)

  # Evaluate the score
  scores = final_board['scores'].flatten()
  avg_score = np.mean(scores)

  # Update best parameters if needed
  if avg_score > best_avg_score:
    best_avg_score = avg_score
    best_params = params

print(best_params)
print(best_avg_score)

  0%|          | 0/6561 [00:00<?, ?it/s]

(0.0, 1.0, 0.5, 0.0, 0.0, 0.0, 0.5, 0.0)
3.111111111111111


Even just considering three points per dimension, and only evaluating each grid point using nine simulation runs, this looks like it is going to take about three minutes (on a default Colab instance circa September 2023). If we want to get a better result from this grid search approach we are going to have to both increase the resolution of the grid, and do a better evaluation of each grid point. Edit your solution and the grid definition cell above to increase the grid resolution to 5 points ranging from -1 to 1, and to evaluate each grid point with 25 simulations runs (hint edit batch_size). What does this do to the expected time to complete? This could take awhile so keep reading while you wait for this grid search to complete. You can also cancel the execution of this cell by clicking the square 'stop' symbol to the upper right of the code cell (Colab) or going to the `Runtime` drop-down and selecting `Interrupt execution`.

Even though the grid search idea is simple, 'evaluate all the possible options and pick the best one', from a practical perspective it suffers from what is called the **Curse of Dimensionality**. When 'all possible options' are derived from a combination of parameters in different dimensions, even if we consider only a very small number of possible values in each dimension, the total number of combinations to be explored is the product of the number of parameters to be considered in each dimension. This number of all possible combinations can thus grow very quickly as the number of points per dimension increases, so much so, that it is often referred to as a **combinatorial explosion**. As a result, approaches to finding good parameters using grid search can only ever be effective if the number of dimensions being considered is very low (or if selecting parameters before the heat-death of the universe is not a concern).

**Memory Awareness Exercise**: If you are running this notebook in a Colab instance pop open the resources tab and watch what happens to System RAM when you run the next cell. (If you're running this notebook locally you can use 'top' or similar in a linux environment, or the task manager in windows). Next increase the number of points considered from $3$ to $10$, and run the cell again, still observing the memory usage. This is the combinatorial explosion in action! What number do you need to go to crash your Colab instance (or start using [swap](## "Swap is when your computer runs out space in memory (RAM) and starts to use the disk as 'fake' RAM, which is usually an order of magnitude slower to read an write from, and will considerably slow most processes.") on you local machine?). Warning, save your work before trying this, you might also crash your local machine!

In [10]:
sym_dimensions = [np.linspace(0, 1, 3) for _ in range(8)]
try:
  # Generate all possible combinations of symmetry parameters
  param_combinations = list(itertools.product(*sym_dimensions))
except ValueError as e:
  display(f"Caught an error: {e}")

In [11]:
# @markdown Submit your feedback
content_review(f"{feedback_prefix}_M2")

VBox(children=(Label(value=''), HBox(children=(Button(description='🙂', layout=Layout(height='auto', width='aut…

# 1.2.1.3 Propose and Test Parameter Tuning
This scaling issue with grid search motivates us to come up with a parameter selection process that does not depend on iterating through an excessively long list of possible values. So, instead of labouriously iterating through a vast array of combinations, we'll 'test' parameter combinations (one at a time) that are close to the best known set of parameter values. If these test parameters outperform relative to the current best parameters they are 'accepted' and become the new best parameters. If the test parameters underperform relative to the current best parameters they are rejected and forgotten. After some fixed number of tests are rejected in a row we will say that the process has 'converged'. Let's make this.

In [12]:
################################################################################
# TODO for students: Complete the lines with ...  to set the stopping criterion
# for the while loop, and to update the best params and average score based on
# each evaluation of the params. Then comment out the line below.
# This optimization may take some a few minutes to complete.
raise NotImplementedError("Exercise: implement stopping criteria and param updates")
################################################################################

def propose_and_test(batch_size=25, high_batch_size=400,
                     max_rejected=100,
                     step_scale=1.0,
                     verbose=True):

  game = GridworldGame(batch_size=batch_size, n_rows=7, n_cols=7,
                       num_critters=1, num_food=10, lifetime=30,
                       rng=np.random.default_rng(48))
  high_batch_game = GridworldGame(batch_size=high_batch_size, n_rows=7, n_cols=7,
                                  num_critters=1, num_food=10, lifetime=30,
                                  rng=np.random.default_rng(48))
  # Initialize parameters
  initial_params = np.zeros(8)
  best_params = initial_params
  best_avg_score = float('-inf')
  rejected_count = 0
  total_tests = 0  # Number of iterations
  std_dev = step_scale  # Standard deviation for Gaussian proposal

  # Propose-and-test loop
  start_time = time.time()
  while rejected_count < ...:
    # Propose new parameters: sample from Gaussian centered at best_params
    delta_params = np.random.normal(0, std_dev, best_params.shape)
    proposal_params = best_params + delta_params
    # Convert symmetry parameters to actual weights
    weights = convert_symmetry_to_weights(proposal_params)
    # Run the game with the proposed weights
    boppp = BatchOptPerceptParamPlayer(game, weights=weights, deterministic=False)
    final_board = game.play_game(players=[boppp], visualize=False)
    # Evaluate the score
    scores = final_board['scores'].flatten()
    avg_score = np.mean(scores)

    # If a promising candidate is found, validate it with a high batch size evaluation
    if avg_score > best_avg_score:
      boppp_high_batch = BatchOptPerceptParamPlayer(high_batch_game, weights=weights, deterministic=False)
      final_board_high_batch = high_batch_game.play_game(players=[boppp_high_batch], visualize=False)
      scores_high_batch = final_board_high_batch['scores'].flatten()
      avg_score_high_batch = np.mean(scores_high_batch)
      # Only update best parameters if the candidate also performs well in the
      # high batch size evaluation to avoid choosing parameters based on 'luck'
      # i.e. from a really exceptional batch of simulations
      if avg_score_high_batch > ...:
        best_avg_score = ...
        best_params = ...
        if verbose:
          #print('best params so far:')
          #display(best_params)
          print(f"Best score so far: {best_avg_score}")
          print(f"Found after {rejected_count} tests")
        rejected_count = 0
      else:
        rejected_count += 1
    else:
      rejected_count += 1
    total_tests +=1
  end_time = time.time()
  elapsed_time = end_time - start_time

  if verbose:
    # Print the best found parameters and score
    print("Best Parameters:", best_params)
    print("Best Average Score:", best_avg_score)
    print("Parameter combinations tested:", total_tests)
    print(f"Time taken for the optimization loop: {elapsed_time:.2f} seconds")
  return best_params, best_avg_score

best_params, best_avg_score = propose_and_test()

NotImplementedError: ignored

In [13]:
# to_remove solution

def propose_and_test(batch_size=25, high_batch_size=400,
                     max_rejected=100,
                     step_scale=1.0,
                     verbose=True):

  game = GridworldGame(batch_size=batch_size, n_rows=7, n_cols=7,
                       num_critters=1, num_food=10, lifetime=30,
                       rng=np.random.default_rng(48))
  high_batch_game = GridworldGame(batch_size=high_batch_size, n_rows=7, n_cols=7,
                                  num_critters=1, num_food=10, lifetime=30,
                                  rng=np.random.default_rng(48))
  # Initialize parameters
  initial_params = np.zeros(8)
  best_params = initial_params
  best_avg_score = float('-inf')
  rejected_count = 0
  total_tests = 0  # Number of iterations
  std_dev = step_scale  # Standard deviation for Gaussian proposal

  # Propose-and-test loop
  start_time = time.time()
  while rejected_count < max_rejected:
    # Propose new parameters: sample from Gaussian centered at best_params
    delta_params = np.random.normal(0, std_dev, best_params.shape)
    proposal_params = best_params + delta_params
    # Convert symmetry parameters to actual weights
    weights = convert_symmetry_to_weights(proposal_params)
    # Run the game with the proposed weights
    boppp = BatchOptPerceptParamPlayer(game, weights=weights, deterministic=False)
    final_board = game.play_game(players=[boppp], visualize=False)
    # Evaluate the score
    scores = final_board['scores'].flatten()
    avg_score = np.mean(scores)

    # If a promising candidate is found, validate it with a high batch size evaluation
    if avg_score > best_avg_score:
      boppp_high_batch = BatchOptPerceptParamPlayer(high_batch_game, weights=weights, deterministic=False)
      final_board_high_batch = high_batch_game.play_game(players=[boppp_high_batch], visualize=False)
      scores_high_batch = final_board_high_batch['scores'].flatten()
      avg_score_high_batch = np.mean(scores_high_batch)
      # Only update best parameters if the candidate also performs well in the
      # high batch size evaluation to avoid choosing parameters based on 'luck'
      # i.e. from a really exceptional batch of simulations
      if avg_score_high_batch > best_avg_score:
        best_avg_score = avg_score_high_batch
        best_params = proposal_params
        if verbose:
          #print('best params so far:')
          #display(best_params)
          print(f"Best score so far: {best_avg_score}")
          print(f"Found after {rejected_count} tests")
        rejected_count = 0
      else:
        rejected_count += 1
    else:
      rejected_count += 1
    total_tests +=1
  end_time = time.time()
  elapsed_time = end_time - start_time

  if verbose:
    # Print the best found parameters and score
    print("Best Parameters:", best_params)
    print("Best Average Score:", best_avg_score)
    print("Parameter combinations tested:", total_tests)
    print(f"Time taken for the optimization loop: {elapsed_time:.2f} seconds")
  return best_params, best_avg_score

best_params, best_avg_score = propose_and_test()

Best score so far: 5.5025
Found after 0 tests
Best score so far: 6.0275
Found after 0 tests
Best score so far: 7.925
Found after 1 tests
Best score so far: 9.4325
Found after 0 tests
Best score so far: 10.425
Found after 9 tests
Best score so far: 10.4575
Found after 37 tests
Best score so far: 10.56
Found after 1 tests
Best score so far: 10.8825
Found after 9 tests
Best score so far: 11.45
Found after 8 tests
Best score so far: 11.67
Found after 87 tests
Best Parameters: [ 5.16068532 -2.11451798  5.64432851 -1.07815224 -1.33680313 -0.11250661
 -2.89415789  0.15679894]
Best Average Score: 11.67
Parameter combinations tested: 262
Time taken for the optimization loop: 39.15 seconds


Okay that took less time than grid search and deliverd a way better result. Now that we have this optimization process in place we can play with and think a bit about the 'hyperparameters' of the optimization process, i.e. the *parameters of the process* that we use to select the *parameters for our policy*. Run the code cells below. In each cell the values assigned to `max_rejected` or `step_scale` have been changed from the defaults above. As these cells execute, consider:
* What will happen when `max_rejected` is very low? [Answer](## "`max_rejected` determines how persistent the algorithm is in testing for better parameters before 'giving up' and accepting an answer. If `max_rejected` is too high the algorithm will take additional time to stop. On the other hand if `max_rejected` is too low, the algorithm might not search adequately for improvements and could stop prematurely, settling for a suboptimal solution.")
* How will a very small or a very large `step_scale` affect results? [Answer](## "The value of step_scale influences the speed of convergence towards optimal parameters. If it's too small, the algorithm will inch towards a solution very slowly. Additionally, if the evaluation function is stochastic then the inherent variation in the evaluation function will overwhelm any changes in performance due to very small parameter changes, which will make improvement impossible. An excessively large step_scale can make the method act like it's randomly sampling the parameter space since the guidance from the current best parameters becomes less relevant. A large `step_scale` is less of an issue in this context since the soft-max normalization means that good solutions are possible at many different scales of parameters, i.e. the relative proportions of the weights far more important than the order of magnitude of the weights")

In [14]:
best_params, best_avg_score = propose_and_test(max_rejected=10)

Best score so far: 5.2425
Found after 0 tests
Best score so far: 5.5475
Found after 0 tests
Best score so far: 5.7375
Found after 3 tests
Best score so far: 5.79
Found after 0 tests
Best score so far: 6.115
Found after 1 tests
Best score so far: 8.205
Found after 0 tests
Best Parameters: [ 4.2650492  -1.85844078  2.17254569 -1.16582709  2.34480054  4.70868765
  0.54930758  1.03102517]
Best Average Score: 8.205
Parameter combinations tested: 20
Time taken for the optimization loop: 6.39 seconds


In [15]:
best_params, best_avg_score = propose_and_test(step_scale=0.0001, max_rejected=200)

Best score so far: 6.2325
Found after 0 tests
Best score so far: 6.2675
Found after 4 tests
Best score so far: 6.295
Found after 60 tests
Best score so far: 6.355
Found after 1 tests
Best score so far: 6.3625
Found after 71 tests
Best score so far: 6.3825
Found after 29 tests
Best score so far: 6.385
Found after 169 tests
Best score so far: 6.445
Found after 42 tests
Best Parameters: [-0.00046484 -0.00061694  0.00314175  0.00336695  0.00224368  0.00350342
 -0.00051277  0.00038381]
Best Average Score: 6.445
Parameter combinations tested: 584
Time taken for the optimization loop: 154.55 seconds


In [16]:
best_params, best_avg_score = propose_and_test(step_scale=200, max_rejected=200)

Best score so far: 4.56
Found after 0 tests
Best score so far: 4.8025
Found after 1 tests
Best score so far: 4.9025
Found after 3 tests
Best score so far: 6.8525
Found after 0 tests
Best score so far: 7.0925
Found after 1 tests
Best score so far: 7.45
Found after 15 tests
Best score so far: 7.94
Found after 0 tests
Best score so far: 8.455
Found after 4 tests
Best score so far: 8.6775
Found after 0 tests
Best score so far: 8.775
Found after 1 tests
Best score so far: 8.84
Found after 3 tests
Best score so far: 9.545
Found after 6 tests
Best score so far: 9.6925
Found after 11 tests
Best score so far: 10.1275
Found after 11 tests
Best score so far: 10.175
Found after 1 tests
Best score so far: 10.45
Found after 18 tests
Best Parameters: [  833.47737267 -1152.9726364   1632.3622642     -8.62048108
  1025.40584312  -428.59703006  -960.93194724   -86.65456834]
Best Average Score: 10.45
Parameter combinations tested: 291
Time taken for the optimization loop: 55.90 seconds


In this particular problem, we observed that setting `max_rejected` too low can lead to catastrophic failure, i.e., a good solution is never be found. In contrast, setting `max_rejected` too high only results in taking longer to identify a solution. So, in this kind of problem, it's better err on the side of setting `max_rejected` too high. Similarly, we can see that setting `step_scale` too small might can lead to failure. This could be because the stochasticity of the feedback function is overwhelming the signal from parameter changes, or because the process is caught in a local minima, and small step sizes do not allow it to escape. On the other hand, due to softmax normalization, setting high `step_scale` values doesn't seem too problematic, though with very large step sizes it may be difficult to really 'dial in' a good solution, i.e. the parameters will be close to optimal but unable to settle in on a precise peak due to the large step scale. Therefore, in a problem like this, where the parameter scale can roughly adapt to the step scale, it's better to err on the side of setting `step_scale` higher.

In [17]:
# @markdown Submit your feedback
content_review(f"{feedback_prefix}_M3")

VBox(children=(Label(value=''), HBox(children=(Button(description='🙂', layout=Layout(height='auto', width='aut…

# 1.2.1.4 Reflections on Symmetry

In this sequence, symmetry emerged as a powerful tool for reducing the complexity of our problem. Evolutionary processes are also able to take advantage of symmetry. Many organisms, such as humans and most other vertebrates, exhibit bilateral symmetry, where one half of an organism (roughly) mirrors the other. Plants, too, display varied symmetries, from the radial symmetry of flowers to the repetitive patterns of leaves and branches. However, within these symmetrical development regimes, localized adaptivity breaks perfect symmetry. Take an oak tree, for instance. Each leaf follows the same developmental program encoded in its DNA. Yet, due to local conditions such as sunlight, moisture, and space, some leaves grow large, while others remain small. The developmental program for each leaf might be identical, but diverse outcomes can still arise from localized conditions.

This widespread symmetry in morphology is possible because cell structures in distal parts of the body can be intrinsically 'linked' by executing the same (or similar) developmental programs encoded in identical copies of DNA. Yet, just as with oak tree leaves, localized conditions and adaptions can lead to deviations from symmetry. As another example consider a person doing bicep curls with just one arm - over time, one arm becomes more muscular, breaking the bilateral symmetry typical in human bodies.

In neural systems, the applicability of symmetry is even more nuanced. On the surface, mammalian nervous systems, with their two brain hemispheres, might suggest that symmetry also governs neural systems. But, beyond the macro morphology and into the functional intricacies of the brain, the symmetry appears to break down. At a functional level, neural structures arise not just from pre-programmed ontogenetic processes but also—and often predominantly—from adaptive learning processes. These learning processes, dynamic adaptive adjustments in synaptic connections made in response to experience and environmental inputs, are inherently localized. Consider, does mastering writing with the right hand automatically confer equal proficiency to the left? Not really. Yet, does this deviation from symmetry imply its complete absence? Not quite. While the neural structures in the motor cortex governing the left and right hands will have followed different learning trajectories, the rules of neural plasticity guiding these trajectories are likely consistent. So, at this level of neural functionality, while we may not see direct physiological symmetry, shared plasticity rules between structures can be thought of as another, albeit more abstract, form of symmetry.

This concludes our introductory sequence on optimization. We learned about a few very basic optimization methods, and saw that while grid search was not powerful enough to solve our problem, a simple propose and test method was. We also looked into the importance of hyper-parameters in optimization processes. In the next sequence we will see just how robust this propose and test method is, looking at how it can be used to optimize a policy (behaviour) across of variety of environments, each with its own challenges.

In [18]:
# @markdown Submit your feedback
content_review(f"{feedback_prefix}_M4")

VBox(children=(Label(value=''), HBox(children=(Button(description='🙂', layout=Layout(height='auto', width='aut…

# Quiz

In [19]:
# @markdown **Run this cell** to take the quiz
comprehension_quiz = [
  {"question": "Which of the following best describes the term 'combinatorial explosion'?",
   "type": "multiple_choice",
   "answers": [
      {"answer": "The rapid growth of combinations when considering multiple parameters, often making brute-force searches infeasible.",
       "correct": True,
       "feedback": "Correct! The number of combinations grows exponentially with each added parameter."},
      {"answer": "An unexpected surge in the number of possible combinations due to a system error.",
       "correct": False,
       "feedback": "False. The increase is expected, not due to an error."},
      {"answer": "The explosion of data points in a database due to redundant entries.",
       "correct": False,
       "feedback": "False. This is not the primary meaning of combinatorial explosion."},
      {"answer": "The branching out of possibilities in a decision tree.",
       "correct": False,
       "feedback": "False. While decision trees can grow, this is not what's meant by combinatorial explosion."}
    ]
  },
  {"question": "When conducting a larger number of simulation runs, which of the following statements is true in the context of the Central Limit Theorem and the Law of Large Numbers?",
   "type": "multiple_choice",
   "answers": [
      {"answer": "The sample mean will diverge from the expected value of the random variable.",
       "correct": False,
       "feedback": "False. The law ensures convergence, not divergence."},
      {"answer": "Only the Central Limit Theorem assures that the sample mean will be closer to the expected value.",
       "correct": False,
       "feedback": "False. Both the CLT and LLN play roles in this context."},
      {"answer": "The distribution of the sample mean will be normally distributed regardless of the original distribution, and as the sample size increases, the sample mean will get closer to the expected value.",
       "correct": True,
       "feedback": "Correct! This captures the essence of both the Central Limit Theorem and the Law of Large Numbers."},
      {"answer": "The Law of Large Numbers and the Central Limit Theorem both state that larger samples guarantee a narrower distribution.",
       "correct": False,
       "feedback": "False. Only the CLT speaks about the distribution's shape; LLN focuses on the convergence of the sample mean."}
    ]
  },
  {"question": "Which statement best describes the role of symmetry in neural systems compared to morphological structures?",
   "type": "multiple_choice",
   "answers": [
      {"answer": "Both morphological and neural systems in organisms are governed by perfect bilateral symmetry.",
       "correct": False,
       "feedback": "False. While morphological systems often display symmetry, neural systems have more nuanced symmetry."},
      {"answer": "While morphological structures like body shape often exhibit symmetry due to identical DNA instructions, neural systems show localized adaptivity and learning, making their symmetry more nuanced.",
       "correct": True,
       "feedback": "Correct! Neural systems have functional adaptivity based on learning and environment."},
      {"answer": "Neural systems are always symmetric, and any deviation from symmetry indicates a neural malfunction.",
       "correct": False,
       "feedback": "False. Neural systems can deviate from perfect symmetry based on learning and localized adaptivity."},
      {"answer": "Symmetry in neural systems is only a byproduct of the symmetry seen in morphological structures.",
       "correct": False,
       "feedback": "False. Symmetry in neural systems, while influenced by morphology, has its own complexities due to learning and adaptivity."}
    ]
  }
]

display_quiz(comprehension_quiz)

<IPython.core.display.Javascript object>