In [282]:
# given state s, input i, get next state of game
# given state s, inputs i, get state after len(i) turns of game
# generate start state
# visualize state

In [283]:
# IMPORTS #

import json
import random
from IPython.display import clear_output
import time
import sys
import copy

In [284]:
# VARIABLE DECLARATIONS #

# directions: 0 = up, 1 = right, 2 = down, 3 = left
# gameState: 1 = running, 0 = game over
gridSize = 10 # 10x10 grid
startingBodyLength = 3

In [285]:
#generate new game state
def newGame():
    state = {
        "gameState": 1,
        "justAteFood": False,
        "turn": 0,
        "head": [round(random.uniform(startingBodyLength-1, gridSize-1-startingBodyLength)),
                 round(random.uniform(startingBodyLength-1, gridSize-1-startingBodyLength))],
        "direction": round(random.uniform(0, 3)),
    }
    body = [state["head"]]
    bx, by = 0, 0
    if state["direction"]==0:
        by = 1
    elif state["direction"]==1:
        bx = -1
    elif state["direction"]==2:
        by = -1
    elif state["direction"]==3:
        bx = 1
    for i in range(1, startingBodyLength):
        prevBody = body[i-1]
        body.append([prevBody[0]+bx, prevBody[1]+by])
    state["body"] = body
    if state["body"][0]!=state["head"]:
        print("HEAD IS NOT ON BODY", state)
        sys.exit(0)
    
    foundGoodFoodStart = False
    while not foundGoodFoodStart:
        fx = round(random.uniform(startingBodyLength-1, gridSize-1-startingBodyLength))
        fy = round(random.uniform(startingBodyLength-1, gridSize-1-startingBodyLength))
        badFood = False
        for i in range(0, startingBodyLength):
            if state["body"][i][0] == fx and state["body"][i][1] == fy:
                badFood = True
        if not badFood:
            state["food"] = [fx, fy]
            foundGoodFoodStart = True
    return state

In [286]:
#advance one turn in game, given state and input
#inp=-1: left, inp=0: forward, inp=1: right
def step(state, inp):
    state["turn"] += 1
    state["direction"] = (state["direction"]+inp+4)%4 #set new direction
    newHeadX, newHeadY = 0, 0
    if state["direction"]==0:
        newHeadY = -1
    elif state["direction"]==1:
        newHeadX = 1
    elif state["direction"]==2:
        newHeadY = 1
    elif state["direction"]==3:
        newHeadX = -1
    newHeadX += state["head"][0]
    newHeadY += state["head"][1]
    state["head"] = [newHeadX, newHeadY] #set new head
    state["body"].insert(0, state["head"]) #add new head to body
    if state["head"][0]==state["food"][0] and state["head"][1]==state["food"][1]: #ate food
        state["justAteFood"]= True
        if len(state["body"])!=gridSize*gridSize: #snake does not fill entire board yet
            foundGoodFoodStart = False
            while not foundGoodFoodStart: #create new food
                fx = round(random.uniform(startingBodyLength-1, gridSize-1-startingBodyLength))
                fy = round(random.uniform(startingBodyLength-1, gridSize-1-startingBodyLength))
                badFood = False
                for i in range(0, startingBodyLength):
                    if state["body"][i][0] == fx and state["body"][i][1] == fy:
                        badFood = True
                if not badFood:
                    state["food"] = [fx, fy]
                    foundGoodFoodStart = True
        else: #snake fills entire board
            state["gameState"] = 0
    else: #did not eat food
        state["justAteFood"]= False
        bodyLen = len(state["body"])
        state["body"].pop(bodyLen-1) #get rid of last tail
    
    #check if game over
    hx = state["head"][0]
    hy = state["head"][1]
    if hx<0 or hx >= gridSize or hy<0 or hy >= gridSize: #hit a wall
        state["gameState"] = 0
    else: #check if ran into self
        for i in range(1, len(state["body"])):
            if state["body"][i][0]==hx and state["body"][i][1]==hy:
                state["gameState"] = 0
    
    return state

In [287]:
def steps(state, inps):
    si = copy.deepcopy(state)
#     si = state.copy()
    for i in range(0, len(inps)):
        si = step(si, inps[i])
        if si["gameState"]==0:
            return si
    return si

