### Action space

The action-space is (0, 1, 2, 3) which correspond to:

| action | keyboard | game board |
| ------ | -------- | ---------- |
| 0      | w        | up         |
| 1      | a        | left       |
| 2      | s        | down       |
| 3      | d        | right      |

### State space
The state could represented as grid of actions, where 4 gets mapped to unoccupied


| action | game board |
| ------ | ---------- |
| -3     | food       |
| -2     | empty      |
| -1      | head       |
| 0      | up         |
| 1      | left       |
| 2      | down       |
| 3      | right      |


Alternatively we could use a list of the last actions $n$ actions where $n$ is the total score of the snake plus the snakes initial length.

### Alternative state representations

For recurrent neural networks the following state might make sense:



### Features
Possible useful features:
* distance from head to to food
Possible useful features:

* distance from head to to food
* time until body-coordinate disappears

For recurrent neural network:

* a list of (x, y) snake-coordinates
* the position of the food

For convolutional neural networks:
* binary grid with snake body coordinates
* binary grid with snake head location
* binary grid with food location

In [4]:
import sys
import numpy as np
from gym import utils
from gym.spaces import Box, Discrete

In [5]:
state2text = {
    -3: 'food',    
    -2: 'empty',
    -1:  'head',    
    0:  'up',    
    1:  'left',
    2:  'down',
    3:  'right',
}    

for state, text in state2text.items():
    print(state, text)

-3 food
-2 empty
-1 head
0 up
1 left
2 down
3 right


In [14]:
state2symbol = {
    -3: utils.colorize(u'\u204E', 'cyan'), # \, u25CF, \u25CB, \u25CF
    -2: utils.colorize(u'\u00B7', 'gray'), # \u002A, \u0020, \u1F784, \u2022
    -1: utils.colorize(u'\u2689', 'green'), # \u25CF, \u25CB, \u2687
    0:  utils.colorize(u'\u2191', 'green'), 
    1:  utils.colorize(u'\u2190', 'green'),
    2:  utils.colorize(u'\u2193', 'green'),
    3:  utils.colorize(u'\u2192', 'green'),
}

for state, symbol in state2symbol.items():
    print(state, symbol)

KeyError: 'fuschia'

In [7]:
action2text = {
    0: 'up',    
    1: 'left',
    2: 'down',
    3: 'right',
}

for action, text in action2text.items():
    print(action, text)

0 up
1 left
2 down
3 right


In [8]:
actseq2symbol = {
    (0, 3): utils.colorize(u'\u250c', 'green'),
    (1, 2): utils.colorize(u'\u250c', 'green'),
    (0, 1): utils.colorize(u'\u2510', 'green'),    
    (3, 2): utils.colorize(u'\u2510', 'green'),
    (2, 3): utils.colorize(u'\u2514', 'green'),
    (1, 0): utils.colorize(u'\u2514', 'green'),   
    (2, 1): utils.colorize(u'\u2518', 'green'),
    (3, 0): utils.colorize(u'\u2518', 'green'),
    (1, 1): utils.colorize(u'\u2500', 'green'),
    (3, 3): utils.colorize(u'\u2500', 'green'),
    (0, 0): utils.colorize(u'\u2502', 'green'),        
    (2, 2): utils.colorize(u'\u2502', 'green'),
}

for (ai, aj), b in actseq2symbol.items():
    ti = action2text[ai]
    tj = action2text[aj]
    print((ai,aj), b, ti,tj)

