In [None]:
import numpy as np
from d4rl.pointmaze import q_iteration
from d4rl.pointmaze.gridcraft import grid_env
from d4rl.pointmaze.gridcraft import grid_spec


ZEROS = np.zeros((2,), dtype=np.float32)
ONES = np.zeros((2,), dtype=np.float32)


class WaypointController(object):
    def __init__(self, maze_str, solve_thresh=0.1, p_gain=10.0, d_gain=-1.0):
        self.maze_str = maze_str
        self._target = -1000 * ONES

        self.p_gain = p_gain
        self.d_gain = d_gain
        self.solve_thresh = solve_thresh
        self.vel_thresh = 0.1

        self._waypoint_idx = 0
        self._waypoints = []
        self._waypoint_prev_loc = ZEROS

        self.env = grid_env.GridEnv(grid_spec.spec_from_string(maze_str))

    def current_waypoint(self):
        return self._waypoints[self._waypoint_idx]

    def get_action(self, location, velocity, target):
        if np.linalg.norm(self._target - np.array(self.gridify_state(target))) > 1e-3: 
            #print('New target!', target, 'old:', self._target)
            self._new_target(location, target)

        dist = np.linalg.norm(location - self._target)
        vel = self._waypoint_prev_loc - location
        vel_norm = np.linalg.norm(vel)
        task_not_solved = (dist >= self.solve_thresh) or (vel_norm >= self.vel_thresh)

        if task_not_solved:
            next_wpnt = self._waypoints[self._waypoint_idx]
        else:
            next_wpnt = self._target

        # Compute control
        prop = next_wpnt - location
        action = self.p_gain * prop + self.d_gain * velocity

        dist_next_wpnt = np.linalg.norm(location - next_wpnt)
        if task_not_solved and (dist_next_wpnt < self.solve_thresh) and (vel_norm<self.vel_thresh):
            self._waypoint_idx += 1
            if self._waypoint_idx == len(self._waypoints)-1:
                assert np.linalg.norm(self._waypoints[self._waypoint_idx] - self._target) <= self.solve_thresh

        self._waypoint_prev_loc = location
        action = np.clip(action, -1.0, 1.0)
        return action, (not task_not_solved)

    def gridify_state(self, state):
        return (int(round(state[0])), int(round(state[1])))

    def _new_target(self, start, target):
        #print('Computing waypoints from %s to %s' % (start, target))
        start = self.gridify_state(start)
        start_idx = self.env.gs.xy_to_idx(start)
        target = self.gridify_state(target)
        target_idx = self.env.gs.xy_to_idx(target)
        self._waypoint_idx = 0

        self.env.gs[target] = grid_spec.REWARD
        q_values = q_iteration.q_iteration(env=self.env, num_itrs=50, discount=0.99)
        # compute waypoints by performing a rollout in the grid
        max_ts = 100
        s = start_idx
        waypoints = []
        for i in range(max_ts):
            a = np.argmax(q_values[s])
            new_s, reward = self.env.step_stateless(s, a)

            waypoint = self.env.gs.idx_to_xy(new_s)
            if new_s != target_idx:
                waypoint = waypoint - np.random.uniform(size=(2,))*0.2
            waypoints.append(waypoint)
            s = new_s
            if new_s == target_idx:
                break
        self.env.gs[target] = grid_spec.EMPTY
        self._waypoints = waypoints
        self._waypoint_prev_loc = start
        self._target = target


if __name__ == "__main__":
    print(q_iteration.__file__)
    TEST_MAZE = \
            "######\\"+\
            "#OOOO#\\"+\
            "#O##O#\\"+\
            "#OOOO#\\"+\
            "######"
    controller = WaypointController(TEST_MAZE)
    start = np.array((1,1), dtype=np.float32)
    target = np.array((4,3), dtype=np.float32)
    act, done = controller.get_action(start, target)
    print('wpt:', controller._waypoints)
    print(act, done)
    import pdb; pdb.set_trace()
    pass


