In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import re
import time
import math
import seaborn as sns
import yaml
import os

from src.mazeworld import WallAutoCollectMazeWorld, LimitedCollectWallMazeWorld, SimpleWallMazeWorld, SimpleWallMazeWorld2
from src.mdp import MDP, LegibleTaskMDP, LearnerMDP, Utilities, MiuraLegibleMDP
from tqdm import tqdm
from itertools import combinations
from pathlib import Path
np.set_printoptions(precision=9, linewidth=2000, threshold=10000, suppress=True)
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [2]:
def get_goal_states(states, goal, with_objs=True, goals=[]):

    if with_objs:
        state_lst = list(states)
        return [state_lst.index(x) for x in states if x.find(goal) != -1]
    else:
        for g in goals:
            if g[2].find(goal) != -1:
                return str(g[0]) + ' ' + str(g[1])

def get_initial_states(states, task_locs):

    return [state for state in states if state.find('N') != -1] + \
           [str(e[0]) + ' ' + str(e[1]) + ' ' + e[2] for e in task_locs]

def simulate(mdp, pol, mdp_tasks, leg_pol, x0, n_trajs, goal):

    mdp_trajs = []
    tasks_trajs = []

    for _ in tqdm(range(n_trajs), desc='Simulate Trajectories'):
        traj, acts = mdp.trajectory(x0, pol)
        traj_leg, acts_leg = mdp_tasks.trajectory(x0, leg_pol)
        mdp_trajs += [[traj, acts]]
        tasks_trajs += [[traj_leg, acts_leg]]

    mdp_r = mdp.trajectory_reward(mdp_trajs)
    mdp_rl = mdp_tasks.trajectory_reward(mdp_trajs, goal)
    task_r = mdp.trajectory_reward(tasks_trajs)
    task_rl = mdp_tasks.trajectory_reward(tasks_trajs, goal)

    return mdp_r, mdp_rl, task_r, task_rl

In [3]:
def create_world_view(n_rows, n_cols, obj_place, walls=None):
    
    fig = plt.figure(figsize=(10, 10))
    #fig = plt.figure()
    ax = fig.add_subplot()
    plt.xlim(0, n_cols)
    plt.ylim(0,  n_rows)
    plt.xticks([i + 1 for i in range(n_cols)])
    plt.yticks([i + 1 for i in range(n_rows)])
    ax.tick_params(axis='x', which='both', bottom=False, top=False)
    ax.tick_params(axis='y', which='both', left=False, right=False)
    ax.set_xticklabels('')
    ax.xaxis.set_minor_locator(ticker.FixedLocator([x + 0.5 for x in range(n_cols + 1)]))
    ax.xaxis.set_minor_formatter(ticker.FixedFormatter([str(x) for x in range(1, n_cols + 1)]))
    ax.set_yticklabels('')
    ax.yaxis.set_minor_locator(ticker.FixedLocator([x + 0.5 for x in range(n_rows + 1)]))
    ax.yaxis.set_minor_formatter(ticker.FixedFormatter([str(x) for x in range(1, n_rows + 1)]))
    plt.grid(True)

    for obj in obj_place:
        x = obj[1] - 0.5
        y = obj[0] - 0.5
        o = '$' + obj[2] + '$'
        
        plt.plot(x, y, marker=o, color='k', markersize=10)
    
    if walls:
        for wall in walls:
            wall_sec = np.array([list(wall_elem) for wall_elem in wall])
            plt.plot(wall_sec[:, 1] - 0.5, wall_sec[:, 0] - 0.5, color='k', linewidth=5)
    
    return fig, ax


def visualize_trajectory(x0, trajectory, figure, color, goal, zorder):
    
    actions = {'U': (0, 0.8), 'D': (0, -0.8), 'L':(-0.8, 0), 'R':(0.8, 0), 'G': (0, 0), 'P':(0, 0), 'N': (0, 0)}
    ax = figure[1]
    
    plt.plot(x0[0]-0.5, x0[1]-0.5, marker='o', markersize=15, color='dimgrey', zorder=0)
    plt.plot(goal[1]-0.5, goal[0]-0.5, marker='x', markersize=20, color='gold')
    for ptr in trajectory:
        x = ptr[0]-0.5
        y = ptr[1]-0.5
        a = actions[ptr[2]]
        if ptr[2] in ['U', 'D', 'L', 'R']:
            ax.arrow(x, y, a[0], a[1], head_width=0.1, head_length=0.1, lw=1.5, fc=color, ec=color, zorder=zorder)
        else:
            ax.add_patch(plt.Circle((x, y), 0.3, linewidth=1.7, fill=False, color=color))

def draw_policy_states(figure, pol, goal_states, state_lst, action_lst, states, color, objs=True):
    
    actions = {'U': (0, 0.8), 'D': (0, -0.8), 'L':(-0.8, 0), 'R':(0.8, 0), 'G': (0, 0), 'P':(0, 0), 'N': (0, 0)}
    ax = figure[1]
    
    for state in states:
        state_idx = state_lst.index(state)
        if objs:
            state_split = re.match(r"([0-9]+) ([0-9]+) ([a-z]+)", state, re.I)
        else:
            state_split = re.match(r"([0-9]+) ([0-9]+)", state, re.I)
        y = int(state_split.group(1)) - 0.5
        x = int(state_split.group(2)) - 0.5
        if state_lst.index(state) in goal_states:
            ax.add_patch(plt.Circle((x, y), 0.3, linewidth=1.7, color='gold', zorder=0))
        pol_actions = np.nonzero(pol[state_idx, :])[0]
        for action in pol_actions:
            act = action_lst[action]
            a = actions[act]
            if act in ['U', 'D', 'L', 'R']:
                ax.arrow(x, y, a[0], a[1], head_width=0.1, head_length=0.1, lw=1.5, fc=color, ec=color)
            else:
                ax.add_patch(plt.Circle((x, y), 0.3, linewidth=1.7, fill=False, color=color))
        
def process_trajectory(trajectory, actions, objs=True):
    
    traj = []
    
    for i in range(len(trajectory) - 1):
        if objs:
            state_split = re.match(r"([0-9]+) ([0-9]+) ([a-z]+)", trajectory[i], re.I)
        else:
            state_split = re.match(r"([0-9]+) ([0-9]+)", trajectory[i], re.I)
        y = int(state_split.group(1))
        x = int(state_split.group(2))
        
        traj += [(x, y, actions[i])]
    
    return traj

In [4]:
n_rows = 8
n_cols = 8
# objs_states = [(7, 2, 'P'), (4, 4, 'D'), (4, 1, 'C'), (8, 1, 'L'), (6, 7, 'T'), (8, 8, 'O')]
objs_states = [(7, 2, 'P'), (4, 4, 'D'), (4, 1, 'C'), (8, 1, 'L'), (6, 7, 'T')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)],
         [(1.5, 3.5), (2.5, 3.5)],
         [(3.5, 3.5), (4.5, 3.5)],
         [(2.5, 0.5), (2.5, 1.5)],
         [(2.5, 2.5), (2.5, 3.5)],
         [(1.5, x + 0.5) for x in range(4, 6)],
         [(1.5, 6.5), (1.5, 7.5)],
         [(7.5, x + 0.5) for x in range(0, 3)],
         [(7.5, x + 0.5) for x in range(5, 8)],
         [(x + 0.5, 4.5) for x in range(4, 6)],
         [(x + 0.5, 4.5) for x in range(6, 8)],
         [(4.5, x + 0.5) for x in range(3, 6)],
         [(4.5, x + 0.5) for x in range(6, 8)],
         [(x + 0.5, 7.5) for x in range(1, 4)],
         [(x + 0.5, 7.5) for x in range(5, 8)]]
# x0 = np.random.choice([x for x in X_a if 'N' in x])
x0 = '1 1 N'
goals = ['P', 'D', 'C', 'L', 'T']
goal = 'T'

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

<IPython.core.display.Javascript object>

In [4]:
n_rows = 8
n_cols = 8
# objs_states = [(6, 1, 'P'), (1, 7, 'D'), (3, 2, 'C'), (8, 1, 'L'), (7, 7, 'T'), (8, 8, 'O')]
objs_states = [(6, 1, 'P'), (1, 7, 'D'), (3, 2, 'C'), (8, 1, 'L'), (7, 7, 'T')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)],
         [(0.5, 3.5), (1.5, 3.5), (2.5, 3.5)],
         [(x + 0.5, 4.5) for x in range(4, 6)],
         [(x + 0.5, 4.5) for x in range(6, 8)],
         [(x + 0.5, 2.5) for x in range(2, 5)],
         [(x + 0.5, 2.5) for x in range(6, 8)],
         [(x + 0.5, 7.5) for x in range(1, 3)],
         [(x + 0.5, 7.5) for x in range(3, 8)],
         [(2.5, 0.5), (2.5, 1.5), (2.5, 2.5)],
         [(2.5, x + 0.5) for x in range(3, 5)],
         [(2.5, x + 0.5) for x in range(5, 7)],
         [(7.5, x + 0.5) for x in range(0, 2)],
         [(7.5, x + 0.5) for x in range(5, 8)],
         [(4.5, x + 0.5) for x in range(4, 6)],
         [(4.5, x + 0.5) for x in range(6, 8)]]
