In [None]:
import numpy as np
import torch

from open_anything_diffusion.simulations.simulation import *

np.random.seed(42)
torch.manual_seed(42)
torch.set_printoptions(precision=10)  # Set higher precision for PyTorch outputs
np.set_printoptions(precision=10)
from hydra import compose, initialize

initialize(config_path="../../configs", version_base="1.3")
cfg = compose(config_name="eval_sim_switch")
# cfg = compose(config_name="eval_sim")

from open_anything_diffusion.models.flow_diffuser_dit import (
    FlowTrajectoryDiffuserSimulationModule_DiT,
)
from open_anything_diffusion.models.flow_diffuser_pndit import (
    FlowTrajectoryDiffuserSimulationModule_PNDiT,
)
from open_anything_diffusion.models.modules.history_encoder import HistoryEncoder
from open_anything_diffusion.models.modules.dit_models import DiT, PN2DiT, PN2HisDiT

## HisPNDiT

In [None]:
from open_anything_diffusion.models.flow_diffuser_hispndit import (
    FlowTrajectoryDiffuserSimulationModule_HisPNDiT,
)

# History model
network = {
    "DiT": PN2HisDiT(
        history_embed_dim=128,
        in_channels=3,
        depth=5,
        hidden_size=128,
        num_heads=4,
        # depth=8,
        # hidden_size=256,
        # num_heads=4,
        learn_sigma=True,
    ).cuda(),
    "History": HistoryEncoder(
        history_dim=128,
        history_len=1,
        batch_norm=True,
        transformer=False,
        repeat_dim=False,
    ).cuda(),
}

# ckpt_file = "/home/yishu/open_anything_diffusion/logs/train_trajectory_diffuser_hispndit/2024-05-25/02-00-54/checkpoints/epoch=299-step=248700-val_loss=0.00-weights-only.ckpt"
# ckpt_file = "/home/yishu/open_anything_diffusion/logs/train_trajectory_diffuser_hispndit/2024-05-25/02-00-54/checkpoints/epoch=359-step=298440.ckpt"
ckpt_file = "/home/yishu/open_anything_diffusion/pretrained/fullset_half_half_hispndit.ckpt"
switch_model = FlowTrajectoryDiffuserSimulationModule_HisPNDiT(
    network, inference_cfg=cfg.inference, model_cfg=cfg.model
).cuda()
switch_model.load_from_ckpt(ckpt_file)
switch_model.eval()

## FlowBot

In [None]:
import rpad.pyg.nets.pointnet2 as pnp
network = pnp.PN2Dense(
    in_channels=0,
    out_channels=3,
    p=pnp.PN2DenseParams(),
)

ckpt_file = "/home/yishu/open_anything_diffusion/pretrained/fullset_half_half_flowbotRO.ckpt"
# ckpt_file = "/home/yishu/open_anything_diffusion/pretrained/fullset_half_half_flowbotRO.ckpt"

# Load the network weights.
ckpt = torch.load(ckpt_file)
network.load_state_dict(
    {k.partition(".")[2]: v for k, v, in ckpt["state_dict"].items()}
)
network.eval()

## Simulation

In [None]:
obj_id = "12484"
joint_id = 0
pm_dir = os.path.expanduser("~/datasets/partnet-mobility/convex")

In [None]:
raw_data = PMObject(os.path.join(pm_dir, obj_id))
available_joints = raw_data.semantics.by_type("hinge") + raw_data.semantics.by_type(
    "slider"
)
available_joints = [joint.name for joint in available_joints]
target_link = available_joints[joint_id]
print(target_link)

# # History
# trial_figs, trial_results, all_signals = trial_with_diffuser_history(
#     obj_id=obj_id,
#     model=switch_model,
#     history_model=switch_model,
#     n_step=30,
#     gui=False,
#     website=True,
#     all_joint=False,
#     available_joints=[target_link],
#     consistency_check=True,
#     history_filter=True,
#     analysis=True
# )
# (sim_trajectory, update_history_signals, cc_cnts, sgp_signals, visual_all_points, visual_link_ixs, visual_grasp_points_idx, visual_grasp_points, visual_flows) = all_signals[0]

