In [1]:
import jax.random as jrandom
import numpy as np
import os

from tqdm.notebook import tqdm

from jaxl.constants import *
from jaxl.models.utils import (
    get_model,
    load_config,
    load_params,
    get_wsrl_model,
    iterate_params,
    get_policy,
    policy_output_dim,
    get_residual_policy,
)
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import get_device, parse_dict

# get_device("gpu:0")
get_device("cpu")

In [2]:
result_dir = "/home/bryan/research/jaxl/logs/manipulator_learning"

ablation_name = "stack"
learner_name = "cross_q-sac-06-03-24_16_53_41-9dd96b95-aefd-44fd-8894-4854c9c08abf"
learner_name = "bc-06-04-24_09_43_02-a09b01c4-e33d-4e88-9eda-d7d36a68cdb8"
# learner_name = "bc-100k_steps-06-04-24_09_57_35-e8bd5a54-9148-41a9-ace8-f33c5cfbab9f"
# learner_name = "bc-10k_steps-06-04-24_10_06_35-333b32a8-c019-4fed-9b8f-1ce59166bb2b"
# learner_name = "warm_start_reinforce-06-04-24_13_28_32-b356a022-d53a-4b11-9726-ccfe4dca0777"
# learner_name = "rlpd-sac-06-05-24_16_15_56-c5ad96da-4ac4-466b-a221-74cfea71bd19"
# learner_name = (
#     "residual-rlpd-sac-06-06-24_17_55_49-1e0d722f-7e5d-4310-95de-2480ad35ab72"
# )
# learner_name = "residual-rlpd-sac-fixed_temp-06-07-24_09_25_20-c64a33d6-de4f-4cae-96e9-e4a09ef9b50c"
learner_name = "residual-rlpd-cross_q-deterministic_exploration-wide_critic-06-10-24_11_12_20-c40f7bbb-80de-47a6-b02f-a3568dd1a877"

learner_path = os.path.join(result_dir, ablation_name, learner_name)

checkpoint = "latest"

In [3]:
_, config = load_config(learner_path)
config

namespace(logging_config=namespace(save_path='./logs/manipulator_learning/stack',
                                   experiment_name='residual-rlpd-cross_q-deterministic_exploration-wide_critic',
                                   log_interval=1,
                                   checkpoint_interval=10),
          model_config=namespace(backbone=namespace(architecture='mlp',
                                                    layers=[256, 256, 256],
                                                    activation='tanh',
                                                    flatten=True,
                                                    policy_distribution='deterministic',
                                                    pretrained_model='/home/bryan/research/jaxl/logs/manipulator_learning/stack/bc-10k_steps-06-04-24_10_06_35-333b32a8-c019-4fed-9b8f-1ce59166bb2b:latest',
                                                    include_absorbing_state=True),
                               

In [4]:
env_config = {
    "env_type": "manipulator_learning",
    "env_name": "PandaPlayInsertTrayXYZState",
    "env_kwargs": {"main_task": "stack", "dense_reward": False},
}
env = get_environment(parse_dict(env_config))

pybullet build time: Nov 28 2023 23:45:17


Loaded EGL 1.5 after reload.
GL_VENDOR=NVIDIA Corporation
GL_RENDERER=Quadro RTX 5000 with Max-Q Design/PCIe/SSE2
GL_VERSION=3.3.0 NVIDIA 535.171.04
GL_SHADING_LANGUAGE_VERSION=3.30 NVIDIA via Cg compiler
Version = 3.3.0 NVIDIA 535.171.04
Vendor = NVIDIA Corporation
Renderer = Quadro RTX 5000 with Max-Q Design/PCIe/SSE2


EGL device choice: -1 of 4.


In [5]:
include_absorbing_state = False
if config.learner_config.task == CONST_RESIDUAL:
    backbone_act_dim = policy_output_dim(env.act_dim, config.model_config.backbone)
    residual_act_dim = policy_output_dim(env.act_dim, config.model_config.residual)

    backbone_model = get_model(
        env.observation_space.shape, env.act_dim, config.model_config.backbone
    )
    residual_model = get_model(
        env.observation_space.shape, residual_act_dim, config.model_config.residual
    )
    policy = get_residual_policy(
        backbone_model,
        residual_model,
        config.model_config,
    )
else:
    model_out_dim = policy_output_dim(env.act_dim, config.learner_config)
    if config.learner_config.learner == CONST_BC:
        model = get_model(
            int(np.prod(env.observation_space.shape)) + 1,
            env.act_dim,
            config.model_config,
        )
        include_absorbing_state = True
    elif config.learner_config.task == CONST_WSRL:
        model = get_wsrl_model(
            env.observation_space.shape, model_out_dim, config.model_config.policy
        )
        include_absorbing_state = True
    else:
        model = get_model(
            env.observation_space.shape, model_out_dim, config.model_config.policy
        )
    policy = get_policy(model, config.learner_config)

params = load_params(f"{learner_path}:{checkpoint}")

CUDA backend failed to initialize: jaxlib/cuda/versions_helpers.cc:99: operation cuInit(0) failed: CUDA_ERROR_NO_DEVICE.(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [6]:
total_episodes = 50
eval_seed = 42
render = False
random = False

In [7]:
rollout = EvaluationRollout(env, eval_seed)

In [8]:
rollout.rollout(
    params[CONST_MODEL_DICT][CONST_MODEL][CONST_POLICY],
    policy,
    False,
    total_episodes,
    random=random,
    render=render,
    include_absorbing_state=include_absorbing_state,
)

100%|██████████| 50/50 [01:54<00:00,  2.28s/it]


In [9]:
np.mean(rollout.episodic_returns), np.std(rollout.episodic_returns), np.sum(
    np.array(rollout.episodic_returns) > 0
)

(281.26, 97.1332713337711, 45)

10k: (203.98, 147.88853775732588)  
100k: (253.7, 127.34115595517422)  
1M: (215.96, 148.52123888521803)

```
Residual RLPD @ 300 deterministic: (208.96, 150.14752212407635)
Residual RLPD @ 400 deterministic: (210.16, 137.68796025796883)
Residual RLPD @ 400 stochastic: (213.56, 135.53481619126504)
```

```
Residual RLPD with fixed temp @ initialization random: (130.32, 139.2440217747247)
Residual RLPD with fixed temp @ 10 random: (125.94, 150.11827470364827)
Residual RLPD with fixed temp @ 20 random: (224.86, 137.30244134755944)
Residual RLPD with fixed temp @ 150 random: (252.2, 125.54106897744657)
```

In [10]:
assert 0

AssertionError: 

```
Residual RLPD CrossQ @ 160 deterministic: (280.46, 94.8232482042247, 46)
Residual RLPD CrossQ @ latest deterministic: (281.26, 97.1332713337711, 45)
```