In [131]:
# IMPORTS #

import json
import random
from IPython.display import clear_output
import time
import sys
import copy
import math
import numpy as np

In [132]:
# VARIABLE DECLARATIONS #

# directions: 0 = up, 1 = right, 2 = down, 3 = left
# gameState: 1 = running, 0 = game over
gridSize = 10 # 10x10 grid
startingBodyLength = 3

In [133]:
#generate new game state
def newGame():
    state = {
        "gameState": 1,
        "justAteFood": False,
        "turn": 0,
        "head": [round(random.uniform(startingBodyLength-1, gridSize-1-startingBodyLength)),
                 round(random.uniform(startingBodyLength-1, gridSize-1-startingBodyLength))],
        "direction": round(random.uniform(0, 3)),
    }
    body = [state["head"]]
    bx, by = 0, 0
    if state["direction"]==0:
        by = 1
    elif state["direction"]==1:
        bx = -1
    elif state["direction"]==2:
        by = -1
    elif state["direction"]==3:
        bx = 1
    for i in range(1, startingBodyLength):
        prevBody = body[i-1]
        body.append([prevBody[0]+bx, prevBody[1]+by])
    state["body"] = body
    if state["body"][0]!=state["head"]:
        print("HEAD IS NOT ON BODY", state)
        sys.exit(0)
    
    foundGoodFoodStart = False
    while not foundGoodFoodStart:
        fx = round(random.uniform(startingBodyLength-1, gridSize-1-startingBodyLength))
        fy = round(random.uniform(startingBodyLength-1, gridSize-1-startingBodyLength))
        badFood = False
        for i in range(0, startingBodyLength):
            if state["body"][i][0] == fx and state["body"][i][1] == fy:
                badFood = True
        if not badFood:
            state["food"] = [fx, fy]
            foundGoodFoodStart = True
    return state

In [134]:
#advance one turn in game, given state and input
#inp=-1: left, inp=0: forward, inp=1: right
def step(state, inp):
    state["turn"] += 1
    state["direction"] = (state["direction"]+inp+4)%4 #set new direction
    newHeadX, newHeadY = 0, 0
    if state["direction"]==0:
        newHeadY = -1
    elif state["direction"]==1:
        newHeadX = 1
    elif state["direction"]==2:
        newHeadY = 1
    elif state["direction"]==3:
        newHeadX = -1
    newHeadX += state["head"][0]
    newHeadY += state["head"][1]
    state["head"] = [newHeadX, newHeadY] #set new head
    state["body"].insert(0, state["head"]) #add new head to body
    if state["head"][0]==state["food"][0] and state["head"][1]==state["food"][1]: #ate food
        state["justAteFood"]= True
        if len(state["body"])!=gridSize*gridSize: #snake does not fill entire board yet
            possibleFoods = []
            for i in range(0, gridSize*gridSize): #(x, y) -> x+gridSize*y
                possibleFoods.append(i+1) #+1 so 0,0 can also be negative
            for i in range(0, len(state["body"])): # i -> (i%gridSize, floor(i/gridSize))
                ind = state["body"][i][0]+gridSize*state["body"][i][1]
                possibleFoods[ind] *= -1
            i = 0
            while i<len(possibleFoods):
                if possibleFoods[i]<0:
                    del possibleFoods[i]
#                     possibleFoods.splice(i, 1)
                else:
                    i += 1
            nextFoodInd = random.randint(0, len(possibleFoods)-1)
            nextFoodVal = possibleFoods[nextFoodInd]-1
            fx = nextFoodVal%gridSize
            fy = math.floor(nextFoodVal/gridSize)
            state["food"] = [fx, fy]
            
        else: #snake fills entire board
            state["gameState"] = 0
    else: #did not eat food
        state["justAteFood"]= False
        bodyLen = len(state["body"])
        state["body"].pop(bodyLen-1) #get rid of last tail
    
    #check if game over
    hx = state["head"][0]
    hy = state["head"][1]
    if hx<0 or hx >= gridSize or hy<0 or hy >= gridSize: #hit a wall
        state["gameState"] = 0
    else: #check if ran into self
        for i in range(1, len(state["body"])):
            if state["body"][i][0]==hx and state["body"][i][1]==hy:
                state["gameState"] = 0
    
    return state