# FlowBot
trial_figs, trial_results, all_signals = trial_with_prediction(
    obj_id=obj_id,
    network=network,
    n_step=30,
    gui=False,
    all_joint=False,
    available_joints=[target_link],
    website=True,
    sgp=False,
    analysis=True,
)
(sim_trajectory, update_history_signals, cc_cnts, sgp_signals, visual_all_points, visual_link_ixs, visual_grasp_points_idx, visual_grasp_points, visual_flows) = all_signals[0]


# breakpoint()

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))

# plt.title(f'DiT & FowBot - Door {obj_id} {target_link}')
# plt.title(f'DiT - Door {obj_id} {target_link}')
# plt.title(f'Flowbot - Door {obj_id} {target_link}')
# plt.title(f'PNDiT & FlowBot - {obj_id} {target_link}')
plt.title(f"HisPNDiT - {obj_id} {target_link}")
fig, ax1 = plt.subplots()

x = [i for i in range(31)]
y = sim_trajectory
# colors = ["red" if color else "blue" for color in colors[1:]]
colors = ["black"] * 30
# colors = ["red"] * 30

for i in range(len(x) - 1):
    plt.plot(x[i : i + 2], y[i : i + 2], color=colors[i], alpha=0.6)

plt.xlabel("Step")
plt.yticks(np.linspace(0, 1, 11))
plt.ylabel("Open ratio")

sgp_label_added = False
history_label_added = False
for i in range(len(update_history_signals)):
    if sgp_signals[i]:
        if not sgp_label_added:
            plt.plot(x[i], y[i], marker='^', color='yellow', markersize=15, alpha=0.8, label='SGP')
            sgp_label_added = True
        else:
            plt.plot(x[i], y[i], marker='^', color='yellow', markersize=15, alpha=0.8)
    if update_history_signals[i]:
        if not history_label_added:
            plt.plot(x[i], y[i], marker='*', color='red', markersize=10, alpha=0.6, label='History')
            history_label_added = True
        else:
            plt.plot(x[i], y[i], marker='*', color='red', markersize=10, alpha=0.6)
    
plt.legend()

new_cc_cnts = [0] * len(x)
for i in range(1, len(cc_cnts)):
    new_cc_cnts[i] = cc_cnts[i] + 1
new_cc_cnts[0] = 1

ax2 = ax1.twinx()
# Plotting the second dataset
ax2.bar(x, new_cc_cnts, color='blue', alpha=0.2)
ax2.set_ylabel('Trial counts', color='purple')
ax2.tick_params(axis='y', labelcolor='purple')


# plt.savefig(f'./traj_visuals/{obj_id}_{target_link}_dit&flowbot.jpg')
# plt.savefig(f'./traj_visuals/{obj_id}_{target_link}_dit.jpg')
# plt.savefig(f'./traj_visuals/{obj_id}_{target_link}_flowbot.jpg')
# plt.savefig(f'./traj_visuals/{obj_id}_{target_link}_pndit&flowbot.jpg')
plt.savefig(f"./traj_visuals/{obj_id}_{target_link}_hispndit.jpg")

In [None]:
trial_figs[target_link].show()

### Grasp point & flow visual

In [None]:
from flowbot3d.grasping.agents.flowbot3d import FlowNetAnimation
from open_anything_diffusion.metrics.trajectory import normalize_trajectory

animation = FlowNetAnimation()

for (P_world, link_ixs, grasp_point_idx, grasp_point, flow) in zip(visual_all_points, visual_link_ixs, visual_grasp_points_idx, visual_grasp_points, visual_flows):
    print(P_world[link_ixs][grasp_point_idx], grasp_point)
    segmented_flow = np.zeros_like(P_world[link_ixs])
    segmented_flow[grasp_point_idx] = flow
    full_flow = np.zeros_like(P_world)
    full_flow[link_ixs] = segmented_flow
    full_flow = np.array(
        normalize_trajectory(
            torch.from_numpy(np.expand_dims(full_flow, 1))
        ).squeeze()
    )
    colors = ["red"] * 1200
    animation.add_trace(
        torch.as_tensor(P_world),
        torch.as_tensor([P_world]),
        torch.as_tensor([full_flow * 3]),
        colors,
    )

