In [None]:
import numpy as np
from pathlib import Path
import glob
import pickle
import matplotlib.pyplot as plt

import os
checkpoint_dirs = [d for d in glob.glob(str(path / "*")) if 'checkpoint' in d and os.path.isdir(d)]
checkpoint_dirs.sort(key=lambda s: int(s.split("_")[-1]))

import r3l
import gym
from r3l.r3l_envs.inhand_env.pickup import SawyerDhandInHandObjectPickupFixed
from softlearning.environments.adapters.gym_adapter import GymAdapter

env = GymAdapter("SawyerDhandInHandValve3", "PickupFixed-v0")
env.reset()

In [None]:
from softlearning.policies.utils import get_policy_from_variant
from softlearning.models.utils import flatten_input_structure

def load_policy_from_checkpoint(ckpt_path, env):
    # Load policy
    with open(os.path.join(ckpt_dir, "policy_params.pkl"), "rb") as f:
        policy_params = pickle.load(f)
        
    with open(os.path.join(ckpt_dir, "..", "params.pkl"), "rb") as f:
        variant = pickle.load(f)
        
    pickup_params = policy_params[0]

    policy = get_policy_from_variant(variant, env)
    policy.set_weights(pickup_params)
    return wrap_policy(policy)

def wrap_policy(policy):
    def wrapped_policy(obs_dict):
        feed_dict = {
            key: obs_dict[key][None, ...]
            for key in policy.observation_keys
        }
        observation = flatten_input_structure(feed_dict)
        with policy.set_deterministic(True):
            action = policy.actions_np(observation)[0]
        return action
    return wrapped_policy

In [None]:
N_EVAL_EPISODES = 1
T = 50

success_rates = []
obs_dicts_per_policy = []
for ckpt_dir in checkpoint_dirs[::2]:
    print("EVALUATING CHECKPOINT: ", ckpt_dir.split("_")[-1])
    policy = load_policy_from_checkpoint(ckpt_dir, env)
    
    successes = []
    obs_dicts = []
    for ep in range(N_EVAL_EPISODES):
        env.reset()
        for t in range(T):
            env.step(policy(env.get_obs_dict()))
        obs_dict = env.get_obs_dict()
        success = obs_dict["object_xyz"][2] > 0.85
        successes.append(success)
        obs_dicts.append(obs_dict)
    success_rate = np.array(successes).astype(int).mean()
    print("success % = ", success_rate)
    success_rates.append(success_rate)
