In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import re
import time
import math
import seaborn as sns
import yaml
import os
import datetime

from src.mazeworld import SimpleWallMazeWorld2
from src.mdp import MDP, LegibleTaskMDP, Utilities
from tqdm import tqdm
from itertools import combinations
from pathlib import Path
from typing import Dict, List, Tuple
np.set_printoptions(precision=9, linewidth=2000, threshold=10000, suppress=True)
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [None]:
# Automatically gets list of all possible goal states
#   states: world states
#   goal: MDP goal, can be a single goal like "A" or a composition of goals "ABC"
#   with_objs: defines if the states include the existence of an objective in the state or not
#   goals: used only if with_objs is false, list of tuples with the goals and the corresponding goals
def get_goal_states(states: np.ndarray, goal: str, with_objs: bool=True, goals: List[Tuple]=[]):

    if with_objs:
        state_lst = list(states)
        return [state_lst.index(x) for x in states if x.find(goal) != -1]
    else:
        for g in goals:
            if g[2].find(goal) != -1:
                return str(g[0]) + ' ' + str(g[1])

# Gets a list of all possible initial states
#   states: world states
#   task_locs: list of tuples with the states of each possible goal in the scenario
def get_initial_states(states: np.ndarray, task_locs: List[Tuple]):

    return [state for state in states if state.find('N') != -1] + \
           [str(e[0]) + ' ' + str(e[1]) + ' ' + e[2] for e in task_locs]

# Simulates the performance of an optimal MDP vs a legible MDP
#   mdp: optimal MDP
#   pol: optimal policy
#   leg_mdp: legible MDP
#   leg_pol: legible policy
#   x0: initial position for the trajectories
#   n_trajs: number of trajectories to test performance
#   rng_gen: Random number generator
def simulate(mdp: MDP, pol: np.ndarray, leg_mdp: LegibleTaskMDP, leg_pol: np.ndarray, x0: str,
             n_trajs: int, goal: int, rng_gen: np.random.Generator):

    mdp_trajs = []
    tasks_trajs = []

    for _ in tqdm(range(n_trajs), desc='Simulate Trajectories'):
        traj, acts = mdp.trajectory(x0, pol, rng_gen)
        traj_leg, acts_leg = leg_mdp.trajectory(x0, leg_pol, rng_gen)
        mdp_trajs += [[traj, acts]]
        tasks_trajs += [[traj_leg, acts_leg]]

    mdp_r = mdp.trajectory_reward(mdp_trajs)
    mdp_rl = leg_mdp.trajectory_reward(mdp_trajs, goal)
    task_r = mdp.trajectory_reward(tasks_trajs)
    task_rl = leg_mdp.trajectory_reward(tasks_trajs, goal)

    return mdp_r, mdp_rl, task_r, task_rl

In [None]:
# Create a figure for the given scenario
#   n_rows: number of rows in the maze world
#   n_cols: number of columns in the maze world
#   obj_place: list of tuples with the position for each goal and respective goal
#   walls: list of walls that make the maze world
def create_world_view(n_rows: int, n_cols: int, obj_place: List, walls: List=None):
    
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot()
    plt.xlim(0, n_cols)
    plt.ylim(0,  n_rows)
    plt.xticks([i + 1 for i in range(n_cols)])
    plt.yticks([i + 1 for i in range(n_rows)])
    ax.tick_params(axis='x', which='both', bottom=False, top=False)
    ax.tick_params(axis='y', which='both', left=False, right=False)
    ax.set_xticklabels('')
    ax.xaxis.set_minor_locator(ticker.FixedLocator([x + 0.5 for x in range(n_cols + 1)]))
    ax.xaxis.set_minor_formatter(ticker.FixedFormatter([str(x) for x in range(1, n_cols + 1)]))
    ax.set_yticklabels('')
    ax.yaxis.set_minor_locator(ticker.FixedLocator([x + 0.5 for x in range(n_rows + 1)]))
    ax.yaxis.set_minor_formatter(ticker.FixedFormatter([str(x) for x in range(1, n_rows + 1)]))
    plt.grid(True)

    for obj in obj_place:
        # position of each goal
        x = obj[1] - 0.5
        y = obj[0] - 0.5
        o = '$' + obj[2] + '$'
        
        plt.plot(x, y, marker=o, color='k', markersize=10)
    
    if walls:
        for wall in walls:
            wall_sec = np.array([list(wall_elem) for wall_elem in wall])
            plt.plot(wall_sec[:, 1] - 0.5, wall_sec[:, 0] - 0.5, color='k', linewidth=5)
    
    return fig, ax