x0 = '1 1 N'
goals = ['P', 'D', 'C', 'L', 'T']
goal = 'T'

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

<IPython.core.display.Javascript object>

In [5]:
n_rows = 8
n_cols = 8
# objs_states = [(8, 5, 'P'), (1, 5, 'D'), (4, 1, 'C'), (8, 1, 'L'), (4, 7, 'T'), (8, 8, 'O')]
objs_states = [(8, 5, 'P'), (1, 5, 'D'), (4, 1, 'C'), (8, 1, 'L'), (4, 7, 'T')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)],
         [(x + 0.5, 2.5) for x in range(2, 6)],
         [(x + 0.5, 2.5) for x in range(6, 8)],
         [(x + 0.5, 3.5) for x in range(0, 2)],
         [(x + 0.5, 3.5) for x in range(6, 8)],
         [(x + 0.5, 5.5) for x in range(6, 9)],
         [(x + 0.5, 6.5) for x in range(3, 6)],
         [(x + 0.5, 6.5) for x in range(6, 8)],
         [(2.5, x + 0.5) for x in range(0, 3)],
         [(1.5, x + 0.5) for x in range(3, 6)],
         [(1.5, x + 0.5) for x in range(6, 9)],
         [(3.5, x + 0.5) for x in range(6, 8)],
         [(6.5, x + 0.5) for x in range(0, 3)],
         [(6.5, x + 0.5) for x in range(3, 5)],
         [(6.5, x + 0.5) for x in range(6, 8)],
         [(5.5, x + 0.5) for x in range(7, 9)]]
x0 = '1 1 N'
goals = ['P', 'D', 'C', 'L', 'T']
goal = 'P'

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

<IPython.core.display.Javascript object>

In [4]:
n_rows = 10
n_cols = 10
objs_states = [(9, 3, 'P'), (9, 1, 'D'), (1, 6, 'C'), (5, 4, 'L'), (10, 8, 'T'), (7, 9, 'O')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)],
         [(0.5, 2.5), (1.5, 2.5)],
         [(3.5, 6.5), (3.5, 7.5)],
         [(3.5, 0.5), (3.5, 1.5)],
         [(x + 0.5, 2.5) for x in range(2, 6)],
         [(x + 0.5, 2.5) for x in range(6, 10)],
         [(x + 0.5, 6.5) for x in range(0, 3)],
         [(x + 0.5, 8.5) for x in range(3, 8)],
         [(x + 0.5, 8.5) for x in range(8, 11)],
         [(3.5, x + 0.5) for x in range(2, 6)],
         [(3.5, x + 0.5) for x in range(8, 10)],
         [(8.5, x + 0.5) for x in range(2, 4)],
         [(8.5, x + 0.5) for x in range(4, 8)]]
# x0 = np.random.choice([x for x in X_a if 'N' in x])
x0 = '1 1 N'
goals = ['P', 'D', 'C', 'L', 'T', 'O']
goal = 'O'

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

<IPython.core.display.Javascript object>

In [4]:
n_rows = 10
n_cols = 10
objs_states = [(1, 7, 'P'), (10, 10, 'D'), (6, 10, 'C'), (10, 2, 'L'), (10, 6, 'T'), (7, 1, 'O')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)],
         [(0.5, 3.5), (1.5, 3.5)],
         [(x + 0.5, 3.5) for x in range(3, 6)],
         [(x + 0.5, 3.5) for x in range(6, 9)],
         [(x + 0.5, 4.5) for x in range(0, 4)],
         [(x + 0.5, 4.5) for x in range(9, 11)],
         [(x + 0.5, 7.5) for x in range(9, 11)],
         [(x + 0.5, 8.5) for x in range(0, 4)],
         [(x + 0.5, 8.5) for x in range(7, 10)],
         [(x + 0.5, 8.5) for x in range(5, 7)],
         [(2.5, x + 0.5) for x in range(0, 4)],
         [(2.5, x + 0.5) for x in range(9, 11)],
         [(8.5, x + 0.5) for x in range(0, 2)],
         [(8.5, x + 0.5) for x in range(2, 4)],
         [(8.5, x + 0.5) for x in range(4, 6)],
         [(8.5, x + 0.5) for x in range(6, 8)],
         [(5.5, x + 0.5) for x in range(8, 11)],
         [(3.5, x + 0.5) for x in range(0, 2)],
         [(3.5, x + 0.5) for x in range(2, 4)],
         [(3.5, x + 0.5) for x in range(4, 6)],
         [(3.5, x + 0.5) for x in range(6, 8)]]
x0 = '1 1 N'
goals = ['P', 'D', 'C', 'L', 'T', 'O']
goal = 'T'

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

<IPython.core.display.Javascript object>

In [4]:
n_rows = 10
n_cols = 10
objs_states = [(1, 7, 'P'), (10, 10, 'D'), (7, 10, 'C'), (9, 1, 'L'), (9, 5, 'T'), (5, 1, 'O')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, 6.5) for x in range(4, 7)], 
         [(x + 0.5, 4.5) for x in range(4, 7)],
         [(x + 0.5, 2.5) for x in range(3, 6)],
         [(x + 0.5, 2.5) for x in range(6, 8)],
         [(x + 0.5, 1.5) for x in range(8, 10)],
         [(x + 0.5, 3.5) for x in range(0, 2)],
         [(x + 0.5, 3.5) for x in range(8, 11)],
         [(x + 0.5, 4.5) for x in range(0, 3)],
         [(x + 0.5, 7.5) for x in range(0, 3)],
         [(x + 0.5, 7.5) for x in range(8, 11)],
         [(x + 0.5, 8.5) for x in range(0, 4)],
         [(x + 0.5, 8.5) for x in range(4, 7)],
         [(x + 0.5, 8.5) for x in range(8, 10)],
         [(2.5, x + 0.5) for x in range(0, 3)],
         [(2.5, x + 0.5) for x in range(4, 6)],
         [(2.5, x + 0.5) for x in range(6, 8)],
         [(3.5, x + 0.5) for x in range(0, 3)],
         [(3.5, x + 0.5) for x in range(9, 11)],
         [(4.5, x + 0.5) for x in range(9, 11)],
         [(4.5, x + 0.5) for x in range(4, 7)],
         [(6.5, x + 0.5) for x in range(4, 7)],
         [(7.5, x + 0.5) for x in range(0, 3)],
         [(7.5, x + 0.5) for x in range(8, 11)],
         [(8.5, x + 0.5) for x in range(0, 2)],
         [(8.5, x + 0.5) for x in range(3, 6)],
         [(8.5, x + 0.5) for x in range(6, 8)],
         [(8.5, x + 0.5) for x in range(8, 10)]
        ]
x0 = '1 1 N'
goals = ['P', 'D', 'C', 'L', 'T', 'O']
goal = 'C'

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

<IPython.core.display.Javascript object>

In [4]:
n_rows = 10
n_cols = 10
objs_states = [(8, 3, 'P'), (5, 7, 'D'), (5, 2, 'C'), (8, 7, 'T')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, 1.5) for x in range(2, 10)],
         [(x + 0.5, 3.5) for x in range(0, 3)],
         [(x + 0.5, 3.5) for x in range(3, 7)],
         [(x + 0.5, 3.5) for x in range(7, 9)],
         [(x + 0.5, 5.5) for x in range(7, 9)],
         [(x + 0.5, 6.5) for x in range(0, 3)],
         [(x + 0.5, 6.5) for x in range(4, 7)],
         [(x + 0.5, 8.5) for x in range(4, 6)],
         [(x + 0.5, 8.5) for x in range(7, 9)],
         [(x + 0.5, 9.5) for x in range(2, 10)],
         [(2.5, x + 0.5) for x in range(1, 3)],
         [(2.5, x + 0.5) for x in range(3, 5)],
         [(2.5, x + 0.5) for x in range(5, 7)],
         [(2.5, x + 0.5) for x in range(8, 10)],
         [(3.5, x + 0.5) for x in range(2, 4)],
         [(4.5, x + 0.5) for x in range(6, 8)],
         [(6.5, x + 0.5) for x in range(2, 4)],
         [(6.5, x + 0.5) for x in range(6, 9)],
         [(7.5, x + 0.5) for x in range(1, 4)],
         [(7.5, x + 0.5) for x in range(5, 9)],
         [(9.5, x + 0.5) for x in range(1, 3)],
         [(9.5, x + 0.5) for x in range(3, 7)],
         [(9.5, x + 0.5) for x in range(7, 10)]
        ]
x0 = '1 1 N'
goals = ['P', 'D', 'C', 'T']
goal = 'T'
max_goal_len = max([len(g) for g in goals]) + 2

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

<IPython.core.display.Javascript object>

