In [None]:
import numpy as np
import torch
np.random.seed(42)
torch.manual_seed(42)
import os

import numpy as np
import pybullet as p
import rpad.pyg.nets.pointnet2 as pnp
import torch

# Trial with dit
from open_anything_diffusion.models.modules.dit_models import DiT
torch.set_printoptions(precision=10)  # Set higher precision for PyTorch outputs
np.set_printoptions(precision=10)

In [None]:
network = DiT(
    in_channels=3 + 3,
    depth=5,
    hidden_size=128,
    num_heads=4,
    learn_sigma=True,
).cuda()

# ckpt_file = "/home/yishu/open_anything_diffusion/logs/train_trajectory_diffuser_dit/2024-03-10/10-54-13/checkpoints/epoch=459-step=80500-val_loss=0.00-weights-only.ckpt"
ckpt_file = "/home/yishu/open_anything_diffusion/logs/train_trajectory_diffuser_dit/2024-03-30/07-12-41/checkpoints/epoch=359-step=199080-val_loss=0.00-weights-only.ckpt"
from hydra import compose, initialize

initialize(config_path="../../configs", version_base="1.3")
cfg = compose(config_name="eval_sim")


## Simulation

### Run trial

In [None]:
import copy
import time
from dataclasses import dataclass
from typing import Optional

import imageio
import numpy as np
import pybullet as p
import torch
from flowbot3d.datasets.flow_dataset import compute_normalized_flow
from flowbot3d.grasping.agents.flowbot3d import FlowNetAnimation
from rpad.partnet_mobility_utils.data import PMObject
from rpad.partnet_mobility_utils.render.pybullet import PMRenderEnv
from rpad.pybullet_envs.suction_gripper import FloatingSuctionGripper
from scipy.spatial.transform import Rotation as R

from open_anything_diffusion.datasets.flow_trajectory_dataset import (
    compute_flow_trajectory,
)
from open_anything_diffusion.metrics.trajectory import normalize_trajectory


