In [1]:
import gym
import os
import torch
import argparse
import pickle
import gtimer as gt
import numpy as np
from tqdm import tqdm
import sys
sys.path.append("./dssm")
from train import main as train_ssm
from diayn.examples.diayn import get_algorithm, get_algorithm_resume, experiment
from rlkit.envs.wrappers import NormalizedBoxEnv
from rlkit.samplers.util import DIAYNRollout as rollout


def collect(env, diayn, depth, args):
    policy = diayn.eval_data_collector.get_snapshot()['policy']

    data = []
    for skill in tqdm(range(policy.stochastic_policy.skill_dim)):
        for trial in range(10):
            # print("skill-{} rollout-{}".format(skill, trial))
            path = rollout(
                env,
                policy,
                skill,
                max_path_length=args.H,
                render=False,
            )
            data.append([path['actions'], path['next_observations']])

    train_data = data[:int(len(data)*0.9)]
    test_data = data[int(len(data)*0.9):]

    train_path = os.path.join(args.data_dir, "./train{}.pkl".format(depth))
    test_path = os.path.join(args.data_dir, "./test{}.pkl".format(depth))
    os.makedirs(args.data_dir, exist_ok=True)
    with open(train_path, mode='wb') as f:
        pickle.dump(train_data, f)
    with open(test_path, mode='wb') as f:
        pickle.dump(test_data, f)


def update_sim(env, depth, args):
    ssm = train_ssm("--H {} --depth {} --epochs 1".format(args.H, depth))
    sim = SimNormalizedBoxEnv(env, ssm, depth, args)
    return sim


def update_policy(diayn, sim, log_dir, args):
    experiment(diayn, sim, sim, args)
    file = os.path.join(log_dir, "params.pkl")
    diayn, log_dir = get_algorithm_resume(env, env, args.skill_dim, file)
    return diayn, log_dir


class SimNormalizedBoxEnv(NormalizedBoxEnv):
    def __init__(self, env, ssm, depth, args):
        super(SimNormalizedBoxEnv, self).__init__(env)
        self.ssm = ssm
        with open(os.path.join(args.data_dir, "param{}.pkl".format(depth)),
                  mode='rb') as f:
            self.a_mean, self.a_std, self.o_mean, self.o_std = \
                pickle.load(f)
        self.env.step = self.step
        self.envreset = self.env.reset
        self.env.reset = self.reset

    def step(self, action):
        lb = self._wrapped_env.action_space.low
        ub = self._wrapped_env.action_space.high
        scaled_action = lb + (action + 1.) * 0.5 * (ub - lb)
        scaled_action = np.clip(scaled_action, lb, ub)

        # wrapped_step = self._wrapped_env.step(scaled_action)
        # next_obs, reward, done, info = wrapped_step
        # if self._should_normalize:
        #     next_obs = self._apply_normalize_obs(next_obs)

        a = scaled_action.astype(np.float32)
        a = (a - self.a_mean) / self.a_std
        a = torch.from_numpy(np.array([a]))
        o = self.ssm.step(a)[0]
        o = o.cpu().detach().numpy()
        next_obs = o * self.o_std + self.o_mean

        # return next_obs, reward * self._reward_scale, done, info
        return next_obs, 0, False, {}

    def reset(self, **kwargs):
        o_original = self.envreset(**kwargs)
        o = o_original.astype(np.float32)
        o = (o - self.o_mean) / self.o_std
        o = torch.from_numpy(np.array([o]))
        o = self.ssm.reset(o)
        return o_original
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('env', type=str,
                        help='environment')
    parser.add_argument("--data_dir", type=str,
                        default="./data/")
    parser.add_argument('--skill_dim', type=int, default=100,
                        help='skill dimension')
    parser.add_argument('--H', type=int, default=300,
                        help='Max length of rollout')
    parser.add_argument('--D', type=int, default=2,
                        help='Depth (The number of update)')
    args = parser.parse_args("HalfCheetah-v2 --H 100".split())

    env = NormalizedBoxEnv(gym.make(str(args.env)))
    sim = None
    diayn, log_dir = get_algorithm(env, env, args.skill_dim)

    for depth in range(args.D):
        # collect(env, diayn, depth, args)  # Sim class test
        sim = update_sim(env, depth, args)
        diayn, log_dir = update_policy(diayn, sim, log_dir, args)
        gt.reset()

