In [49]:
import numpy as np
from PIL import Image
import cv2
import pickle
import time
from tqdm import tqdm


In [50]:
SIZE = 10

# for coloring 
thief_key = 1
police_key = 2
gold_key = 3

episodes = 10000
learning_rate = 0.2
gamma = 0.9
truncate_ep = 200

move_penalty = -1
police_penalty = -100
gold_penalty = 50

show_every = 1000

# RGB color coding
d = {1:(255, 0, 0), 2:(0,255,0), 3:(0,0,255)}

In [51]:
# 6 states: x,y for police1, police2 and gold: 3*2 = 6
# we need to create a grid of SIZE for each of the 6 states
# 8 actions: up, down, left, right, up-right, up-left, down-right, down-left
# example if SIZE=2
# {((-1, -1), (-1, -1), (-1, -1)): [-6.592410750130212,
#   -3.2471509698026493,
#   -4.7096766987080025,
#   -6.3440140636200555,
#   -6.902051084930461,
#   -6.515393612375639,
#   -7.2622537200687605,
#   -1.709905353157553],
#  ((-1, -1), (-1, -1), (-1, 0)): [-1.4594508327038005,
#   -4.32404910526015,
#   -7.686202446915885,
#   -2.6785328772095767,
#   -4.254465011558587,
#   -5.19922263273938,
#   -7.765618843532037,
#   -0.10793496750316933],
#  ((-1, -1), (-1, -1), (-1, 1)): [-6.205776496499926,
#   -5.87045518036915,
#   -2.0057085630353964,
#   -2.940079148163205,
#   -0.7256731319898639,
#   -1.6015840101670271,
#   -4.516944477364622,
#   -0.3334368784800672],

# q_table = {}
# for a in range(-SIZE+1, SIZE):
#   for b in range(-SIZE+1, SIZE):
#     for c in range(-SIZE+1, SIZE):
#       for d in range(-SIZE+1, SIZE):
#         for e in range(-SIZE+1, SIZE):
#           for f in range(-SIZE+1, SIZE):
#             q_table[((a,b),(c,d),(e,f))] = [np.random.uniform(-8, 0) for i in range(8)]

In [52]:
# a better method to implemet the Q-table

# Define the shape of the Q-table
# here's how we can visualize it
#   for police1, police2, gold create a 2D matrix to represent their position in the grid. here it means we have three 2D matrices
#   for each combo of a point of all 3 matrices (xy gold, xy police 1, xy police 2), we have 8 possible actions

q_table_shape = (2*SIZE-1, 2*SIZE-1, 2*SIZE-1, 2*SIZE-1, 2*SIZE-1, 2*SIZE-1, 8)

# Initialize the Q-table with random values using a NumPy array
q_table = np.random.uniform(-8, 0, q_table_shape)

# Example of accessing a Q-value for a specific state-action pair
state = (SIZE-1, SIZE-1, SIZE-1, SIZE-1, SIZE-1, SIZE-1)  # Example state (center of the grid)
action = 3  # Example action
q_value = q_table[state + (action,)]
print(q_value)

-3.0672018080984262


In [53]:
class Grid:
    def __init__(self, size=SIZE):
        self.x = np.random.randint(0, size)
        self.y = np.random.randint(0, size)

    def subtract(self, other):
        return (self.x-other.x, self.y-other.y)
        
    def action(self, choice):
        '''
        Gives us 8 total movement options. (0,1,2,3,4,5,6,7)
        left right up down diagonal
        '''
        if choice == 0:
            self.move(x=1, y=1)
        elif choice == 1:
            self.move(x=-1, y=-1)
        elif choice == 2:
            self.move(x=-1, y=1)
        elif choice == 3:
            self.move(x=1, y=-1)
        elif choice == 4:
            self.move(x=1,y=0)
        elif choice == 5:
            self.move(x=0, y=1)
        elif choice == 6:
            self.move(x=-1, y=0)
        elif choice == 7:
            self.move(x=0, y=-1)

    def move(self, x=False, y=False):
        if not x:
            # random num between and inc -1 and 1
            self.x += np.random.randint(-1, 2)
        else:
            self.x += x

        if not y:
            self.y += np.random.randint(-1, 2)
        else:
            self.y += y
        
        if self.x<0:
            self.x=0
        if self.x>=SIZE:
            self.x = SIZE-1
        if self.y<0:
            self.y=0
        if self.y>=SIZE:
            self.y = SIZE-1

    # def isequal(self, other):
    #     if(self.x-other.x==0 and self.y-other.y==0):
    #         return True
    #     else:
    #         return False

