In [13]:
import dowel_wrapper
import os
import pickle
import torch
import numpy as np
from sklearn.neighbors import KernelDensity
import matplotlib.pyplot as plt
import seaborn as sns
from envs.mujoco.ant_env import AntEnv
import io
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import matplotlib.colors as mcolors
import matplotlib.cm as cm
from matplotlib.lines import Line2D
import scipy.stats as stats

In [2]:
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)

def load_models(chkpt):
    # 1. Load Metra weights
    METRA_EXP_ROOT = 'exp/ant_multi_goals_metra_chkpt_30k/sd000_s_21377600.0.1721411268_ant_nav_prime_sac'
    with open(os.path.join(METRA_EXP_ROOT, f'itr_{chkpt}.pkl'), 'rb') as f:
        metra_itr = CPU_Unpickler(f).load()

    metra_algo = metra_itr['algo']
    metra_algo.device = 'cpu'
    metra_algo.option_policy._module.to('cpu')

    metra_option_data = torch.load(os.path.join(METRA_EXP_ROOT, f'option_policy{chkpt}.pt'), map_location=torch.device('cpu'))
    metra_option_policy = metra_option_data['policy']
    metra_option_policy.to('cpu')
    metra_option_policy.eval()

    # 2. Load Metra SF TD weights
    METRA_SF_TD_EXP_ROOT = 'exp/ant_multi_goals_metra_sf_td_chkpt_30k/sd000_s_21377629.0.1721411543_ant_nav_prime_sac'
    with open(os.path.join(METRA_SF_TD_EXP_ROOT, f'itr_{chkpt}.pkl'), 'rb') as f:
        metra_sf_td_itr = CPU_Unpickler(f).load()

    metra_sf_td_algo = metra_sf_td_itr['algo']
    metra_sf_td_algo.device = 'cpu'
    metra_sf_td_algo.option_policy._module.to('cpu')

    metra_sf_td_option_data = torch.load(os.path.join(METRA_SF_TD_EXP_ROOT, f'option_policy{chkpt}.pt'), map_location=torch.device('cpu'))
    metra_sf_td_option_policy = metra_sf_td_option_data['policy']
    metra_sf_td_option_policy.to('cpu')
    metra_sf_td_option_policy.eval()
    
    return metra_algo, metra_option_policy, metra_sf_td_algo, metra_sf_td_option_policy

CHKPT = 60_000
metra_algo, metra_option_policy, metra_sf_td_algo, metra_sf_td_option_policy = load_models(CHKPT)

In [3]:
import functools
from garagei.experiment.option_local_runner import OptionLocalRunner
from garaged.src.garage.experiment.experiment import ExperimentContext
from garagei.sampler.option_multiprocessing_sampler import OptionMultiprocessingSampler
from iod.utils import get_normalizer_preset
from garagei.envs.consistent_normalized_env import consistent_normalize
from garagei.envs.child_policy_env import ChildPolicyEnv

def make_env(cp_path):
    from envs.mujoco.ant_nav_prime_env import AntNavPrimeEnv

    env = AntNavPrimeEnv(
        max_path_length=200,
        goal_range=7.5,
        num_goal_steps=50,
        reward_type='esparse',
    )
    cp_num_truncate_obs = 2
    
    normalizer_type = "preset"
    normalizer_kwargs = {}
    
    normalizer_name = 'ant'
    additional_dim = cp_num_truncate_obs
    
    normalizer_mean, normalizer_std = get_normalizer_preset(f'{normalizer_name}_preset')
    if additional_dim > 0:
        normalizer_mean = np.concatenate([normalizer_mean, np.zeros(additional_dim)])
        normalizer_std = np.concatenate([normalizer_std, np.ones(additional_dim)])
    env = consistent_normalize(env, normalize_obs=True, mean=normalizer_mean, std=normalizer_std, **normalizer_kwargs)

    if not os.path.exists(cp_path):
        import glob
        cp_path = glob.glob(cp_path)[0]
    cp_dict = torch.load(cp_path, map_location='cpu')

    env = ChildPolicyEnv(
        env,
        cp_dict,
        cp_action_range=1.5,
        cp_unit_length=1,
        cp_multi_step=25,
        cp_num_truncate_obs=cp_num_truncate_obs,
    )
    
    return env
    
