In [None]:
import rpad.partnet_mobility_utils.dataset as rpd
all_objs = (
    rpd.UMPNET_TEST_OBJS
)
id_to_obj_class = {obj_id: obj_class for obj_id, obj_class in all_objs}

In [None]:
set(id_to_obj_class.values())

In [None]:
from hydra import compose, initialize
from omegaconf import OmegaConf
initialize(config_path="../configs", version_base="1.3")
cfg = compose(config_name="train")

In [None]:
import torch_geometric.loader as tgl
from open_anything_diffusion.datasets.flow_trajectory_dataset_pyg import FlowTrajectoryPyGDataset
datamodule = FlowTrajectoryPyGDataset(
    root="/home/yishu/datasets/partnet-mobility/raw",
    split="umpnet-test",
    randomize_joints=True,
    randomize_camera=True,
    # batch_size=1,
    # num_workers=30,
    # n_proc=2,
    seed=42,
    trajectory_len=cfg.training.trajectory_len,  # Only used when training trajectory model
)
val_dataloader = tgl.DataLoader(datamodule, 1, shuffle=False, num_workers=0)

samples = list(enumerate(val_dataloader))

In [None]:
import tqdm
door_cnt = 0
door_samples = []
for sample in tqdm.tqdm(samples):
    sample_id = sample[1].id[0]
    sample_class = id_to_obj_class[sample_id]
    if sample_class=="Door":
        door_cnt += 1
        door_samples.append(sample[1])

In [None]:
door_cnt

### Diffuser visual

In [None]:
import torch
from flowbot3d.grasping.agents.flowbot3d import FlowNetAnimation
from open_anything_diffusion.metrics.trajectory import artflownet_loss, flow_metrics, normalize_trajectory

@torch.no_grad()
def diffuse_visual(batch, model):  # 1 sample batch
    model.eval()
    
    animation = FlowNetAnimation()
    pcd = batch.pos.cpu().numpy()
    mask = batch.mask.cpu().long().numpy()

    fix_noise = torch.randn_like(batch.delta, device="cuda")

    bs = batch.delta.shape[0] // 1200
    # batch.traj_noise = torch.randn_like(batch.delta, device="cuda")
    batch.traj_noise = fix_noise
    # batch.traj_noise = normalize_trajectory(batch.traj_noise)
    # breakpoint()

    # import time
    # batch_time = 0
    # model_time = 0
    # noise_scheduler_time = 0
    # self.noise_scheduler_inference.set_timesteps(self.num_inference_timesteps)
    # print(self.noise_scheduler_inference.timesteps)
    # for t in self.noise_scheduler_inference.timesteps:
    for t in model.noise_scheduler.timesteps:
        
        # tm = time.time()
        batch.timesteps = torch.zeros(bs, device=model.device) + t  # Uniform t steps
        batch.timesteps = batch.timesteps.long()
        # batch_time += time.time() - tm

        # tm = time.time()
        model_output = model(batch)          # bs * 1200, traj_len * 3
        model_output = model_output.reshape(model_output.shape[0], -1, 3)  # bs * 1200, traj_len, 3
        
        batch.traj_noise = model.noise_scheduler.step(
            # batch.traj_noise = self.noise_scheduler_inference.step(
            model_output.reshape(
                -1, model.sample_size, model_output.shape[1], model_output.shape[2]
            ),
            t,
            batch.traj_noise.reshape(
                -1, model.sample_size, model_output.shape[1], model_output.shape[2]
            ),
        ).prev_sample
        batch.traj_noise = torch.flatten(batch.traj_noise, start_dim=0, end_dim=1)

        # print(batch.traj_noise)
        if t % 50 == 0:
            flow = batch.traj_noise.squeeze().cpu().numpy()
            # print(flow[mask])
            # segmented_flow = np.zeros_like(flow, dtype=np.float32)
            # segmented_flow[mask] = flow[mask]
            # print("seg", segmented_flow, "flow", flow)
            animation.add_trace(
                torch.as_tensor(pcd),
                # torch.as_tensor([pcd[mask]]),
                # torch.as_tensor([flow[mask].detach().cpu().numpy()]),
                torch.as_tensor([pcd]),
                torch.as_tensor([normalize_trajectory(batch.traj_noise).squeeze().cpu().numpy()]),
                "red",
            )

    f_pred = batch.traj_noise
    f_pred = normalize_trajectory(f_pred)
    # largest_mag: float = torch.linalg.norm(
    #     f_pred, ord=2, dim=-1
    # ).max()
    # f_pred = f_pred / (largest_mag + 1e-6)

    # Compute the loss.
    n_nodes = torch.as_tensor([d.num_nodes for d in batch.to_data_list()]).to("cuda")  # type: ignore
    f_ix = batch.mask.bool()
    f_target = batch.delta
    f_target = normalize_trajectory(f_target)

    f_target = f_target.float()
    # loss = artflownet_loss(f_pred, f_target, n_nodes)

    # Compute some metrics on flow-only regions.
    rmse, cos_dist, mag_error = flow_metrics(
        f_pred[f_ix], batch.delta[f_ix]
    )

    return cos_dist, animation

### Model

In [None]:
import rpad.pyg.nets.pointnet2 as pnp
from open_anything_diffusion.models.flow_trajectory_diffuser import (
    FlowTrajectoryDiffusionModule,
)
ckpt_path = "/home/yishu/open_anything_diffusion/logs/train_trajectory/2023-08-31/16-13-10/checkpoints/epoch=394-step=310470-val_loss=0.00-weights-only.ckpt"
network = pnp.PN2Dense(
    in_channels=67,
    out_channels=3,
    p=pnp.PN2DenseParams(),
)

model = FlowTrajectoryDiffusionModule(network, cfg.training, cfg.model)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["state_dict"])
model = model.cuda()

In [None]:
import tqdm
import math
best_animations = []
best_cos_dists = []
worst_animations = []
worst_cos_dists = []
mean_cos_dist = 0
for sample in tqdm.tqdm(door_samples[1:2]):
    best_cos = -1
    best_cos_reverse = 1
    for repeat in range(10):
        cos_dist, animation = diffuse_visual(sample.cuda(), model)
        if cos_dist > best_cos:
            best_animation = animation
        if cos_dist < best_cos_reverse:
            worst_animation = animation
        
        best_cos = max(best_cos, cos_dist)
        best_cos_reverse = min(best_cos_reverse, cos_dist)
    
    best_animations.append(best_animation)
    best_cos_dists.append(best_cos)
    worst_animations.append(worst_animation)
    worst_cos_dists.append(best_cos_reverse)
    mean_cos_dist += best_cos
mean_cos_dist /= 27

In [None]:
mean_cos_dist

In [None]:
for i in range(27):
    print(best_cos_dists[i], worst_cos_dists[i])

In [None]:
fig = animation[0].animate()
fig.show()

In [None]:
fig = animation.animate()
fig.show()