# Place a given trajectory in a figure of the maze world being used
#   x0: tuple with trajectory's initial position in the maze world
#   trajectory: sequence of positions in the maze and actions to draw
#   figure: figure of the maze world to be used
#   color: color to draw the trajectory
#   goal: tuple with trajectory's final position in the maze world
#   zorder: integer that gives the vertical position of the trajectory, useful when 
#           reusing the same figure to draw multiple trajectories
def visualize_trajectory(x0: Tuple, trajectory: np.ndarray, figure: plt.Figure, color: str, goal: Tuple, zorder: int):
    
    actions = {'U': (0, 0.8), 'D': (0, -0.8), 'L':(-0.8, 0), 'R':(0.8, 0), 'G': (0, 0), 'P':(0, 0), 'N': (0, 0)}
    ax = figure[1]
    
    plt.plot(x0[0]-0.5, x0[1]-0.5, marker='o', markersize=15, color='dimgrey', zorder=0)
    plt.plot(goal[1]-0.5, goal[0]-0.5, marker='x', markersize=20, color='gold')
    for ptr in trajectory:
        x = ptr[0]-0.5
        y = ptr[1]-0.5
        a = actions[ptr[2]]
        if ptr[2] in ['U', 'D', 'L', 'R']:
            ax.arrow(x, y, a[0], a[1], head_width=0.1, head_length=0.1, lw=1.5, fc=color, ec=color, zorder=zorder)
        else:
            ax.add_patch(plt.Circle((x, y), 0.3, linewidth=1.7, fill=False, color=color))

# Draw all possible actions a policy might choose for each state
#   figure: figure of the maze world to be used
#   pol: policy to draw
#   goal_states: list of the possible goal states
#   states: world states
#   action_lst: list of the different world actions
#   color: color to draw the policy
#   objs: defines if the state definition includes possible environment objects and goals
def draw_policy_states(figure: plt.Figure, pol: np.ndarray, goal_states: List, states: np.ndarray, 
                       action_lst: List, color: str, objs: bool=True):
    
    actions = {'U': (0, 0.8), 'D': (0, -0.8), 'L':(-0.8, 0), 'R':(0.8, 0), 'G': (0, 0), 'P':(0, 0), 'N': (0, 0)}
    state_lst = list(states)
    ax = figure[1]
    
    for state in states:
        state_idx = state_lst.index(state)
        if objs:
            state_split = re.match(r"([0-9]+) ([0-9]+) ([a-z]+)", state, re.I)
        else:
            state_split = re.match(r"([0-9]+) ([0-9]+)", state, re.I)
        y = int(state_split.group(1)) - 0.5
        x = int(state_split.group(2)) - 0.5
        if state_lst.index(state) in goal_states:
            ax.add_patch(plt.Circle((x, y), 0.3, linewidth=1.7, color='gold', zorder=0))
        pol_actions = np.nonzero(pol[state_idx, :])[0]
        for action in pol_actions:
            act = action_lst[action]
            a = actions[act]
            if act in ['U', 'D', 'L', 'R']:
                ax.arrow(x, y, a[0], a[1], head_width=0.1, head_length=0.1, lw=1.5, fc=color, ec=color)
            else:
                ax.add_patch(plt.Circle((x, y), 0.3, linewidth=1.7, fill=False, color=color))

# Combine a state trajectory and an action trajectory to one single sequence of state-actions
#    trajectory: sequence of states
#    actions: sequence of actions
#    objs: defines if the state definition includes possible environment objects and goals
def process_trajectory(trajectory: np.ndarray, actions: np.ndarray, objs: bool=True):
    
    traj = []
    
    for i in range(len(trajectory) - 1):
        if objs:
            state_split = re.match(r"([0-9]+) ([0-9]+) ([a-z]+)", trajectory[i], re.I)
        else:
            state_split = re.match(r"([0-9]+) ([0-9]+)", trajectory[i], re.I)
        y = int(state_split.group(1))
        x = int(state_split.group(2))
        
        traj += [(x, y, actions[i])]
    
    return traj

In [None]:
# Maze world dimensions
n_rows = 10
n_cols = 10

