In [8]:
#Importing necessary dependencies
import itertools
import json
import random
import dataclasses
import pygame
import matplotlib.pyplot as plt

# Initializing the Q-tables. It doesnt contain all the states yet, they will be added when they are explored.
sequence = [''.join(s) for s in list(itertools.product(*[['0','1']] * 4))]
widths = ['0','1','NA']
heights = ['2','3','NA']

qtable = {}
for i in widths:
    for j in heights:
        for k in sequence:
            qtable[str((i,j,k))] = [0,0,0,0]

with open("qtable.json", "w") as f:
    json.dump(qtable, f)



In [11]:


@dataclasses.dataclass
class GameState:
    distance: tuple
    position: tuple
    surroundings: str
    food: tuple

# The class represents the main brain of the agent
class Brain(object):
    def __init__(self, width, height, block_size):
        
        # Represent the pararmeters for the game board
        self.width = width
        self.height = height
        self.block_size = block_size

        # Represent the learning parameters
        self.epsilon = 0.1 
        self.learning_rate = 0.7
        self.discount = 0.5

        # Contains the log of the State/Action pairs and values
        self.qvalues = self.loadqvalues()
        self.log = []


        # Actions that the agent can take
        self.actions = {
            0:'left',
            1:'right',
            2:'up',
            3:'down'
        }
        
    #Resets the log to an empty list
    def reset(self):
        self.log = []

    #Loads the q-values
    def loadqvalues(self, path="qtable.json"):
        with open(path, "r") as f:
            qvalues = json.load(f)
        return qvalues

    #This function decides how the agent should act in a given state and also remebers them.     
    def act(self, snake, food):
        state = self.getstate(snake, food)

        # Chooses a random action
        rand = random.uniform(0,1)
        if rand < self.epsilon:
            action_key = random.choices(list(self.actions.keys()))[0]
        else:
            state_scores = self.qvalues[self.getstateStr(state)]
            action_key = state_scores.index(max(state_scores))
        action_val = self.actions[action_key]
        
        # Remember the actions it took at each state
        self.log.append({
            'state': state,
            'action': action_key
            })
        return action_val
    #Using the Bellman Equation, updates the q-values
    def updateqvalues(self, reason):
        log = self.log[::-1]
        for i, h in enumerate(log[:-1]):
            #If the snake ate itself then we would reward it negatively. 
            #The bellman equation doesnt contain the future reward as the game is over
            if reason == 'Tail': 
                _state = log[0]['state']
                _action = log[0]['action']
                state_str = self.getstateStr(_state)
                reward = -5
                self.qvalues[state_str][_action] = (1-self.learning_rate) * self.qvalues[state_str][_action] +self.learning_rate * reward 
                reason = None
            #If the snake gets killed by going off screen
            #The bellman equation doesnt contain the future reward as the game is over
            elif reason == 'Screen': 
                _state = log[0]['state']
                _action = log[0]['action']
                state_str = self.getstateStr(_state)
                reward = -1
                self.qvalues[state_str][_action] = (1-self.learning_rate) * self.qvalues[state_str][_action] + self.learning_rate * reward 
                reason = None
            #Case where the snake survives
            else:
                # Current state of the snake
                s1 = h['state'] 

                #Previous state of the snake
                s0 = log[i+1]['state'] 
                #Action taken at previous state
                a0 = log[i+1]['action']
                
                #Location in current state
                x1 = s0.distance[0] 
                y1 = s0.distance[1] 
                #Location in previous state
                x2 = s1.distance[0] 
                y2 = s1.distance[1] 

                # Case where the snake would eat the food 
                if s0.food != s1.food: 
                    reward = 10
                #Case where the snake would move closer to the food pallet instead of eating it
                elif (abs(x1) > abs(x2) or abs(y1) > abs(y2)): # Snake is closer to the food, positive reward
                    reward = 1
                #Case where the snake starts to move away from the food pallet
                else:
                    reward = -1
                #The q-values would be updated accordingly   
                state_str = self.getstateStr(s0)
                new_state_str = self.getstateStr(s1)
                # Bellman equation
                self.qvalues[state_str][a0]=(1-self.learning_rate)*(self.qvalues[state_str][a0])+self.learning_rate*(reward+self.discount*max(self.qvalues[new_state_str])) 
    #This function helps in defining the positiong of the snake in the environmnet
    
    def getstate(self, snake, food):
        head = snake[-1]
        #Distance described with respect to the food pallet
        dist_x = food[0] - head[0]
        dist_y = food[1] - head[1]

        if dist_x > 0:
            #Food it to right of the snake
            pos_x = '1' 
        elif dist_x < 0:
            #Food it to left of the snak
            pos_x = '0' 
        else:
            #Snake is in the same direction of the food
            pos_x = 'NA' 

        if dist_y > 0:
            #Food is below snake
            pos_y = '3' 
        elif dist_y < 0:
            #Food is above the snake
            pos_y = '2' 
        else:
            #Snake is in the same direction of the food
            pos_y = 'NA' 
        #Sequence of positions of the snake
        sequences = [
            (head[0]-self.block_size, head[1]),   
            (head[0]+self.block_size, head[1]),         
            (head[0],head[1]-self.block_size),
            (head[0],head[1]+self.block_size),
        ]
        
        surrounding_list = []
        #Describes the environment in relation to the snake. Like where is the screen and where is its tail.
        #Value of 1 would be given for obstacles, and 0 in case nothing dangerous is there.
        for sequence in sequences:
            if sequence[0] < 0 or sequence[1] < 0: 
                surrounding_list.append('1')
            elif sequence[0] >= self.width or sequence[1] >= self.height: 
                surrounding_list.append('1')
            elif sequence in snake[:-1]: 
                surrounding_list.append('1')
            else:
                surrounding_list.append('0')
        surroundings = ''.join(surrounding_list)
        
        return GameState((dist_x, dist_y), (pos_x, pos_y), surroundings, food)

    def getstateStr(self, state):
        return str((state.position[0],state.position[1],state.surroundings))


