In [None]:
# The evaluation script that runs a rollout for each object in the eval-ed dataset and calculates:
# - success : 90% open
# - distance to open
import json
import os

import hydra
import lightning as L
import numpy as np
import omegaconf
import pandas as pd
import rpad.pyg.nets.pointnet2 as pnp
import torch
import tqdm
import wandb
from rpad.visualize_3d import html

from open_anything_diffusion.datasets.flowbot import FlowBotDataModule
from open_anything_diffusion.simulations.simulation import trial_with_prediction
from open_anything_diffusion.utils.script_utils import PROJECT_ROOT, match_fn


def load_obj_id_to_category():
    # Extract existing classes.
    with open("../scripts/umpnet_data_split.json", "r") as f:
        data = json.load(f)

    id_to_cat = {}
    for _, category_dict in data.items():
        for category, split_dict in category_dict.items():
            for _, id_list in split_dict.items():
                for id in id_list:
                    id_to_cat[id] = category
    return id_to_cat


def load_obj_and_link():
    with open("../scripts/umpnet_object_list.json", "r") as f:
        object_link_json = json.load(f)
    return object_link_json


id_to_cat = load_obj_id_to_category()
object_to_link = load_obj_and_link()

object_ids = [   # Door
    "8877",
    "8893",
    "8897",
    "8903",
    "8919",
    "8930",
    "8961",
    "8997",
    "9016",
    "9032",
    "9035",
    "9041",
    "9065",
    "9070",
    "9107",
    "9117",
    "9127",
    "9128",
    "9148",
    "9164",
    "9168",
    "9277",
    "9280",
    "9281",
    "9288",
    "9386",
    "9388",
    "9410",
    "8867",
    "8983",
    "8994",
    "9003",
    "9263",
    "9393",
]


suc_results = []
suc_figs = []
fail_results = []
fail_figs = []

from hydra import compose, initialize
from omegaconf import OmegaConf
initialize(config_path="../configs", version_base="1.3")
cfg = compose(config_name="eval_sim")


######################################################################
# Torch settings.
######################################################################

# Make deterministic + reproducible.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Since most of us are training on 3090s+, we can use mixed precision.
torch.set_float32_matmul_precision("medium")

# Global seed for reproducibility.
L.seed_everything(42)


mask_channel = 1 if cfg.inference.mask_input_channel else 0
network = pnp.PN2Dense(
    in_channels=mask_channel,
    out_channels=3 * cfg.inference.trajectory_len,
    p=pnp.PN2DenseParams(),
)

ckpt_file = "/home/yishu/open_anything_diffusion/logs/train_trajectory/2023-12-05/21-52-28/checkpoints/epoch=199-step=157200.ckpt"

# Load the network weights.
ckpt = torch.load(ckpt_file)
network.load_state_dict(
    {k.partition(".")[2]: v for k, v, in ckpt["state_dict"].items()}
)

# Simulation and results.
print("Simulating")

obj_cats = list(set(id_to_cat.values()))
metric_df = pd.DataFrame(
    np.zeros((len(set(id_to_cat.values())), 3)),
    index=obj_cats,
    columns=["count", "success_rate", "norm_dist"],
)
category_counts = {}
# for obj_id, obj_cat in tqdm.tqdm(list(id_to_cat.items())):
for obj_id, available_links in tqdm.tqdm(list(object_to_link.items())):
    if obj_id not in object_ids:  # For Door dataset
        continue

    obj_cat = id_to_cat[obj_id]
    if not os.path.exists(f"/home/yishu/datasets/partnet-mobility/raw/{obj_id}"):
        continue
    print(f"OBJ {obj_id} of {obj_cat}")
    trial_figs, trial_results = trial_with_prediction(
        obj_id=obj_id,
        network=network,
        n_step=30,
        gui=cfg.gui,
        all_joint=True,
        # available_joints=available_links,   # Don't use this for doors
        website=cfg.website,
    )

    if len(trial_results) == 0:  # If nothing succeeds
        continue

    # Wandb table
    if obj_cat not in category_counts.keys():
        category_counts[obj_cat] = 0
    # category_counts[obj_cat] += len(trial_results)
    for (fig, result) in zip(trial_figs, trial_results):
        if result.contact == False:
            continue
        category_counts[obj_cat] += 1
        metric_df.loc[obj_cat]["success_rate"] += result.success
        # print(fig, trial_figs)
        if result.success:
            suc_results.append(result.metric)
            # print(fig, trial_figs)
            suc_figs.append(trial_figs[fig])

        if result.metric < 0.1:
            fail_results.append(result.metric)
            # print(fig, trial_figs)
            fail_figs.append(trial_figs[fig])

        metric_df.loc[obj_cat]["norm_dist"] += result.metric

    wandb_df = metric_df.copy(deep=True)
    for obj_cat in category_counts.keys():
        wandb_df.loc[obj_cat]["success_rate"] /= category_counts[obj_cat]
        wandb_df.loc[obj_cat]["norm_dist"] /= category_counts[obj_cat]
        wandb_df.loc[obj_cat]["count"] = category_counts[obj_cat]

    table = wandb.Table(dataframe=wandb_df.reset_index())

# for obj_cat in category_counts.keys():
#     metric_df.loc[obj_cat]["success_rate"] /= category_counts[obj_cat]
#     metric_df.loc[obj_cat]["norm_dist"] /= category_counts[obj_cat]
#     metric_df.loc[obj_cat]["category"] = obj_cat

# table = wandb.Table(dataframe=metric_df)
# run.log({f"simulation_metric_table": table})


In [None]:
fail_results

In [None]:
wandb_df

In [None]:
fail_figs[3].show()