In [135]:
def steps(state, inps):
    si = copy.deepcopy(state)
#     si = state.copy()
    for i in range(0, len(inps)):
        si = step(si, inps[i])
        if si["gameState"]==0:
            return si
    return si

In [136]:
def visualizeState(state):
    grid = [["_" for x in range(gridSize)] for y in range(gridSize)] 
    grid[state["head"][0]][state["head"][1]] = "#"
    for i in range(1, len(state["body"])):
        try:
            grid[state["body"][i][0]][state["body"][i][1]] = "*"
        except:
            print(state)
            print("=========ERRRORRR==========")
            sys.exit(0)
            
    grid[state["food"][0]][state["food"][1]] = "@"
    
    for i in range(0, len(grid)):
        for j in range(0, len(grid[i])):
            print(grid[i][j], end ="")
        print()

In [137]:
def visualizeSteps(state, inps, speed):
    si = copy.deepcopy(state)
#     si = state.copy()
    for i in range(0, len(inps)):
        si = step(si, inps[i])
        if si["gameState"]==0:
            return
        clear_output(wait=True)
        visualizeState(si)
        time.sleep(speed)
    return

In [138]:
# def reduce(state):
#     res = 0
#     res += state["head"][0]
#     res += state["head"][1]*10
#     res += state["food"][0]*100
#     res += state["food"][1]*1000
#     return res

In [139]:
def isIn(reduState, q_table):
    return q_table.get(reduState, -1)
#     for i in range(0, len(q_table)):
#         if q_table[i][0]==reduState:
#             return i
#     return -1

In [140]:
alpha = 0.7
gamma = 1
epsilon = 0.1
stepReward = -0.5
foodReward = 50.0
deathReward = -5.0

numEps = 1000
q_table = {}

In [141]:
# # epsilon = 0.01
# # numEps = 1

# inps = []
# ssEnd = {}
# # q_table = [[reduce(startState.copy()), 0, 0, 0]] #[reduced State, q_left, q_forward, q_right]

# oldbars = 0

# totalRewards = 0
# barNumbers = 50

# for g in range(0, numEps):
#     startState = newGame()
#     si = copy.deepcopy(startState)

#     #==== progress printing =====
#     if g==numEps-1:
#         ssEnd = copy.deepcopy(startState)
#     bars = round(barNumbers*g/numEps)
#     if bars>oldbars:
#         oldbars = bars
#         clear_output(wait=True)
#         print("["+("="*bars)+(" "*(barNumbers-bars))+"]")
#     #==== end progress printing =====
    
#     while si["gameState"]!=0:
#         rsi = reduce(si)
#         if isIn(rsi, q_table) == -1:
#             q_table[rsi] = [0, 0, 0]
#         rowSi = isIn(rsi, q_table)
        
#         #find q-values of all actions from current state
#         qLeft = rowSi[0]
#         qForward = rowSi[1]
#         qRight = rowSi[2]

#         #find max q-value
#         at = 0
#         if random.uniform(0,1)<epsilon:
#             at = round(random.uniform(-1, 1))
#         else:
#             if qLeft>=qForward and qLeft>=qRight:
#                 at = -1
#             elif qForward>=qLeft and qForward>=qRight:
#                 at = 0
#             elif qRight>=qLeft and qRight>=qForward:
#                 at = 1
        
#         if g==numEps-1:
#             inps.append(at)
        
#         #find next state with max q-value action
#         siPlus1 = step(si, at)
#         sip1hx = siPlus1["head"][0]
#         sip1hy = siPlus1["head"][1]
#         maxFutureQ = 0
#         if 0<=sip1hx<gridSize and 0<=sip1hy<gridSize:
#             redu_siPlus1 = reduce(siPlus1)
#             row_redu_siPlus1 = isIn(redu_siPlus1, q_table)
#             if row_redu_siPlus1!=-1:
#                 qsil = row_redu_siPlus1[0]
#                 qsif = row_redu_siPlus1[1]
#                 qsir = row_redu_siPlus1[2]
#                 if qsil>=qsif and qsil>=qsir:
#                     maxFutureQ = qsil
#                 elif qsif>=qsil and qsif>=qsir:
#                     maxFutureQ = qsif
#                 elif qsir>=qsil and qsir>=qsif:
#                     maxFutureQ = qsir
#             else:
#                 q_table[redu_siPlus1] = [0, 0, 0]
                
