In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from flipper_training.experiments.ppo.eval import get_eval_rollout, log_from_eval_rollout, PPOExperimentConfig
import contextlib
from tqdm import tqdm
import pickle
import os

In [None]:
with open("../cross_eval_configs/cross_eval_seeds.txt") as f:
    seeds = list(map(int, f.readlines()))

print(seeds)

In [None]:
run = "final_trunk_thesis_training_42_2025-05-09_21-01-13"

In [None]:
run_path = Path("../runs/ppo") / run
weights_step = "final"
train_config = OmegaConf.load(run_path / "config.yaml")
train_config.policy_weights_path = run_path / "weights" / f"policy_{weights_step}.pth"
train_config.vecnorm_weights_path = run_path / "weights" / f"vecnorm_{weights_step}.pth"

assert train_config.policy_weights_path.exists(), f"Policy weights not found at {train_config.policy_weights_path}"
assert train_config.vecnorm_weights_path.exists(), f"Vecnorm weights not found at {train_config.vecnorm_weights_path}"

In [None]:
train_config["num_robots"] = 16
train_config["objective_opts"]["cache_size"] = 10
train_config["max_eval_steps"] = 1000
train_config = PPOExperimentConfig(**train_config)

In [None]:
test_configs_base = Path("../cross_eval_configs")
eval_results_base = Path("../cross_eval_results") / run
eval_results_base.mkdir(parents=True, exist_ok=True)

In [None]:
devnull_handle = open(os.devnull, "w")

In [None]:
# Training environment
if not (eval_results_base / "training.pkl").exists():
    print("Evaluating training environment")
    results = []
    for seed in tqdm(seeds, desc="Training environment"):
        tqdm.write(f"Evaluating seed {seed}")
        train_config.seed = seed
        with contextlib.redirect_stdout(devnull_handle):
            with contextlib.redirect_stderr(devnull_handle):
                _, eval_rollout = get_eval_rollout(train_config)
        log = log_from_eval_rollout(eval_rollout)
        results.append(log)
    with open(eval_results_base / "training.pkl", "wb") as f:
        pickle.dump(results, f)

for test_config_path in test_configs_base.glob("*.yaml"):
    if (eval_results_base / test_config_path.stem).exists():
        print(f"Skipping {test_config_path.stem} as it already exists")
        continue
    print(f"Evaluating {test_config_path.stem}")
    test_config = OmegaConf.load(test_config_path)
    train_config.objective_opts = test_config["objective_opts"]
    train_config.heightmap_gen_opts = test_config["heightmap_gen_opts"]
    train_config.objective = test_config["objective"]
    train_config.heightmap_gen = test_config["heightmap_gen"]
    results = []
    for seed in tqdm(seeds, desc=f"Evaluating {test_config_path.stem}"):
        train_config.seed = seed
        with contextlib.redirect_stdout(devnull_handle):
            with contextlib.redirect_stderr(devnull_handle):
                _, eval_rollout = get_eval_rollout(train_config)
        log = log_from_eval_rollout(eval_rollout)
        results.append(log)

    with open(eval_results_base / f"{test_config_path.stem}.pkl", "wb") as f:
        pickle.dump(results, f)

In [None]:
results_dict = {}
for f in eval_results_base.glob("*.pkl"):
    with open(f, "rb") as r:
        results = pickle.load(r)
    results_dict[f.stem] = results

In [None]:
def list_of_dicts_to_dict_of_lists(list_of_dicts):
    """
    Convert a list of dictionaries to a dictionary of lists.
    """
    dict_of_lists = {}
    for d in list_of_dicts:
        for k, v in d.items():
            if k not in dict_of_lists:
                dict_of_lists[k] = []
            dict_of_lists[k].append(v)
    return dict_of_lists

In [None]:
results_transposed = {k: list_of_dicts_to_dict_of_lists(v) for k, v in results_dict.items()}
results_transposed.keys()

In [None]:
key2name = {
    "training": "Training Dist",
    "stairs_hard": "Stairs-Hard",
    "stairs_easy": "Stairs-Easy",
    "barrier_easy": "Barrier-Easy",
    "barrier_hard": "Barrier-Hard",
    "gauss_coarse_hard": "Gauss-Coarse-Hard",
    "gauss_fine_easy": "Gauss-Fine-Easy",
    "gauss_fine_hard": "Gauss-Fine-Hard",
    "gauss_coarse_easy": "Gauss-Coarse-Easy",
    "trunk_easy": "Trunk-Easy",
    "trunk_hard": "Trunk-Hard",
}

In [None]:
import matplotlib

# matplotlib.rcParams["text.usetex"] = True

# Separate keys for easy/training and hard/training
easy_keys = [k for k in results_transposed if "easy" in k or "training" in k]
hard_keys = [k for k in results_transposed if "hard" in k or "training" in k]

