In [345]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import plotly.express as px
from tqdm import tqdm
import copy
from deepdiff import DeepDiff
import random

## CREATE A MAZE

In [384]:
def create_maze(size):
    a,b = size
    random_array = np.random.rand(a,b)
    random_array = (random_array >= 0.8).astype(int)

    return np.dstack([random_array]*3)

maze = create_maze((10,10))
px.imshow(maze,zmin = 0,zmax = 1)

In [385]:
start = (0,0)
end = (9,9)

maze[start[0],start[1],:] = 0
maze[end[0],end[1],:] = 0

## RFL MODEL

In [390]:
class MazeSolver():
    def __init__(self,start,end,maze,round = 2):
        self.all_states = [(i,j) for i in range(maze.shape[0]) for j in range(maze.shape[1])]
        self.states = [(i,j) for (i,j) in self.all_states if not maze[i,j,:].mean() > 0 and (i,j) != end]
        self.values = {(i,j) : 0 for (i,j) in self.all_states if not maze[i,j,:].mean() > 0}
        self.round = round
        self.policy_iter()
        self.prev_policy = {}
        self.start = start
        self.end = end
        self.reward = -1
        self.discount_factor = 1
        self.delta = float('inf')
        self.val_matrix = np.zeros((maze.shape[0],maze.shape[1]))
        self.maze = maze

    def Delta(self):
        max = 0
        for state in self.states:
            diff = np.abs(self.values[state] - self.prev_values[state])
            if  diff > max:
                max = diff
        self.delta = max

    def policy_eval(self):
        while self.delta > 10**(-self.round):
            self.prev_values = copy.deepcopy(self.values)
            for state in self.states:
                temp = 0
                for action in self.transition_probs[state].keys():
                    nxt_state = tuple(np.array(state) + np.array(action))
                    temp +=  self.transition_probs[state][action]*(self.reward + self.discount_factor*self.values[nxt_state])
                self.values[state] = temp
                self.Delta()
                   
    def policy_iter(self):
        acns = [(0,1),(1,0),(-1,0),(0,-1)]
        self.transition_probs = {}
        self.policy = {}
        for state in self.states:
            values = []
            for ac in acns:
                nxt_state = tuple(np.array(state) + np.array(ac))
                if nxt_state in self.all_states:
                    a,b = nxt_state
                    if not maze[a,b,:].mean() > 0:
                        values.append(self.values[nxt_state])
                    else:
                        values.append(-float('inf'))
                else:
                    values.append(-float('inf'))
            values = np.array(values)
            values = np.round(values,self.round)
            feasible_acns = np.array(acns)[(values == np.max(values)) & (values != -float('inf'))].tolist()
            self.policy[state] = feasible_acns
            prob = 1/len(feasible_acns) if len(feasible_acns) > 0 else 0
            self.transition_probs[state] = {tuple(ac) : prob for ac in feasible_acns}

    def policy_stable(self):
        diff = DeepDiff(self.policy,self.prev_policy)
        return True if not diff else False

    def plot_path(self,max_count = 1000):
        maze = self.maze.copy()
        zmax = np.max(maze)
        maze[start[0],start[1],:] = [0,255,0]
        state = start
        count = 0
        while state != end and count < max_count:
            action = random.choice(self.policy[state])
            state = tuple(map(sum, zip(state,action)))
            a,b = state
            maze[a,b,:] = [0,0,255]
            count += 1
        maze[end[0],end[1],:] = [255,0,0]
        plt = px.imshow(maze,zmin = 0,zmax = zmax)
        plt.show()

    def run(self):
        count = 0
        while not self.policy_stable():
            self.policy_eval() 
            self.prev_policy = copy.deepcopy(self.policy)
            self.policy_iter()
            count += 1
            print(count,' pass done.')
        print('optimization done')


In [391]:
model = MazeSolver(start,end,maze,round = 1)

## Run

In [392]:
model.run()

1  pass done.
2  pass done.
optimization done


In [393]:
model.plot_path()