In [66]:
n_rows = 5
n_cols = 8
objs_states = [(3, 5, 'A'), (5, 8, 'B'), (3, 8, 'C')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, 2.5) for x in range(2, 4)],
         [(x + 0.5, 4.5) for x in range(2, 4)],
         [(x + 0.5, 3.5) for x in range(3, 5)],
         [(x + 0.5, 7.5) for x in range(3, 5)],
         [(2.5, x + 0.5) for x in range(2, 5)],
         [(3.5, x + 0.5) for x in range(2, 4)],
         [(3.5, x + 0.5) for x in range(4, 8)],
         [(4.5, x + 0.5) for x in range(3, 8)]
        ]
x0 = '1 1 N'
goals = ['A', 'B', 'C']
goal = 'B'
max_goal_len = max([len(g) for g in goals]) + 2

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

<IPython.core.display.Javascript object>

In [None]:
n_rows = 5
n_cols = 8
objs_states = [(1, 3, 'A'), (5, 6, 'B'), (3, 8, 'C')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, 3.5) for x in range(0, 3)],
         [(x + 0.5, 4.5) for x in range(0, 3)],
         [(2.5, x + 0.5) for x in range(3, 5)]
        ]
x0 = '3 1 N'
goals = ['A', 'B', 'C']
goal = 'C'
max_goal_len = max([len(g) for g in goals]) + 2

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

In [None]:
n_rows = 5
n_cols = 8
objs_states = [(5, 4, 'A'), (5, 1, 'B'), (5, 7, 'C')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, 1.5) for x in range(2, 4)],
         [(x + 0.5, 2.5) for x in range(2, 4)],
         [(x + 0.5, 7.5) for x in range(2, 4)],
         [(2.5, x + 0.5) for x in range(0, 2)],
         [(2.5, x + 0.5) for x in range(2, 8)],
         [(3.5, x + 0.5) for x in range(0, 2)],
         [(3.5, x + 0.5) for x in range(2, 8)]
        ]
x0 = '1 1 N'
goals = ['A', 'B', 'C']
goal = 'C'
max_goal_len = max([len(g) for g in goals]) + 2

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

In [78]:
print('##########################################')
print('#####  Wall Auto Collect Maze World  #####')
print('##########################################')
beta = 0.5
max_goal_len = max([len(g) for g in goals]) + 1
# wacmw = WallAutoCollectMazeWorld()
# wacmw = LimitedCollectWallMazeWorld()
# wacmw = SimpleWallMazeWorld()
wacmw = SimpleWallMazeWorld2()
X_w, A_w, P_w = wacmw.generate_world(n_rows, n_cols, objs_states, walls, 'stochastic', 0.15)
with_objs = True
goal_states = get_goal_states(X_w, goal)

print('### Computing Costs and Creating Task MDPs ###')
mdps_w = {}
v_mdps_w = {}
q_mdps_w = {}
task_mdps_w = {}
dists = []
costs = []
rewards = {}
for i in tqdm(range(len(goals)), desc='Single Task MDPs'):
    # c = wacmw.generate_costs(goals[i], X_w, A_w)
    c = wacmw.generate_rewards(goals[i], X_w, A_w)
    costs += [c]
    rewards[goals[i]] = c
    mdp = MDP(X_w, A_w, P_w, c, 0.9, get_goal_states(X_w, goals[i]), 'rewards')
    pol, q = mdp.policy_iteration()
    v = Utilities.v_from_q(q, pol)
    q_mdps_w[goals[i]] = q
    v_mdps_w[goals[i]] = v
    dists += [mdp.policy_dist(pol)]
    # print(dists)
    mdps_w['mdp' + str(i + 1)] = mdp
dists = np.array(dists)
print('Legible task MDP')
task_mdp_w = LegibleTaskMDP(X_w, A_w, P_w, 0.9, goal, objs_states, goals, beta, goal_states, 1, 
                     'leg_optimal', q_mdps=q_mdps_w, v_mdps=v_mdps_w, dists=dists)

leg_costs = [task_mdp_w.costs[idx] for idx in range(len(goals))]

print('### Computing Optimal policy ###')
time1 = time.time()
pol_w, Q1 = mdps_w['mdp' + str(goals.index(goal) + 1)].policy_iteration()
print('Took %.3f seconds to compute policy' % (time.time() - time1))

print('### Computing Legible policy ###')
time1 = time.time()
task_pol_w, task_Q = task_mdp_w.policy_iteration(goals.index(goal))
print('Took %.3f seconds to compute policy' % (time.time() - time1))

##########################################
#####  Wall Auto Collect Maze World  #####
##########################################
### Computing Costs and Creating Task MDPs ###


Single Task MDPs: 100%|██████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 18.40it/s]

Iteration 1Iteration 2Iteration 3N. iterations:  3
Iteration 1Iteration 2Iteration 3N. iterations:  3
Iteration 1Iteration 2Iteration 3N. iterations:  3
Legible task MDP
### Computing Optimal policy ###
Iteration 1Iteration 2Iteration 3N. iterations: 




 3
Took 0.008 seconds to compute policy
### Computing Legible policy ###
Iteration 1Iteration 2Iteration 3Iteration 4N. iterations:  4
Took 0.010 seconds to compute policy


In [68]:
nonzerostates = np.nonzero(q_mdps_w[goal].sum(axis=1))[0]
print(nonzerostates)
clean_states = [np.delete(nonzerostates, np.argwhere(nonzerostates == g)) for g in goal_states][0]
count = np.zeros(len(X_w))
for i in range(10000):
    count[np.random.choice(clean_states)] += 1
print(count)

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 20 21 22 23 24 25 26 31 32 33 34 35 36 37 38 39]
[302. 280. 305. 304. 301. 308. 314. 317. 288. 280. 302. 313. 315. 291. 356. 326. 274. 294.   0.   0. 309. 311. 295. 280. 302. 311. 297.   0.   0.   0.   0. 320. 289. 275. 315. 287. 301. 324. 314.   0.]


In [None]:
beta = 0.3
depth = 20
n_its = 50000
miura_mdp = MiuraLegibleMDP(X_w, A_w, P_w, 0.9, goal, goals, beta, goal_states, q_mdps=q_mdps_w)
x0 = '1 1 N'
time1 = time.time()
miura_traj = miura_mdp.legible_trajectory(x0, pol_w, depth, n_its, beta)
print('Took %.3f seconds to compute trajectory' % (time.time() - time1))
print(miura_traj)

In [None]:
print('Miura trajectory')
goal_state = objs_states[goals.index(goal)]
i = 0
traj, a_traj = miura_traj
p_traj = process_trajectory(traj, a_traj)

print('Legible trajectory for task: ' + goal)
task_traj, task_act = task_mdp_w.trajectory(x0, task_pol_w)
print('Trajectory: ' + str(task_traj))
print('Cost: ' + str(mdps_w['mdp' + str(goals.index(goal) + 1)].trajectory_reward([[task_traj, task_act]])))
print('Legible Reward: ' + str(task_mdp_w.trajectory_reward([[task_traj, task_act]], goals.index(goal))))
t_leg = process_trajectory(task_traj, task_act)
figure = create_world_view(n_rows, n_cols, objs_states, walls)
visualize_trajectory(p_traj[0], p_traj, figure, 'blue', goal_state, zorder=0)
visualize_trajectory(t_leg[0], t_leg, figure, 'k', goal_state, zorder=1)
fig, _ = figure
fig.show()

In [None]:
q_pi_mdps = []
for i in range(len(goals)):
    q_pi_mdps += [q_mdps_w[goals[i]]]
q_pi_mdps = np.array(q_pi_mdps)
print(task_mdp_w.policy_legibility(q_pi_mdps, v_mdps_w, len(X_w), len(A_w), len(goals), 0.025, goals)[goals.index(goal)])

q_pi_mdps_l = []
v_mdps_l = []
for g in range(len(goals)):
    mdp_l = LegibleTaskMDP(X_w, A_w, P_w, 0.9, goals[g], objs_states, goals, beta, get_goal_states(X_w, goals[g]), 1, 
                           'leg_optimal', q_mdps=q_mdps_w, v_mdps=v_mdps_w)
    pol_l, q_pi_l = task_mdp_w.policy_iteration(goals.index(goal))
    v_l = Utilities.v_from_q(q, pol)
    q_pi_mdps_l += [q_pi_l]
    v_mdps_l += [v_l]
q_pi_mdps_l = np.array(q_pi_mdps_l)
v_mdps_l = np.array(v_mdps_l)
print(task_mdp_w.policy_legibility(q_pi_mdps_l, v_mdps_w, len(X_w), len(A_w), len(goals), 0.025, goals)[goals.index(goal)])

In [None]:
states = X_w
print('Goal: ' + goal)
opt_figure = create_world_view(n_rows, n_cols, objs_states, walls)
draw_policy_states(opt_figure, pol_w, list(X_w), list(A_w), states, 'b', with_objs)
fig, _ = opt_figure
fig.show()
opt_leg_pol = np.isclose(opt_leg_pol, opt_leg_pol.max(axis=1, keepdims=True), atol=1e-8, rtol=1e-8).astype(int)
opt_leg_pol = opt_leg_pol / opt_leg_pol.sum(axis=1, keepdims=True)
leg_figure = create_world_view(n_rows, n_cols, objs_states, walls)
draw_policy_states(leg_figure, opt_leg_pol, list(X_w), list(A_w), states, 'k', with_objs)
fig, _ = leg_figure
fig.show()

