In [1]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
from builtins import range, input

# Frame size
LENGTH = 5

In [12]:
class Environment:

    def __init__(self):
        self.board = np.zeros((LENGTH, LENGTH))

        self.ended = False
        self.num_states = None

        self.apple = []
        self.snake = [] 
        
        self.reset()

    def reset(self):
        self.set_snake()
        self.set_apple()
        
    def draw_board(self):
        for i in range(LENGTH):
            for j in range(LENGTH):
                p = False   
                if(j == 0):
                    print("|", end="")
                if(i == self.apple[0] and j == self.apple[1]):
                    print("🍏", end="") 
                    p = True
                for x, s in enumerate(self.snake):
                    if(i == s[0] and j == s[1]):
                        if(x == len(self.snake) -1 ):
                            print("🌝", end="")
                            p = True
                        else:
                            print("🌕", end="")
                            p = True
                if(p == False):
                    print("⌏", end="")
                if(j == LENGTH - 1):
                    print("|")
        print("")
        
    def game_over(self): 
        
        #Case hits on itself
        for p in self.snake[0:-2]:
            if p == self.snake[-1]:
                return True
            
        #Case hits a wall
        for p in self.snake[-1]:
            if(p < 0 or p == LENGTH):
                return True
        return False
    
    
    def get_state(self):
        # S = Total of all environmental variations
        # | S | = 4 ^ (LENGTH)
        # Four possible states, being Snake's Head, Snake's Body, Empty or the Apple 
        # The head and body are different as the states change according to the snake's direction
        h = 0
        for i in range(LENGTH):
            for j in range(LENGTH):
                p = False   
                if(i == self.apple[0] and j == self.apple[1]):
                    v = i + j + 3
                    p = True
                for x, s in enumerate(self.snake):
                    if(i == s[0] and j == s[1]):
                        if(x == len(self.snake) -1 ):
                            v = i + j + 2
                            p = True
                        else:
                            v = i + j +  1
                            p = True
                if(p == False):
                    v = i + j +  0
                h += v
        return h
    
    def reward(self):
        return 20 if self.snake[-1] == self.apple else 0
    
    def set_apple(self):
        while True:
            apple = [int(np.random.rand()*(LENGTH)), int(np.random.rand()*(LENGTH))]
            if(apple not in self.snake):
                break
        self.apple = apple
        
    def set_snake(self):
        self.snake = [[2,0], [2,1]]
    
    def move_snake(self, action):
        # w = 0, a= 1, s = 2, d = 3, 
        movement = [-1, 0] if action == 0 else [0, 1] if action == 3 else [1, 0] if action == 2 else [0,-1]
        if([sum(x) for x in zip(self.snake[-1], movement)] == self.apple):
            self.snake.append(self.apple)
        else:
            for x, s in enumerate(reversed(self.snake)):
                if(x == 0):
                    oldPosition = self.snake[-1]
                    self.snake[-1] = [sum(x) for x in zip(self.snake[-1], movement)]
                #Update body position
                else:
                    old = oldPosition
                    oldPosition = self.snake[(-1 * x) -1]
                    self.snake[(-1 * x) -1] = old

In [30]:
class Agent:

    def __init__(self):
        self.Q = np.zeros([4**LENGTH, 4])
        self.alpha = 0.618
        
    def take_action(self, env, action):
        env.move_snake(action)
        return env.get_state(), env.reward(), env.game_over() 
    
    def update_reward(self, reward):
        self.reward += reward
    
    def train(self, env):
        episodes = 1
        for episode in range(1,episodes+1):
            done = False
            G, reward = 0,0
            state = env.get_state()
            firstState = state
            print("Initial State = {}".format(state))
            while done == False:
                action = np.argmax(self.Q[state]) 
                state2, reward, done = self.take_action(env, action)
                self.Q[state,action] += self.alpha * (reward + np.max(self.Q[state2]) - self.Q[state,action]) 
                G += reward
                state = state2
        finalState = state
        print("Final State = {}".format(finalState))
        print("Reward = {}".format(G))
    
        

In [31]:
env = Environment()
agent = Agent()
agent.train(env)

Initial State = 106
False
False
True
Final State = 104
Reward = 0
