## Flowbot inference

In [None]:
ckpt_path = "/home/yishu/open_anything_diffusion/logs/train_trajectory/2023-11-15/23-59-12/checkpoints/epoch=199-step=9200.ckpt"

In [None]:
from open_anything_diffusion.models.flow_trajectory_predictor import (
    FlowTrajectoryTrainingModule
)

In [None]:
from hydra import compose, initialize
from omegaconf import OmegaConf
initialize(config_path="../configs", version_base="1.3")
cfg = compose(config_name="train")

In [None]:
cfg

In [None]:
import rpad.pyg.nets.pointnet2 as pnp
network = pnp.PN2Dense(
    in_channels=0,
    out_channels=3,
    p=pnp.PN2DenseParams(),
)

model = FlowTrajectoryTrainingModule(network, cfg.training)

In [None]:
import torch
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["state_dict"])

In [None]:
import torch_geometric.loader as tgl
from open_anything_diffusion.datasets.flow_trajectory_dataset_pyg import FlowTrajectoryPyGDataset
from open_anything_diffusion.datasets.flow_trajectory import FlowTrajectoryDataModule
datamodule = FlowTrajectoryDataModule(
    root="/home/yishu/datasets/partnet-mobility",
    batch_size=1,
    num_workers=30,
    n_proc=2,
    seed=42,
    trajectory_len=cfg.training.trajectory_len,  # Only used when training trajectory model
    # toy_dataset = {
    #     "id": "door-1",
    #     "train-train": ["8994", "9035"],
    #     "train-test": ["8994", "9035"],
    #     "test": ["8867"],
    # }
    toy_dataset = {
        "id": "door-full-new",
        "train-train": ["8877", "8893", "8897", "8903", "8919", "8930", "8961", "8997", "9016", "9032", "9035", "9041", "9065", "9070", "9107", "9117", "9127", "9128", "9148", "9164", "9168", "9277", "9280", "9281", "9288", "9386", "9388", "9410"],
        "train-test": ["8867", "8983", "8994", "9003", "9263", "9393"],
        "test": ["8867", "8983", "8994", "9003", "9263", "9393"],
    }
)

train_val_dataloader = datamodule.train_val_dataloader()
val_dataloader = datamodule.val_dataloader()
# unseen_dataloader = datamodule.unseen_dataloader()

# datamodule = FlowTrajectoryPyGDataset(
#     root="/home/yishu/datasets/partnet-mobility/raw",
#     split="umpnet-train-test",
#     randomize_joints=True,
#     randomize_camera=True,
#     # batch_size=1,
#     # num_workers=30,
#     # n_proc=2,
#     seed=42,
#     trajectory_len=cfg.training.trajectory_len,  # Only used when training trajectory model
# )
# unseen_dataloader = tgl.DataLoader(datamodule, 1, shuffle=False, num_workers=0)

samples = list(enumerate(train_val_dataloader))
# # breakpoint()

In [None]:
from open_anything_diffusion.metrics.trajectory import artflownet_loss, flow_metrics, normalize_trajectory
from flowbot3d.grasping.agents.flowbot3d import FlowNetAnimation
import numpy as np

@torch.no_grad()
def flowbot_visual(batch, model):  # 1 sample batch
    model.eval()
    
    animation = FlowNetAnimation()
    pcd = batch.pos.cpu().numpy()
    f_pred = model(batch)
    f_pred = normalize_trajectory(f_pred[:, None, :])

    animation.add_trace(
        torch.as_tensor(pcd),
        # torch.as_tensor([pcd[mask]]),
        # torch.as_tensor([flow[mask]]),
        torch.as_tensor([pcd]),
        torch.as_tensor([f_pred.squeeze().cpu().numpy()]),
        "red",
    )

    # largest_mag: float = torch.linalg.norm(
    #     f_pred, ord=2, dim=-1
    # ).max()
    # f_pred = f_pred / (largest_mag + 1e-6)

    # Compute the loss.
    n_nodes = torch.as_tensor([d.num_nodes for d in batch.to_data_list()]).to("cuda")  # type: ignore
    f_ix = batch.mask.bool()
    f_target = batch.delta
    f_target = normalize_trajectory(f_target)

    f_target = f_target.float()
    loss = artflownet_loss(f_pred, f_target, n_nodes)

    # Compute some metrics on flow-only regions.
    rmse, cos_dist, mag_error = flow_metrics(
        f_pred[f_ix], batch.delta[f_ix]
    )

    return rmse, cos_dist, mag_error, loss, animation

In [None]:
all_rmse = 0
all_cos_dist = 0
all_mag_error = 0
all_loss = 0
model = model.cuda()
for i in range(len(samples)):
    sample = samples[i][1].cuda()
    batch = sample
    rmse, cos_dist, mag_error, loss, animation = flowbot_visual(sample, model)
    all_rmse += rmse.item()
    all_cos_dist += cos_dist.item()
    all_loss += loss.item()
    all_mag_error += mag_error.item()

all_rmse /= len(samples)
all_cos_dist /= len(samples)
all_mag_error /= len(samples)
all_loss /= len(samples)
print(f"rmse:{all_rmse:.4f}, cos:{all_cos_dist:.4f}, mag:{all_mag_error:.4f}, flowloss:{all_loss:.4f}")

### Example 1

In [None]:
len(samples)

In [None]:
sample = samples[15][1].cuda()
batch = sample
model = model.cuda()

In [None]:
rmse, cos_dist, mag_error, loss, animation = flowbot_visual(sample, model)
print(f"rmse:{rmse:.4f}, cos:{cos_dist:.4f}, mag:{mag_error:.4f}, flowloss:{loss:.4f}")
fig = animation.animate()
fig.show()

In [None]:
print(rmse, cos_dist, mag_error, loss)

###  Example 2

In [None]:
sample = samples[1][1].cuda()
batch = sample
model = model.cuda()

In [None]:
rmse, cos_dist, mag_error, loss, animation = flowbot_visual(sample, model)
fig = animation.animate()
fig.show()

In [None]:
print(rmse, cos_dist, mag_error, loss)

### Test 1

In [None]:
sample = samples[0][1].cuda()
batch = sample
model = model.cuda()

rmse, cos_dist, mag_error, loss, animation = flowbot_visual(sample, model)
fig = animation.animate()
fig.show()

In [None]:
print(rmse, cos_dist, mag_error, loss)