In [None]:
for x in X_w:
    print('State: %s' % x)
    for g in goals:
        print('Goal: %s V = %s\tQ = %s' % (g, str(v_mdps_w[g][list(X_w).index(x)]), str(q_mdps_w[g][list(X_w).index(x)])))
        # print('Goal: %s adv = %s' % (g, str(leg_costs[goals.index(g)][list(X_w).index(x)])))

In [None]:
# states = get_initial_states(X_w, objs_states)
states = X_w
print('Goal: ' + goal)
opt_figure = create_world_view(n_rows, n_cols, objs_states, walls)
draw_policy_states(opt_figure, pol_w, goal_states, list(X_w), list(A_w), states, 'b', with_objs)
fig, _ = opt_figure
fig.show()
leg_figure = create_world_view(n_rows, n_cols, objs_states, walls)
draw_policy_states(leg_figure, task_pol_w, goal_states, list(X_w), list(A_w), states, 'k', with_objs)
fig, _ = leg_figure
fig.show()
# for x in states:
#     print('State %s' % x)
#     x_idx = list(X_w).index(x)
#     state_cost_sum = 0
#     for g_idx in range(len(goals)):
#         print('Goal: %s' % goals[g_idx], end=' ')
#         print(leg_costs[g_idx][x_idx])
#         state_cost_sum += leg_costs[g_idx][x_idx]
#     print('Goal: %s Q Values ' % goal, end=' ')
#     print(task_Q[x_idx])
#     print(state_cost_sum)
#     print('-------------------------------------------------')

In [75]:
x0 = '1 1 N'
goal_state = objs_states[goals.index(goal)]
print('Initial State: ' + x0)
print('##########################################')
print('#####  Wall Auto Collect Maze World  #####')
print('##########################################')

print('Optimal trajectory for task: ' + goal)
t1, a1 = mdps_w['mdp' + str(goals.index(goal) + 1)].trajectory(x0, pol_w)
print('Trajectory: ' + str(t1))
print('Cost: ' + str(mdps_w['mdp' + str(goals.index(goal) + 1)].trajectory_reward([[t1, a1]])))
print('Legible Reward: ' + str(task_mdp_w.trajectory_reward([[t1, a1]], goals.index(goal))))
t_opt = process_trajectory(t1, a1)

print('Legible trajectory for task: ' + goal)
task_traj, task_act = task_mdp_w.trajectory(x0, task_pol_w)
print('Trajectory: ' + str(task_traj))
print('Cost: ' + str(mdps_w['mdp' + str(goals.index(goal) + 1)].trajectory_reward([[task_traj, task_act]])))
print('Legible Reward: ' + str(task_mdp_w.trajectory_reward([[task_traj, task_act]], goals.index(goal))))
t_leg = process_trajectory(task_traj, task_act)

figure = create_world_view(n_rows, n_cols, objs_states, walls)
#visualize_trajectory(t_opt[0], t_opt, figure, 'b', goal_state, zorder=0)
visualize_trajectory(t_leg[0], t_leg, figure, 'k', goal_state, zorder=1)
fig, _ = figure
fig.show()

# print('Getting model performance!!')
# clock_1 = time.time()
# mdp_r, mdp_rl, leg_mdp_r, leg_mdp_rl = simulate(mdps_w['mdp' + str(goals.index(goal) + 1)], pol_w,
#                                                 task_mdp_w, task_pol_w, x0, 1000, goals.index(goal))
# time_simulation = time.time() - clock_1
# print('Simulation length = %.3f' % time_simulation)
# print('Optimal Policy performance:\nCost: %.3f\nLegible Reward: %.3f' % (mdp_r, mdp_rl))
# print('Legible Policy performance:\nCost: %.3f\nLegible Reward: %.3f' % (leg_mdp_r, leg_mdp_rl))

Initial State: 1 1 N
##########################################
#####  Wall Auto Collect Maze World  #####
##########################################
Optimal trajectory for task: B
Trajectory: ['1 1 N' '1 2 N' '2 2 N' '2 3 N' '2 4 N' '2 5 N' '2 6 N' '3 6 N' '3 7 N' '3 8 C' '4 8 N' '5 8 B']


100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s]


Cost: 0.31381059609000017


100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 998.41it/s]


Legible Reward: 2.033749555360062
Legible trajectory for task: B
Trajectory: ['1 1 N' '2 1 N' '3 1 N' '4 1 N' '4 1 N' '5 1 N' '5 2 N' '5 3 N' '5 4 N' '5 5 N' '5 6 N' '5 7 N' '5 8 B']


100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s]


Cost: 0.28242953648100017


100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s]

Legible Reward: 3.0245290868402552





<IPython.core.display.Javascript object>

In [76]:
colors = ['blue', 'darkred', 'green', 'black', 'orange', 'pink', 'yellow', 'magenta', 'brown', 'cyan', 'khaki', 'olivedrab', 'lightcoral']
trajs, a_trajs = mdps_w['mdp' + str(goals.index(goal) + 1)].all_trajectories(x0, pol_w)
print('Optimal trajectories: %d' % len(trajs))
i = 0
fig = create_world_view(n_rows, n_cols, objs_states, walls)
for j in range(len(trajs)):
    traj = trajs[j]
    a_traj = a_trajs[j]
    p_traj = process_trajectory(traj, a_traj)
    visualize_trajectory(p_traj[0], p_traj, fig, colors[min(i, len(colors) - 1)], goal_state, zorder=(len(trajs) - j))
    i += 1
fig, _ = fig
fig.show()
leg_trajs, leg_a_trajs = task_mdp_w.all_trajectories(x0, task_pol_w)
print('Legible trajectories: %d' % len(leg_trajs))
i = 0
fig = create_world_view(n_rows, n_cols, objs_states, walls)
for j in range(len(leg_trajs)):
    traj = leg_trajs[j]
    a_traj = leg_a_trajs[j]
    p_traj = process_trajectory(traj, a_traj)
    visualize_trajectory(p_traj[0], p_traj, fig, colors[min(i, len(colors) - 1)], goal_state, zorder=(len(trajs) - j))
    i += 1
fig, _ = fig
fig.show()

Optimal trajectories: 66


<IPython.core.display.Javascript object>

Legible trajectories: 1


<IPython.core.display.Javascript object>

In [79]:
# states = get_initial_states(X_w, objs_states)
states = X_w
leg_qs = []
for g in goals:
    print('Goal: ' + g)
    opt_mdp = mdps_w['mdp' + str(goals.index(g) + 1)]
    opt_pol, _ = opt_mdp.policy_iteration()
    leg_pol, leg_q = task_mdp_w.policy_iteration(goals.index(g))
    leg_qs += [leg_q]
    opt_figure = create_world_view(n_rows, n_cols, objs_states, walls)
    draw_policy_states(opt_figure, opt_pol, get_goal_states(X_w, g), list(X_w), list(A_w), states, 'b', with_objs)
    fig, _ = opt_figure
    fig.show()
    leg_figure = create_world_view(n_rows, n_cols, objs_states, walls)
    draw_policy_states(leg_figure, leg_pol, get_goal_states(X_w, g), list(X_w), list(A_w), states, 'k', with_objs)
    fig, _ = leg_figure
    fig.show()

Goal: A
Iteration 1Iteration 2Iteration 3N. iterations:  3
Iteration 1Iteration 2Iteration 3Iteration 4N. iterations:  4


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Goal: B
Iteration 1Iteration 2Iteration 3N. iterations:  3
Iteration 1Iteration 2Iteration 3Iteration 4N. iterations:  4


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Goal: C
N. iterations:  3
N. iterations:  6


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
for x in states:
    print('State %s' % x)
    x_idx = list(X_w).index(x)
    state_cost_sum = 0
    state_q_sum = 0
    for g_idx in range(len(goals)):
        print('Goal: %s' % goals[g_idx], end=' ')
        print(leg_costs[g_idx][x_idx], leg_qs[g_idx][x_idx])
        state_cost_sum += leg_costs[g_idx][x_idx]
        state_q_sum += leg_qs[g_idx][x_idx]
    print(state_cost_sum, state_q_sum)
    print('-------------------------------------------------')

In [None]:
states = get_initial_states(X_w, objs_states)

for g in goals:
    
    goal_idx = goals.index(g)
    opt_mdp = mdps_w['mdp' + str(goals.index(g) + 1)]
    leg_mdp = task_mdps_w['leg_mdp_' + str(goal_idx + 1)]
    
    opt_pol, _ = opt_mdp.policy_iteration()
    leg_pol, _ = leg_mdp.policy_iteration()
    
    opt_pol_cost = task_mdp_w.policy_cost(opt_pol, states)
    leg_pol_cost = task_mdp_w.policy_cost(leg_pol, states)    
    print(g, opt_pol_cost, leg_pol_cost)

