## Load the model

In [None]:
from open_anything_diffusion.models.modules.dit_models import DGDiT, DiT
from open_anything_diffusion.models.flow_diffuser_dit import FlowTrajectoryDiffuserInferenceModule_DiT
from open_anything_diffusion.models.flow_diffuser_dgdit import FlowTrajectoryDiffuserInferenceModule_DGDiT
inference_module_class = {
    "dit": FlowTrajectoryDiffuserInferenceModule_DiT,
    "dgdit": FlowTrajectoryDiffuserInferenceModule_DGDiT,
}
networks = {
    "dit": DiT(in_channels=6, depth=5, hidden_size=128, num_heads=4, learn_sigma=True),
    "dgdit": DGDiT(in_channels=3, depth=5, hidden_size=128, patch_size=1, num_heads=4, n_points=1200),
}

In [None]:
class InferenceConfig:
    def __init__(self):
        self.batch_size = 1
        self.trajectory_len = 1

inference_config = InferenceConfig()

class ModelConfig:
    def __init__(self):
        self.num_train_timesteps = 100

model_config = ModelConfig()

In [None]:
import os
ckpt_dir = './pretrained'
train_type = 'fullset_half_half'   # door_half_half, fullset_half_half - what dataset the model is trained on 
model_type = 'dit'   # dit, dgdit - model structure
ckpt_path = os.path.join(ckpt_dir, f'{train_type}_{model_type}.ckpt')

In [None]:
model = inference_module_class[model_type](
    networks[model_type].cuda(), inference_cfg=inference_config, model_cfg=model_config
).cuda()
model.load_from_ckpt(ckpt_path)
model.eval()

## Make a prediction

Read the point cloud

In [None]:
import numpy as np
pcd_dir = '/home/yishu/Azure_Kinect_ROS_Driver/src/pc_data_for_yishu'
pcd_paths = [os.path.join(pcd_dir, pcd_name) for pcd_name in os.listdir(pcd_dir)]

In [None]:
id = 2
# path = pcd_paths[id]
path = '/home/yishu/Azure_Kinect_ROS_Driver/src/pc_data_for_yishu/fridge_L_open_fully.npy'
print(path)
pcd = np.load(path)

Sample it to 1200 points

In [None]:
# Could use pytorch3d for this but it has some cuda conflict with my current env and I don't wnana change lol
import numpy as np

def farthest_point_sampling(points, k):
    num_points = points.shape[0]
    chosen_indices = np.zeros(k, dtype=int)
    chosen_indices[0] = np.random.randint(num_points)
    distances = np.full(num_points, np.inf)
    
    for i in range(1, k):
        dist = np.linalg.norm(points - points[chosen_indices[i-1]], axis=1)
        distances = np.minimum(distances, dist)
        chosen_indices[i] = np.argmax(distances)
        
    return points[chosen_indices]

# Example usage
sampled_points = farthest_point_sampling(pcd, 1200)
print(sampled_points.shape)

In [None]:
import torch
pred_flow = model.predict(sampled_points)[:, 0, :]

## Visualize the prediction

In [None]:
import torch
import numpy as np
from flowbot3d.grasping.agents.flowbot3d import FlowNetAnimation
animation = FlowNetAnimation()
animation.add_trace(
    torch.as_tensor(sampled_points),
    # torch.as_tensor([pcd[mask]]),
    # torch.as_tensor([flow[mask]]),
    torch.as_tensor([sampled_points]),
    torch.as_tensor([pred_flow.cpu().numpy()]),
    "red",
)
fig = animation.animate()
fig.show()

In [None]:
import torch
import numpy as np
from flowbot3d.grasping.agents.flowbot3d import FlowNetAnimation
animation = FlowNetAnimation()
animation.add_trace(
    torch.as_tensor(pcd),
    # torch.as_tensor([pcd[mask]]),
    # torch.as_tensor([flow[mask]]),
    torch.as_tensor([pcd]),
    torch.as_tensor([np.zeros_like(pcd)]),
    "red",
)
fig = animation.animate()
fig.show()

## About the policy

In [None]:
# Pseudo codes
def switch_grasp_point(last_gripper_pos, current_gripper_pos, flow_prediction, current_pcd):
    # 1 - find the point in current_pcd closest to current_grasp_point
    grasp_point_id = 0  # current_pcd's closest point id
    grasp_flow = flow_prediction[grasp_point_id]
    # 2 - Compare the grasp point flow with the max prediction flow
    leverage_increase = flow_prediction.norm(dim=-1).max() - grasp_flow.norm()
    if last_gripper_pos - current_gripper_pos < 0.01 or leverage_increase > 0.2:  # move threshold
        return True
    return False