In [3]:
### Testing reset

from four_room.env import FourRoomsEnv
from four_room.wrappers import gym_wrapper
from four_room.utils import obs_to_state, obs_to_img
import gymnasium as gym
import dill

gym.register('MiniGrid-FourRooms-v1', FourRoomsEnv)

with open('./four_room/configs/fourrooms_train_config.pl', 'rb') as file:
    train_config = dill.load(file)
        #(player location x, player location y, player direction, goal location x, goal location y, 
         #   door position up, door position down, door position left, door position right)


def describe(state):
    state=obs_to_state(state)
    print("----Observation----")
    print(f"Player: {state[0]},{state[1]}, dir: {state[2]}")
    print(f"Goal: {state[3]},{state[4]}")
    print(f"Doors: up:{state[5]}, down:{state[6]}, left:{state[7]}, right:{state[8]}")



def make_env_fn(config, seed: int= 0, rank: int = 0):
    def _init():
        env = gym_wrapper(gym.make('MiniGrid-FourRooms-v1', 
                    agent_pos=config['agent positions'], 
                    goal_pos=config['goal positions'], 
                    doors_pos=config['topologies'], 
                    agent_dir=config['agent directions']))
        env.reset(seed=seed+rank)
        return env

    return _init

env = make_env_fn(train_config,seed=123958)()
target_env = make_env_fn(train_config,seed=412318)()

target, _ = target_env.reset()
obs,_ = env.reset(options={'load_state':target})
target2, _, _ , _ , _  = target_env.step(2)
obs2,  _ ,_ ,_ ,_  = env.step(2)

print("----To mimic: ----")
describe(target)
describe(target2)

print("----Reconstruction: ---")

describe(obs)
describe(obs2)



----To mimic: ----
----Observation----
Player: 5,5, dir: 1
Goal: 2,1
Doors: up:0, down:1, left:1, right:1
----Observation----
Player: 5,6, dir: 1
Goal: 2,1
Doors: up:0, down:1, left:1, right:1
----Reconstruction: ---
----Observation----
Player: 5,5, dir: 1
Goal: 3,7
Doors: up:2, down:1, left:0, right:1
----Observation----
Player: 5,6, dir: 1
Goal: 3,7
Doors: up:2, down:1, left:0, right:1


  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [1]:
training_set=[(1, 3, 1, 3, 6, 1, 2, 1, 1),
(6, 7, 0, 6, 5, 1, 2, 2, 2),
(6, 7, 2, 2, 2, 0, 0, 1, 2),
(2, 1, 0, 2, 3, 1, 0, 2, 0),
(5, 7, 3, 3, 3, 2, 1, 0, 2),
(7, 1, 3, 1, 5, 2, 0, 2, 2),
(5, 5, 3, 1, 3, 1, 2, 2, 0),
(6, 1, 3, 2, 3, 1, 1, 0, 2),
(6, 1, 1, 7, 6, 2, 0, 2, 1),
(5, 6, 0, 7, 2, 1, 2, 0, 0),
(1, 7, 3, 7, 5, 1, 0, 0, 0),
(7, 2, 0, 3, 7, 2, 1, 0, 1),
(7, 1, 3, 7, 5, 0, 1, 2, 2),
(1, 3, 3, 3, 5, 1, 1, 0, 0),
(6, 7, 0, 2, 6, 2, 1, 1, 2),
(3, 2, 3, 7, 3, 0, 1, 1, 0),
(5, 6, 2, 3, 7, 1, 0, 1, 2),
(2, 2, 1, 5, 6, 0, 2, 0, 0),
(5, 2, 0, 6, 6, 0, 2, 0, 1),
(2, 1, 2, 6, 2, 0, 0, 2, 0),
(7, 6, 0, 2, 3, 1, 2, 0, 1),
(6, 5, 0, 5, 3, 0, 0, 1, 1),
(7, 6, 0, 3, 2, 2, 0, 1, 0),
(3, 7, 0, 6, 6, 2, 0, 0, 0),
(1, 6, 3, 6, 1, 1, 1, 2, 1),
(1, 6, 1, 2, 5, 0, 2, 1, 0),
(2, 2, 3, 3, 5, 1, 0, 1, 0),
(1, 7, 0, 1, 2, 1, 0, 2, 2),
(6, 7, 1, 3, 7, 2, 2, 0, 0),
(3, 2, 3, 2, 7, 0, 2, 1, 1),
(5, 5, 3, 6, 3, 2, 1, 2, 1),
(5, 2, 1, 1, 6, 0, 2, 0, 2),
(1, 6, 2, 6, 1, 1, 1, 1, 0),
(2, 1, 1, 5, 5, 1, 1, 1, 1),
(5, 5, 1, 2, 1, 0, 1, 1, 1),
(1, 7, 3, 3, 1, 2, 2, 2, 1),
(3, 3, 2, 1, 5, 2, 0, 1, 1),
(3, 1, 0, 2, 7, 1, 0, 0, 1),
(5, 6, 3, 1, 3, 2, 2, 1, 2),
(1, 2, 0, 7, 6, 0, 1, 2, 1),
(1, 3, 1, 3, 6, 1, 2, 1, 1),
(6, 7, 0, 6, 5, 1, 2, 2, 2),
(6, 7, 2, 2, 2, 0, 0, 1, 2),
(2, 1, 0, 2, 3, 1, 0, 2, 0),
(5, 7, 3, 3, 3, 2, 1, 0, 2),
(7, 1, 3, 1, 5, 2, 0, 2, 2),
(5, 5, 3, 1, 3, 1, 2, 2, 0),
(6, 1, 3, 2, 3, 1, 1, 0, 2),
(6, 1, 1, 7, 6, 2, 0, 2, 1),
(5, 6, 0, 7, 2, 1, 2, 0, 0),
(1, 7, 3, 7, 5, 1, 0, 0, 0),
(7, 2, 0, 3, 7, 2, 1, 0, 1),
(7, 1, 3, 7, 5, 0, 1, 2, 2),
(1, 3, 3, 3, 5, 1, 1, 0, 0),
(6, 7, 0, 2, 6, 2, 1, 1, 2),
(3, 2, 3, 7, 3, 0, 1, 1, 0),
(5, 6, 2, 3, 7, 1, 0, 1, 2),
(2, 2, 1, 5, 6, 0, 2, 0, 0),
(5, 2, 0, 6, 6, 0, 2, 0, 1),
(2, 1, 2, 6, 2, 0, 0, 2, 0),
(7, 6, 0, 2, 3, 1, 2, 0, 1),
(6, 5, 0, 5, 3, 0, 0, 1, 1),
(7, 6, 0, 3, 2, 2, 0, 1, 0),
(3, 7, 0, 6, 6, 2, 0, 0, 0),
(1, 6, 3, 6, 1, 1, 1, 2, 1),
(1, 6, 1, 2, 5, 0, 2, 1, 0),
(2, 2, 3, 3, 5, 1, 0, 1, 0),
(1, 7, 0, 1, 2, 1, 0, 2, 2),
(6, 7, 1, 3, 7, 2, 2, 0, 0),
(3, 2, 3, 2, 7, 0, 2, 1, 1),
(5, 5, 3, 6, 3, 2, 1, 2, 1),
(5, 2, 1, 1, 6, 0, 2, 0, 2),
(1, 6, 2, 6, 1, 1, 1, 1, 0),
(2, 1, 1, 5, 5, 1, 1, 1, 1),
(5, 5, 1, 2, 1, 0, 1, 1, 1),
(1, 7, 3, 3, 1, 2, 2, 2, 1),
(3, 3, 2, 1, 5, 2, 0, 1, 1),
(3, 1, 0, 2, 7, 1, 0, 0, 1),
(5, 6, 3, 1, 3, 2, 2, 1, 2),
(1, 2, 0, 7, 6, 0, 1, 2, 1),
(1, 3, 1, 3, 6, 1, 2, 1, 1),
(6, 7, 0, 6, 5, 1, 2, 2, 2),
(6, 7, 2, 2, 2, 0, 0, 1, 2),
(2, 1, 0, 2, 3, 1, 0, 2, 0),
(5, 7, 3, 3, 3, 2, 1, 0, 2),
(7, 1, 3, 1, 5, 2, 0, 2, 2),
(5, 5, 3, 1, 3, 1, 2, 2, 0),
(6, 1, 3, 2, 3, 1, 1, 0, 2),
(6, 1, 1, 7, 6, 2, 0, 2, 1),
(5, 6, 0, 7, 2, 1, 2, 0, 0),
(1, 7, 3, 7, 5, 1, 0, 0, 0),
(7, 2, 0, 3, 7, 2, 1, 0, 1),
(7, 1, 3, 7, 5, 0, 1, 2, 2),
(1, 3, 3, 3, 5, 1, 1, 0, 0),
(6, 7, 0, 2, 6, 2, 1, 1, 2),
(3, 2, 3, 7, 3, 0, 1, 1, 0),
(5, 6, 2, 3, 7, 1, 0, 1, 2),
(2, 2, 1, 5, 6, 0, 2, 0, 0),
(5, 2, 0, 6, 6, 0, 2, 0, 1),
(2, 1, 2, 6, 2, 0, 0, 2, 0),
(7, 6, 0, 2, 3, 1, 2, 0, 1),
(6, 5, 0, 5, 3, 0, 0, 1, 1),
(7, 6, 0, 3, 2, 2, 0, 1, 0),
(3, 7, 0, 6, 6, 2, 0, 0, 0),
(1, 6, 3, 6, 1, 1, 1, 2, 1),
(1, 6, 1, 2, 5, 0, 2, 1, 0),
(2, 2, 3, 3, 5, 1, 0, 1, 0),
(1, 7, 0, 1, 2, 1, 0, 2, 2),
(6, 7, 1, 3, 7, 2, 2, 0, 0),
(3, 2, 3, 2, 7, 0, 2, 1, 1),
(5, 5, 3, 6, 3, 2, 1, 2, 1),
(5, 2, 1, 1, 6, 0, 2, 0, 2)]