#         rt = 0
#         if siPlus1["gameState"]==0:
#             rt = deathReward
#         elif siPlus1["justAteFood"]:
#             rt = foodReward
#         else:
#             rt = stepReward
#         if g==numEps-1:
#             totalRewards += rt
#         q_table[rsi][at+1] = (1-alpha)*q_table[rsi][at+1]+alpha*(rt + gamma*maxFutureQ)
#         si = siPlus1
# visualizeSteps(ssEnd, inps, 0.25)
# print(totalRewards)

In [142]:
# print(len(q_table))

In [143]:
def visualizeStates(state, inps, foods, speed):
    si = copy.deepcopy(state)
    cf = 1
    for i in range(0, len(inps)):
        si = step(si, inps[i])
        if si["gameState"]==0:
            return
        if si["justAteFood"]:
            si["food"] = [foods[cf][0], foods[cf][1]]
            cf+=1
        clear_output(wait=True)
        visualizeState(si)
        time.sleep(speed)
    return

Neural Nets

In [144]:
def isClearStraightAhead(state, hx, hy, d):
    aheadX = hx
    aheadY = hy
    if d == 0:
        aheadY -= 1
    elif d==1:
        aheadX += 1
    elif d==2:
        aheadY += 1
    elif d==3:
        aheadX -= 1
    
    if (aheadX < 0) or (aheadY < 0) or (aheadX>=gridSize) or (aheadY>=gridSize): #about to head into a wall
        return 0
    
    body = state["body"]
    for i in range(len(body)):
        if (body[i][0]==aheadX) and (body[i][1]==aheadY):
            return 0
    return 1

In [145]:
def isClearLeft(state, hx, hy, d):
    lx = hx
    ly = hy
    if d == 0:
        lx -= 1
    elif d==1:
        ly -= 1
    elif d==2:
        lx += 1
    elif d==3:
        ly += 1
    
    if (lx < 0) or (ly < 0) or (lx>=gridSize) or (ly>=gridSize): #wall is to the left
        return 0
    
    body = state["body"]
    for i in range(len(body)):
        if (body[i][0]==lx) and (body[i][1]==ly):
            return 0
    return 1

In [146]:
def isClearRight(state, hx, hy, d):
    rx = hx
    ry = hy
    if d == 0:
        rx += 1
    elif d==1:
        ry += 1
    elif d==2:
        rx -= 1
    elif d==3:
        ry -= 1
    
    if (rx < 0) or (ry < 0) or (rx>=gridSize) or (ry>=gridSize): #wall is to the right
        return 0
    
    body = state["body"]
    for i in range(len(body)):
        if (body[i][0]==rx) and (body[i][1]==ry):
            return 0
    return 1

In [147]:
def getFoodRelatively(hx, hy, d, fx, fy):
    vdir = 0
    vleft = 0
    if d==0:
        vdir = np.array([0, -1])
        vleft = np.array([-1, 0])
    elif d==1:
        vdir = np.array([1, 0])
        vleft = np.array([0, -1])
    elif d==2:
        vdir = np.array([0, 1])
        vleft = np.array([1, 0])
    elif d==3:
        vdir = np.array([-1, 0])
        vleft = np.array([0, 1])
    vdiff = np.array([fx-hx, fy-hy])
    cosang = np.dot(vdir, vdiff)/(np.linalg.norm(vdir)*np.linalg.norm(vdiff))
    if cosang>0.707: #sqrt(2)/2
        return [1, 0, 0]
    elif cosang<-0.707: 
        return [0, 0, 0]
    
    cosang = np.dot(vleft, vdiff)/(np.linalg.norm(vleft)*np.linalg.norm(vdiff))
    if cosang>0.707: #sqrt(2)/2
        return [0, 1, 0]
    elif cosang<-0.707: 
        return [0, 0, 1]