In [None]:
n_rows = 5
n_cols = 4
objs_states = [(3, 4, 'P'), (3, 1, 'D'), (5, 2, 'C')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)], 
         [(1.5, x + 0.5) for x in range(0, 2)],
         [(1.5, x + 0.5) for x in range(2, 4)],
         [(4.5, x + 0.5) for x in range(1, 3)],
         [(x + 0.5, 1.5) for x in range(1, 4)],
         [(x + 0.5, 2.5) for x in range(1, 3)],
        ]
# x0 = np.random.choice([x for x in X_a if 'N' in x])
x0 = '1 1 N'
goals = ['P', 'D', 'C']
goal = 'P'

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

In [None]:
print('##########################################')
print('#####  Wall Auto Collect Maze World  #####')
print('##########################################')
wacmw = WallAutoCollectMazeWorld()
wacmw = SimpleWallMazeWorld2()
X_w, A_w, P_w = wacmw.generate_world(n_rows, n_cols, objs_states, walls, 'stochastic', 0.15)

print('### Computing Costs and Creating Task MDPs ###')
mdps_w = {}
q_mdps_w = []
task_mdps_w = {}
costs = []
for i in tqdm(range(len(goals)), desc='Single Task MDPs'):
    c = wacmw.generate_costs(goals[i], X_w, A_w)
    costs += [c]
    mdp = MDP(X_w, A_w, P_w, c, 0.9, get_goal_states(X_w, goals[i]), 'costs')
    _, q = mdp.policy_iteration()
    q_mdps_w += [q]
    mdps_w['mdp' + str(i + 1)] = mdp
print('Legible task MDP')
for i in tqdm(range(len(goals)), desc='Legible Task MDPs'):
    mdp = LegibleTaskMDP(X_w, A_w, P_w, 0.9, goals[i], objs_states, goals, 1.0, get_goal_states(X_w, goals[i]), 
                         -1, 'leg_optimal', q_mdps=q_mdps_w)
    task_mdps_w['leg_mdp_' + str(i + 1)] = mdp
task_mdp_w = task_mdps_w['leg_mdp_' + str(goals.index(goal) + 1)]
leg_costs = [task_mdp_w.costs[idx] for idx in range(len(goals))]

print('### Computing Optimal policy ###')
time1 = time.time()
pol_w, Q1 = mdps_w['mdp' + str(goals.index(goal) + 1)].policy_iteration()
print('Took %.3f seconds to compute policy' % (time.time() - time1))

print('### Computing Legible policy ###')
time1 = time.time()
task_pol_w, task_Q = task_mdp_w.policy_iteration(goals.index(goal))
print('Took %.3f seconds to compute policy' % (time.time() - time1))

In [None]:
print('Initial State: ' + x0)
print('##########################################')
print('#####  Wall Auto Collect Maze World  #####')
print('##########################################')

print('Optimal trajectory for task: ' + goal)
t1, a1 = mdps_w['mdp' + str(goals.index(goal) + 1)].trajectory(x0, pol_w)
print('Trajectory: ' + str(t1))
print('Cost: ' + str(mdps_w['mdp' + str(goals.index(goal) + 1)].trajectory_reward([[t1, a1]])))
print('Legible Reward: ' + str(task_mdp_w.trajectory_reward(goals.index(goal), [[t1, a1]])))
t_opt = process_trajectory(t1, a1)

print('Legible trajectory for task: ' + goal)
task_traj, task_act = task_mdp_w.trajectory(x0, task_pol_w)
print('Trajectory: ' + str(task_traj))
print('Cost: ' + str(mdps_w['mdp' + str(goals.index(goal) + 1)].trajectory_reward([[task_traj, task_act]])))
print('Legible Reward: ' + str(task_mdp_w.trajectory_reward(goals.index(goal), [[task_traj, task_act]])))
t_leg = process_trajectory(task_traj, task_act)

fig = create_world_view(n_rows, n_cols, objs_states, walls)
visualize_trajectory(t_opt[0], t_opt, fig, 'b')
visualize_trajectory(t_leg[0], t_leg, fig, 'k')
fig, _ = fig
fig.show()

# states = get_initial_states(X_w, objs_states)
states = X_w
for state in states:
    print('State: %s' % (state))
    for g in goals:
        print('Goal %s Q-Values %s' % (g, str(q_mdps_w[goals.index(g)][list(X_w).index(state)])))
    print('Cost: ' + str(leg_costs[goals.index(goal)][list(X_w).index(state)]))

print('Getting model performance!!')
clock_1 = time.time()
mdp_r, mdp_rl, leg_mdp_r, leg_mdp_rl = simulate(mdps_w['mdp' + str(goals.index(goal) + 1)], pol_w,
                                                task_mdp_w, task_pol_w, x0, 1000, goals.index(goal))
time_simulation = time.time() - clock_1
print('Simulation length = %.3f' % time_simulation)
print('Optimal Policy performance:\nCost: %.3f\nLegible Reward: %.3f' % (mdp_r, mdp_rl))
print('legible Policy performance:\nCost: %.3f\nLegible Reward: %.3f' % (leg_mdp_r, leg_mdp_rl))

In [None]:
print('######################################')
print('#####     IRL Agent Learning     #####')
print('######################################')

print('IRL Agent')
opt_learner = LearnerMDP(X_w, A_w, P_w, 0.9, costs, -1)
leg_learner = LearnerMDP(X_w, A_w, P_w, 0.9, leg_costs, 1)

print('Preparing Trajectories')
p_traj = []
for j in range(len(t1)):
    p_traj += [[list(X_w).index(t1[j]), list(A_w).index(a1[j])]]
p_traj =  np.array(p_traj)

print('Learning')
indexes = []
traj_len = len(p_traj)
step = 2
for i in range(step, traj_len+1, step):
    indexes += [i]

if traj_len % step == 0:
    n_idx = traj_len // step
else:
    n_idx = traj_len // step + 1
    indexes += [traj_len]
    
for i in tqdm(range(n_idx)):
    idx = indexes[i]
    r, o_idx = opt_learner.birl_inference(p_traj[:idx], 0.9)
    r, l_idx = leg_learner.birl_inference(p_traj[:idx], 0.9)
    print(o_idx, l_idx, goals.index(goal))

In [None]:
print('######################################')
print('#####     IRL Agent Learning     #####')
print('######################################')

print('IRL Agent')
opt_learner = LearnerMDP(X_w, A_w, P_w, 0.9, costs, -1)
leg_learner = LearnerMDP(X_w, A_w, P_w, 0.9, leg_costs, 1)

print('Preparing Trajectories')
p_traj = []
for j in range(len(t1)):
    p_traj += [[list(X_w).index(t1[j]), list(A_w).index(a1[j])]]
p_traj =  np.array(p_traj)

print('Learning')
indexes = []
traj_len = len(p_traj)
step = 2
for i in range(step, traj_len+1, step):
    indexes += [i]

if traj_len % step == 0:
    n_idx = traj_len // step
else:
    n_idx = traj_len // step + 1
    indexes += [traj_len]
    
for i in tqdm(range(n_idx)):
    idx = indexes[i]
    r, o_idx = opt_learner.birl_inference(p_traj[:idx], 0.9)
    r, l_idx = leg_learner.birl_inference(p_traj[:idx], 0.9)
    print(o_idx, l_idx, goals.index(goal))

In [None]:
n_rows = 8
n_cols = 8
objs_states = [(8, 5, 'P'), (1, 5, 'D'), (6, 70, 'C'), (8, 1, 'L'), (4, 7, 'T'), (8, 8, 'O')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)],
         [(x + 0.5, 2.5) for x in range(2, 6)],
         [(x + 0.5, 2.5) for x in range(6, 8)],
         [(x + 0.5, 3.5) for x in range(0, 2)],
         [(x + 0.5, 3.5) for x in range(6, 8)],
         [(x + 0.5, 5.5) for x in range(6, 9)],
         [(x + 0.5, 6.5) for x in range(3, 6)],
         [(x + 0.5, 6.5) for x in range(6, 8)],
         [(2.5, x + 0.5) for x in range(0, 3)],
         [(1.5, x + 0.5) for x in range(3, 6)],
         [(1.5, x + 0.5) for x in range(6, 9)],
         [(3.5, x + 0.5) for x in range(6, 8)],
         [(6.5, x + 0.5) for x in range(0, 3)],
         [(6.5, x + 0.5) for x in range(3, 5)],
         [(6.5, x + 0.5) for x in range(6, 8)],
         [(5.5, x + 0.5) for x in range(7, 9)]
        ]

fig = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()
print(np.array(walls, dtype=object))

