In [1]:


import sys
sys.path.append("./")
import metaworld
import torch

import os
import time
import os.path as osp

import numpy as np

from torchrl.utils import get_args
from torchrl.utils import get_params
from torchrl.env import get_env


from torchrl.utils import Logger

class parser:
    def __init__(self):
        self.config='meta_config/mt10/modular_2_2_2_256_reweight.json'
        self.id='MT10_Fixed_Modular_Shallow'
        self.worker_nums=10
        self.eval_worker_nums=10
        self.seed=20
        self.log_dir='./log/MT10'
        self.pf_snap=r'/root/metaworld-master/softmodule_log_5/MT10/MT10_Fixed_Modular_Shallow/mt10/20/model/model_pf_best.pth'
        self.pf1_snap=None
        self.pf2_snap=None
        self.qf1_snap=r'/root/metaworld-master/softmodule_log_5/MT10/MT10_Fixed_Modular_Shallow/mt10/20/model/model_qf1_best.pth'
        self.qf2_snap=r'/root/metaworld-master/softmodule_log_5/MT10/MT10_Fixed_Modular_Shallow/mt10/20/model/model_qf2_best.pth'
        
args=parser()
params = get_params(args.config)

import torchrl.policies as policies
import torchrl.networks as networks
from torchrl.algo import SAC
from torchrl.algo import TwinSAC
from torchrl.algo import TwinSACQ
from torchrl.algo import MTSAC
from torchrl.collector.para import ParallelCollector
from torchrl.collector.para import AsyncParallelCollector
from torchrl.collector.para.mt import SingleTaskParallelCollectorBase
from torchrl.collector.para.async_mt import AsyncSingleTaskParallelCollector
from torchrl.collector.para.async_mt import AsyncMultiTaskParallelCollectorUniform

from torchrl.replay_buffers.shared import SharedBaseReplayBuffer
from torchrl.replay_buffers.shared import AsyncSharedReplayBuffer
import gym

from metaworld_utils.meta_env import get_meta_env

import random


device = torch.device("cpu")

env, cls_dicts, cls_args = get_meta_env( params['env_name'], params['env'], params['meta_env'])

env.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

buffer_param = params['replay_buffer']

experiment_name = os.path.split( os.path.splitext( args.config )[0] )[-1] if args.id is None \
    else args.id
logger = Logger( experiment_name , params['env_name'], args.seed, params, args.log_dir )

params['general_setting']['env'] = env
params['general_setting']['logger'] = logger
params['general_setting']['device'] = device

params['net']['base_type']=networks.MLPBase

import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)

from torchrl.networks.init import normal_init

example_ob = env.reset()
example_embedding = env.active_task_one_hot

pf = policies.ModularGuassianGatedCascadeCondContPolicy(
    input_shape=env.observation_space.shape[0],
    em_input_shape=np.prod(example_embedding.shape),
    output_shape=2 * env.action_space.shape[0],
    **params['net'])

if args.pf_snap is not None:
    pf.load_state_dict(torch.load('/root/metaworld-master/softmodule_log_5/MT10/MT10_Fixed_Modular_Shallow/mt10/20/model/model_pf_best.pth', map_location='cpu'))

qf1 = networks.FlattenModularGatedCascadeCondNet(
    input_shape=env.observation_space.shape[0] + env.action_space.shape[0],
    em_input_shape=np.prod(example_embedding.shape),
    output_shape=1,
    **params['net'])
qf2 = networks.FlattenModularGatedCascadeCondNet( 
    input_shape=env.observation_space.shape[0] + env.action_space.shape[0],
    em_input_shape=np.prod(example_embedding.shape),
    output_shape=1,
    **params['net'])

if args.qf1_snap is not None:
    qf1.load_state_dict(torch.load(args.qf2_snap, map_location='cpu'))
if args.qf2_snap is not None:
    qf2.load_state_dict(torch.load(args.qf2_snap, map_location='cpu'))

example_dict = { 
    "obs": example_ob,
    "next_obs": example_ob,
    "acts": env.action_space.sample(),
    "rewards": [0],
    "terminals": [False],
    "task_idxs": [0],
    "embedding_inputs": example_embedding
}

replay_buffer = AsyncSharedReplayBuffer(int(buffer_param['size']),
        args.worker_nums
)
replay_buffer.build_by_example(example_dict)

params['general_setting']['replay_buffer'] = replay_buffer

epochs = params['general_setting']['pretrain_epochs'] + \
    params['general_setting']['num_epochs']

print(env.action_space)
print(env.observation_space)
params['general_setting']['collector'] = AsyncMultiTaskParallelCollectorUniform(
    env=env, pf=pf, replay_buffer=replay_buffer,
    env_cls = cls_dicts, env_args = [params["env"], cls_args, params["meta_env"]],
    device=device,
    reset_idx=True,
    epoch_frames=params['general_setting']['epoch_frames'],
    max_episode_frames=params['general_setting']['max_episode_frames'],
    eval_episodes = params['general_setting']['eval_episodes'],
    worker_nums=args.worker_nums, eval_worker_nums=args.eval_worker_nums,
    train_epochs = epochs, eval_epochs= params['general_setting']['num_epochs']
)
params['general_setting']['batch_size'] = int(params['general_setting']['batch_size'])
params['general_setting']['save_dir'] = osp.join(logger.work_dir,"model")
agent = MTSAC(
    pf = pf,
    qf1 = qf1,
    qf2 = qf2,
    task_nums=env.num_tasks,
    **params['sac'],
    **params['general_setting']
)
agent.train()

2021-01-24 23:42:58,096 MainThread INFO: Experiment Name:MT10_Fixed_Modular_Shallow
2021-01-24 23:42:58,099 MainThread INFO: {
  "env_name": "mt10",
  "env": {
    "reward_scale": 1,
    "obs_norm": false
  },
  "meta_env": {
    "obs_type": "with_goal"
  },
  "replay_buffer": {
    "size": 1000000.0
  },
  "net": {
    "hidden_shapes": [
      400,
      400
    ],
    "em_hidden_shapes": [
      400
    ],
    "num_layers": 2,
    "num_modules": 2,
    "module_hidden": 256,
    "num_gating_layers": 2,
    "gating_hidden": 256,
    "add_bn": false,
    "pre_softmax": false
  },
  "general_setting": {
    "discount": 0.99,
    "pretrain_epochs": 20,
    "num_epochs": 7500,
    "epoch_frames": 200,
    "max_episode_frames": 200,
    "batch_size": 1280,
    "min_pool": 10000,
    "target_hard_update_period": 1000,
    "use_soft_update": true,
    "tau": 0.005,
    "opt_times": 200,
    "eval_episodes": 3
  },
  "sac": {
    "plr": 0.0003,
    "qlr": 0.0003,
    "reparameterization": true



TypeError: relu() missing 1 required positional argument: 'input'