In [148]:
def state2Tuple(state):
    hx = state["head"][0]
    hy = state["head"][1]
    d = state["direction"]
    csa = isClearStraightAhead(state, hx, hy, d)
    cl = isClearLeft(state, hx, hy, d)    
    cr = isClearRight(state, hx, hy, d)
    gfr = getFoodRelatively(hx, hy, d, state["food"][0], state["food"][1])
    fsa = gfr[0]
    fl = gfr[1]
    fr = gfr[2]
    return str(csa)+str(cl)+str(cr)+str(fsa)+str(fl)+str(fr)

In [149]:
def setupQTable():
    for csa in range(0, 2): #clear straight ahead
        for cl in range(0, 2): #clear left
            for cr in range(0, 2): #clear right
                q_table[str(csa)+str(cl)+str(cr)+"001"] = [0,0,0]
                q_table[str(csa)+str(cl)+str(cr)+"010"] = [0,0,0]
                q_table[str(csa)+str(cl)+str(cr)+"100"] = [0,0,0]
                q_table[str(csa)+str(cl)+str(cr)+"000"] = [0,0,0]

In [150]:
def helpQTable():
    helpAmount = 50
    #when food is unobstructed
    q_table['111100'][1] = helpAmount #go straight
    q_table['110100'][1] = helpAmount #go straight
    q_table['101100'][1] = helpAmount #go straight
    q_table['100100'][1] = helpAmount #go straight
    q_table['111001'][2] = helpAmount #go right
    q_table['101001'][2] = helpAmount #go right
    q_table['011001'][2] = helpAmount #go right
    q_table['001001'][2] = helpAmount #go right
    q_table['111010'][0] = helpAmount #go left
    q_table['011010'][0] = helpAmount #go left
    q_table['110010'][0] = helpAmount #go left
    q_table['010010'][0] = helpAmount #go left
    #when only one choice
    q_table['010010'][0] = helpAmount
    q_table['010001'][0] = helpAmount
    q_table['010100'][0] = helpAmount
    q_table['100010'][1] = helpAmount
    q_table['100001'][1] = helpAmount
    q_table['100100'][1] = helpAmount
    q_table['001010'][2] = helpAmount
    q_table['001001'][2] = helpAmount
    q_table['001100'][2] = helpAmount

In [151]:
def readQTable(filename):
    with open(filename) as fileobj:
        qt = eval(fileobj.read())
        for i in qt:
            q_table[i] = qt[i]
        fileobj.close()