# Plot for easy + training
fig, ax = plt.subplots(1, 1, figsize=(16, 12), dpi=200)
ax.set_ylabel("Success Rate", fontsize=24, labelpad=12)
ax.set_ylim(0, 1)
green_color = "#a8e6a3"
line_color = "black"
for i, k in enumerate(easy_keys):
    v = results_transposed[k]
    pct_succeeded = v["eval/pct_succeeded"]
    ax.boxplot(
        pct_succeeded,
        positions=[i],
        widths=0.3,
        patch_artist=True,
        boxprops=dict(facecolor=green_color, color=line_color),
        medianprops=dict(color=line_color),
        meanprops=dict(markerfacecolor="#ff4081", markeredgecolor="black", markersize=20),
        capprops=dict(color=line_color),
        whiskerprops=dict(color=line_color),
        showmeans=True,
    )
ax.set_xticks(range(len(easy_keys)))
ax.set_xticklabels([key2name.get(k.split(".")[0], k.split(".")[0]) for k in easy_keys], rotation=45, fontsize=20)
ax.tick_params(axis="y", labelsize=20)
ax.set_title("Cross Evaluation Results", fontsize=22, pad=10)
plt.tight_layout()
plt.show()

# Plot for hard + training
fig, ax = plt.subplots(1, 1, figsize=(16, 12), dpi=200)
ax.set_ylabel("Success Rate", fontsize=24, labelpad=12)
ax.set_ylim(0, 1)
for i, k in enumerate(hard_keys):
    v = results_transposed[k]
    pct_succeeded = v["eval/pct_succeeded"]
    ax.boxplot(
        pct_succeeded,
        positions=[i],
        widths=0.3,
        patch_artist=True,
        boxprops=dict(facecolor=green_color, color=line_color),
        medianprops=dict(color=line_color),
        meanprops=dict(markerfacecolor="#ff4081", markeredgecolor="black", markersize=20),
        capprops=dict(color=line_color),
        whiskerprops=dict(color=line_color),
        showmeans=True,
    )
ax.set_xticks(range(len(hard_keys)))
ax.set_xticklabels([key2name.get(k.split(".")[0], k.split(".")[0]) for k in hard_keys], rotation=45, fontsize=20)
ax.tick_params(axis="y", labelsize=20)
ax.set_title("Cross Evaluation Results", fontsize=22, pad=10)
plt.tight_layout()
plt.show()

In [None]:
matplotlib.rcParams["text.usetex"] = True

# Separate keys for easy/training and hard/training
easy_keys = [k for k in results_transposed if "easy" in k or "training" in k]
hard_keys = [k for k in results_transposed if "hard" in k or "training" in k]

# Plot for easy + training
fig, ax = plt.subplots(1, 1, figsize=(16, 12), dpi=200)
ax.set_ylabel("Failure Rate", fontsize=24, labelpad=12)
ax.set_ylim(0, 1)
red_color = "#ff8a80"
line_color = "black"
for i, k in enumerate(easy_keys):
    v = results_transposed[k]
    pct_failed = v["eval/pct_failed"]
    ax.boxplot(
        pct_failed,
        positions=[i],
        widths=0.3,
        patch_artist=True,
        boxprops=dict(facecolor=red_color, color=line_color),
        medianprops=dict(color=line_color),
        meanprops=dict(markerfacecolor="#a8e6a3", markeredgecolor="black", markersize=20),
        capprops=dict(color=line_color),
        whiskerprops=dict(color=line_color),
        showmeans=True,
    )
ax.set_xticks(range(len(easy_keys)))
ax.set_xticklabels([key2name.get(k.split(".")[0], k.split(".")[0]) for k in easy_keys], rotation=45, fontsize=20)
ax.tick_params(axis="y", labelsize=20)
ax.set_title("Cross Evaluation Results", fontsize=22, pad=10)
plt.tight_layout()
plt.show()

# Plot for hard + training
fig, ax = plt.subplots(1, 1, figsize=(16, 12), dpi=200)
ax.set_ylabel("Failure Rate", fontsize=24, labelpad=12)
ax.set_ylim(0, 1)
for i, k in enumerate(hard_keys):
    v = results_transposed[k]
    pct_failed = v["eval/pct_failed"]
    ax.boxplot(
        pct_failed,
        positions=[i],
        widths=0.3,
        patch_artist=True,
        boxprops=dict(facecolor=red_color, color=line_color),
        medianprops=dict(color=line_color),
        meanprops=dict(markerfacecolor="#a8e6a3", markeredgecolor="black", markersize=20),
        capprops=dict(color=line_color),
        whiskerprops=dict(color=line_color),
        showmeans=True,
    )
ax.set_xticks(range(len(hard_keys)))
ax.set_xticklabels([key2name.get(k.split(".")[0], k.split(".")[0]) for k in hard_keys], rotation=45, fontsize=20)
ax.tick_params(axis="y", labelsize=20)
ax.set_title("Cross Evaluation Results", fontsize=22, pad=10)
plt.tight_layout()
plt.show()