animation.animate()

### Grasp point & flow visual (static)

In [None]:
import pickle
with open(f'data_flowbot3D_{obj_id}_{target_link.replace("_", "")}.pkl', 'wb') as f:
    pickle.dump({
        'flow_visual': trial_figs[target_link],
        # 'grasp_visual': animation,
        'visual_points': visual_all_points,
        'visual_link_ixs': visual_link_ixs,
        'visual_grasp_points_idx': visual_grasp_points_idx,
        'visual_grasp_points': visual_grasp_points,
        'visual_flow': visual_flows,
    }, f) 

In [None]:
import numpy as np
import plotly.graph_objects as go
import importlib
import plotly.io as pio

step = 0
points, link_ixs, grasp_point_idx, grasp_point, flow = visual_all_points[step], visual_link_ixs[step], visual_grasp_points_idx[step], visual_grasp_points[step], visual_flows[step]
num_points = 1200

# Color part of the points red and others blue
colors = np.array(['blue'] * (num_points))
colors[link_ixs] = 'red'

# Grasp arrow
arrow_origin = np.array(grasp_point)
arrow_direction = np.array(flow / np.linalg.norm(flow) * 0.3)
arrow_end = arrow_origin + arrow_direction

# Create a scatter plot for the point cloud
scatter = go.Scatter3d(
    x=points[:, 0], y=points[:, 1], z=points[:, 2],
    mode='markers',
    marker=dict(
        size=3,
        color=colors,
        opacity=0.2
    )
)


arrow = go.Scatter3d(
    x=[arrow_origin[0], arrow_end[0]],
    y=[arrow_origin[1], arrow_end[1]],
    z=[arrow_origin[2], arrow_end[2]],
    mode='lines+markers',
    line=dict(color='green', width=10),
    marker=dict(size=3, color='black')
)

cone = go.Cone(x=[arrow_end[0]], y=[arrow_end[1]], z=[arrow_end[2]], u=[(arrow_end[0]-arrow_origin[0]) * 0.7], v=[(arrow_end[1]-arrow_origin[1]) * 0.7], w=[(arrow_end[2]-arrow_origin[2]) * 0.7], colorscale='Greens', showscale=False)

# Combine the scatter plot and the arrow
fig = go.Figure(data=[scatter, arrow, cone])

camera = dict(
    up=dict(x=0, y=0, z=0.5),
    center=dict(x=0, y=0.05, z=0.05),
    eye=dict(x=1.32, y=0.48, z=0.84)
)

fig.update_layout(
    scene=dict(
        xaxis=dict(visible=False, range=[-0.5, 0.5]),
        yaxis=dict(visible=False),
        zaxis=dict(visible=False),
        xaxis_showspikes=False,
        yaxis_showspikes=False,
        zaxis_showspikes=False,
        bgcolor='rgba(255,255,255,1)'
    ),
    width=500,
    height=500,
    showlegend=False,
    coloraxis_showscale=False,
    margin=dict(l=0, r=0, b=0, t=0),
    paper_bgcolor='rgba(0,0,0,0)',
    scene_camera=camera,
)

# Show plot
fig.show()
# pio.write_image(fig, f'{step}.jpeg', format='jpeg')
# fig.to_image(f'{step}.jpeg')


In [None]:
scatter

In [None]:
print(arrow_origin, arrow_end)

## Read simulation results - generate demos!

In [None]:
import json
with open('/home/yishu/open_anything_diffusion/logs/sim_demo_trajectory_pn++/2024-06-22/11-37-33/logs/instance_result.json', 'r') as f:
    data = json.load(f)

In [None]:
data

In [None]:
import pickle as pkl
with open('/home/yishu/open_anything_diffusion/logs/sim_demo_trajectory_pn++/2024-06-22/11-37-33/logs/flow_vis/8867_link_1.pkl', 'rb') as f:
    animation = pkl.load(f)

In [None]:
animation.show()