## Inference

In [None]:
ckpt_path = "/home/yishu/open_anything_diffusion/logs/train_trajectory/2023-10-18/08-45-11/checkpoints/epoch=74-step=15000-val_loss=0.00-weights-only.ckpt"

In [None]:
from open_anything_diffusion.models.flow_trajectory_diffuser import (
    FlowTrajectoryDiffusionModule,
)

In [None]:
from hydra import compose, initialize
from omegaconf import OmegaConf
initialize(config_path="../configs", version_base="1.3")
cfg = compose(config_name="train_synthetic")

In [None]:
cfg

In [None]:
import rpad.pyg.nets.pointnet2 as pnp
network = pnp.PN2Dense(
    in_channels=67,
    out_channels=3,
    p=pnp.PN2DenseParams(),
)

model = FlowTrajectoryDiffusionModule(network, cfg.training, cfg.model)

In [None]:
import torch
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["state_dict"])

In [None]:
import torch_geometric.loader as tgl
from open_anything_diffusion.datasets.flow_trajectory_dataset_pyg import FlowTrajectoryPyGDataset
from open_anything_diffusion.datasets.flow_trajectory import FlowTrajectoryDataModule
datamodule = FlowTrajectoryDataModule(
    root="/home/yishu/datasets/partnet-mobility",
    batch_size=1,
    num_workers=30,
    n_proc=2,
    seed=42,
    trajectory_len=cfg.training.trajectory_len,  # Only used when training trajectory model
    toy_dataset = {
        "id": "door-1",
        "train-train": ["8994", "9035"],
        "train-test": ["8994", "9035"],
        "test": ["8867"],
    }
)

train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.train_val_dataloader()
unseen_dataloader = datamodule.unseen_dataloader()

# datamodule = FlowTrajectoryPyGDataset(
#     root="/home/yishu/datasets/partnet-mobility/raw",
#     split="umpnet-train-test",
#     randomize_joints=True,
#     randomize_camera=True,
#     # batch_size=1,
#     # num_workers=30,
#     # n_proc=2,
#     seed=42,
#     trajectory_len=cfg.training.trajectory_len,  # Only used when training trajectory model
# )
# unseen_dataloader = tgl.DataLoader(datamodule, 1, shuffle=False, num_workers=0)

samples = list(enumerate(val_dataloader))
# # breakpoint()

In [None]:
from open_anything_diffusion.metrics.trajectory import artflownet_loss, flow_metrics, normalize_trajectory
from flowbot3d.grasping.agents.flowbot3d import FlowNetAnimation
import numpy as np

@torch.no_grad()
def diffuse_visual(initial_noise, batch, model):  # 1 sample batch
    model.eval()
    
    animation = FlowNetAnimation()
    pcd = batch.pos.cpu().numpy()
    mask = batch.mask.cpu().long().numpy()


    bs = batch.delta.shape[0] // 1200
    # batch.traj_noise = torch.randn_like(batch.delta, device="cuda")
    batch.traj_noise = initial_noise
    # batch.traj_noise = normalize_trajectory(batch.traj_noise)
    # breakpoint()

    # import time
    # batch_time = 0
    # model_time = 0
    # noise_scheduler_time = 0
    # self.noise_scheduler_inference.set_timesteps(self.num_inference_timesteps)
    # print(self.noise_scheduler_inference.timesteps)
    # for t in self.noise_scheduler_inference.timesteps:
    for t in model.noise_scheduler.timesteps:
        
        # tm = time.time()
        batch.timesteps = torch.zeros(bs, device=model.device) + t  # Uniform t steps
        batch.timesteps = batch.timesteps.long()
        # batch_time += time.time() - tm

        # tm = time.time()
        model_output = model(batch)          # bs * 1200, traj_len * 3
        model_output = model_output.reshape(model_output.shape[0], -1, 3)  # bs * 1200, traj_len, 3

        print(model_output)
        
        batch.traj_noise = model.noise_scheduler.step(
            # batch.traj_noise = self.noise_scheduler_inference.step(
            model_output.reshape(
                -1, model.sample_size, model_output.shape[1], model_output.shape[2]
            ),
            t,
            batch.traj_noise.reshape(
                -1, model.sample_size, model_output.shape[1], model_output.shape[2]
            ),
        ).prev_sample
        batch.traj_noise = torch.flatten(batch.traj_noise, start_dim=0, end_dim=1)

        # print(batch.traj_noise)
        if t % 1 == 0:
            flow = batch.traj_noise.squeeze().cpu().numpy()
            # print(flow[mask])
            # segmented_flow = np.zeros_like(flow, dtype=np.float32)
            # segmented_flow[mask] = flow[mask]
            # print("seg", segmented_flow, "flow", flow)
            animation.add_trace(
                torch.as_tensor(pcd),
                # torch.as_tensor([pcd[mask]]),
                # torch.as_tensor([flow[mask]]),
                torch.as_tensor([pcd]),
                torch.as_tensor([flow]),
                "red",
            )
            # animation.append_gif_frame(f)

    f_pred = batch.traj_noise
    f_pred = normalize_trajectory(f_pred)
    # largest_mag: float = torch.linalg.norm(
    #     f_pred, ord=2, dim=-1
    # ).max()
    # f_pred = f_pred / (largest_mag + 1e-6)

    # Compute the loss.
    n_nodes = torch.as_tensor([d.num_nodes for d in batch.to_data_list()]).to("cuda")  # type: ignore
    f_ix = batch.mask.bool()
    f_target = batch.delta
    f_target = normalize_trajectory(f_target)

    f_target = f_target.float()
    # loss = artflownet_loss(f_pred, f_target, n_nodes)

    # Compute some metrics on flow-only regions.
    rmse, cos_dist, mag_error = flow_metrics(
        f_pred[f_ix], batch.delta[f_ix]
    )

    return cos_dist, animation

In [None]:
samples[0][1].pos

In [None]:
len(samples)

In [None]:
sample = samples[1][1].cuda()
batch = sample
model = model.cuda()

In [None]:
initial_noise = torch.randn_like(batch.delta, device="cuda")
cos_dist, animation = diffuse_visual(initial_noise, batch, model)
fig = animation.animate()
fig.show()

In [None]:
import tqdm
for i in tqdm.tqdm(range(100)):
    initial_noise = torch.randn_like(batch.delta, device="cuda")
    cos_dist, animation = diffuse_visual(initial_noise, batch, model)
    if cos_dist < -0.7:
        break

In [None]:
cos_dist

In [None]:
fig = animation.animate()
fig.show()

In [None]:
import tqdm
for i in tqdm.tqdm(range(100)):
    initial_noise = torch.randn_like(batch.delta, device="cuda")
    cos_dist, animation = diffuse_visual(initial_noise, batch, model)
    if cos_dist > 0.5:
        break

In [None]:
cos_dist

In [None]:
fig = animation.animate()
fig.show()

## Find multimodal cases

In [None]:
import tqdm
repeat_times = 10
stop = False
for sample in tqdm.tqdm(samples):
    sample_id = sample[0]
    sample = sample[1]
    if stop:
        break
    batch = sample.cuda()
    has_correct = False
    has_incorrect = False
    for _ in range(repeat_times):
        cos_dist, animation = diffuse_visual(batch, model)
        if cos_dist > 0.7:
            has_correct = True
            correct_animation = animation
        elif cos_dist < 0: 
            has_incorrect = True
            incorrect_animation = animation
    if has_correct and has_incorrect:
        print(sample_id, sample_id)
        stop = True
        break