class PMSuctionSim:
    def __init__(self, obj_id: str, dataset_path: str, gui: bool = False):
        self.render_env = PMRenderEnv(obj_id=obj_id, dataset_path=dataset_path, gui=gui)
        self.gui = gui
        self.gripper = FloatingSuctionGripper(self.render_env.client_id)
        self.gripper.set_pose(
            [-1, 0.6, 0.8], p.getQuaternionFromEuler([0, np.pi / 2, 0])
        )

    def reset(self):
        pass

    def reset_gripper(self):
        # print(self.gripper.contact_const)
        self.gripper.release()
        self.gripper.set_pose(
            [-1, 0.6, 0.8], p.getQuaternionFromEuler([0, np.pi / 2, 0])
        )

    def set_gripper_pose(self, pos, ori):
        self.gripper.set_pose(pos, ori)

    def set_joint_state(self, link_name: str, value: float):
        p.resetJointState(
            self.render_env.obj_id,
            self.render_env.link_name_to_index[link_name],
            value,
            0.0,
            self.render_env.client_id,
        )

    def render(self, filter_nonobj_pts: bool = False, n_pts: Optional[int] = None):
        output = self.render_env.render()
        rgb, depth, seg, P_cam, P_world, pc_seg, segmap = output

        if filter_nonobj_pts:
            pc_seg_obj = np.ones_like(pc_seg) * -1
            for k, (body, link) in segmap.items():
                if body == self.render_env.obj_id:
                    ixs = pc_seg == k
                    pc_seg_obj[ixs] = link

            is_obj = pc_seg_obj != -1
            P_cam = P_cam[is_obj]
            P_world = P_world[is_obj]
            pc_seg = pc_seg_obj[is_obj]
        if n_pts is not None:
            perm = np.random.permutation(len(P_world))[:n_pts]
            P_cam = P_cam[perm]
            P_world = P_world[perm]
            pc_seg = pc_seg[perm]

        return rgb, depth, seg, P_cam, P_world, pc_seg, segmap

    def set_camera(self):
        pass

    def teleport_and_approach(
        self, point, contact_vector, video_writer=None, standoff_d: float = 0.2
    ):
        # Normalize contact vector.
        contact_vector = (contact_vector / contact_vector.norm(dim=-1)).float()

        p_teleport = (torch.from_numpy(point) + contact_vector * standoff_d).float()

        # breakpoint()

        e_z_init = torch.tensor([0, 0, 1.0]).float()
        e_y = -contact_vector
        e_x = torch.cross(-contact_vector, e_z_init)
        e_x = e_x / e_x.norm(dim=-1)
        e_z = torch.cross(e_x, e_y)
        e_z = e_z / e_z.norm(dim=-1)
        R_teleport = torch.stack([e_x, e_y, e_z], dim=1)
        R_gripper = torch.as_tensor(
            [
                [1, 0, 0],
                [0, 0, 1.0],
                [0, -1.0, 0],
            ]
        )
        # breakpoint()
        o_teleport = R.from_matrix(R_teleport @ R_gripper).as_quat()

        self.gripper.set_pose(p_teleport, o_teleport)

        contact = self.gripper.detect_contact(self.render_env.obj_id)
        max_steps = 500
        curr_steps = 0
        self.gripper.set_velocity(-contact_vector * 0.4, [0, 0, 0])
        while not contact and curr_steps < max_steps:
            p.stepSimulation(self.render_env.client_id)

            if video_writer is not None and curr_steps % 50 == 49:
                # if video_writer is not None:
                frame_width = 640
                frame_height = 480
                width, height, rgbImg, depthImg, segImg = p.getCameraImage(
                    width=frame_width,
                    height=frame_height,
                    viewMatrix=p.computeViewMatrixFromYawPitchRoll(
                        cameraTargetPosition=[0, 0, 0],
                        distance=5,
                        # yaw=270,
                        # distance=3,
                        yaw=90,
                        pitch=-30,
                        roll=0,
                        upAxisIndex=2,
                    ),
                    projectionMatrix=p.computeProjectionMatrixFOV(
                        fov=60,
                        aspect=float(frame_width) / frame_height,
                        nearVal=0.1,
                        farVal=100.0,
                    ),
                )
                image = np.array(rgbImg, dtype=np.uint8)
                image = image[:, :, :3]

                # Add the frame to the video
                video_writer.append_data(image)

            curr_steps += 1
            if self.gui:
                time.sleep(1 / 240.0)
            if curr_steps % 1 == 0:
                contact = self.gripper.detect_contact(self.render_env.obj_id)

        # Give it another chance
        if contact:
            print("contact detected")

        return contact

    def teleport(
        self, points, contact_vectors, video_writer=None, standoff_d: float = 0.2
    ):
        # p.setTimeStep(1.0/240)
        for id, (point, contact_vector) in enumerate(zip(points, contact_vectors)):
            # Normalize contact vector.
            # contact_vector = -1 * contact_vector
            contact_vector = (contact_vector / contact_vector.norm(dim=-1)).float()
            p_teleport = (torch.from_numpy(point) + contact_vector * standoff_d).float()
            # breakpoint()

            e_z_init = torch.tensor([0, 0, 1.0]).float()
            e_y = -contact_vector
            e_x = torch.cross(-contact_vector, e_z_init)
            e_x = e_x / e_x.norm(dim=-1)
            e_z = torch.cross(e_x, e_y)
            e_z = e_z / e_z.norm(dim=-1)
            R_teleport = torch.stack([e_x, e_y, e_z], dim=1)
            R_gripper = torch.as_tensor(
                [
                    [1, 0, 0],
                    [0, 0, 1.0],
                    [0, -1.0, 0],
                ]
            )
            o_teleport = R.from_matrix(R_teleport @ R_gripper).as_quat()
            self.gripper.set_pose(p_teleport, o_teleport)

            contact = self.gripper.detect_contact(self.render_env.obj_id)
            max_steps = 500
            curr_steps = 0
            self.gripper.set_velocity(-contact_vector * 0.4, [0, 0, 0])
            while not contact and curr_steps < max_steps:
                p.stepSimulation(self.render_env.client_id)
                # print(point, p.getBasePositionAndOrientation(self.gripper.body_id),p.getBasePositionAndOrientation(self.gripper.base_id))
                if video_writer is not None and curr_steps % 50 == 49:
                    # if video_writer is not None:
                    frame_width = 640
                    frame_height = 480
                    width, height, rgbImg, depthImg, segImg = p.getCameraImage(
                        width=frame_width,
                        height=frame_height,
                        viewMatrix=p.computeViewMatrixFromYawPitchRoll(
                            cameraTargetPosition=[0, 0, 0],
                            distance=5,
                            # yaw=270,
                            yaw=90,
                            pitch=-30,
                            roll=0,
                            upAxisIndex=2,
                        ),
                        projectionMatrix=p.computeProjectionMatrixFOV(
                            fov=60,
                            aspect=float(frame_width) / frame_height,
                            nearVal=0.1,
                            farVal=100.0,
                        ),
                    )
                    image = np.array(rgbImg, dtype=np.uint8)
                    image = image[:, :, :3]

                    # Add the frame to the video
                    video_writer.append_data(image)

                curr_steps += 1
                if self.gui:
                    time.sleep(1 / 240.0)
                if curr_steps % 1 == 0:
                    contact = self.gripper.detect_contact(self.render_env.obj_id)

            # Give it another chance
            if contact:
                print("contact detected")
                return id, True

        return -1, False

    def attach(self):
        self.gripper.activate(self.render_env.obj_id)

    def pull(self, direction, n_steps: int = 100):
        direction = torch.as_tensor(direction)
        direction = direction / direction.norm(dim=-1)
        # breakpoint()
        for _ in range(n_steps):
            self.gripper.set_velocity(direction * 0.4, [0, 0, 0])
            p.stepSimulation(self.render_env.client_id)
            if self.gui:
                time.sleep(1 / 240.0)

    def get_joint_value(self, target_link: str):
        link_index = self.render_env.link_name_to_index[target_link]
        state = p.getJointState(
            self.render_env.obj_id, link_index, self.render_env.client_id
        )
        joint_pos = state[0]
        return joint_pos

    def detect_success(self, target_link: str):
        link_index = self.render_env.link_name_to_index[target_link]
        info = p.getJointInfo(
            self.render_env.obj_id, link_index, self.render_env.client_id
        )
        lower, upper = info[8], info[9]
        curr_pos = self.get_joint_value(target_link)

        sign = -1 if upper < lower else 1
        print(
            f"lower: {lower}, upper: {upper}, curr: {curr_pos}, success:{(upper - curr_pos) / (upper - lower) < 0.1}"
        )

        return (upper - curr_pos) / (upper - lower) < 0.1, (curr_pos - lower) / (
            upper - lower
        )

    def randomize_joints(self):
        for i in range(
            p.getNumJoints(self.render_env.obj_id, self.render_env.client_id)
        ):
            jinfo = p.getJointInfo(self.render_env.obj_id, i, self.render_env.client_id)
            if jinfo[2] == p.JOINT_REVOLUTE or jinfo[2] == p.JOINT_PRISMATIC:
                lower, upper = jinfo[8], jinfo[9]
                angle = np.random.random() * (upper - lower) + lower
                p.resetJointState(
                    self.render_env.obj_id, i, angle, 0, self.render_env.client_id
                )

    def randomize_specific_joints(self, joint_list):
        for i in range(
            p.getNumJoints(self.render_env.obj_id, self.render_env.client_id)
        ):
            jinfo = p.getJointInfo(self.render_env.obj_id, i, self.render_env.client_id)
            if jinfo[12].decode("UTF-8") in joint_list:
                lower, upper = jinfo[8], jinfo[9]
                angle = np.random.random() * (upper - lower) + lower
                p.resetJointState(
                    self.render_env.obj_id, i, angle, 0, self.render_env.client_id
                )

    def articulate_specific_joints(self, joint_list, amount):
        for i in range(
            p.getNumJoints(self.render_env.obj_id, self.render_env.client_id)
        ):
            jinfo = p.getJointInfo(self.render_env.obj_id, i, self.render_env.client_id)
            if jinfo[12].decode("UTF-8") in joint_list:
                lower, upper = jinfo[8], jinfo[9]
                angle = amount * (upper - lower) + lower
                p.resetJointState(
                    self.render_env.obj_id, i, angle, 0, self.render_env.client_id
                )

    def randomize_joints_openclose(self, joint_list):
        randind = np.random.choice([0, 1])
        # Close: 0
        # Open: 1
        self.close_or_open = randind
        for i in range(
            p.getNumJoints(self.render_env.obj_id, self.render_env.client_id)
        ):
            jinfo = p.getJointInfo(self.render_env.obj_id, i, self.render_env.client_id)
            if jinfo[12].decode("UTF-8") in joint_list:
                lower, upper = jinfo[8], jinfo[9]
                angles = [lower, upper]
                angle = angles[randind]
                p.resetJointState(
                    self.render_env.obj_id, i, angle, 0, self.render_env.client_id
                )


