In [1]:
import sys
import multiprocessing
import os.path as osp
import gym
from collections import defaultdict
import tensorflow as tf
import numpy as np

from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env
from baselines.common.tf_util import get_session, load_variables
from baselines import bench, logger
from importlib import import_module

from baselines.common.vec_env.vec_normalize import VecNormalize
from baselines.common import atari_wrappers, retro_wrappers

import joblib


try:
    from mpi4py import MPI
except ImportError:
    MPI = None

try:
    import pybullet_envs
except ImportError:
    pybullet_envs = None

try:
    import roboschool
except ImportError:
    roboschool = None

_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
    #print(env)
    # TODO: solve this with regexes
    env_type = env._entry_point.split(':')[0].split('.')[-1]
    _game_envs[env_type].add(env.id)

"""
examples:

python -m baselines.run --alg=ppo2 --env=Robosimian-v2 --num_timesteps=0 --load_path=./models/robosimian_full_video --play

python -m baselines.run --alg=ppo2 --env=RobosimianIKTable-v2 --num_timesteps=0 --load_path=./models/robosimian_ik_2 --play

python -m baselines.run --alg=ppo2 --env=RobosimianIKTable-v2 --num_timesteps=0 --load_path=./models/FS_CS_50_rough_5mil --play

python -m baselines.run --alg=ppo2 --env=LeglessRobosimian-v2 --num_timesteps=0 --load_path=./models/robosimian_simple_4 --play

python -m baselines.run --alg=ppo2 --env=LeglessRobosimian3Wheels-v2 --num_timesteps=0 --load_path=./models/3_wheels --play

python -m baselines.run --alg=ppo2 --env=LeglessRobosimianXYZ-v2 --num_timesteps=1000000 --save_path=./models/simple_xyz_rough --play

python -m baselines.run --alg=ppo2 --env=RSBalance-v2 --num_timesteps=100000 --save_path=./models/balance_v0 --play

python -m baselines.run --alg=ppo2 --env=RSBalanceWithTraj-v2 --num_timesteps=100000 --save_path=./models/balance_with_traj_v0 --play

python -m baselines.run --alg=ppo2 --env=RSBalanceWithTraj-v2 --num_timesteps=0 --load_path=./models/balance_with_traj_v1 --play

"""

#sys.exit()
# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = {
    'BubbleBobble-Nes',
    'SuperMarioBros-Nes',
    'TwinBee3PokoPokoDaimaou-Nes',
    'SpaceHarrier-Nes',
    'SonicTheHedgehog-Genesis',
    'Vectorman-Genesis',
    'FinalFight-Snes',
    'SpaceInvaders-Snes',
}


def train(args, extra_args):
    env_type, env_id = get_env_type(args.env)
    print('env_type: {}'.format(env_type))

    total_timesteps = int(args.num_timesteps)
    seed = args.seed

    learn = get_learn_function(args.alg)
    alg_kwargs = get_learn_function_defaults(args.alg, env_type)
    alg_kwargs.update(extra_args)
    #print("model defaults",alg_kwargs)

    env = build_env(args, alg_kwargs)

    if args.network:
        alg_kwargs['network'] = args.network
    else:
        if alg_kwargs.get('network') is None:
            alg_kwargs['network'] = get_default_network(env_type)

    print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))

    model = learn(
        env=env,
        seed=seed,
        total_timesteps=total_timesteps,
        **alg_kwargs
    )

    return model, env