In [12]:
starting_pos=set()

for i in set(training_set):
    starting_pos.add((i[0],i[1]))

starting=list(starting_pos)
starting.sort(key=lambda x: x[0]+x[1]/10)
print((starting))


[(1, 2), (1, 3), (1, 6), (1, 7), (2, 1), (2, 2), (3, 1), (3, 2), (3, 3), (3, 7), (5, 2), (5, 5), (5, 6), (5, 7), (6, 1), (6, 5), (6, 7), (7, 1), (7, 2), (7, 6)]


In [103]:
import numpy as np
def obs_to_state(obs):
    """
        Turn a numpy observation array into a tuple of the form:
        (player location x, player location y, player direction, goal location x, goal location y, 
            door position up, door position down, door position left, door position right)
    """
    walls = obs[2]
    lower_right = np.array([np.where(walls.sum(axis=1) == 9)[0][0], np.where(walls.sum(axis=0) == 9)[0][0]])
    shift = np.array([8,8]) - lower_right
    if np.where(walls.sum(axis=1) == 9)[0][0] == 0 and np.where(walls.sum(axis=1) == 9)[0][1] == 8:
        shift[0] = 0
    if np.where(walls.sum(axis=0) == 9)[0][0] == 0 and np.where(walls.sum(axis=0) == 9)[0][1] == 8:
        shift[1] = 0

    uncentered_obs = np.roll(obs, tuple(shift), axis=(1,2))
    player_loc = (np.where(uncentered_obs[0] == 1)[1][0], np.where(uncentered_obs[0] == 1)[0][0])
    player_dir_loc = (np.where(uncentered_obs[1] == 1)[1][0], np.where(uncentered_obs[1] == 1)[0][0])
    player_dir_loc = np.array(player_loc) - np.array(player_dir_loc)
        
    if player_dir_loc[0] == 1 and  player_dir_loc[1] == 0:
        # left
        player_dir = 2
    if player_dir_loc[0] == 0 and  player_dir_loc[1] == -1:
        # down
        player_dir = 1
    if player_dir_loc[0] == -1 and  player_dir_loc[1] == 0:
        # right
        player_dir = 0
    if player_dir_loc[0] == 0 and  player_dir_loc[1] == 1:
        # up
        player_dir = 3
    

    goal_loc = (np.where(uncentered_obs[3] == 1)[1][0], np.where(uncentered_obs[3] == 1)[0][0])

    walls = uncentered_obs[2]
    doors_pos = (*(np.where(walls[:, 4] == 0)[0] - np.array([1, 5])), *(np.where(walls[4, :] == 0)[0] - np.array([1, 5])))

    return (*player_loc, player_dir, *goal_loc, *doors_pos)