In [None]:
def reset_env(env, sg_dict):
    env.reset()
    env.set_target(sg_dict['goal_cell'])
    return env.reset_to_location(sg_dict['reset_cell'])


# TODO: make the data saving as before not following d4rl
# the controller is from d4rl

# each episode not necessarily continue from the last step, that is why the dataset is not continuous
def collect_short_dataset(env_name, num_data, seed=1127, verbose=True):
    # seed
    np.random.seed(seed)            
    random.seed(seed)               
    
    # environment initialisation
    env = datasets.load_environment(env_name)
    maze = env.str_maze_spec
    # seed
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    controller = waypoint_controller.WaypointController(maze)
    # get data collecting start state and goal pairs
    train_start_state_goal = get_start_state_goal_pairs(env_name, env)

    done = False
    data = reset_data()
    ts = 0
    num_dc = len(train_start_state_goal)
    episode_idx = 0
    sg_dict = train_start_state_goal[ episode_idx % num_dc ]
    # print(sg_dict)
    obs = reset_env(env, sg_dict)

    for _ in range(num_data):
        pass
    
    while data_idx < int(num_data)-1:
        # compute actions
        position = obs[0:2]
        velocity = obs[2:4]
        act, done = controller.get_action(position, velocity, env._target)
        if noisy:
            act = act + np.random.randn(*act.shape)*0.5
        act = np.clip(act, -1.0, 1.0)
        
        if ts >= max_episode_steps:
            done = True
        append_data(data, obs, act, env._target, done, env.sim.data)
        # (s, a) in the dataset
        
        data_idx += 1
        ns, _, _, _ = env.step(act)

        if len(data['observations']) % 10000 == 0:
            print(len(data['observations']))

        ts += 1
        if done:
            env.set_target()
            done = False
            ts = 0
        else:
            s = ns

#         if args.render:
#             env.render()
        
        
        if terminated or truncated:
            if info['success']:
                success_count += 1

                observation_data["episode_id"][data_idx] = episode_idx
                observation_data["observation"][data_idx] = obs["observation"]
                observation_data["achieved_goal"][data_idx] = obs["achieved_goal"]
                observation_data["desired_goal"][data_idx] = obs["desired_goal"]
                termination_data[data_idx] = terminated or truncated

                data_idx += 1
                episode_idx += 1
                episode_start_idx = data_idx
        
            else:
                data_idx = episode_start_idx
                # print(f' === failed {episode_idx} ===')

            sg_dict = train_start_state_goal[ episode_idx % num_dc ]
            # print(sg_dict)

            obs, _ = env.reset(options=sg_dict)

            terminated = False
            truncated = False
            
            if verbose:
                print("STEPS RECORDED:", data_idx)
                print("EPISODES RECORDED:", episode_idx)
                print("SUCCESS EPISODES RECORDED:", success_count)

    dataset = {
        "observations":observation_data, 
        "actions":action_data, 
        "terminations":termination_data,
        "success_count": success_count,
        "episode_count": episode_idx,
    }

    return env, dataset
    
    
collect_short_dataset(env_name, 1)



In [None]:
maze_map = np.array(env.unwrapped.maze._maze_map)
maze = env.unwrapped.maze
_background = maze_map == 1


train_start_state_goal = get_start_state_goal_pairs(dataset_name, env)