def build_env(args, alg_kwargs):
    ncpu = multiprocessing.cpu_count()
    if sys.platform == 'darwin': ncpu //= 2
    nenv = args.num_env or ncpu
    alg = args.alg
    rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
    seed = args.seed

    env_type, env_id = get_env_type(args.env)

    if env_type == 'atari':
        if alg == 'acer':
            env = make_vec_env(env_id, env_type, nenv, seed)
        elif alg == 'deepq':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(env, logger.get_dir())
            env = atari_wrappers.wrap_deepmind(env, frame_stack=True, scale=True)
        elif alg == 'trpo_mpi':
            env = atari_wrappers.make_atari(env_id)
            env.seed(seed)
            env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
            env = atari_wrappers.wrap_deepmind(env)
            # TODO check if the second seeding is necessary, and eventually remove
            env.seed(seed)
        else:
            frame_stack_size = 4
            env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size)

    elif env_type == 'retro':
        import retro
        gamestate = args.gamestate or 'Level1-1'
        env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000,
                                        use_restricted_actions=retro.Actions.DISCRETE)
        env.seed(args.seed)
        env = bench.Monitor(env, logger.get_dir())
        env = retro_wrappers.wrap_deepmind_retro(env)

    else: 
        get_session(tf.ConfigProto(allow_soft_placement=True,
                                   intra_op_parallelism_threads=1,
                                   inter_op_parallelism_threads=1))

        env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)

        if env_type == 'mujoco':
            env = VecNormalize(env)
            if alg_kwargs.get('load_path'):
                sess = get_session()
                loaded_params = joblib.load(osp.expanduser(alg_kwargs['load_path']))
                restores = []
                for v in tf.trainable_variables():
                    restores.append(v.assign(loaded_params[v.name]))
                sess.run(restores)

                """ THIS IS NECESSARY TO UPDATE the .mean, .count vars.. otherwise it is still confused!!!!!!! THANKS A LOT OPENAI"""
                env.ob_rms._set_mean_var_count()
                env.ret_rms._set_mean_var_count()
                #print(dir(env.ret_rms))

    return env


def get_env_type(env_id):
    if env_id in _game_envs.keys():
        env_type = env_id
        env_id = [g for g in _game_envs[env_type]][0]
    else:
        env_type = None
        for g, e in _game_envs.items():
            if env_id in e:
                env_type = g
                break
        assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())

    return env_type, env_id


def get_default_network(env_type):
    if env_type == 'atari':
        return 'cnn'
    else:
        return 'mlp'

    raise ValueError('Unknown env_type {}'.format(env_type))


def get_alg_module(alg, submodule=None):
    submodule = submodule or alg
    try:
        # first try to import the alg module from baselines
        alg_module = import_module('.'.join(['baselines', alg, submodule]))
    except ImportError:
        # then from rl_algs
        alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))

    return alg_module


def get_learn_function(alg):
    return get_alg_module(alg).learn


def get_learn_function_defaults(alg, env_type):
    try:
        alg_defaults = get_alg_module(alg, 'defaults')
        kwargs = getattr(alg_defaults, env_type)()
    except (ImportError, AttributeError):
        kwargs = {}
    return kwargs


def parse(v):
    '''
    convert value of a command-line arg to a python object if possible, othewise, keep as string
    '''

    assert isinstance(v, str)
    try:
        return eval(v)
    except (NameError, SyntaxError):
        return v


def main():
    # configure logger, disable logging in child MPI processes (with rank > 0)
    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args()
    extra_args = {k: parse(v) for k, v in parse_unknown_args(unknown_args).items()}

    # print(MPI, MPI.COMM_WORLD.Get_rank())
    # sys.exit()

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        logger.configure()
        #logger.configure(dir="./logging_dir/")
    else:
        logger.configure(format_strs=[])
        rank = MPI.COMM_WORLD.Get_rank()

    #with tf.Session(graph=tf.Graph()):
    model, env = train(args, extra_args)

    print(env.ob_rms.mean)

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        obs = env.reset()
        #print("\n IK table... ")

        doneCounter = 0
        num_runs = 100

        while doneCounter < num_runs:
            actions, v, state, neglogp = model.step(obs)
            obs, rewards, done, infos = env.step(actions)
            done = done.any() if isinstance(done, np.ndarray) else done

            if done:
                print(doneCounter, "done", infos)
                infos = infos[0]

                obs = env.reset()
                doneCounter +=1 


  from ._conv import register_converters as _register_converters


Logging to /var/folders/qq/gpxz4l6s1tndfdhysbz8bdym0000gn/T/openai-2018-11-30-18-04-44-622709