In [None]:
n_rows = 90
n_cols = 90
objs_states = [(7, 84, 'P'), (61, 85, 'D'), (3, 37, 'C'), (30, 7, 'L'), (31, 85, 'T'), (70, 27, 'O')]
# objs_states = [(9, 11, 'A'), (20, 19, 'B'), (16, 15, 'C'), (7, 6, 'D'), (13, 23, 'E'), (7, 1, 'F'),
#                (24, 3, 'L'), (18, 6, 'O'), (5, 20, 'P'), (19, 25, 'T')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)],
         [(10.5, x + 0.5) for x in range(65, 91)],
         [(20.5, x + 0.5) for x in range(70, 78)],
         [(20.5, x + 0.5) for x in range(83, 91)],
         [(24.5, x + 0.5) for x in range(25, 41)],
         [(35.5, x + 0.5) for x in range(40, 51)],
         [(40.5, x + 0.5) for x in range(5, 11)],
         [(55.5, x + 0.5) for x in range(40, 51)],
         [(60.5, x + 0.5) for x in range(20, 31)],
         [(70.5, x + 0.5) for x in range(70, 78)],
         [(70.5, x + 0.5) for x in range(83, 91)],
         [(83.5, x + 0.5) for x in range(15, 25)],
         [(x + 0.5, 10.5) for x in range(0, 14)],
         [(x + 0.5, 10.5) for x in range(20, 41)],
         [(x + 0.5, 15.5) for x in range(60, 84)],
         [(x + 0.5, 25.5) for x in range(0, 19)],
         [(x + 0.5, 30.5) for x in range(60, 84)],
         [(x + 0.5, 40.5) for x in range(0, 36)],
         [(x + 0.5, 40.5) for x in range(55, 91)],
         [(x + 0.5, 50.5) for x in range(0, 36)],
         [(x + 0.5, 50.5) for x in range(55, 91)],
         [(x + 0.5, 70.5) for x in range(20, 41)],
         [(x + 0.5, 70.5) for x in range(50, 71)],
        ]

# wacmw_test = SimpleWallMazeWorld2()
#X_w, A_w, P_w = wacmw_test.generate_world(n_rows, n_cols, objs_states, walls, 'stochastic', 0.15)
# print(len(X_w))
fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()
print(np.array(walls, dtype=object))

In [5]:
n_rows = 5
n_cols = 8
objs_states = [(3, 5, 'A'), (5, 8, 'B'), (3, 8, 'C')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, 2.5) for x in range(2, 4)],
         [(x + 0.5, 4.5) for x in range(2, 4)],
         [(x + 0.5, 3.5) for x in range(3, 5)],
         [(x + 0.5, 7.5) for x in range(3, 5)],
         [(2.5, x + 0.5) for x in range(2, 5)],
         [(3.5, x + 0.5) for x in range(2, 4)],
         [(3.5, x + 0.5) for x in range(4, 8)],
         [(4.5, x + 0.5) for x in range(3, 8)]
        ]
x0 = '1 1 N'
goals = ['A', 'B', 'C']
goal = 'B'
max_goal_len = max([len(g) for g in goals]) + 2

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()
print(np.array(walls, dtype=object))

<IPython.core.display.Javascript object>

[list([(0.5, 0.5), (0.5, 1.5), (0.5, 2.5), (0.5, 3.5), (0.5, 4.5), (0.5, 5.5), (0.5, 6.5), (0.5, 7.5), (0.5, 8.5)]) list([(5.5, 0.5), (5.5, 1.5), (5.5, 2.5), (5.5, 3.5), (5.5, 4.5), (5.5, 5.5), (5.5, 6.5), (5.5, 7.5), (5.5, 8.5)]) list([(0.5, 0.5), (1.5, 0.5), (2.5, 0.5), (3.5, 0.5), (4.5, 0.5), (5.5, 0.5)]) list([(0.5, 8.5), (1.5, 8.5), (2.5, 8.5), (3.5, 8.5), (4.5, 8.5), (5.5, 8.5)]) list([(2.5, 2.5), (3.5, 2.5)]) list([(2.5, 4.5), (3.5, 4.5)]) list([(3.5, 3.5), (4.5, 3.5)]) list([(3.5, 7.5), (4.5, 7.5)]) list([(2.5, 2.5), (2.5, 3.5), (2.5, 4.5)]) list([(3.5, 2.5), (3.5, 3.5)]) list([(3.5, 4.5), (3.5, 5.5), (3.5, 6.5), (3.5, 7.5)]) list([(4.5, 3.5), (4.5, 4.5), (4.5, 5.5), (4.5, 6.5), (4.5, 7.5)])]


In [None]:
n_rows = 8
n_cols = 8
objs_states = [(3, 5, 'A'), (5, 8, 'B'), (3, 8, 'C')]
walls = [[(0.5, x + 0.5) for x in range(0, n_cols + 1)],
         [(n_rows + 0.5, x + 0.5) for x in range(0, n_cols + 1)], 
         [(x + 0.5, 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, n_cols + 0.5) for x in range(0, n_rows + 1)], 
         [(x + 0.5, 2.5) for x in range(2, 4)],
         [(x + 0.5, 4.5) for x in range(2, 4)],
         [(x + 0.5, 3.5) for x in range(3, 5)],
         [(x + 0.5, 7.5) for x in range(3, 5)],
         [(2.5, x + 0.5) for x in range(2, 5)],
         [(3.5, x + 0.5) for x in range(2, 4)],
         [(3.5, x + 0.5) for x in range(4, 8)],
         [(4.5, x + 0.5) for x in range(3, 8)]
        ]
x0 = '1 1 N'
goals = ['A', 'B', 'C']
goal = 'B'
max_goal_len = max([len(g) for g in goals]) + 2

fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()
print(np.array(walls, dtype=object))

In [6]:
world = '5x8_world'
with open('data/configs/' + (world + '.yaml')) as file:
    config_params = yaml.full_load(file)

    n_cols = config_params['n_cols']
    n_rows = config_params['n_rows']
    walls = config_params['walls']
    task_states = config_params['task_states']
    tasks = config_params['tasks']

wacmw_test = SimpleWallMazeWorld2()
X_w, A_w, P_w = wacmw_test.generate_world(n_rows, n_cols, task_states, walls, 'stochastic', 0.15)
state_goal = []
for i in range(250):
    goal = np.random.choice(tasks)
    goal_states = get_goal_states(list(X_w), goal)
    c = wacmw_test.generate_rewards(goal, X_w, A_w)
    mdp = MDP(X_w, A_w, P_w, c, 0.9, goal_states, 'rewards', False)
    pol, q = mdp.policy_iteration()
    nonzerostates = np.nonzero(q.sum(axis=1))[0]
    init_states = [np.delete(nonzerostates, np.argwhere(nonzerostates == g)) for g in goal_states][0]
    x0 = X_w[np.random.choice(init_states)]
    state_goal += [(x0, goal)]
    print('Completed %d iterations, %.2f%%' % ((i + 1), (i + 1) / 250 * 100), end='\r')
print(world + ': ' + str(state_goal))