In [118]:
import numpy as np

def state_to_obs(state):
    """
    Turn a state tuple back into a numpy observation array.
    """
    # Create an empty observation array
    obs = np.zeros((4, 9, 9))

    # Unpack the state tuple
    player_loc_x, player_loc_y, player_dir, goal_loc_x, goal_loc_y, door_pos_up, door_pos_down, door_pos_left, door_pos_right = state

    # Center the player location
    center_x, center_y = 4, 4  # Center of the 9x9 grid
    player_loc_x = center_x + (player_loc_x - center_x)
    player_loc_y = center_y + (player_loc_y - center_y)

    # Set the player location and direction
    obs[0, player_loc_y, player_loc_x] = 1
    if player_dir == 0:  # right
        obs[1, player_loc_y, player_loc_x+1] = 1
    elif player_dir == 1:  # down
        obs[1, player_loc_y+1, player_loc_x] = 1
    elif player_dir == 2:  # left
        obs[1, player_loc_y, player_loc_x-1] = 1
    elif player_dir == 3:  # up
        obs[1, player_loc_y-1, player_loc_x] = 1

    # Set the goal location
    obs[3, goal_loc_y, goal_loc_x] = 1

    # Set the walls and doors
    obs[2] = np.zeros((9, 9))
    obs[2, 0, :] = 1
    obs[2, 8, :] = 1
    obs[2, :, 0] = 1
    obs[2, :, 8] = 1   
    obs[2, 4, :] = 1
    obs[2, :, 4] = 1

    obs[2, door_pos_up+1, 4] = 0
    obs[2, door_pos_down+5, 4] = 0
    obs[2, 4, door_pos_left+1] = 0
    obs[2, 4, door_pos_right+5] = 0

    obs=np.roll(obs, (4-player_loc_x,4-player_loc_y), axis=(2,1))

    return obs

#(player location x, player location y, player direction, goal location x, goal location y,  door position up, door position down, door position left, door position right)

training_set=list(set(training_set))

for state in training_set[:3]:
    assert state==obs_to_state(state_to_obs(state))
    state_to_obs((5,5,*state[2:]))


In [None]:
(3, 7, 0, 6, 6, 2, 0, 0, 0)
 [[1 1 0 0 0 0 0 0 0]
  [1 1 0 1 1 1 0 1 1]
  [1 1 0 0 0 0 0 0 0]
  [1 1 0 0 0 1 0 0 0]
  [1 1 0 0 0 1 0 0 0]
  [1 1 1 1 1 1 1 1 1]
  [1 1 1 1 1 1 1 1 1]
  [1 1 0 0 0 1 0 0 0]
  [1 1 0 0 0 1 0 0 0]]

 [[0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 1 0]
  [0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0 0]]]

In [None]:
##saving for worse times

#TODO update to one hot encoded goal (either 2x8, or 64 array)
class MultiInput_CNN_Goal(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, features_dim: int = 512, device=th.device("cuda")):
        super(MultiInput_CNN_Goal, self).__init__(observation_space, features_dim)
        self.device=device
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        
        # input= observation(4,9,9)
        n_input_channels = observation_space['observation'].shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 64, kernel_size=3, stride=1, padding=1, padding_mode='circular'),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, padding_mode='circular'),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, padding_mode='circular'),
            nn.ReLU(),
            nn.Flatten(),
        )        
        # input= goal(2,) -> goal(16,) as both x and y one hot encoded
        self.lin_goal=nn.Sequential(nn.Linear(16, 64), nn.ReLU())

        with torch.no_grad():
            n_flatten = self.cnn(torch.ones(1,n_input_channels,*observation_space['observation'].shape[1:])).shape[1]
        self.linear = nn.Sequential(nn.Linear(n_flatten + 64, features_dim), nn.ReLU())

    def forward(self, observations: spaces.Dict) -> torch.Tensor:
        obs=observations['observation']
        if len(observations['observation'].shape)<4:
            obs=obs.unsqueeze(0)

        goal=observations['desired_goal']
        goal_stack=torch.cat([goal[:,0,:],goal[:,1,:]],axis=1)

        processed_obs=self.cnn(obs)
        processed_goal=self.lin_goal(goal_stack)

        return self.linear(torch.cat([processed_obs, processed_goal], axis=1))