METRA_CP_PATH = 'exp/ant_metra/sd000_s_56955647.0.1718292963_ant_metra/option_policy30000.pt'
METRA_SF_TD_CP_PATH = 'exp/ant_metra_sf_td/sd000_s_56969167.0.1718294685_ant_metra_sf/option_policy30000.pt'
    
metra_env = make_env(METRA_CP_PATH)
metra_contextualized_make_env = functools.partial(make_env, cp_path=METRA_CP_PATH)

metra_sf_td_env = make_env(METRA_SF_TD_CP_PATH)
metra_sf_td_contextualized_make_env = functools.partial(make_env, cp_path=METRA_SF_TD_CP_PATH)

# Setup runners
metra_runner = OptionLocalRunner(ExperimentContext(
    snapshot_dir='.',
    snapshot_mode='last',
    snapshot_gap=1,
))

metra_runner.setup(
    algo=metra_algo,
    env=metra_env,
    make_env=metra_contextualized_make_env,
    sampler_cls=OptionMultiprocessingSampler,
    sampler_args=dict(n_thread=1),
    n_workers=1,
)

metra_sf_td_runner = OptionLocalRunner(ExperimentContext(
    snapshot_dir='.',
    snapshot_mode='last',
    snapshot_gap=1,
))

metra_sf_td_runner.setup(
    algo=metra_sf_td_algo,
    env=metra_sf_td_env,
    make_env=metra_sf_td_contextualized_make_env,
    sampler_cls=OptionMultiprocessingSampler,
    sampler_args=dict(n_thread=1),
    n_workers=1,
)

In [4]:
def _get_trajectories(runner,
                        sampler_key,
                        batch_size=None,
                        extras=None,
                        update_stats=False,
                        worker_update=None,
                        env_update=None,
                         option_policy=None):
    if batch_size is None:
        batch_size = len(extras)
    policy_sampler_key = sampler_key[6:] if sampler_key.startswith('local_') else sampler_key
    time_get_trajectories = [0.0]

    trajectories, infos = runner.obtain_exact_trajectories(
        runner.step_itr,
        sampler_key=sampler_key,
        batch_size=batch_size,
        agent_update=_get_policy_param_values({'option_policy':option_policy}, policy_sampler_key),
        env_update=env_update,
        worker_update=worker_update,
        extras=extras,
        update_stats=update_stats,
    )
    print(f'_get_trajectories({sampler_key}) {time_get_trajectories[0]}s')

    for traj in trajectories:
        for key in ['ori_obs', 'next_ori_obs', 'coordinates', 'next_coordinates']:
            if key not in traj['env_infos']:
                continue

    return trajectories

def _get_policy_param_values(policy, key):
    param_dict = policy[key].get_param_values()
    for k in param_dict.keys():
        param_dict[k] = param_dict[k].detach().cpu()
    return param_dict

def _generate_option_extras(options):
    return [{'option': option} for option in options]

In [15]:
NUM_RANDOM_TRAJECTORIES = 1

# metra_random_trajectories = _get_trajectories(
#     metra_runner,
#     sampler_key='option_policy',
#     extras=[{} for _ in range(NUM_RANDOM_TRAJECTORIES)],
#     worker_update=dict(
#         _render=False,
#         _deterministic_initial_state=False,
#         _deterministic_policy=True, 
#     ),
#     env_update=dict(_action_noise_std=None),
#     option_policy=metra_option_policy
# )

def evaluate_goals(env, option_policy):
    obs = env.reset()
    step = 0
    done = False
    reward_total = 0
    attempted_goals = []
    all_goals = [obs[-2:]]
    xs = []
    ys = []
    success = {
        'goal_1': 0,
        'goal_2': 0,
        'goal_3': 0,
        'goal_4': 0
    }
    while not done:
        attempted_goals.append(obs[-2:])
        action, agent_info = option_policy.get_action(obs)
        next_obs, reward, done, info = env.step(action, debug=True)
        obs = next_obs
        

        step += 1
        reward_total += reward
        xs.append(info['coordinates'][:, 0])
        ys.append(info['coordinates'][:, 1])
        for _obs in info['original_next_observations']:
            if not np.allclose(all_goals[-1], _obs[-2:]):
                all_goals.append(_obs[-2:])
                
        for i in range(1, 5):
            success[f'goal_{i}'] = 1 if success[f'goal_{i}'] else int(info[f'goal_{i}'] > 0)
        
#     print(reward_total)
    # drop the very last goal, since we don't try to reach it
    all_goals.pop()

    return xs, ys, attempted_goals, reward_total, all_goals, success
    
