In [None]:
import numpy as np
import torch

from open_anything_diffusion.simulations.simulation import *

np.random.seed(42)
torch.manual_seed(42)
torch.set_printoptions(precision=10)  # Set higher precision for PyTorch outputs
np.set_printoptions(precision=10)
from hydra import compose, initialize

initialize(config_path="../../configs", version_base="1.3")
cfg = compose(config_name="eval_sim_switch")
# cfg = compose(config_name="eval_sim")

from open_anything_diffusion.models.flow_diffuser_dit import (
    FlowTrajectoryDiffuserSimulationModule_DiT,
)
from open_anything_diffusion.models.flow_diffuser_pndit import (
    FlowTrajectoryDiffuserSimulationModule_PNDiT,
)
from open_anything_diffusion.models.modules.history_encoder import HistoryEncoder
from open_anything_diffusion.models.modules.dit_models import DiT, PN2DiT, PN2HisDiT

In [None]:
import rpad.pyg.nets.pointnet2 as pnp
network = pnp.PN2Dense(
    in_channels=0,
    out_channels=3,
    p=pnp.PN2DenseParams(),
)

ckpt_file = "/home/yishu/open_anything_diffusion/pretrained/fullset_half_half_flowbotRO.ckpt"

# Load the network weights.
ckpt = torch.load(ckpt_file)
network.load_state_dict(
    {k.partition(".")[2]: v for k, v, in ckpt["state_dict"].items()}
)
network.eval()

### Try to find occluded examples

In [None]:
pm_dir = os.path.expanduser("~/datasets/partnet-mobility/convex")

In [None]:
import json

def load_obj_id_to_category(toy_dataset=None):
    id_to_cat = {}
    if toy_dataset is None:
        # Extract existing classes.
        with open(f"../../scripts/umpnet_data_split_new.json", "r") as f:
            data = json.load(f)

        for _, category_dict in data.items():
            for category, split_dict in category_dict.items():
                for train_or_test, id_list in split_dict.items():
                    for id in id_list:
                        id_to_cat[id] = f"{category}_{train_or_test}"

    else:
        with open(f"../../scripts/umpnet_object_list.json", "r") as f:
            data = json.load(f)
        for split in ["train-train", "train-test"]:
            # for split in ["train-test"]:
            for id in toy_dataset[split]:
                id_to_cat[id] = split
    return id_to_cat

id_to_cat = load_obj_id_to_category(None)
with open('/home/yishu/open_anything_diffusion/scripts/movable_links_fullset_000.json', 'r') as f:
    obj_dict = json.load(f)

In [None]:
all_categories = list(set(list(id_to_cat.values())))
all_categories_name = list(set([name.split('_')[0] for name in all_categories]))
print(all_categories_name)

In [None]:
might_occlude_names = ['WashingMachine', 'Door', 'Safe', 'Dishwasher', 'Refrigerator', 'Microwave', 'Oven']
might_occludes = [name + '_test' for name in might_occlude_names]

In [None]:
might_be_occluded_trials = []

for obj_id, joint_ids in obj_dict.items():
    if id_to_cat[obj_id] not in might_occludes:
        continue
    print(obj_id, joint_ids)
    for joint_id in joint_ids:
        raw_data = PMObject(os.path.join(pm_dir, obj_id))
        # available_joints = raw_data.semantics.by_type("hinge") + raw_data.semantics.by_type(
        #     "slider"
        # )
        # available_joints = [joint.name for joint in available_joints]
        # target_link = available_joints[joint_id]
        # print(target_link)
        target_link = joint_id

        # # History
        # trial_figs, trial_results, all_signals = trial_with_diffuser_history(
        #     obj_id=obj_id,
        #     model=switch_model,
        #     history_model=switch_model,
        #     n_step=30,
        #     gui=False,
        #     website=True,
        #     all_joint=False,
        #     available_joints=[target_link],
        #     consistency_check=True,
        #     history_filter=True,
        #     analysis=True
        # )
        # (sim_trajectory, update_history_signals, cc_cnts, sgp_signals, visual_all_points, visual_link_ixs, visual_grasp_points_idx, visual_grasp_points, visual_flows) = all_signals[0]

        # FlowBot
        trial_figs, trial_results, all_signals = trial_with_prediction(
            obj_id=obj_id,
            network=network,
            n_step=30,
            gui=False,
            all_joint=False,
            available_joints=[target_link],
            website=True,
            sgp=False,
            analysis=True,
        )
        # print(trial_results, all_signals)
        try:
            (sim_trajectory, update_history_signals, cc_cnts, sgp_signals, visual_all_points, visual_link_ixs, visual_grasp_points_idx, visual_grasp_points, visual_flows) = all_signals[0]
            for step in range(1, len(sim_trajectory)):
                if sim_trajectory[step] < sim_trajectory[step-1]:
                    might_be_occluded_trials.append([trial_figs, trial_results, all_signals[0]])
                    print(f"Found! id: {obj_id}")
        except:
            continue
        

# breakpoint()