In [54]:
def run():

    for eps in range(episodes):
        police1 = Grid()
        police2 = Grid()
        gold = Grid()
        thief = Grid()
        show = False
        if(eps%show_every==0):
            show = True

        for i in range(truncate_ep):
            dstate = (police1.subtract(thief), police2.subtract(thief), gold.subtract(thief))
            action = np.random.randint(0,8)
            thief.action(action)
            if(thief.x==police1.x and thief.y==police1.y):
                reward = police_penalty
            elif(thief.x==police2.x and thief.y==police2.y):
                reward = police_penalty
            elif(thief.x==gold.x and thief.y==gold.y):
                reward = gold_penalty
            else:
                reward = move_penalty

            new_dstate = (police1.subtract(thief), police2.subtract(thief), gold.subtract(thief))
            max_future_qval = np.max(q_table[new_dstate])
            current_qval = q_table[dstate][action]

            if reward == gold_penalty:
                new_qval = gold_penalty
            else:
                new_qval = (1 - learning_rate) * current_qval + learning_rate * (reward + gamma * max_future_qval)
            q_table[dstate][action] = new_qval

            if(show):
                env = np.zeros((SIZE, SIZE, 3), dtype=np.uint8) # 3 is the number of channels for RGB image
                env[gold.x][gold.y] = d[gold_key]
                env[thief.x][thief.y] = d[thief_key]
                env[police1.x][police1.y] = d[police_key]
                env[police2.x][police2.y] = d[police_key]
                
                image = Image.fromarray(env, 'RGB')
                image = image.resize((300, 300))
                cv2.imshow("ENV", np.array(image))

                reward_hit = (reward == gold_penalty or reward == police_penalty)
                
                if reward_hit:
                    if cv2.waitKey(500) and 0xFF == ord('q'):
                        break
                else:
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
                
                if reward_hit:
                    break

In [55]:
def train():

    for eps in tqdm(range(episodes), desc="Episodes"):
        police1 = Grid()
        police2 = Grid()
        gold = Grid()
        thief = Grid()
        show = False
        if(eps%show_every==0):
            show = True

        for _ in tqdm(range(truncate_ep), desc="Truncate Episodes", leave=False):
            dstate = (police1.subtract(thief), police2.subtract(thief), gold.subtract(thief))
            action = np.random.randint(0,8)
            thief.action(action)
            if(thief.x==police1.x and thief.y==police1.y):
                reward = police_penalty
            elif(thief.x==police2.x and thief.y==police2.y):
                reward = police_penalty
            elif(thief.x==gold.x and thief.y==gold.y):
                reward = gold_penalty
            else:
                reward = move_penalty

            new_dstate = (police1.subtract(thief), police2.subtract(thief), gold.subtract(thief))
            max_future_qval = np.max(q_table[tuple(item for subtuple in new_dstate for item in subtuple)])
            current_qval = q_table[tuple(item for subtuple in dstate for item in subtuple) + (action,)]

            if reward == gold_penalty:
                new_qval = gold_penalty
            else:
                new_qval = (1 - learning_rate) * current_qval + learning_rate * (reward + gamma * max_future_qval)
            q_table[tuple(item for subtuple in dstate for item in subtuple) + (action,)] = new_qval

            if(show):
                env = np.zeros((SIZE, SIZE, 3), dtype=np.uint8) # 3 is the number of channels for RGB image
                env[gold.x][gold.y] = d[gold_key]
                env[thief.x][thief.y] = d[thief_key]
                env[police1.x][police1.y] = d[police_key]
                env[police2.x][police2.y] = d[police_key]
                
                image = Image.fromarray(env, 'RGB')
                image = image.resize((300, 300))
                cv2.imshow("ENV", np.array(image))

                reward_hit = (reward == gold_penalty or reward == police_penalty)
                
                if reward_hit:
                    if cv2.waitKey(500) and 0xFF == ord('q'):
                        break
                else:
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
                
                if reward_hit:
                    break

    return q_table
    

In [56]:
q_table = train()

Episodes: 100%|██████████| 10000/10000 [01:48<00:00, 92.38it/s]


In [60]:
def eval(q_table):
    police1 = Grid()
    police2 = Grid()
    gold = Grid()
    thief = Grid()
    show = True

    for _ in tqdm(range(truncate_ep), desc="Truncate Episodes", leave=False):
        dstate = (police1.subtract(thief), police2.subtract(thief), gold.subtract(thief))
        action = q_table[tuple(item for subtuple in dstate for item in subtuple)].argmax()
        thief.action(action)
        if(thief.x==police1.x and thief.y==police1.y):
            reward = police_penalty
        elif(thief.x==police2.x and thief.y==police2.y):
            reward = police_penalty
        elif(thief.x==gold.x and thief.y==gold.y):
            reward = gold_penalty
        else:
            reward = move_penalty

        if(show):
            env = np.zeros((SIZE, SIZE, 3), dtype=np.uint8) # 3 is the number of channels for RGB image
            env[gold.x][gold.y] = d[gold_key]
            env[thief.x][thief.y] = d[thief_key]
            env[police1.x][police1.y] = d[police_key]
            env[police2.x][police2.y] = d[police_key]
            
            image = Image.fromarray(env, 'RGB')
            image = image.resize((300, 300))
            cv2.imshow("ENV", np.array(image))

            reward_hit = (reward == gold_penalty or reward == police_penalty)
            
            if reward_hit:
                if cv2.waitKey(500) and 0xFF == ord('q'):
                    break
            else:
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
            
            if reward_hit:
                break
    print(reward)


In [73]:
eval(q_table)

                                                                   

50