def plot_multi_goal(env, option_policy):
    xs, ys, attempted_goals, _, all_goals, success = evaluate_goals(env, option_policy)
    print(success)

    # 1. plot the goals
    fig, ax = plt.subplots()
    cmap = cm.get_cmap('tab10', 10)

    # plot goals
    for i, unique_goal in enumerate(all_goals, 1):
        # add goal to plot
        ax.scatter(unique_goal[0], unique_goal[1], color=cmap(i), marker='*', s=700)
        ax.text(unique_goal[0], unique_goal[1], str(i), fontsize=10, ha='center', color='white', verticalalignment='center')
        circle = plt.Circle((unique_goal[0], unique_goal[1]), 3, color=cmap(i), fill=False)
        ax.add_patch(circle)

    for _xs, _ys, goal in zip(xs, ys, attempted_goals):
        goal_num  = None
        for i, unique_goal in enumerate(all_goals, 1):
            if np.allclose(goal, unique_goal):
                goal_num = i
                break

        # add path to plot
        ax.scatter(_xs, _ys, label='Metra SF TD', color=cmap(goal_num), linestyle='dotted', s=0.5)
    
    # ax.set_xlim(-20, 20)
    # ax.set_ylim(-20, 20)

    # ax.plot(metra_xs, metra_ys, label='Metra', color=cmap(0), linestyle='dashed')
    # ax.plot(metra_sf_td_xs, metra_sf_td_ys, label='Metra SF TD', color=cmap(1), linestyle='dotted')

    # legend_elements = [
    #     Line2D([0], [0], color='black', lw=2, linestyle='dashed', label='Metra'),
    #     Line2D([0], [0], color='black', lw=2, linestyle='dotted', label='Metra SF TD')
    # ]
    # ax.legend(handles=legend_elements, loc='upper left')
    ax.set_aspect('equal', 'box')
    plt.show()
    
# plot_multi_goal(metra_env, metra_option_policy)

# METRA metrics
NUM_ROLLOUTS = 100
all_success = defaultdict(list)
for i in range(NUM_ROLLOUTS):
    xs, ys, attempted_goals, _, all_goals, success = evaluate_goals(metra_env, metra_option_policy)
    for k in success:
        all_success[k].append(success[k])

print('METRA SUCCESS RATES')
for k in all_success:
    print(k, f'{np.mean(all_success[k])} {stats.t.interval(0.95, NUM_ROLLOUTS-1, np.mean(all_success[k]), stats.sem(all_success[k]))})')
print()

all_success = defaultdict(list)
for i in range(NUM_ROLLOUTS):
    xs, ys, attempted_goals, _, all_goals, success = evaluate_goals(metra_sf_td_env, metra_sf_td_option_policy)
    for k in success:
        all_success[k].append(success[k])

print('METRA SF TD SUCCESS RATES')
for k in all_success:
    print(k, f'{np.mean(all_success[k])} ({stats.t.interval(0.95, NUM_ROLLOUTS-1, np.mean(all_success[k]), stats.sem(all_success[k]))})')
print()
    



METRA SUCCESS RATES
goal_1 0.93 (0.8791182485051084, 0.9808817514948917))
goal_2 0.37 (0.27371853383960454, 0.46628146616039545))
goal_3 0.37 (0.2737185338396045, 0.4662814661603955))
goal_4 0.35 (0.25488209883009155, 0.4451179011699084))

METRA SUCCESS RATES
goal_1 0.93 (0.8791182485051084, 0.9808817514948917))
goal_2 0.37 (0.27371853383960454, 0.46628146616039545))
goal_3 0.37 (0.2737185338396045, 0.4662814661603955))
goal_4 0.35 (0.25488209883009155, 0.4451179011699084))

METRA SF TD SUCCESS RATES
goal_1 0.95 ((0.9065371337811178, 0.9934628662188821))
goal_2 0.38 ((0.2832036009440236, 0.4767963990559764))
goal_3 0.38 ((0.28320360094402364, 0.47679639905597637))
goal_4 0.44 ((0.3410098664856729, 0.5389901335143271))

METRA SF TD SUCCESS RATES
goal_1 0.95 ((0.9065371337811178, 0.9934628662188821))
goal_2 0.38 ((0.2832036009440236, 0.4767963990559764))
goal_3 0.38 ((0.28320360094402364, 0.47679639905597637))
goal_4 0.44 ((0.3410098664856729, 0.5389901335143271))