@dataclass
class TrialResult:
    success: bool
    contact: bool
    assertion: bool
    init_angle: float
    final_angle: float
    now_angle: float

    # UMPNet metric goes here
    metric: float


class GTFlowModel:
    def __init__(self, raw_data, env):
        self.env = env
        self.raw_data = raw_data

    def __call__(self, obs) -> torch.Tensor:
        rgb, depth, seg, P_cam, P_world, pc_seg, segmap = obs
        env = self.env
        raw_data = self.raw_data

        links = raw_data.semantics.by_type("slider")
        links += raw_data.semantics.by_type("hinge")
        current_jas = {}
        for link in links:
            linkname = link.name
            chain = raw_data.obj.get_chain(linkname)
            for joint in chain:
                current_jas[joint.name] = 0

        normalized_flow = compute_normalized_flow(
            P_world,
            env.render_env.T_world_base,
            current_jas,
            pc_seg,
            env.render_env.link_name_to_index,
            raw_data,
            "all",
        )

        return torch.from_numpy(normalized_flow)

    def get_movable_mask(self, obs) -> torch.Tensor:
        flow = self(obs)
        mask = (~(np.isclose(flow, 0.0)).all(axis=-1)).astype(np.bool_)
        return mask


class GTTrajectoryModel:
    def __init__(self, raw_data, env, traj_len=20):
        self.raw_data = raw_data
        self.env = env
        self.traj_len = traj_len

    def __call__(self, obs) -> torch.Tensor:
        rgb, depth, seg, P_cam, P_world, pc_seg, segmap = obs
        env = self.env
        raw_data = self.raw_data

        links = raw_data.semantics.by_type("slider")
        links += raw_data.semantics.by_type("hinge")
        current_jas = {}
        for link in links:
            linkname = link.name
            chain = raw_data.obj.get_chain(linkname)
            for joint in chain:
                current_jas[joint.name] = 0
        trajectory, _ = compute_flow_trajectory(
            self.traj_len,
            P_world,
            env.render_env.T_world_base,
            current_jas,
            pc_seg,
            env.render_env.link_name_to_index,
            raw_data,
            "all",
        )
        return torch.from_numpy(trajectory)


