In [5]:
import sys
sys.path.append('../')
from deep_rl.gridworld import ReachGridWorld, PickGridWorld, PORGBEnv, GoalManager, ScaleObsEnv
from deep_rl.network import *
from deep_rl.utils import *
import os
import random
import dill
import json
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple
from IPython.display import display
from PIL import Image
from pathlib import Path
from IPython.core.debugger import Tracer

def set_seed(s):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)

set_seed(0) # set seed 

def imshow(img):
    display(Image.fromarray(np.asarray(img).astype(np.uint8)))

def fload(fn, ftype):
    if ftype == 'json':
        with open(fn) as f:
            return json.load(f)
    elif ftype == 'pkl':
        with open(fn, 'rb') as f:
            return dill.load(f)
    elif ftype == 'png':
        raise NotImplementedError
    else:
        raise Exception('cannot read this data type: {}'.format(ftype))
    
def fsave(data, fn, ftype):
    dirname = os.path.dirname(fn)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    if ftype == 'json':
        with open(fn, 'w') as f:
            json.dump(data, f)
    elif ftype == 'pkl':
        with open(fn, 'wb') as f:
            dill.dump(data, f)    
    elif ftype == 'png':
        Image.fromarray(data).save(fn)
    else:
        raise Exception('unsupported file type: {}'.format(ftype))
        
GoalConfig = namedtuple('GoalConfig', ['map_name', 'n_goal', 'min_dis'])

def visualize_env_config(env_config):
    env = PORGBEnv(
        PickGridWorld(
            **env_config,
            min_dis=1,
            window=1,
            task_length=1,
            seed=0,
        ),
        l=16,
    )
    img = env.reset(sample_obj_pos=False)
    imshow(img.transpose(1, 2, 0).repeat(16, 0).repeat(16, 1))
    print(env.unwrapped.agent_pos)
    
def get_pick_config(goal_config, train_combos=None, seed=0):
    MAX_OBJ_NUM = 15
    goal_manager = GoalManager(goal_config.map_name, seed=seed)
    obj_pos = goal_manager.gen_goals(MAX_OBJ_NUM + 1, min_dis=goal_config.min_dis)
    obj_pos = [obj_pos[-1:] + obj_pos[:goal_config.n_goal-1]] # always the same test
    if train_combos is None:
        train_combos = [(0, i) for i in range(1, goal_config.n_goal)]
    env_config = dict(
        map_names = [goal_config.map_name],
        train_combos = train_combos,
        test_combos = [(0, 0)],
        num_obj_types=goal_config.n_goal,
        obj_pos=obj_pos,
    )
    return env_config 

# FourRoom

In [6]:
set_seed(0)

env_config = dict(
    map_names=['fourroom'],
    train_combos=[(0, 0)],
    test_combos=[(0, 0)],
    num_obj_types=4,
    obj_pos=[[(1, 1), (9, 1), (1, 9), (9, 9)]],
)

def save_individual():
    for i in range(12):
        env_config = dict(
            map_names=['fourroom'],
            train_combos=[(0, i)],
            test_combos=[(0, 0)],
            num_obj_types=4,
            obj_pos=[[(1, 1), (9, 1), (1, 9), (9, 9)]],
        )
        fsave(
            env_config,
            '../data/env_configs/pick2/fourroom/p2.fourroom.{}'.format(i),
            ftype='pkl',
        )
#visualize_env_config(env_config)

In [7]:
save_individual()