5x8_world: [('4 2 N', 'C'), ('5 4 N', 'B'), ('5 4 N', 'B'), ('2 5 N', 'C'), ('5 5 N', 'A'), ('3 1 N', 'B'), ('3 1 N', 'B'), ('4 1 N', 'A'), ('3 6 N', 'A'), ('1 1 N', 'A'), ('3 7 N', 'B'), ('1 8 N', 'A'), ('1 2 N', 'A'), ('1 6 N', 'A'), ('5 6 N', 'B'), ('5 7 N', 'C'), ('1 8 N', 'C'), ('5 4 N', 'A'), ('2 4 N', 'B'), ('5 6 N', 'A'), ('2 8 N', 'C'), ('5 3 N', 'A'), ('1 3 N', 'A'), ('1 3 N', 'C'), ('1 8 N', 'C'), ('5 2 N', 'C'), ('5 1 N', 'C'), ('3 6 N', 'A'), ('5 7 N', 'A'), ('1 6 N', 'C'), ('2 8 N', 'C'), ('5 8 B', 'A'), ('2 7 N', 'B'), ('1 5 N', 'C'), ('2 4 N', 'A'), ('1 3 N', 'A'), ('5 6 N', 'A'), ('1 8 N', 'A'), ('2 2 N', 'C'), ('2 2 N', 'B'), ('4 3 N', 'B'), ('1 3 N', 'C'), ('1 3 N', 'C'), ('1 3 N', 'B'), ('5 1 N', 'C'), ('5 1 N', 'B'), ('5 4 N', 'A'), ('5 6 N', 'B'), ('4 8 N', 'B'), ('5 1 N', 'A'), ('3 8 C', 'B'), ('1 1 N', 'A'), ('4 2 N', 'A'), ('5 6 N', 'C'), ('1 2 N', 'C'), ('3 8 C', 'B'), ('1 8 N', 'B'), ('1 4 N', 'A'), ('1 2 N', 'B'), ('1 3 N', 'A'), ('2 3 N', 'B'), ('3 1 N', 'A

In [5]:
SCALABILITY_WORLDS = {3: '40x40_world', 5: '60x60_world', 7: '80x80_world', 8: '90x90_world'}

OBJECTS_WORLDS = {1: '25x25_world_g3', 2: '25x25_world_g4', 3: '25x25_world_g5', 4: '25x25_world_g6',
                  5: '25x25_world_g7', 6: '25x25_world_g8', 7: '25x25_world_g9', 8: '25x25_world_g10'}

# Load mazeworld information

for key in SCALABILITY_WORLDS.keys():
    
    world = SCALABILITY_WORLDS[key]
    with open('data/configs/' + (world + '.yaml')) as file:
        config_params = yaml.full_load(file)

        n_cols = config_params['n_cols']
        n_rows = config_params['n_rows']
        walls = config_params['walls']
        task_states = config_params['task_states']
        tasks = config_params['tasks']

    wacmw_test = SimpleWallMazeWorld2()
    X_w, A_w, P_w = wacmw_test.generate_world(n_rows, n_cols, task_states, walls, 'stochastic', 0.15)
    state_goal = []
    for i in range(250):
        goal = np.random.choice(tasks)
        goal_states = get_goal_states(list(X_w), goal)
        c = wacmw_test.generate_rewards(goal, X_w, A_w)
        mdp = MDP(X_w, A_w, P_w, c, 0.9, goal_states, 'rewards', False)
        pol, q = mdp.policy_iteration()
        nonzerostates = np.nonzero(q.sum(axis=1))[0]
        init_states = [np.delete(nonzerostates, np.argwhere(nonzerostates == g)) for g in goal_states][0]
        x0 = X_w[np.random.choice(init_states)]
        state_goal += [(x0, goal)]
        print('Completed %d iterations, %.2f%%' % ((i + 1), (i + 1) / 250 * 100), end='\r')
    print(world + ': ' + str(state_goal))

40x40_world: [('27 37 N', 'P'), ('19 5 N', 'C'), ('15 15 N', 'D'), ('37 15 N', 'T'), ('31 37 N', 'P'), ('23 31 N', 'O'), ('21 8 N', 'L'), ('16 25 N', 'D'), ('27 38 N', 'D'), ('21 30 N', 'P'), ('8 31 N', 'T'), ('4 34 N', 'D'), ('1 20 N', 'O'), ('21 13 N', 'D'), ('4 20 N', 'O'), ('23 37 N', 'D'), ('28 33 N', 'L'), ('30 2 N', 'O'), ('8 19 N', 'D'), ('19 37 N', 'P'), ('11 1 N', 'C'), ('6 23 N', 'P'), ('18 11 N', 'D'), ('38 8 N', 'O'), ('10 6 N', 'T'), ('18 14 N', 'D'), ('40 33 N', 'D'), ('1 35 N', 'D'), ('10 5 N', 'L'), ('29 8 N', 'D'), ('16 35 N', 'L'), ('6 4 N', 'D'), ('38 18 N', 'O'), ('8 27 N', 'P'), ('27 28 N', 'T'), ('12 32 N', 'T'), ('13 11 N', 'T'), ('19 12 N', 'T'), ('40 7 N', 'O'), ('27 35 N', 'D'), ('8 22 N', 'D'), ('24 11 N', 'L'), ('8 31 N', 'L'), ('13 31 N', 'P'), ('4 23 N', 'D'), ('40 6 N', 'O'), ('16 37 N', 'P'), ('12 13 N', 'T'), ('11 30 N', 'L'), ('10 7 N', 'O'), ('38 7 N', 'C'), ('11 2 N', 'C'), ('7 9 N', 'O'), ('20 27 N', 'T'), ('36 20 N', 'O'), ('20 38 N', 'L'), ('33 1

80x80_world: [('28 24 N', 'T'), ('63 16 N', 'T'), ('20 76 N', 'C'), ('41 56 N', 'O'), ('31 21 N', 'O'), ('32 15 N', 'D'), ('40 28 N', 'O'), ('51 57 N', 'T'), ('7 70 N', 'D'), ('10 37 N', 'P'), ('49 38 N', 'L'), ('49 11 N', 'T'), ('77 25 N', 'P'), ('3 12 N', 'L'), ('64 20 N', 'P'), ('46 4 N', 'D'), ('57 80 N', 'L'), ('13 71 N', 'L'), ('53 53 N', 'C'), ('59 18 N', 'P'), ('53 16 N', 'P'), ('68 51 N', 'D'), ('29 6 N', 'L'), ('67 42 N', 'L'), ('69 43 N', 'O'), ('38 13 N', 'D'), ('4 37 N', 'P'), ('21 19 N', 'D'), ('49 56 N', 'L'), ('41 12 N', 'T'), ('41 62 N', 'P'), ('71 37 N', 'T'), ('42 11 N', 'L'), ('27 69 N', 'L'), ('58 69 N', 'L'), ('45 61 N', 'T'), ('23 30 N', 'P'), ('42 7 N', 'L'), ('24 79 N', 'D'), ('8 14 N', 'O'), ('54 42 N', 'T'), ('22 55 N', 'L'), ('63 79 N', 'P'), ('2 14 N', 'O'), ('76 55 N', 'C'), ('37 66 N', 'D'), ('6 8 N', 'T'), ('39 2 N', 'D'), ('11 29 N', 'O'), ('44 21 N', 'T'), ('29 2 N', 'O'), ('14 35 N', 'O'), ('65 37 N', 'D'), ('42 30 N', 'L'), ('62 23 N', 'D'), ('37 3 N

In [17]:
OBJECTS_WORLDS = {1: '25x25_world_g3', 2: '25x25_world_g4', 3: '25x25_world_g5', 4: '25x25_world_g6',
                  5: '25x25_world_g7', 6: '25x25_world_g8', 7: '25x25_world_g9', 8: '25x25_world_g10'}

# Load mazeworld information

for key in OBJECTS_WORLDS.keys():

    world = OBJECTS_WORLDS[key]
    with open('data/configs/' + (world + '.yaml')) as file:
        config_params = yaml.full_load(file)

        n_cols = config_params['n_cols']
        n_rows = config_params['n_rows']
        walls = config_params['walls']
        task_states = config_params['task_states']
        tasks = config_params['tasks']

    wacmw_test = SimpleWallMazeWorld2()
    X_w, A_w, P_w = wacmw_test.generate_world(n_rows, n_cols, task_states, walls, 'stochastic', 0.15)
    state_goal = []
    for i in range(250):
        goal = np.random.choice(tasks)
        goal_states = get_goal_states(list(X_w), goal)
        c = wacmw_test.generate_rewards(goal, X_w, A_w)
        mdp = MDP(X_w, A_w, P_w, c, 0.9, goal_states, 'rewards', False)
        pol, q = mdp.policy_iteration()
        nonzerostates = np.nonzero(q.sum(axis=1))[0]
        init_states = [np.delete(nonzerostates, np.argwhere(nonzerostates == g)) for g in goal_states][0]
        x0 = X_w[np.random.choice(init_states)]
        state_goal += [(x0, goal)]
        print('Completed %d iterations, %.2f%%' % ((i + 1), (i + 1) / 250 * 100), end='\r')
    print(world + ': ' + str(state_goal))

25x25_world_g3: [('1 1 N', 'D'), ('6 24 N', 'C'), ('16 24 N', 'O'), ('23 2 N', 'O'), ('22 7 N', 'O'), ('12 8 N', 'C'), ('3 1 N', 'O'), ('25 12 N', 'D'), ('21 18 N', 'D'), ('9 17 N', 'D'), ('7 13 N', 'D'), ('11 7 N', 'D'), ('6 10 N', 'C'), ('12 13 N', 'C'), ('24 16 N', 'D'), ('19 17 N', 'D'), ('21 6 N', 'C'), ('16 4 N', 'O'), ('10 4 N', 'D'), ('10 11 N', 'O'), ('16 5 N', 'C'), ('3 22 N', 'D'), ('22 2 N', 'D'), ('5 23 N', 'O'), ('2 1 N', 'D'), ('3 6 N', 'C'), ('2 18 N', 'O'), ('2 12 N', 'D'), ('12 14 N', 'C'), ('16 14 N', 'D'), ('6 21 N', 'O'), ('9 24 N', 'C'), ('5 18 N', 'O'), ('2 4 N', 'D'), ('11 20 N', 'D'), ('25 13 N', 'D'), ('18 16 N', 'C'), ('7 5 N', 'O'), ('5 16 N', 'O'), ('23 4 N', 'D'), ('7 19 N', 'O'), ('9 5 N', 'O'), ('17 9 N', 'C'), ('2 15 N', 'O'), ('5 23 N', 'O'), ('7 20 N', 'O'), ('12 1 N', 'O'), ('4 19 N', 'C'), ('19 23 N', 'C'), ('25 9 N', 'D'), ('3 19 N', 'D'), ('14 5 N', 'D'), ('7 15 N', 'O'), ('25 13 N', 'D'), ('13 12 N', 'O'), ('20 9 N', 'O'), ('24 3 N', 'C'), ('6 24

25x25_world_g5: [('8 4 N', 'B'), ('19 5 N', 'D'), ('13 10 N', 'B'), ('20 9 N', 'O'), ('8 2 N', 'O'), ('3 2 N', 'P'), ('4 15 N', 'B'), ('12 4 N', 'B'), ('2 13 N', 'C'), ('4 22 N', 'B'), ('15 20 N', 'O'), ('4 4 N', 'B'), ('8 19 N', 'C'), ('8 12 N', 'C'), ('19 10 N', 'O'), ('3 18 N', 'C'), ('23 20 N', 'P'), ('25 8 N', 'B'), ('15 7 N', 'D'), ('24 20 N', 'O'), ('13 16 N', 'P'), ('23 22 N', 'O'), ('15 3 N', 'O'), ('21 8 N', 'B'), ('17 9 N', 'D'), ('7 3 N', 'O'), ('4 24 N', 'D'), ('22 22 N', 'D'), ('17 18 N', 'O'), ('24 1 N', 'P'), ('10 3 N', 'D'), ('21 2 N', 'P'), ('12 8 N', 'B'), ('11 7 N', 'O'), ('15 7 N', 'B'), ('11 21 N', 'C'), ('20 19 B', 'C'), ('11 2 N', 'C'), ('13 5 N', 'C'), ('10 13 N', 'D'), ('21 8 N', 'D'), ('19 10 N', 'P'), ('13 4 N', 'B'), ('2 23 N', 'C'), ('12 21 N', 'D'), ('17 24 N', 'C'), ('25 10 N', 'O'), ('17 13 N', 'C'), ('18 22 N', 'C'), ('13 7 N', 'B'), ('18 24 N', 'C'), ('10 7 N', 'C'), ('22 6 N', 'C'), ('8 14 N', 'C'), ('9 13 N', 'C'), ('2 21 N', 'B'), ('15 21 N', 'B'),

25x25_world_g7: [('2 13 N', 'E'), ('13 9 N', 'P'), ('12 23 N', 'D'), ('19 16 N', 'A'), ('16 1 N', 'B'), ('19 21 N', 'C'), ('16 19 N', 'O'), ('17 15 N', 'D'), ('9 25 N', 'E'), ('14 5 N', 'O'), ('7 23 N', 'D'), ('19 13 N', 'P'), ('19 4 N', 'E'), ('19 18 N', 'P'), ('17 16 N', 'O'), ('23 6 N', 'D'), ('9 5 N', 'P'), ('5 12 N', 'E'), ('2 20 N', 'C'), ('20 18 N', 'O'), ('8 8 N', 'E'), ('6 2 N', 'C'), ('15 2 N', 'E'), ('9 20 N', 'C'), ('14 9 N', 'B'), ('20 10 N', 'O'), ('22 18 N', 'E'), ('1 22 N', 'P'), ('21 22 N', 'O'), ('11 17 N', 'B'), ('5 7 N', 'A'), ('8 13 N', 'P'), ('13 23 E', 'O'), ('22 13 N', 'C'), ('15 16 N', 'E'), ('25 25 N', 'P'), ('3 24 N', 'D'), ('22 21 N', 'B'), ('3 21 N', 'A'), ('17 24 N', 'O'), ('13 15 N', 'C'), ('10 5 N', 'C'), ('19 16 N', 'P'), ('18 18 N', 'C'), ('21 3 N', 'B'), ('16 25 N', 'A'), ('12 15 N', 'C'), ('11 22 N', 'B'), ('22 21 N', 'B'), ('19 22 N', 'D'), ('21 23 N', 'E'), ('7 14 N', 'O'), ('1 5 N', 'E'), ('19 22 N', 'D'), ('23 16 N', 'A'), ('24 11 N', 'D'), ('13 

25x25_world_g9: [('22 20 N', 'A'), ('23 1 N', 'T'), ('5 10 N', 'E'), ('22 20 N', 'A'), ('18 9 N', 'B'), ('9 2 N', 'P'), ('16 7 N', 'T'), ('14 15 N', 'A'), ('17 21 N', 'E'), ('11 4 N', 'D'), ('24 10 N', 'B'), ('11 13 N', 'B'), ('23 6 N', 'O'), ('18 23 N', 'O'), ('7 3 N', 'C'), ('4 22 N', 'B'), ('18 24 N', 'O'), ('6 22 N', 'P'), ('17 22 N', 'O'), ('11 13 N', 'B'), ('12 9 N', 'C'), ('15 21 N', 'E'), ('13 13 N', 'D'), ('15 5 N', 'E'), ('15 7 N', 'L'), ('15 22 N', 'B'), ('9 2 N', 'E'), ('18 10 N', 'E'), ('25 16 N', 'T'), ('16 24 N', 'L'), ('1 12 N', 'L'), ('9 24 N', 'P'), ('22 9 N', 'L'), ('5 6 N', 'T'), ('13 21 N', 'C'), ('23 17 N', 'O'), ('16 1 N', 'T'), ('7 7 N', 'T'), ('2 22 N', 'A'), ('9 8 N', 'P'), ('19 2 N', 'C'), ('13 25 N', 'C'), ('16 24 N', 'A'), ('23 9 N', 'D'), ('21 2 N', 'C'), ('3 12 N', 'A'), ('22 9 N', 'D'), ('15 3 N', 'T'), ('24 24 N', 'T'), ('7 14 N', 'E'), ('1 11 N', 'P'), ('13 17 N', 'T'), ('22 20 N', 'C'), ('20 8 N', 'T'), ('3 3 N', 'C'), ('14 9 N', 'B'), ('12 4 N', 'T')

In [9]:
world = '10x10_world_3'
with open('data/configs/' + (world + '.yaml')) as file:
    config_params = yaml.full_load(file)

    n_cols = config_params['n_cols']
    n_rows = config_params['n_rows']
    walls = config_params['walls']
    task_states = config_params['task_states']
    tasks = config_params['tasks']
    
fig, _ = create_world_view(n_rows, n_cols, objs_states, walls)
fig.show()

<IPython.core.display.Javascript object>

In [13]:
world = '10x10_world_3'
with open('data/configs/' + (world + '.yaml')) as file:
    config_params = yaml.full_load(file)

    n_cols = config_params['n_cols']
    n_rows = config_params['n_rows']
    walls = config_params['walls']
    task_states = config_params['task_states']
    tasks = config_params['tasks']

wacmw_test = SimpleWallMazeWorld2()
X_w, A_w, P_w = wacmw_test.generate_world(n_rows, n_cols, task_states, walls, 'stochastic', 0.15)
state_goal = []
state_goal_samples = []
for i in range(250):
    goal = np.random.choice(tasks)
    goal_states = get_goal_states(list(X_w), goal)
    c = wacmw_test.generate_rewards(goal, X_w, A_w)
    mdp = MDP(X_w, A_w, P_w, c, 0.9, goal_states, 'rewards', False)
    pol, q = mdp.policy_iteration()
    nonzerostates = np.nonzero(q.sum(axis=1))[0]
    init_states = [np.delete(nonzerostates, np.argwhere(nonzerostates == g)) for g in goal_states][0]
    x0 = X_w[np.random.choice(init_states, size=20)]
    state_goal += [x0[0]]
    state_goal_samples += [list(x0)]
    print('Completed %d iterations, %.2f%%' % ((i + 1), (i + 1) / 250 * 100), end='\r')
print('trajectory: ' + str(state_goal))
print('samples: ' + str(state_goal_samples))

trajectory: ['4 4 N', '8 1 N', '1 7 P', '3 7 N', '9 1 L', '9 10 N', '9 7 N', '2 5 N', '7 9 N', '7 5 N', '3 6 N', '7 2 N', '10 5 N', '3 5 N', '4 8 N', '8 5 N', '10 2 N', '10 5 N', '6 8 N', '1 3 N', '3 8 N', '2 10 N', '1 4 N', '10 6 N', '5 7 N', '8 10 N', '3 9 N', '10 6 N', '3 6 N', '9 9 N', '7 1 N', '5 10 N', '3 4 N', '2 9 N', '2 8 N', '10 4 N', '9 5 T', '10 10 D', '2 7 N', '9 3 N', '5 4 N', '7 6 N', '10 8 N', '10 8 N', '4 9 N', '10 5 N', '1 3 N', '7 2 N', '1 9 N', '9 4 N', '2 10 N', '9 5 T', '4 4 N', '1 10 N', '2 6 N', '6 4 N', '9 1 L', '10 9 N', '1 8 N', '4 4 N', '10 1 N', '6 8 N', '5 1 O', '8 5 N', '2 6 N', '5 8 N', '7 4 N', '3 4 N', '5 2 N', '3 1 N', '8 4 N', '1 2 N', '3 8 N', '9 3 N', '5 8 N', '8 9 N', '9 8 N', '3 9 N', '7 10 N', '6 1 N', '2 7 N', '4 5 N', '9 10 N', '7 9 N', '10 1 N', '7 7 N', '2 5 N', '7 8 N', '3 4 N', '5 4 N', '8 6 N', '1 1 N', '9 2 N', '5 1 O', '1 9 N', '2 1 N', '6 1 N', '2 2 N', '3 4 N', '10 7 N', '7 6 N', '1 1 N', '3 7 N', '8 2 N', '3 9 N', '2 6 N', '8 5 N', '