def choose_grasp_points(raw_pred_flow, raw_point_cloud, filter_edge=False, k=20):
    pred_flow = raw_pred_flow.clone()
    point_cloud = raw_point_cloud
    # Choose top k non-edge grasp points:
    if filter_edge:  # Need to filter the edge points
        squared_diff = (
            point_cloud[:, np.newaxis, :] - point_cloud[np.newaxis, :, :]
        ) ** 2
        dists = np.sqrt(np.sum(squared_diff, axis=2))
        dist_thres = np.percentile(dists, 10)
        neighbour_points = np.sum(dists < dist_thres, axis=0)
        invalid_points = neighbour_points < np.percentile(
            neighbour_points, 30
        )  # Not edge
        pred_flow[invalid_points] = 0  # Don't choose these edge points!!!!!

    top_k_point = min(k, len(pred_flow))
    best_flow_ix = torch.topk(pred_flow.norm(dim=-1), top_k_point)[1]
    if top_k_point == 1:
        best_flow_ix = torch.tensor(list(best_flow_ix) * 2)
    best_flow = pred_flow[best_flow_ix]
    best_point = point_cloud[best_flow_ix]
    return best_flow_ix, best_flow, best_point


def run_trial(
    env: PMSuctionSim,
    raw_data: PMObject,
    target_link: str,
    model,
    gt_model=None,  # When we use mask_input_channel=True, this is the mask generator
    n_steps: int = 30,
    n_pts: int = 1200,
    save_name: str = "unknown",
    website: bool = False,
    gui: bool = False,
) -> TrialResult:
    torch.manual_seed(42)
    torch.set_printoptions(precision=10)  # Set higher precision for PyTorch outputs
    np.set_printoptions(precision=10)
    # p.setPhysicsEngineParameter(numSolverIterations=10)
    # p.setPhysicsEngineParameter(contactBreakingThreshold=0.01, contactSlop=0.001)

    sim_trajectory = [0.05] + [0] * (n_steps)  # start from 0.05

    if website:
        # Flow animation
        animation = FlowNetAnimation()

    # First, reset the environment.
    env.reset()
    # Joint information
    info = p.getJointInfo(
        env.render_env.obj_id,
        env.render_env.link_name_to_index[target_link],
        env.render_env.client_id,
    )
    init_angle, target_angle = info[8], info[9]

    # Sometimes doors collide with themselves. It's dumb.
    if (
        raw_data.category == "Door"
        and raw_data.semantics.by_name(target_link).type == "hinge"
    ):
        env.set_joint_state(
            target_link, init_angle + 0.05 * (target_angle - init_angle)
        )
        # env.set_joint_state(target_link, 0.2)

    if raw_data.semantics.by_name(target_link).type == "hinge":
        env.set_joint_state(
            target_link, init_angle + 0.05 * (target_angle - init_angle)
        )
        # env.set_joint_state(target_link, 0.05)

    # Predict the flow on the observation.
    pc_obs = env.render(filter_nonobj_pts=True, n_pts=n_pts)
    rgb, depth, seg, P_cam, P_world, pc_seg, segmap = pc_obs

    if init_angle == target_angle:  # Not movable
        p.disconnect(physicsClientId=env.render_env.client_id)
        return (
            None,
            TrialResult(
                success=False,
                assertion=False,
                contact=False,
                init_angle=0,
                final_angle=0,
                now_angle=0,
                metric=0,
            ),
            sim_trajectory,
        )

    # breakpoint()
    if gt_model is None:  # GT Flow model
        pred_trajectory = model(copy.deepcopy(pc_obs))
    else:
        movable_mask = gt_model.get_movable_mask(pc_obs)
        pred_trajectory = model(copy.deepcopy(pc_obs), movable_mask)
    # pred_trajectory = model(copy.deepcopy(pc_obs))
    # breakpoint()
    pred_trajectory = pred_trajectory.reshape(
        pred_trajectory.shape[0], -1, pred_trajectory.shape[-1]
    )
    traj_len = pred_trajectory.shape[1]  # Trajectory length
    print(f"Predicting {traj_len} length trajectories.")
    pred_flow = pred_trajectory[:, 0, :]

    # flow_fig(torch.from_numpy(P_world), pred_flow, sizeref=0.1, use_v2=True).show()
    # breakpoint()

    # Filter down just the points on the target link.
    link_ixs = pc_seg == env.render_env.link_name_to_index[target_link]
    # assert link_ixs.any()
    if not link_ixs.any():
        p.disconnect(physicsClientId=env.render_env.client_id)
        print("link_ixs finds no point")
        animation_results = animation.animate() if website else None
        return (
            animation_results,
            TrialResult(
                success=False,
                assertion=False,
                contact=False,
                init_angle=0,
                final_angle=0,
                now_angle=0,
                metric=0,
            ),
            sim_trajectory,
        )

    if website:
        if gui:
            # Record simulation video
            log_id = p.startStateLogging(
                p.STATE_LOGGING_VIDEO_MP4,
                f"./logs/simu_eval/video_assets/{save_name}.mp4",
            )
        else:
            video_file = f"./logs/simu_eval/video_assets/{save_name}.mp4"
            # # cv2 output videos won't show on website
            frame_width = 640
            frame_height = 480
            # fps = 5
            # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            # videoWriter = cv2.VideoWriter(video_file, fourcc, fps, (frame_width, frame_height))
            # videoWriter.write(rgbImgOpenCV)

            # Camera param
            writer = imageio.get_writer(video_file, fps=5)

            # Capture image
            width, height, rgbImg, depthImg, segImg = p.getCameraImage(
                width=frame_width,
                height=frame_height,
                viewMatrix=p.computeViewMatrixFromYawPitchRoll(
                    cameraTargetPosition=[0, 0, 0],
                    distance=5,
                    # yaw=270, 
                    yaw = 90,
                    pitch=-30,
                    roll=0,
                    upAxisIndex=2,
                ),
                projectionMatrix=p.computeProjectionMatrixFOV(
                    fov=60,
                    aspect=float(frame_width) / frame_height,
                    nearVal=0.1,
                    farVal=100.0,
                ),
            )
            image = np.array(rgbImg, dtype=np.uint8)
            image = image[:, :, :3]

            # Add the frame to the video
            writer.append_data(image)

    # The attachment point is the point with the highest flow.
    # best_flow_ix = pred_flow[link_ixs].norm(dim=-1).argmax()
    best_flow_ix, best_flows, best_points = choose_grasp_points(
        pred_flow[link_ixs], P_world[link_ixs], filter_edge=False, k=20
    )

    # Teleport to an approach pose, approach, the object and grasp.
    if website and not gui:
        # contact = env.teleport_and_approach(best_point, best_flow, video_writer=writer)
        best_flow_ix, contact = env.teleport(
            best_points, best_flows, video_writer=writer
        )
    else:
        # contact = env.teleport_and_approach(best_point, best_flow)
        best_flow_ix, contact = env.teleport(best_points, best_flows)
    best_flow = pred_flow[link_ixs][best_flow_ix]
    best_point = P_world[link_ixs][best_flow_ix]
    last_step_grasp_point = best_point

    if not contact:
        if website:
            segmented_flow = np.zeros_like(pred_flow)
            # segmented_flow[link_ixs] = pred_flow[link_ixs]
            # segmented_flow = np.array(
            #     normalize_trajectory(
            #         torch.from_numpy(np.expand_dims(segmented_flow, 1))
            #     ).squeeze()
            # )
            point = best_point
            contact_vector = best_flow
            contact_vector = (contact_vector / contact_vector.norm(dim=-1)).float()
            p_teleport = (torch.from_numpy(point) + contact_vector * 0.2).float()
            print(segmented_flow[link_ixs].shape)
            segmented_flow[link_ixs][best_flow_ix] = p_teleport - point
            animation.add_trace(
                torch.as_tensor(P_world),
                torch.as_tensor([P_world]),
                torch.as_tensor([segmented_flow]),
                "red",
            )
            if gui:
                p.stopStateLogging(log_id)
            else:
                # Write video
                writer.close()
                # videoWriter.release()

        print("No contact!")
        p.disconnect(physicsClientId=env.render_env.client_id)
        animation_results = None if not website else animation.animate()
        return (
            animation_results,
            TrialResult(
                success=False,
                assertion=True,
                contact=False,
                init_angle=0,
                final_angle=0,
                now_angle=0,
                metric=0,
            ),
            sim_trajectory,
        )

    env.attach()

    pc_obs = env.render(filter_nonobj_pts=True, n_pts=n_pts)
    success = False

    global_step = 0
    # for i in range(n_steps):
    while global_step < n_steps:
        # Predict the flow on the observation.
        if gt_model is None:  # GT Flow model
            pred_trajectory = model(copy.deepcopy(pc_obs))
        else:
            movable_mask = gt_model.get_movable_mask(pc_obs)
            # breakpoint()
            pred_trajectory = model(pc_obs, movable_mask)
            # pred_trajectory = model(pc_obs)
        pred_trajectory = pred_trajectory.reshape(
            pred_trajectory.shape[0], -1, pred_trajectory.shape[-1]
        )

        for traj_step in range(pred_trajectory.shape[1]):
            if global_step == n_steps:
                break
            global_step += 1
            pred_flow = pred_trajectory[:, traj_step, :]
            rgb, depth, seg, P_cam, P_world, pc_seg, segmap = pc_obs

            # Filter down just the points on the target link.
            # breakpoint()
            link_ixs = pc_seg == env.render_env.link_name_to_index[target_link]
            # assert link_ixs.any()
            if not link_ixs.any():
                if website:
                    if gui:
                        p.stopStateLogging(log_id)
                    else:
                        writer.close()
                        # videoWriter.release()
                p.disconnect(physicsClientId=env.render_env.client_id)
                print("link_ixs finds no point")
                animation_results = animation.animate() if website else None
                return (
                    animation_results,
                    TrialResult(
                        assertion=False,
                        success=False,
                        contact=False,
                        init_angle=0,
                        final_angle=0,
                        now_angle=0,
                        metric=0,
                    ),
                    sim_trajectory,
                )

            # Get the best direction.
            # best_flow_ix = pred_flow[link_ixs].norm(dim=-1).argmax()
            best_flow_ix, best_flows, best_points = choose_grasp_points(
                pred_flow[link_ixs], P_world[link_ixs], filter_edge=False, k=20
            )

            # (1) Strategy 1 - Don't change grasp point
            # (2) Strategy 2 - Change grasp point when leverage difference is large
            lev_diff_thres = 0.2
            no_movement_thres = -1

            # # Don't use this policy
            # lev_diff_thres = 100
            # no_movement_thres = -1
            # good_movement_thres = 1000

            # Only change if the new point's leverage is a great increase
            # gripper_tip_pos = p.getClosestPoints(
            #     env.gripper.body_id, env.render_env.obj_id, distance=0.5, linkIndexA=0
            # )[0][5]
            # gripper_object_contact = p.getContactPoints(
            #     env.gripper.body_id, env.render_env.obj_id, linkIndexA=0
            # )[0]
            # gripper_contact, object_contact = gripper_object_contact[5], gripper_object_contact[6]
            gripper_tip_pos, _ = p.getBasePositionAndOrientation(env.gripper.body_id)
            pcd_dist = torch.tensor(P_world[link_ixs] - np.array(gripper_tip_pos)).norm(
                dim=-1
            )
            grasp_point_id = pcd_dist.argmin()
            lev_diff = best_flows.norm(dim=-1) - pred_flow[link_ixs][
                grasp_point_id
            ].norm(dim=-1)

            gripper_movement = torch.from_numpy(P_world[grasp_point_id] - last_step_grasp_point).norm()
            # print("gripper: ",gripper_movement)
            # breakpoint()
            if (
                gripper_movement < no_movement_thres
                or lev_diff[0] > lev_diff_thres
            ):  # pcd_dist < 0.05 -> didn't move much....
                env.reset_gripper()
                p.stepSimulation(
                    env.render_env.client_id
                )  # Make sure the constraint is lifted

                if website and not gui:
                    # contact = env.teleport_and_approach(best_point, best_flow, video_writer=writer)
                    best_flow_ix, contact = env.teleport(
                        best_points, best_flows, video_writer=writer
                    )
                else:
                    # contact = env.teleport_and_approach(best_point, best_flow)
                    best_flow_ix, contact = env.teleport(best_points, best_flows)
                best_flow = pred_flow[link_ixs][best_flow_ix]
                best_point = P_world[link_ixs][best_flow_ix]
                last_step_grasp_point = best_point  # Grasp a new point
                # print("new!", last_step_grasp_point)

                if not contact:
                    if website:
                        segmented_flow = np.zeros_like(pred_flow)
                        # segmented_flow[link_ixs] = pred_flow[link_ixs]
                        # segmented_flow = np.array(
                        #     normalize_trajectory(
                        #         torch.from_numpy(np.expand_dims(segmented_flow, 1))
                        #     ).squeeze()
                        # )
                        point = best_point
                        contact_vector = best_flow
                        contact_vector = (contact_vector / contact_vector.norm(dim=-1)).float()
                        p_teleport = (torch.from_numpy(point) + contact_vector * 0.2).float()
                        print(segmented_flow[link_ixs].shape)
                        segmented_flow[link_ixs][best_flow_ix] = p_teleport - point
                        print("HI!!!:", p_teleport - point, segmented_flow[link_ixs][best_flow_ix])
                        animation.add_trace(
                            torch.as_tensor(P_world),
                            torch.as_tensor([P_world]),
                            torch.as_tensor([segmented_flow]),
                            "red",
                        )
                        if gui:
                            p.stopStateLogging(log_id)
                        else:
                            # Write video
                            writer.close()
                            # videoWriter.release()

                    print("No contact!")
                    p.disconnect(physicsClientId=env.render_env.client_id)
                    animation_results = None if not website else animation.animate()
                    return (
                        animation_results,
                        TrialResult(
                            success=False,
                            assertion=True,
                            contact=False,
                            init_angle=0,
                            final_angle=0,
                            now_angle=0,
                            metric=0,
                        ),
                        sim_trajectory,
                    )

                env.attach()
            else:
                best_flow = pred_flow[link_ixs][best_flow_ix[0]]
                last_step_grasp_point = P_world[link_ixs][
                    grasp_point_id
                ]  # The original point - don't need to change
                best_point = P_world[link_ixs][grasp_point_id]
                best_flow_ix = grasp_point_id
                # print("same:", last_step_grasp_point)

            env.attach()
            # Perform the pulling.
            # if best_flow.sum() == 0:
            #     continue
            # print(best_flow)
            env.pull(best_flow)
            env.attach()

            if website:
                # Add pcd to flow animation
                segmented_flow = np.zeros_like(pred_flow)
                # segmented_flow[link_ixs] = pred_flow[link_ixs]
                # segmented_flow = np.array(
                #     normalize_trajectory(
                #         torch.from_numpy(np.expand_dims(segmented_flow, 1))
                #     ).squeeze()
                # )
                point = best_point
                contact_vector = best_flow
                contact_vector = (contact_vector / contact_vector.norm(dim=-1)).float()
                p_teleport = (torch.from_numpy(point) + contact_vector * 0.2).float()
                print(segmented_flow[link_ixs].shape)
                segmented_flow[link_ixs][best_flow_ix] = p_teleport - point
                print("HI!!!:", p_teleport - point, segmented_flow[link_ixs][best_flow_ix])
                animation.add_trace(
                    torch.as_tensor(P_world),
                    torch.as_tensor([P_world]),
                    torch.as_tensor([segmented_flow]),
                    "red",
                )

                # Capture frame
                width, height, rgbImg, depthImg, segImg = p.getCameraImage(
                    width=frame_width,
                    height=frame_height,
                    viewMatrix=p.computeViewMatrixFromYawPitchRoll(
                        cameraTargetPosition=[0, 0, 0],
                        distance=5,
                        # yaw=270, 
                        yaw=90, 
                        pitch=-30,
                        roll=0,
                        upAxisIndex=2,
                    ),
                    projectionMatrix=p.computeProjectionMatrixFOV(
                        fov=60,
                        aspect=float(frame_width) / frame_height,
                        nearVal=0.1,
                        farVal=100.0,
                    ),
                )
                # rgbImgOpenCV = cv2.cvtColor(np.array(rgbImg), cv2.COLOR_RGB2BGR)
                # videoWriter.write(rgbImgOpenCV)
                image = np.array(rgbImg, dtype=np.uint8)
                image = image[:, :, :3]

                # Add the frame to the video
                writer.append_data(image)

            success, sim_trajectory[global_step] = env.detect_success(target_link)

            if success:
                for left_step in range(global_step, 31):
                    sim_trajectory[left_step] = sim_trajectory[global_step]
                break

            pc_obs = env.render(filter_nonobj_pts=True, n_pts=1200)

        if success:
            for left_step in range(global_step, 31):
                sim_trajectory[left_step] = sim_trajectory[global_step]
            break

    # calculate the metrics
    curr_pos = env.get_joint_value(target_link)
    metric = (curr_pos - init_angle) / (target_angle - init_angle)
    metric = min(max(metric, 0), 1)

    if website:
        if gui:
            p.stopStateLogging(log_id)
        else:
            writer.close()
            # videoWriter.release()

    p.disconnect(physicsClientId=env.render_env.client_id)
    animation_results = None if not website else animation.animate()
    return (
        animation_results,
        TrialResult(  # Save the flow visuals
            success=success,
            contact=True,
            assertion=True,
            init_angle=init_angle,
            final_angle=target_angle,
            now_angle=curr_pos,
            metric=metric,
        ),
        sim_trajectory,
    )