# Maze world position of each possible goal
objs_states = [(9, 3, 'P'), (9, 1, 'D'), (1, 6, 'C'), (5, 4, 'L'), (10, 8, 'T'), (7, 9, 'O')]

# Maze world walls
walls = [
         # walls surrounding the world
         [(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)],
         
         # horizontal walls
         [(x + 0.5, 2.5) for x in range(2, 6)],
         [(x + 0.5, 2.5) for x in range(6, 10)],
         [(x + 0.5, 6.5) for x in range(0, 3)],
         [(x + 0.5, 8.5) for x in range(3, 8)],
         [(x + 0.5, 8.5) for x in range(8, 11)],
         
         # vertical walls
         [(0.5, 2.5), (1.5, 2.5)],
         [(3.5, 6.5), (3.5, 7.5)],
         [(3.5, 0.5), (3.5, 1.5)],
         [(3.5, x + 0.5) for x in range(2, 6)],
         [(3.5, x + 0.5) for x in range(8, 10)],
         [(8.5, x + 0.5) for x in range(2, 4)],
         [(8.5, x + 0.5) for x in range(4, 8)]]

# Default initial state
x0 = '1 1 N'

# List of world possible goals and default goal
goals = ['P', 'D', 'C', 'L', 'T', 'O']
goal = 'D'

# Visualize world
fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

In [None]:
# Maze world dimensions
n_rows = 10
n_cols = 10

# Maze world position of each possible goal
objs_states = [(1, 7, 'P'), (1, 10, 'D'), (10, 10, 'C'), (10, 2, 'L'), (10, 6, 'T'), (7, 1, 'O')]

# Maze world walls
walls = [
         # walls surrounding the world
         [(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)],
         
         # horizontal walls
         [(0.5, 3.5), (1.5, 3.5)],
         [(x + 0.5, 3.5) for x in range(3, 6)],
         [(x + 0.5, 3.5) for x in range(6, 9)],
         [(x + 0.5, 4.5) for x in range(0, 4)],
         [(x + 0.5, 4.5) for x in range(9, 11)],
         [(x + 0.5, 7.5) for x in range(9, 11)],
         [(x + 0.5, 8.5) for x in range(0, 4)],
         [(x + 0.5, 8.5) for x in range(7, 10)],
         [(x + 0.5, 8.5) for x in range(5, 7)],
         
         # vertical walls
         [(2.5, x + 0.5) for x in range(0, 4)],
         [(2.5, x + 0.5) for x in range(9, 11)],
         [(8.5, x + 0.5) for x in range(0, 2)],
         [(8.5, x + 0.5) for x in range(2, 4)],
         [(8.5, x + 0.5) for x in range(4, 6)],
         [(8.5, x + 0.5) for x in range(6, 8)],
         [(5.5, x + 0.5) for x in range(8, 11)],
         [(3.5, x + 0.5) for x in range(0, 2)],
         [(3.5, x + 0.5) for x in range(2, 4)],
         [(3.5, x + 0.5) for x in range(4, 6)],
         [(3.5, x + 0.5) for x in range(6, 8)]
        ]

# Default initial state
x0 = '1 1 N'

# List of world possible goals and default goal
goals = ['P', 'D', 'C', 'L', 'T', 'O']
goal = 'T'

# Visualize world
fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

In [None]:
n_rows = 10
n_cols = 10
objs_states = [(1, 7, 'P'), (10, 10, 'D'), (7, 10, 'C'), (9, 1, 'L'), (9, 5, 'T'), (5, 1, 'O')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)],
         
         [(x + 0.5, 6.5) for x in range(4, 7)], 
         [(x + 0.5, 4.5) for x in range(4, 7)],
         [(x + 0.5, 2.5) for x in range(3, 6)],
         [(x + 0.5, 2.5) for x in range(6, 8)],
         [(x + 0.5, 1.5) for x in range(8, 10)],
         [(x + 0.5, 3.5) for x in range(0, 2)],
         [(x + 0.5, 3.5) for x in range(8, 11)],
         [(x + 0.5, 4.5) for x in range(0, 3)],
         [(x + 0.5, 7.5) for x in range(0, 3)],
         [(x + 0.5, 7.5) for x in range(8, 11)],
         [(x + 0.5, 8.5) for x in range(0, 4)],
         [(x + 0.5, 8.5) for x in range(4, 7)],
         [(x + 0.5, 8.5) for x in range(8, 10)],
         
         [(2.5, x + 0.5) for x in range(0, 3)],
         [(2.5, x + 0.5) for x in range(4, 6)],
         [(2.5, x + 0.5) for x in range(6, 8)],
         [(3.5, x + 0.5) for x in range(0, 3)],
         [(3.5, x + 0.5) for x in range(9, 11)],
         [(4.5, x + 0.5) for x in range(9, 11)],
         [(4.5, x + 0.5) for x in range(4, 7)],
         [(6.5, x + 0.5) for x in range(4, 7)],
         [(7.5, x + 0.5) for x in range(0, 3)],
         [(7.5, x + 0.5) for x in range(8, 11)],
         [(8.5, x + 0.5) for x in range(0, 2)],
         [(8.5, x + 0.5) for x in range(3, 6)],
         [(8.5, x + 0.5) for x in range(6, 8)],
         [(8.5, x + 0.5) for x in range(8, 10)]
        ]