#### *alpha needs to be small* ($\alpha$ = 0.2)
#### *gamma also needs to be small* ($\gamma$ = 0.3)
![alt-text](http://www.thelogomix.com/files/imagecache/v3-logo-detail/snake-01.png "a logo")

In [259]:
#q_table[rsi][at+1] = (1-alpha)*q_table[rsi][at+1]+alpha*(rt + gamma*maxFutureQ)
def sim(alpha=0.2, gamma=0.8, minEpsilon = 0.005, maxEpsilon = 0.01, numEps = 1000, foodReward=10, stepReward=-2, deathReward=-10, printStats = False, showLastSim = True):
    # alpha = 0.2
    # gamma = 0.8
    # maxEpsilon = 0.01
    # minEpsilon = 0.005

    # stepReward = -2
    # foodReward = 10.0
    # deathReward = -10.0

    # numEps = 1000
    q_table = {}
    setupQTable()
    helpQTable()
    # readQTable("greatqtable_json.txt")
    # print(q_table)
    stats = {
        "buckets": 10,
        "totalEps": 0,
        "lastBucket": -1
    }
    stats["bucketSize"] = numEps/stats["buckets"]
    stats["totalTurnsLived"] = [0]*stats["buckets"]
    stats["totalFoodEaten"] = [0]*stats["buckets"]
    return runSim(alpha, gamma, minEpsilon, maxEpsilon, numEps, foodReward, stepReward, deathReward, printStats, showLastSim, stats)

In [260]:
def runSim(alpha, gamma, minEpsilon, maxEpsilon, numEps, foodReward, stepReward, deathReward, printStats, showLastSim, stats):
    inps = []
    foods = []
    ssEnd = {}

    oldbars = -1
    totalRewards = 0
    barNumbers = 100

    # numEps = 1
    for g in range(0, numEps):
        epsilon = (1-(g/numEps))*(maxEpsilon-minEpsilon)+minEpsilon
        #get stats
        if printStats:
            bucket = math.floor(g/stats["bucketSize"])
            if bucket>stats["lastBucket"]:
                stats["lastBucket"] = bucket
            stats["totalEps"] += 1
        #end stats
        startState = newGame()
        si = copy.deepcopy(startState)

        #==== progress printing =====
        if g==numEps-1:
            ssEnd = copy.deepcopy(startState)
        bars = round(barNumbers*g/numEps)
        if bars>oldbars:
            oldbars = bars
            clear_output(wait=True)
            print("["+("="*bars)+(" "*(barNumbers-bars))+"]")
        #==== end progress printing =====

        while si["gameState"]!=0:
            if printStats:
                stats["totalTurnsLived"][stats["lastBucket"]]+=1
            rsi = state2Tuple(si)
            rowSi = isIn(rsi, q_table)

            #find q-values of all actions from current state
            qLeft = rowSi[0]
            qForward = rowSi[1]
            qRight = rowSi[2]

            #find max q-value
            at = 0
            if random.uniform(0,1)<epsilon:
                at = round(random.uniform(-1, 1))
            else:
                if qLeft>=qForward and qLeft>=qRight:
                    at = -1
                elif qForward>=qLeft and qForward>=qRight:
                    at = 0
                elif qRight>=qLeft and qRight>=qForward:
                    at = 1

            if g==numEps-1:
                inps.append(at)

            #find next state with max q-value action
            siPlus1 = step(si, at)
            sip1hx = siPlus1["head"][0]
            sip1hy = siPlus1["head"][1]
            maxFutureQ = 0
            if 0<=sip1hx<gridSize and 0<=sip1hy<gridSize:
                redu_siPlus1 = state2Tuple(siPlus1)
                row_redu_siPlus1 = isIn(redu_siPlus1, q_table)
                if row_redu_siPlus1!=-1:
                    qsil = row_redu_siPlus1[0]
                    qsif = row_redu_siPlus1[1]
                    qsir = row_redu_siPlus1[2]
                    if qsil>=qsif and qsil>=qsir:
                        maxFutureQ = qsil
                    elif qsif>=qsil and qsif>=qsir:
                        maxFutureQ = qsif
                    elif qsir>=qsil and qsir>=qsif:
                        maxFutureQ = qsir
                else: #shouldn't be called
                    sys.exit(0)
                    console.log("you got problems")

            rt = 0
            if siPlus1["gameState"]==0:
                rt = deathReward
            elif siPlus1["justAteFood"]:
                rt = foodReward
                if printStats:
                    stats["totalFoodEaten"][stats["lastBucket"]]+=1
            else:
                rt = stepReward
            if g==numEps-1:
                if len(foods)==0 or si["food"][0]!=foods[len(foods)-1][0] or si["food"][1]!=foods[len(foods)-1][1]:
                    foods.append([si["food"][0], si["food"][1]])
                totalRewards += rt
            q_table[rsi][at+1] = (1-alpha)*q_table[rsi][at+1]+alpha*(rt + gamma*maxFutureQ)
            si = siPlus1
    if showLastSim:
        visualizeStates(copy.deepcopy(ssEnd), inps, foods, 0.2)
#     print(totalRewards)
    if printStats:
        return stats
    else:
        return
#     print("Average turns:", (stats["totalTurnsLived"][stats["buckets"]-1]*stats["buckets"]/stats["totalEps"]))
#     print("Average food eaten:", (stats["totalFoodEaten"][stats["buckets"]-1]*stats["buckets"]/stats["totalEps"]))

In [261]:
stats = sim(printStats=True, showLastSim = False)
print(stats["totalFoodEaten"])
print(stats["totalTurnsLived"])

[177, 384, 1116, 1391, 1230, 1199, 1175, 1231, 1567, 1338]
[3936, 5126, 9969, 13193, 11505, 10166, 15560, 18981, 14492, 12778]


In [243]:
# f = open("greatqtable.txt","w")
# f.write( str(q_table) )
# f.close()