(0, 3) [32m┌[0m up right
(1, 2) [32m┌[0m left down
(0, 1) [32m┐[0m up left
(3, 2) [32m┐[0m right down
(2, 3) [32m└[0m down right
(1, 0) [32m└[0m left up
(2, 1) [32m┘[0m down left
(3, 0) [32m┘[0m right up
(1, 1) [32m─[0m left left
(3, 3) [32m─[0m right right
(0, 0) [32m│[0m up up
(2, 2) [32m│[0m down down


In [9]:
def render(s, score=0, mode='human'):
    outfile = sys.stdout
    outfile.write(f'Score: {score}')
    line = f"\n+{'-'*s.shape[1]}+"
    outfile.write(line)
    for row in s.tolist():
        x = [state2symbol[i] for i in row]
        x = f"\n|{''.join(x)}|"
        outfile.write(x)
    outfile.write(line)
    

m, n = 6, 6    
state_space = Box(low=-3, high=2, shape=(m,n), dtype=np.int64)
x = state_space.sample()
x

array([[ 0,  1,  0,  0,  0,  0],
       [ 0,  2,  2,  0,  1,  0],
       [ 0,  2, -2, -2, -2,  1],
       [ 1,  2,  2,  1,  0,  1],
       [-2,  0, -2,  2,  0,  0],
       [-1,  1,  0,  0, -2,  0]])

In [60]:
action_space = Discrete(4)
action_space.sample()

3

In [61]:
# setup user input 
from IPython.display import clear_output
import sys

key2action = {'w': 0, 'a':1, 's':2, 'd':3}
def get_input():
    
    x = input('Enter move: ')
    clear_output()

    if x=='q':
        sys.exit('Thanks for playing!')
    elif x in ('w','a','s','d'):
        return key2action[x]
    else:
        return get_input()

get_input()

1

In [62]:
action_seq = [2, 3, 0, 3, 2, 2, 2, 1, 0, 1, 2, 2, 1, 0, 0, 0]

snake_seq = [(0,1)]
for a in action_seq:
    i, j = snake_seq[-1]    
    if a==0:   i-=1 # up
    elif a==1: j-=1 # left
    elif a==2: i+=1 # down
    elif a==3: j+=1 # right
    snake_seq.append((i,j))

symbol_seq = [actseq2symbol[(0,0)]]
for ai,aj in zip(action_seq[:-1], action_seq[1:]):
    symbol = actseq2symbol[(ai,aj)]
    symbol_seq.append(symbol)

In [63]:
# generate test game
initial_length = 3
info = {
    'food': (5,4),
    'snake_seq': snake_seq,
    'action_seq': action_seq, 
    'symbol_seq': symbol_seq,
    'snake_length': len(snake_seq)+initial_length
}

In [64]:
def info2state(info):
    s = -2*np.ones((m,n), dtype=np.int64)
    
    # add snake body
    for (i,j), a in zip(info['snake_seq'], info['action_seq']):
        s[i,j] = a

    # add snake head    
    i,j = info['snake_seq'][-1]
    s[i,j] = -1

    # add food
    i,j = info['food']    
    s[i,j] = -3
    
    return s

def render(state, score=0, mode='human'):
    outfile = sys.stdout
    outfile.write(f'Score: {score}\n')
    line = f"+{'-'*state.shape[1]}+\n"
    outfile.write(line)
    for row in state.tolist():
        row = [state2symbol[i] for i in row]
        x = f"|{''.join(row)}|\n"
        outfile.write(x)
    outfile.write(line)
    
# display state
state = info2state(info)
render(state)        

Score: 0
+------+
|[37m·[0m[32m↓[0m[32m→[0m[32m↓[0m[37m·[0m[37m·[0m|
|[32m⚉[0m[32m→[0m[32m↑[0m[32m↓[0m[37m·[0m[37m·[0m|
|[32m↑[0m[32m↓[0m[32m←[0m[32m↓[0m[37m·[0m[37m·[0m|
|[32m↑[0m[32m↓[0m[32m↑[0m[32m←[0m[37m·[0m[37m·[0m|
|[32m↑[0m[32m←[0m[37m·[0m[37m·[0m[37m·[0m[37m·[0m|
|[37m·[0m[37m·[0m[37m·[0m[37m·[0m[36m●[0m[37m·[0m|
+------+


In [65]:
def info2board(info):
    b = np.full((m, n), state2symbol[-2], dtype='<U16')
    
    # add snake body
    for (i,j), s in zip(info['snake_seq'], info['symbol_seq']):
        b[i,j] = s

    # add snake head    
    i,j = info['snake_seq'][-1]
    b[i,j] = state2symbol[-1]

    # add food
    i,j = info['food']    
    b[i,j] = state2symbol[-3]
    
    return b

def render_board(board, score=0):
    outfile = sys.stdout
    outfile.write(f'Score: {score}\n')
    line = f"+{'-'*board.shape[1]}+\n"
    outfile.write(line)
    for row in board.tolist():
        x = f"|{''.join(row)}|\n"
        outfile.write(x)
    outfile.write(line)    
        
# display board
board = info2board(info)
render_board(board)    

Score: 0
+------+
|[37m·[0m[32m│[0m[32m┌[0m[32m┐[0m[37m·[0m[37m·[0m|
|[32m⚉[0m[32m└[0m[32m┘[0m[32m│[0m[37m·[0m[37m·[0m|
|[32m│[0m[32m┌[0m[32m┐[0m[32m│[0m[37m·[0m[37m·[0m|
|[32m│[0m[32m│[0m[32m└[0m[32m┘[0m[37m·[0m[37m·[0m|
|[32m└[0m[32m┘[0m[37m·[0m[37m·[0m[37m·[0m[37m·[0m|
|[37m·[0m[37m·[0m[37m·[0m[37m·[0m[36m●[0m[37m·[0m|
+------+


In [66]:
def step(state, info, action):
    
    # get head of snake
    head = info['snake_seq'][-1]
    i, j = head
    
    # implement action
    if action==0:   i-=1 # up
    elif action==1: j-=1 # left
    elif action==2: i+=1 # down
    elif action==3: j+=1 # right
    
    # check out of bounds
    if i<0 or j<0 or i==m or j==n:
        sys.exit('You lose!')
           
    # check collision
    tail = info['snake_seq'].pop(0)
    if (i,j) in info['snake_seq']:
        sys.exit('You lose!')
    
    # update state
    last_action = info['action_seq'][-1]    
    state[head[0], head[1]] = last_action
    state[tail[0], tail[1]] = -2
    state[i, j] = -1
    
    # update info
    info['symbol_seq'].pop(0)
    info['action_seq'].pop(0)
        
    symbol = actseq2symbol[(last_action, action)]
    info['symbol_seq'].append(symbol)
    info['action_seq'].append(action)
    info['snake_seq'].append((i,j))
        
    # update food
    if (i,j) == info['food']:
        open_locs = list(zip(*np.where(state==-2)))
        ind = np.random.randint(len(open_locs), size=1)[0]
        info['food'] = open_locs[ind]
        info['snake_length'] +=1
    return state, info

for a in [0,3,2,3,0,3,3,2,2,2,2,2,2,2,2]:
    state, info = step(state, info, a)
    render_board(info2board(info))

Score: 0
+------+
|[32m⚉[0m[37m·[0m[32m┌[0m[32m┐[0m[37m·[0m[37m·[0m|
|[32m│[0m[32m└[0m[32m┘[0m[32m│[0m[37m·[0m[37m·[0m|
|[32m│[0m[32m┌[0m[32m┐[0m[32m│[0m[37m·[0m[37m·[0m|
|[32m│[0m[32m│[0m[32m└[0m[32m┘[0m[37m·[0m[37m·[0m|
|[32m└[0m[32m┘[0m[37m·[0m[37m·[0m[37m·[0m[37m·[0m|
|[37m·[0m[37m·[0m[37m·[0m[37m·[0m[36m●[0m[37m·[0m|
+------+
Score: 0
+------+
|[32m┌[0m[32m⚉[0m[32m┌[0m[32m┐[0m[37m·[0m[37m·[0m|
|[32m│[0m[37m·[0m[32m┘[0m[32m│[0m[37m·[0m[37m·[0m|
|[32m│[0m[32m┌[0m[32m┐[0m[32m│[0m[37m·[0m[37m·[0m|
|[32m│[0m[32m│[0m[32m└[0m[32m┘[0m[37m·[0m[37m·[0m|
|[32m└[0m[32m┘[0m[37m·[0m[37m·[0m[37m·[0m[37m·[0m|
|[37m·[0m[37m·[0m[37m·[0m[37m·[0m[36m●[0m[37m·[0m|
+------+
Score: 0
+------+
|[32m┌[0m[32m┐[0m[32m┌[0m[32m┐[0m[37m·[0m[37m·[0m|
|[32m│[0m[32m⚉[0m[37m·[0m[32m│[0m[37m·[0m[37m·[0m|
|[32m│[0m[32m┌[0m[32m┐[0m[32m│[0m[37m

SystemExit: You lose!

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [599]:
render(state)

Score: 0
+------+
|[32m→[0m[32m↓[0m[32m→[0m[32m→[0m[32m↓[0m |
|[32m↑[0m[32m→[0m[32m↑[0m [32m↓[0m |
|[32m↑[0m[36m★[0m  [32m↓[0m |
|[32m↑[0m   [32m↓[0m |
|[32m↑[0m[32m←[0m  [32m↓[0m |
|    [32m●[0m |
+------+


In [154]:
from random import randint
from sys import exit
from msvcrt import getch
from copy import deepcopy
import numpy as np
import string
import os


class game:
    #------------------
    # attributes
    #------------------
    def __init__(self, w, h):
        self.board = np.zeros((h,w), dtype=np.int)  # board
        self.score = 0


    #------------------
    # implement move
    def move_blocks(self):
        def get_move():
            while True:
                key = ord(getch())
                if key == 27: break
                elif key == 224: 
                    key = ord(getch())
                    if key == 80: return 'down'
                    if key == 72: return 'up'
                    if key == 77: return 'right'
                    if key == 75: return 'left'

In [None]:



out = self.desc.copy().tolist()
out = [[c.decode('utf-8') for c in line] for line in out]
taxirow, taxicol, passidx, destidx = self.decode(self.s)
def ul(x): return "_" if x == " " else x
if passidx < 4:
    out[1+taxirow][2*taxicol+1] = utils.colorize(out[1+taxirow][2*taxicol+1], 'yellow', highlight=True)
    pi, pj = self.locs[passidx]
    out[1+pi][2*pj+1] = utils.colorize(out[1+pi][2*pj+1], 'blue', bold=True)
else: # passenger in taxi
    out[1+taxirow][2*taxicol+1] = utils.colorize(ul(out[1+taxirow][2*taxicol+1]), 'green', highlight=True)

di, dj = self.locs[destidx]
out[1+di][2*dj+1] = utils.colorize(out[1+di][2*dj+1], 'magenta')
outfile.write("\n".join(["".join(row) for row in out])+"\n")
if self.lastaction is not None:
    outfile.write("  ({})\n".format(["South", "North", "East", "West", "Pickup", "Dropoff"][self.lastaction]))
else: outfile.write("\n")

# No need to return anything for human
if mode != 'human':

NameError: name 's' is not defined

In [12]:
import sys
from six import StringIO
from gym import utils
from gym.envs.toy_text import discrete
import numpy as np

MAP = [
    "+---------+",
    "|R: | : :G|",
    "| : : : : |",
    "| : : : : |",
    "| | : | : |",
    "|Y| : |B: |",
    "+---------+",
]

class TaxiEnv(discrete.DiscreteEnv):
    """
    The Taxi Problem
    from "Hierarchical Reinforcement Learning with the MAXQ Value Function Decomposition"
    by Tom Dietterich
    rendering:
    - blue: passenger
    - magenta: destination
    - yellow: empty taxi
    - green: full taxi
    - other letters: locations
    """
    metadata = {'render.modes': ['human', 'ansi']}

    def __init__(self):
        self.desc = np.asarray(MAP,dtype='c')

        self.locs = locs = [(0,0), (0,4), (4,0), (4,3)]

        nS = 500
        nR = 5
        nC = 5
        maxR = nR-1
        maxC = nC-1
        isd = np.zeros(nS)
        nA = 6
        P = {s : {a : [] for a in range(nA)} for s in range(nS)}
        for row in range(5):
            for col in range(5):
                for passidx in range(5):
                    for destidx in range(4):
                        state = self.encode(row, col, passidx, destidx)
                        if passidx < 4 and passidx != destidx:
                            isd[state] += 1
                        for a in range(nA):
                            # defaults
                            newrow, newcol, newpassidx = row, col, passidx
                            reward = -1
                            done = False
                            taxiloc = (row, col)

                            if a==0:
                                newrow = min(row+1, maxR)
                            elif a==1:
                                newrow = max(row-1, 0)
                            if a==2 and self.desc[1+row,2*col+2]==b":":
                                newcol = min(col+1, maxC)
                            elif a==3 and self.desc[1+row,2*col]==b":":
                                newcol = max(col-1, 0)
                            elif a==4: # pickup
                                if (passidx < 4 and taxiloc == locs[passidx]):
                                    newpassidx = 4
                                else:
                                    reward = -10
                            elif a==5: # dropoff
                                if (taxiloc == locs[destidx]) and passidx==4:
                                    done = True
                                    reward = 20
                                elif (taxiloc in locs) and passidx==4:
                                    newpassidx = locs.index(taxiloc)
                                else:
                                    reward = -10
                            newstate = self.encode(newrow, newcol, newpassidx, destidx)
                            P[state][a].append((1.0, newstate, reward, done))
        isd /= isd.sum()
        discrete.DiscreteEnv.__init__(self, nS, nA, P, isd)

    def encode(self, taxirow, taxicol, passloc, destidx):
        # (5) 5, 5, 4
        i = taxirow
        i *= 5
        i += taxicol
        i *= 5
        i += passloc
        i *= 4
        i += destidx
        return i

    def decode(self, i):
        out = []
        out.append(i % 4)
        i = i // 4
        out.append(i % 5)
        i = i // 5
        out.append(i % 5)
        i = i // 5
        out.append(i)
        assert 0 <= i < 5
        return reversed(out)

    def render(self, mode='human'):
        outfile = StringIO() if mode == 'ansi' else sys.stdout

        out = self.desc.copy().tolist()
        out = [[c.decode('utf-8') for c in line] for line in out]
        taxirow, taxicol, passidx, destidx = self.decode(self.s)
        def ul(x): return "_" if x == " " else x
        if passidx < 4:
            out[1+taxirow][2*taxicol+1] = utils.colorize(out[1+taxirow][2*taxicol+1], 'yellow', highlight=True)
            pi, pj = self.locs[passidx]
            out[1+pi][2*pj+1] = utils.colorize(out[1+pi][2*pj+1], 'blue', bold=True)
        else: # passenger in taxi
            out[1+taxirow][2*taxicol+1] = utils.colorize(ul(out[1+taxirow][2*taxicol+1]), 'green', highlight=True)

        di, dj = self.locs[destidx]
        out[1+di][2*dj+1] = utils.colorize(out[1+di][2*dj+1], 'magenta')
        outfile.write("\n".join(["".join(row) for row in out])+"\n")
        if self.lastaction is not None:
            outfile.write("  ({})\n".format(["South", "North", "East", "West", "Pickup", "Dropoff"][self.lastaction]))
        else: outfile.write("\n")

        # No need to return anything for human
        if mode != 'human':
            return outfile

gym.wrappers.time_limit.TimeLimit

In [10]:
env.close()