x0 = '1 1 N'
goals = ['P', 'D', 'C', 'L', 'T', 'O']
goal = 'C'

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

In [None]:
n_rows = 10
n_cols = 10
objs_states = [(8, 3, 'P'), (5, 7, 'D'), (5, 2, 'C'), (8, 7, 'T'), (1, 4, 'L'), (1, 7, 'O')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)], 
         
         [(x + 0.5, 1.5) for x in range(2, 7)],
         [(x + 0.5, 1.5) for x in range(7, 10)],
         [(x + 0.5, 3.5) for x in range(0, 3)],
         [(x + 0.5, 3.5) for x in range(3, 7)],
         [(x + 0.5, 3.5) for x in range(7, 9)],
         [(x + 0.5, 5.5) for x in range(0, 3)],
         [(x + 0.5, 5.5) for x in range(7, 9)],
         [(x + 0.5, 7.5) for x in range(0, 3)],
         [(x + 0.5, 6.5) for x in range(4, 7)],
         [(x + 0.5, 8.5) for x in range(4, 6)],
         [(x + 0.5, 8.5) for x in range(7, 9)],
         [(x + 0.5, 9.5) for x in range(2, 7)],
         [(x + 0.5, 9.5) for x in range(7, 10)],
         
         [(2.5, x + 0.5) for x in range(1, 3)],
         [(2.5, x + 0.5) for x in range(4, 7)],
         [(3.5, x + 0.5) for x in range(2, 4)],
         [(4.5, x + 0.5) for x in range(6, 8)],
         [(6.5, x + 0.5) for x in range(2, 4)],
         [(6.5, x + 0.5) for x in range(6, 9)],
         [(7.5, x + 0.5) for x in range(1, 4)],
         [(7.5, x + 0.5) for x in range(5, 9)],
         [(9.5, x + 0.5) for x in range(1, 5)],
         [(9.5, x + 0.5) for x in range(5, 10)],
        ]
x0 = '1 1 N'
goals = ['P', 'D', 'C', 'T', 'L', 'O']
goal = 'T'
max_goal_len = max([len(g) for g in goals]) + 2

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

In [None]:
print('##########################################')
print('#####  Wall Auto Collect Maze World  #####')
print('##########################################')
beta = 0.5                      # used in the LegibleTaskMDP to define how close to the optimal to follow

wacmw = SimpleWallMazeWorld2()  # type of world to use, this is a simplified version of the maze world, 
                                # where each goal is just to "visit" one location instead of a sequence of locations
                                # and the robot just needs to pass by the local to consider it "visited"

X_w, A_w, P_w = wacmw.generate_world(n_rows, n_cols, objs_states, walls, 'stochastic', 0.15) # use the world definitions before to create the world

with_objs = True                # world states include if the state has a possible goal or if it is neutral "N"

goal_states = get_goal_states(X_w, goal) # possible goal state indexes

print('### Computing Costs and Creating Task MDPs ###')
legible_function = 'leg_optimal' # there are two legible functions possible: 'leg_optimal' and 'leg_weight'
                                 # 'leg_optimal' is the more simple and default legible function, 'leg_weight'
                                 # is more complex but in most cases not needed