In [None]:
#Game code adopted from the following website: https://www.edureka.co/blog/snake-game-with-pygame/

#Represents the game environmnet


pygame.init()

#used the make the game environmnet
YELLOW = (255, 255, 102)
BLACK = (0, 0, 0)
GREEN = (0, 255, 0)
BLUE = (50, 153, 213)

BLOCK_SIZE = 20 
DIS_WIDTH = 400
DIS_HEIGHT = 400

FRAMESPEED = 99999



def GameLoop():
    global dis
    
    dis = pygame.display.set_mode((DIS_WIDTH, DIS_HEIGHT))
    pygame.display.set_caption('Snake')
    clock = pygame.time.Clock()

    # Starting position of snake
    x1 = DIS_WIDTH / 2
    y1 = DIS_HEIGHT / 2
    x1_change = 0
    y1_change = 0
    snake_list = [(x1,y1)]
    length_of_snake = 1

    # Create first food pallet at a random location
    food_posx = round(random.randrange(0, DIS_WIDTH - BLOCK_SIZE) / 20.0) * 20.0
    food_posy = round(random.randrange(0, DIS_HEIGHT - BLOCK_SIZE) / 20.0) * 20.0

    dead = False
    reason = None
    while not dead:
        # Get valid actions and implement them
        action = brain.act(snake_list, (food_posx,food_posy))
        if action == "left":
            x1_change = -BLOCK_SIZE
            y1_change = 0
        elif action == "right":
            x1_change = BLOCK_SIZE
            y1_change = 0
        elif action == "up":
            y1_change = -BLOCK_SIZE
            x1_change = 0
        elif action == "down":
            y1_change = BLOCK_SIZE
            x1_change = 0

        # Move the snake
        x1 += x1_change
        y1 += y1_change
        head = (x1,y1)
        snake_list.append(head)

        # Check if snake is off screen, in which case it would be dead
        if x1 >= DIS_WIDTH or x1 < 0 or y1 >= DIS_HEIGHT or y1 < 0:
            reason = 'Screen'
            dead = True

        # Check if snake hit tail, in which case it would be dead

        if len(snake_list) >4:
            if head in snake_list[:-1]:
                reason = 'Tail'
                dead = True

        # Check if snake ate food. If it did then generate a different food pallet
        if x1 == food_posx and y1 == food_posy:
            food_posx = round(random.randrange(0, DIS_WIDTH - BLOCK_SIZE) / 20.0) * 20.0
            food_posy = round(random.randrange(0, DIS_HEIGHT - BLOCK_SIZE) / 20.0) * 20.0
            length_of_snake += 1

        # Delete the last cell since we just added a head for moving, unless we ate a food
        if len(snake_list) > length_of_snake:
            del snake_list[0]

        # Draw food, snake and update score
        dis.fill(BLUE)
        DrawFood(food_posx, food_posy)
        DrawSnake(snake_list)
        DrawScore(length_of_snake - 1)
        pygame.display.update()

        # Update Q Table
        brain.updateqvalues(reason)
        
        # Next Frame
        clock.tick(FRAMESPEED)

    return length_of_snake - 1, reason

#Function that draws the food pallets
def DrawFood(food_posx, food_posy):
    pygame.draw.rect(dis, GREEN, [food_posx, food_posy, BLOCK_SIZE, BLOCK_SIZE])   