### Trial with diffuser

In [None]:
from rpad.partnet_mobility_utils.data import PMObject
from open_anything_diffusion.simulations.suction import PMSuctionSim
def trial_with_diffuser(
    obj_id="41083",
    model=None,
    n_step=30,
    gui=False,
    all_joint=False,
    website=False,
    available_joints=None,
):
    # pm_dir = os.path.expanduser("~/datasets/partnet-mobility/raw")
    pm_dir = os.path.expanduser("~/datasets/partnet-mobility/convex")
    # env = PMSuctionSim(obj_id, pm_dir, gui=gui)
    raw_data = PMObject(os.path.join(pm_dir, obj_id))

    if available_joints is None:  # Use the passed in joint sets
        available_joints = raw_data.semantics.by_type(
            "hinge"
        ) + raw_data.semantics.by_type("slider")
        available_joints = [joint.name for joint in available_joints]

    print("available_joints:", available_joints)
    if all_joint:  # Need to traverse all the joints
        picked_joints = available_joints
    else:
        picked_joints = [available_joints[np.random.randint(0, len(available_joints))]]

    sim_trajectories = []
    results = []
    figs = {}
    for joint_name in picked_joints:
        # t0 = time.perf_counter()
        # print(f"opening {joint.name}, {joint.label}")
        print(f"opening {joint_name}")
        env = PMSuctionSim(obj_id, pm_dir, gui=gui)

        # Close all joints:
        for link_to_restore in [
            joint.name
            for joint in raw_data.semantics.by_type("hinge")
            + raw_data.semantics.by_type("slider")
        ]:
            info = p.getJointInfo(
                env.render_env.obj_id,
                env.render_env.link_name_to_index[link_to_restore],
                env.render_env.client_id,
            )
            init_angle, target_angle = info[8], info[9]
            env.set_joint_state(link_to_restore, init_angle)

        # gt_model = GTFlowModel(raw_data, env)
        fig, result, sim_trajectory = run_trial(
            env,
            raw_data,
            joint_name,
            model,
            gt_model=None,  # Don't need mask
            n_steps=n_step,
            save_name=f"{obj_id}_{joint_name}",
            website=website,
            gui=gui,
        )
        sim_trajectories.append(sim_trajectory)
        if result.assertion is False:
            with open(
                "/home/yishu/open_anything_diffusion/logs/assertion_failure.txt", "a"
            ) as f:
                f.write(f"Object: {obj_id}; Joint: {joint_name}\n")
            continue
        if result.contact is False:
            continue
        figs[joint_name] = fig
        results.append(result)

    return figs, results, sim_trajectories

In [None]:
from open_anything_diffusion.models.flow_diffuser_dit import (
    FlowTrajectoryDiffuserSimulationModule_DiT,
)

model = FlowTrajectoryDiffuserSimulationModule_DiT(
    network, inference_cfg=cfg.inference, model_cfg=cfg.model
).cuda()
model.load_from_ckpt(ckpt_file)
model.eval()

trial_figs, trial_results, sim_trajectory = trial_with_diffuser(
    obj_id="102663",
    model=model,
    n_step=5,
    gui=False,
    website=cfg.website,
    all_joint=False,
    available_joints=["link_2"],
)

In [None]:
trial_figs['link_2'].show()