for i in range(len(train_start_state_goal) // 4):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5), constrained_layout=True)

    for idx, ax in enumerate(axes):
        ax.imshow(
            _background * 0.5,
            cmap=plt.cm.binary,
            vmin=0,
            vmax=1,
        )
        # collected trajectory
        obs = trjs[4*i + idx]['observations']['observation'][:, :2]
        ij = cell_xy_to_rowcol(maze, obs) -0.5
        colors = plt.cm.Blues(np.linspace(0.4, 1.0, len(ij)))
        ax.scatter(ij[:, 1], ij[:, 0], s=10, c=colors)

        # start and goal
        goal_cell = train_start_state_goal[4*i + idx]['goal_cell']
        reset_cell = train_start_state_goal[4*i + idx]['reset_cell']
        ax.scatter(goal_cell[1], goal_cell[0], color='Red', s=100)
        ax.scatter(reset_cell[1], reset_cell[0], color='Blue', s=100)

        ax.axis('off')

    plt.show()

In [None]:
def cell_xy_to_rowcol(maze, xy_pos: np.ndarray) -> np.ndarray:
    """Converts a cell x and y coordinates to `(i,j)`"""

    i = np.reshape((maze.y_map_center - xy_pos[:, 1]) / maze.maze_size_scaling, (-1, 1))
    j = np.reshape((xy_pos[:, 0] + maze.x_map_center) / maze.maze_size_scaling, (-1, 1))

    return np.concatenate([i,j], axis=-1)

In [None]:
MAZE_BOUNDS = {
    "maze2d-umaze-v1": (0, 5, 0, 5),
    "maze2d-medium-v1": (0, 8, 0, 8),
    "maze2d-large-v1": (0, 9, 0, 12),
    "maze2d-xxlarge-v1": (0, 18, 0, 24),
}


class MazeRenderer:
    def __init__(self, env):
        if type(env) is str:
            env = load_environment(env)
            self._config = env._config
            self._background = self._config != " "
        self._remove_margins = False
        self._extent = (0, 1, 1, 0)

    def renders(self, observations, conditions=None, title=None):
        plt.clf()
        fig = plt.gcf()
        fig.set_size_inches(5, 5)
        plt.imshow(
            self._background * 0.5,
            extent=self._extent,
            cmap=plt.cm.binary,
            vmin=0,
            vmax=1,
        )

        path_length = len(observations)
        colors = plt.cm.jet(np.linspace(0, 1, path_length))
        plt.plot(observations[:, 1], observations[:, 0], c="black", zorder=10)
        plt.scatter(observations[:, 1], observations[:, 0], c=colors, zorder=20)
        plt.axis("off")
        plt.title(title)
        img = plot2img(fig, remove_margins=self._remove_margins)
        return img

    def composite(self, savepath, paths, ncol=5, **kwargs):
        """
        savepath : str
        observations : [ n_paths x horizon x 2 ]
        """
        assert (
            len(paths) % ncol == 0
        ), "Number of paths must be divisible by number of columns"

        images = []
        for path, kw in zipkw(paths, **kwargs):
            img = self.renders(*path, **kw)
            images.append(img)
        images = np.stack(images, axis=0)

        nrow = len(images) // ncol
        images = einops.rearrange(
            images, "(nrow ncol) H W C -> (nrow H) (ncol W) C", nrow=nrow, ncol=ncol
        )
        imageio.imsave(savepath, images)
        print(f"Saved {len(paths)} samples to: {savepath}")


class Maze2dRenderer(MazeRenderer):
    def __init__(self, env, observation_dim=None):
        self.env_name = env
        self.env = load_environment(env)
        self._background = self.env.maze_arr == 10
        self.observation_dim = np.prod(self.env.observation_space.shape)
        self.action_dim = np.prod(self.env.action_space.shape)
        self.goal = None
        self._remove_margins = False
        self._extent = (0, 1, 1, 0)

    def renders(self, observations, conditions=None, **kwargs):
        bounds = MAZE_BOUNDS[self.env_name]

        observations = observations + 0.5
        if len(bounds) == 2:
            _, scale = bounds
            observations /= scale
        elif len(bounds) == 4:
            _, iscale, _, jscale = bounds
            observations[:, 0] /= iscale
            observations[:, 1] /= jscale
        else:
            raise RuntimeError(f"Unrecognized bounds for {self.env_name}: {bounds}")

        if conditions is not None:
            conditions /= scale
        return super().renders(observations, conditions, **kwargs)