#Function that shows the score
def DrawScore(score):
    font = pygame.font.SysFont("comicsansms", 35)
    value = font.render(f"Score: {score}", True, YELLOW)
    dis.blit(value, [0, 0])

#Renders the snake
def DrawSnake(snake_list):
    head = snake_list[0]
    for x in snake_list:
        #This allows us to have one different block to differentiate which part is the head
        if x == head:
            pygame.draw.rect(dis, YELLOW, [x[0], x[1], BLOCK_SIZE, BLOCK_SIZE])
        else:
            pygame.draw.rect(dis, BLACK, [x[0], x[1], BLOCK_SIZE, BLOCK_SIZE])




game_count = 0

brain = Brain(DIS_WIDTH, DIS_HEIGHT, BLOCK_SIZE)
results = []

while game_count <400:
    
    brain.reset()
    #We stop it from making random moves
    if game_count > 200:
        brain.epsilon = 0
    else:
        brain.epsilon = 0.1
    score, reason = GameLoop()
    results.append(score)
    #Outputs the game number along with the score and reason of death
    print(f"Games number: {game_count}; Final Score: {score}; Died because of: {reason}") 
    game_count += 1

pygame.quit()

Games number: 0; Final Score: 0; Died because of: Screen
Games number: 1; Final Score: 3; Died because of: Tail
Games number: 2; Final Score: 3; Died because of: Tail
Games number: 3; Final Score: 3; Died because of: Screen
Games number: 4; Final Score: 3; Died because of: Tail
Games number: 5; Final Score: 2; Died because of: Screen
Games number: 6; Final Score: 3; Died because of: Tail
Games number: 7; Final Score: 3; Died because of: Tail
Games number: 8; Final Score: 3; Died because of: Tail
Games number: 9; Final Score: 3; Died because of: Tail
Games number: 10; Final Score: 3; Died because of: Tail
Games number: 11; Final Score: 3; Died because of: Tail
Games number: 12; Final Score: 4; Died because of: Tail
Games number: 13; Final Score: 3; Died because of: Tail
Games number: 14; Final Score: 4; Died because of: Tail
Games number: 15; Final Score: 2; Died because of: Screen
Games number: 16; Final Score: 4; Died because of: Tail
Games number: 17; Final Score: 5; Died because of:

Games number: 147; Final Score: 9; Died because of: Tail
Games number: 148; Final Score: 1; Died because of: Screen
Games number: 149; Final Score: 2; Died because of: Screen
Games number: 150; Final Score: 3; Died because of: Screen
Games number: 151; Final Score: 3; Died because of: Tail
Games number: 152; Final Score: 4; Died because of: Tail
Games number: 153; Final Score: 2; Died because of: Screen
Games number: 154; Final Score: 4; Died because of: Tail
Games number: 155; Final Score: 4; Died because of: Tail
Games number: 156; Final Score: 12; Died because of: Tail
Games number: 157; Final Score: 6; Died because of: Tail
Games number: 158; Final Score: 3; Died because of: Tail
Games number: 159; Final Score: 3; Died because of: Tail
Games number: 160; Final Score: 6; Died because of: Tail
Games number: 161; Final Score: 4; Died because of: Tail
Games number: 162; Final Score: 2; Died because of: Screen
Games number: 163; Final Score: 5; Died because of: Tail
Games number: 164; F

Games number: 289; Final Score: 49; Died because of: Tail
Games number: 290; Final Score: 20; Died because of: Tail
Games number: 291; Final Score: 17; Died because of: Tail
Games number: 292; Final Score: 30; Died because of: Tail
Games number: 293; Final Score: 23; Died because of: Tail
Games number: 294; Final Score: 21; Died because of: Tail
Games number: 295; Final Score: 19; Died because of: Tail
Games number: 296; Final Score: 17; Died because of: Tail
Games number: 297; Final Score: 25; Died because of: Tail
Games number: 298; Final Score: 12; Died because of: Tail
Games number: 299; Final Score: 19; Died because of: Tail
Games number: 300; Final Score: 35; Died because of: Tail
Games number: 301; Final Score: 32; Died because of: Tail
Games number: 302; Final Score: 28; Died because of: Tail
Games number: 303; Final Score: 23; Died because of: Tail
Games number: 304; Final Score: 6; Died because of: Tail
Games number: 305; Final Score: 17; Died because of: Tail
Games number: 3

In [None]:
plt.plot(results)
plt.xlabel("Number of Games")
plt.ylabel("Score")
plt.title('Performance Measure')
plt.show()
print('Parameters used: Learning rate = 0.7  Discount Rate = 0.5')