opt_mdps = {}                    # dictionary that stores the optimal MDPs for each goal
opt_v_mdps = {}                  # dictionary that stores the expected optimal reward for each state (used by the 'leg_weight' function)
opt_q_mdps = {}                  # dictionary that stores the expected optimal reward for each state-action pair
leg_mdps = {}                    # dictionary that stores the legible MDPs for each goal
dists = []                       # list of average distance between each state and each goal state (used by the 'leg_weight' function)
rng_gen = np.random.default_rng(int(time.time())) # random number generator used for randomization calls
verbosity = False                # defines if we want extra info from the MDPs when computing the final policies

print('### Defining the MDPs for each different goal ###')
for i in range(len(goals)):
    # Generate rewards for each goal and corresponding MDP
    c = wacmw.generate_rewards(goals[i], X_w, A_w)
    # MDP
    # LegibleTaskMDP
    #    x - world states
    #    a - world actions
    #    p - world state transition probabilities
    #    c - MDP reward function
    #    gamma - discount factor
    #    goal_states - list of the MDP's goal states
    #    feedback_type - 'costs' or 'rewards' depending on if we are using costs or rewards as the reward function
    #    verbose - defines if we want extra information from internal methods
    mdp = MDP(X_w, A_w, P_w, c, 0.9, get_goal_states(X_w, goals[i]), 'rewards', verbosity)
    
    # Find optimal policy, optimal q-values and optimal expected state rewards
    pol, q = mdp.policy_iteration()
    v = Utilities.v_from_q(q, pol)
    
    # Store q-values, expected rewards and optimal MDP
    opt_q_mdps[goals[i]] = q
    opt_v_mdps[goals[i]] = v
    #dists += [mdp.policy_dist(pol, rng_gen)] # uncomment this line if you use the 'leg_weight' legible function
    opt_mdps['mdp' + str(i + 1)] = mdp
dists = np.array(dists)

print('### Defining Legible MDP for intended goal ###')
# LegibleTaskMDP
#    x - world states
#    a - world actions
#    p - world state transition probabilities
#    gamma - discount factor
#    verbose - defines if we want extra information from internal methods
#    task - legible MDP's goal
#    task_states - list of each goal's maze world position (x, y, goal)
#    tasks - list of possible goals
#    beta - constant that define how close the legible policy follows the optimal policy 
#    goal_states - list of the legible MDP's goal states
#    sign - 1 or -1 whether we are using rewards or costs as the optimal MDP's reward function
#    leg_func - legible function being used
#    q_mdps - optimal q-values for each goal
#    v_mdps - optimal expected state rewards for each goal (used only by 'legible_weight' function)
#    dists - average distance between each state and each goal (used only by 'legible_weight' function)
leg_mdp = LegibleTaskMDP(X_w, A_w, P_w, 0.9, verbosity, goal, objs_states, goals, beta, goal_states, 1, 
                         legible_function, q_mdps=opt_q_mdps, v_mdps=opt_v_mdps, dists=dists)

print('### Computing Optimal policy ###')
time1 = time.time()
opt_pol, opt_q = opt_mdps['mdp' + str(goals.index(goal) + 1)].policy_iteration()
print('Took %.3f seconds to compute policy' % (time.time() - time1))

print('### Computing Legible policy ###')
time1 = time.time()
leg_pol, leg_q = task_mdp_w.policy_iteration(goals.index(goal))
print('Took %.3f seconds to compute policy' % (time.time() - time1))

In [None]:
x0 = '1 1 N'
goal_state = objs_states[goals.index(goal)]
now = datetime.datetime.now()
print('Initial State: ' + x0)

# Get optimal trajectory
print('Optimal trajectory for task: ' + goal)
rng_gen = np.random.default_rng(int(time.time()))
opt_states, opt_actions = opt_mdps['mdp' + str(goals.index(goal) + 1)].trajectory(x0, opt_pol, rng_gen)
print('Trajectory: ' + str(opt_states))
print('Cost: ' + str(opt_mdps['mdp' + str(goals.index(goal) + 1)].trajectory_reward([[opt_states, opt_actions]])))
print('Legible Reward: ' + str(task_mdp_w.trajectory_reward([[opt_states, opt_actions]], goals.index(goal))))
opt_traj = process_trajectory(opt_states, opt_actions)