No personal conf_private.py found.
doodad not detected
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
2021-02-23 23:47:27.223781 UTC | Variant:
2021-02-23 23:47:27.224314 UTC | {
  "replay_buffer_size": 1000000,
  "algorithm_kwargs": {
    "num_eval_steps_per_epoch": 5000,
    "max_path_length": 1000,
    "min_num_steps_before_training": 1000,
    "num_epochs": 1,
    "num_trains_per_train_loop": 1000,
    "num_expl_steps_per_train_loop": 1000,
    "batch_size": 256
  },
  "trainer_kwargs": {
    "reward_scale": 1,
    "policy_lr": 0.0003,
    "discount": 0.99,
    "qf_lr": 0.0003,
    "soft_target_tau": 0.005,
    "use_automatic_entropy_tuning": true,
    "target_update_period": 1
  },
  "version": "normal",
  "l

[I 210223 23:47:27 train:47] args: Namespace(B=64, H=100, T=10, a_dim=6, data_dir='./data/', depth=0, device=[0], epochs=1, h_dim=128, iters_to_accumulate=1, load_epoch=None, o_dim=17, s_dim=64, seed=0, timestamp='Feb23_23-47-27')


(900, 100, 6) (900, 100, 17)
(100, 100, 6) (100, 100, 17)


[I 210223 23:47:29 trainer:53] (train) Epoch: 1 {'x_loss': 245.81510271344865, 'loss': 31405331.598214287, 's_aux_loss': 35.18444102151053, 's_loss': 31405086.42410714}
[I 210223 23:47:29 trainer:53] (test) Epoch: 1 {'x_loss': 209.17181396484375, 'loss': 236711.1875, 's_aux_loss': 30.627729415893555, 's_loss': 236502.015625}


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
2021-02-23 23:47:52.467997 UTC | [DIAYN_100_HalfCheetah-v2_2021_02_23_23_47_27_0000--s-0] Epoch 0 finished
------------------------------------------  --------------
replay_buffer/size                          2000
trainer/Intrinsic Rewards                     -0.0042671
trainer/DF Loss                                4.60944
trainer/DF Accuracy                            0
trainer/QF1 Loss                              15.6222
trainer/QF2 Loss                              15.6948
trainer/Policy Loss                           -4.02072
trainer/Q1 Predictions Mean                    0.00414251
trainer/Q1 Predictions Std                     0.000796448
trainer/Q1 Predictions Max                     0.00619714
trainer/Q1 Predictions Min                     0.00261055
trainer/Q2 Predictions Mean                   -0.00516464
trainer/Q2 Predictions Std                     0.000520687
tra

[I 210223 23:47:52 train:47] args: Namespace(B=64, H=100, T=10, a_dim=6, data_dir='./data/', depth=1, device=[0], epochs=1, h_dim=128, iters_to_accumulate=1, load_epoch=None, o_dim=17, s_dim=64, seed=0, timestamp='Feb23_23-47-52')


(900, 100, 6) (900, 100, 17)
(100, 100, 6) (100, 100, 17)


[I 210223 23:47:53 trainer:53] (train) Epoch: 1 {'x_loss': 238.6244386945452, 'loss': 27548396.651785713, 's_aux_loss': 41.14420182364328, 's_loss': 27548157.36495536}
[I 210223 23:47:53 trainer:53] (test) Epoch: 1 {'x_loss': 216.10598754882812, 'loss': 188829.890625, 's_aux_loss': 36.294288635253906, 's_loss': 188613.78125}


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
2021-02-23 23:48:16.932403 UTC | [DIAYN_100_HalfCheetah-v2_2021_02_23_23_47_27_0000--s-0] [DIAYN_100_HalfCheetah-v2_2021_02_23_23_47_52_0000--s-0] Epoch 0 finished
------------------------------------------  --------------
replay_buffer/size                          2000
trainer/Intrinsic Rewards                    -11.573
trainer/DF Loss                               16.1781
trainer/DF Accuracy                            0
trainer/QF1 Loss                              93.0057
trainer/QF2 Loss                              80.0507
trainer/Policy Loss                          -30.6685
trainer/Q1 Predictions Mean                   27.2991
trainer/Q1 Predictions Std                     2.99603
trainer/Q1 Predictions Max                    32.3188
trainer/Q1 Predictions Min                    21.258
trainer/Q2 Predictions Mean                   26.5708
trainer/Q2 Predictions Std      