In [288]:
def visualizeState(state):
    grid = [["_" for x in range(gridSize)] for y in range(gridSize)] 
    grid[state["head"][0]][state["head"][1]] = "#"
    for i in range(1, len(state["body"])):
        try:
            grid[state["body"][i][0]][state["body"][i][1]] = "*"
        except:
            print(state)
            print("=========ERRRORRR==========")
            sys.exit(0)
            
    grid[state["food"][0]][state["food"][1]] = "@"
    
    for i in range(0, len(grid)):
        for j in range(0, len(grid[i])):
            print(grid[i][j], end ="")
        print()

In [289]:
def visualizeSteps(state, inps, speed):
    si = copy.deepcopy(state)
#     si = state.copy()
    for i in range(0, len(inps)):
        si = step(si, inps[i])
        if si["gameState"]==0:
            return
        clear_output(wait=True)
        visualizeState(si)
        time.sleep(speed)
    return

In [290]:
def reduce(state):
    res = 0
    res += state["head"][0]
    res += state["head"][1]*10
    res += state["food"][0]*100
    res += state["food"][1]*1000
    return res

In [291]:
def isIn(reduState, q_table):
    return q_table.get(reduState, -1)
#     for i in range(0, len(q_table)):
#         if q_table[i][0]==reduState:
#             return i
#     return -1

In [309]:
alpha = 0.7
gamma = 1
epsilon = 0.1
stepReward = -0.5
foodReward = 50.0
deathReward = -5.0

numEps = 100000
q_table = {}

In [316]:
# epsilon = 0.01
# numEps = 1

inps = []
ssEnd = {}
# q_table = [[reduce(startState.copy()), 0, 0, 0]] #[reduced State, q_left, q_forward, q_right]

oldbars = 0

totalRewards = 0
barNumbers = 50

for g in range(0, numEps):
    startState = newGame()
    si = copy.deepcopy(startState)

    #==== progress printing =====
    if g==numEps-1:
        ssEnd = copy.deepcopy(startState)
    bars = round(barNumbers*g/numEps)
    if bars>oldbars:
        oldbars = bars
        clear_output(wait=True)
        print("["+("="*bars)+(" "*(barNumbers-bars))+"]")
    #==== end progress printing =====
    
    while si["gameState"]!=0:
        rsi = reduce(si)
        if isIn(rsi, q_table) == -1:
            q_table[rsi] = [0, 0, 0]
        rowSi = isIn(rsi, q_table)
        
        #find q-values of all actions from current state
        qLeft = rowSi[0]
        qForward = rowSi[1]
        qRight = rowSi[2]

        #find max q-value
        at = 0
        if random.uniform(0,1)<epsilon:
            at = round(random.uniform(-1, 1))
        else:
            if qLeft>=qForward and qLeft>=qRight:
                at = -1
            elif qForward>=qLeft and qForward>=qRight:
                at = 0
            elif qRight>=qLeft and qRight>=qForward:
                at = 1
        
        if g==numEps-1:
            inps.append(at)
        
        #find next state with max q-value action
        siPlus1 = step(si, at)
        sip1hx = siPlus1["head"][0]
        sip1hy = siPlus1["head"][1]
        maxFutureQ = 0
        if 0<=sip1hx<gridSize and 0<=sip1hy<gridSize:
            redu_siPlus1 = reduce(siPlus1)
            row_redu_siPlus1 = isIn(redu_siPlus1, q_table)
            if row_redu_siPlus1!=-1:
                qsil = row_redu_siPlus1[0]
                qsif = row_redu_siPlus1[1]
                qsir = row_redu_siPlus1[2]
                if qsil>=qsif and qsil>=qsir:
                    maxFutureQ = qsil
                elif qsif>=qsil and qsif>=qsir:
                    maxFutureQ = qsif
                elif qsir>=qsil and qsir>=qsif:
                    maxFutureQ = qsir
            else:
                q_table[redu_siPlus1] = [0, 0, 0]
                
        rt = 0
        if siPlus1["gameState"]==0:
            rt = deathReward
        elif siPlus1["justAteFood"]:
            rt = foodReward
        else:
            rt = stepReward
        if g==numEps-1:
            totalRewards += rt
        q_table[rsi][at+1] = (1-alpha)*q_table[rsi][at+1]+alpha*(rt + gamma*maxFutureQ)
        si = siPlus1
visualizeSteps(ssEnd, inps, 0.25)
print(totalRewards)

__________
__________
__________
__________
____@_____
__________
_____*#___
_____**___
__________
__________
88.0


In [315]:
print(len(q_table))

2475