# Get legible trajectory
print('Legible trajectory for task: ' + goal)
rng_gen = np.random.default_rng(int(time.time()))
leg_states, leg_act = task_mdp_w.trajectory(x0, leg_pol, rng_gen)
print('Trajectory: ' + str(leg_states))
print('Cost: ' + str(opt_mdps['mdp' + str(goals.index(goal) + 1)].trajectory_reward([[leg_states, leg_act]])))
print('Legible Reward: ' + str(task_mdp_w.trajectory_reward([[leg_states, leg_act]], goals.index(goal))))
leg_traj = process_trajectory(leg_states, leg_act)

# Visualize both trajectories side-by-side
figure = create_world_view(n_rows, n_cols, objs_states, walls)
visualize_trajectory(opt_traj[0], opt_traj, figure, 'b', goal_state, zorder=0)
visualize_trajectory(leg_traj[0], leg_traj, figure, 'k', goal_state, zorder=1)
fig, _ = figure
fig.show()

# Simulate model performance
print('Getting model performance!!')
rng_gen = np.random.default_rng(int(time.time()))
clock_1 = time.time()
mdp_r, mdp_rl, leg_mdp_r, leg_mdp_rl = simulate(opt_mdps['mdp' + str(goals.index(goal) + 1)], opt_pol,
                                                leg_mdp, leg_pol, x0, 1000, goals.index(goal), rng_gen)
time_simulation = time.time() - clock_1
print('Simulation length = %.3f' % time_simulation)
print('Optimal Policy performance:\nCost: %.3f\nLegible Reward: %.3f' % (mdp_r, mdp_rl))
print('Legible Policy performance:\nCost: %.3f\nLegible Reward: %.3f' % (leg_mdp_r, leg_mdp_rl))

In [None]:
colors = ['blue', 'darkred', 'green', 'black', 'orange', 'pink', 'yellow', 'magenta', 
          'brown', 'cyan', 'khaki', 'olivedrab', 'lightcoral']
x0 = '1 1 N'
goal = 'O'


opt_mdp = opt_mdps['mdp' + str(goals.index(goal) + 1)]
opt_pol, _ = opt_mdp.policy_iteration()
leg_mdp = LegibleTaskMDP(X_w, A_w, P_w, 0.9, True, goal, objs_states, goals, beta, 
                         get_goal_states(X_w, goals[goals.index(goal)]), 1, 'leg_optimal', q_mdps=opt_q_mdps, 
                         v_mdps=opt_v_mdps, dists=dists)
leg_pol, _ = leg_mdp.policy_iteration(goals.index(goal))

# Visualize all possible optimal trajectories for given starting state
trajs, a_trajs = opt_mdp.all_trajectories(x0, opt_pol, rng_gen)
rng_gen = np.random.default_rng(int(time.time()))
print('Optimal trajectories: %d' % len(trajs))
i = 0
fig = create_world_view(n_rows, n_cols, objs_states, walls)
for j in range(len(trajs)):
    traj = trajs[j]
    a_traj = a_trajs[j]
    p_traj = process_trajectory(traj, a_traj)
    visualize_trajectory(p_traj[0], p_traj, fig, colors[min(i, len(colors) - 1)], goal_state, zorder=(len(trajs) - j))
    i += 1
fig, _ = fig
fig.show()

# Visualize all possible legible trajectories for given starting state
rng_gen = np.random.default_rng(int(time.time()))
leg_trajs, leg_a_trajs = leg_mdp.all_trajectories(x0, leg_pol, rng_gen)
print('Legible trajectories: %d' % len(leg_trajs))
i = 0
fig = create_world_view(n_rows, n_cols, objs_states, walls)
for j in range(len(leg_trajs)):
    traj = leg_trajs[j]
    a_traj = leg_a_trajs[j]
    p_traj = process_trajectory(traj, a_traj)
    visualize_trajectory(p_traj[0], p_traj, fig, colors[min(i, len(colors) - 1)], goal_state, zorder=(len(trajs) - j))
    i += 1
fig, _ = fig
fig.show()

In [None]:
states = X_w
print('Goal: ' + goal)

# Visualize optimal policy
opt_figure = create_world_view(n_rows, n_cols, objs_states, walls)
draw_policy_states(opt_figure, pol_w, goal_states, states, list(A_w), 'b', with_objs)
fig, _ = opt_figure
fig.show()

# Visualize legible policy
leg_figure = create_world_view(n_rows, n_cols, objs_states, walls)
draw_policy_states(leg_figure, task_pol_w, goal_states, states, list(A_w), 'k', with_objs)
fig, _ = leg_figure
fig.show()