In [1]:
import random
import json

import dataclasses
@dataclasses.dataclass
class State:
    distance: tuple
    target_position: tuple
    obstacle_position: str
    target: tuple


class RL(object):
    def __init__(self, height_squares, width_squares, square_size):
        

        # Bellman eqn parameters 
        self.e = 0.1 #epsilon
        self.lr = 0.5 #learning rate (between 0 and 1)
        self.df = 0.1 #discount factor (between 0 and 1)
        
        
        # qvalue table
        self.qvalues = self.read_qvalues()
        self.qvalues_record = []

        # Action space
        self.actions = {
            0:'up',
            1:'left',
            2:'down',
            3:'right'
        }

        # Screen parameters
        self.height_squares = height_squares
        self.width_squares = width_squares
        self.square_size = square_size
        

    def read_qvalues(self, path="qvalues.json"):
        with open(path, "r") as f:
            qvalues = json.load(f)
        return qvalues

    def write_qvalues(self, path="qvalues.json"):
        with open(path, "w") as f:
            json.dump(self.qvalues, f)
            
    def choose_action(self, chaser, target):
        
        state = self.state_params(chaser, target)

        #Epsilon greedy policy (exploration and exploitation balance)
        n = random.uniform(0,1)
        
        #exploration
        if n < self.e:
            action_key = random.choices(list(self.actions.keys()))[0]  #chooses one of four actions randomly
        #explotation
        else:
            qvals_current_state=self.qvalues[self.state_str(state)]
            max_q_val=max(qvals_current_state)
            action_key = qvals_current_state.index(max_q_val) #chooses action with max q value
            
            
        action = self.actions[action_key]
        self.qvalues_record.append({'state': state, 'action': action_key})
        
        return action
    
    
    
    def Reset(self):
        self.qvalues_record = []
        
        
    
    def New_q_vals(self, reason):
        
        qvalues_record = self.qvalues_record[::-1] 
        
        for indx, record_element in enumerate(qvalues_record[:-1]):
            
            #reward and q value update when the chaser runs into its body
            if reason: # only when chaser meets a wall or runs into its own body
                current_state = qvalues_record[0]['state']
                current_action= qvalues_record[0]['action']
                state_str = self.state_str(current_state)
                reward = -1
                # Bellman equation
                self.qvalues[state_str][current_action] = self.qvalues[state_str][current_action] + self.lr * (reward - self.qvalues[state_str][current_action])
                #after update reset reason to none 
                reason = None
            
            else:
                current_state = record_element['state'] 
                previous_state = qvalues_record[indx+1]['state'] 
                previous_action = qvalues_record[indx+1]['action'] 
                
                #distance chaser-target at current state
                horizontal_dist_curr = current_state.distance[0]
                vertica_dist_curr = current_state.distance[1] 
                #distance chaser-target at previous state
                horizontal_dist_prev = previous_state.distance[0] 
                vertica_dist_prev = previous_state.distance[1] 
                
                
                #reward if the chaser found the target
                if previous_state.target != current_state.target: 
                    reward = 1
                #reward if chaser is closer to the target
                elif (abs(horizontal_dist_prev) > abs(horizontal_dist_curr) or abs(vertica_dist_prev) > abs(vertica_dist_curr)): 
                    reward = 1
                #reward if chaser is further from the target
                else:
                    reward = -1 
                       
                
                # Bellman equation
                state_str = self.state_str(previous_state)
                next_state_string = self.state_str(current_state)
                self.qvalues[state_str][previous_action] = self.qvalues[state_str][previous_action] + self.lr * (reward + self.df*max(self.qvalues[next_state_string]) -self.qvalues[state_str][previous_action]) 


    def state_params(self, chaser, target):
        
      #states for the position of chaser relative to the target
        #measure the vertical and horizontal distance between the chaser and the target
        chaser_head = chaser[-1]
        vertical_dist = target[1] - chaser_head[1]
        horizontal_dist = target[0] - chaser_head[0]

        if vertical_dist < 0:
            pos_u = '1' # target is above chaser
        else:
            pos_u = '0' #target is NOT above the chaser
            
        
        if horizontal_dist < 0:
            pos_l = '1' # target is to the left of the chaser
        else:
            pos_l = '0' #target is NOT to the left of the chaser
            
            
        if vertical_dist > 0:
            pos_d = '1' # target is below chaser
        else:
            pos_d = '0' #target is NOT below the chaser
        
        
        if horizontal_dist > 0:
            pos_r = '1' # target is to the right of the chaser
        else:
            pos_r = '0' #target is NOT in the right of the chaser
        
        
      #states for the position of chaser relative to the walls and chaser body   
        surrounding_positions = [
            (chaser_head[0]-21, chaser_head[1]),   # square to the left of chaser head
            (chaser_head[0]+21, chaser_head[1]),   #square to the right of chaser head      
            (chaser_head[0],                  chaser_head[1]-21), #square above the chaser head
            (chaser_head[0],                  chaser_head[1]+21), #square below the chaser head
        ]
        
        obstacle_position_list = []
        for sq in surrounding_positions:
            if sq[0] < 0 or sq[1] < 0 or sq[0] >= self.width_squares or sq[1] >= self.height_squares: # if the walls are left, above, right or below
                obstacle_position_list.append('1')
            elif sq in chaser[:-1]: # if the chaser body is left, right, below or above
                obstacle_position_list.append('1')
            else:
                obstacle_position_list.append('0')
        obstacle_position = ''.join(obstacle_position_list)

        return State((horizontal_dist, vertical_dist), (pos_u, pos_r, pos_d, pos_l), obstacle_position, target)

    def state_str(self, state):
        return str((state.target_position[0],state.target_position[1],state.target_position[2],state.target_position[3],state.obstacle_position))