In [None]:
# Create a dataset.
from open_anything_diffusion.datasets.flow_trajectory import FlowTrajectoryDataModule
datamodule = FlowTrajectoryDataModule(
        root="/home/yishu/datasets/partnet-mobility",
        batch_size=1,
        num_workers=30,
        n_proc=2,
        seed=42,
        trajectory_len=1,  # Only used when training trajectory model
    )

In [None]:
# Define a contextual model.
# from open_anything_diffusion.models.diffusion.model import PNDiffuser

from torch import nn
import torch

import rpad.pyg.nets.pointnet2 as pnp
from diffusers.models.embeddings import TimestepEmbedding, Timesteps

class PNTrajPredictor(nn.Module):
    def __init__(self):
        super().__init__()

        self.traj_len = 1
        time_emb_dim = 64
        
        # Module to go from timestep to 
        self.time_proj = Timesteps(num_channels=time_emb_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.time_emb = TimestepEmbedding(in_channels=time_emb_dim, time_embed_dim=time_emb_dim)
        
        # Backbone.
        in_channels = 3 * self.traj_len + time_emb_dim
        self.backbone = pnp.PN2Dense(
            in_channels=in_channels,
            out_channels = 3 * self.traj_len,
            p=pnp.PN2DenseParams(),
        )

    def forward(self, batch):

        # Make the shape right.
        timesteps = batch.timesteps
        traj_noise = batch.traj_noise   # bs * 1200, 3 * traj_len

        # Get the time embedding.
        t_emb = self.time_emb(self.time_proj(timesteps))  # bs, 64
        # Repeat the time embedding. MAKE SURE THAT EACH BATCH ITEM IS INDEPENDENT!
        t_emb = t_emb.repeat(1, traj_noise.shape[0], 1)  # bs, 1200, 64
        t_emb = torch.flatten(t_emb, start_dim=0, end_dim=1)  # bs * 1200, 64

        # Concatenate the time embedding with the trajectory noise.
        x = torch.cat([traj_noise, t_emb], dim=-1)  # bs * 1200, 64 + 3 * traj_len

        # Put it back in the batch.
        batch.x = x

        # The backbone.
        pred = self.backbone(batch)  # bs * 1200, 3 * traj_len
        
        return pred.reshape(pred.shape[0], self.traj_len, 3)

In [None]:
# Get the first element of the training dataset.
batch = next(iter(datamodule.train_dataloader())).cuda()

In [None]:
samples = list(enumerate(datamodule.train_dataloader()))
# breakpoint()
sample = samples[0][1]

In [None]:
batch

In [None]:
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=500)

In [None]:
model = PNTrajPredictor().cuda()
optimizer = torch.optim.AdamW(
    model.parameters(), lr=1e-4
)

In [None]:
device = "cuda"
# Write a training loop.
for i in range(10000):
    model.train()

    # Add timesteps and trajectory noise to the batch.
    batch.timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (1,), device=device)

    noise = torch.randn_like(batch.delta, device=device)
    batch.traj_noise = noise_scheduler.add_noise(batch.delta, noise, batch.timesteps)
    batch.traj_noise = torch.flatten(batch.traj_noise, start_dim=1, end_dim=2)

    # Get the prediction.
    model.zero_grad()
    pred = model(batch)

    # Get the loss.
    loss = torch.nn.functional.mse_loss(pred, noise)

    loss.backward()
    optimizer.step()

    print(loss.item())


In [None]:
# Inference
noisy_input = torch.randn(1200, 1, 3).cuda()

with torch.no_grad():
    for t in noise_scheduler.timesteps:
        batch.timesteps = torch.tensor([t.item()]).cuda()
        batch.traj_noise = noisy_input
        batch.traj_noise = torch.flatten(batch.traj_noise, start_dim=1, end_dim=2)

        model_output = model(batch)

        noisy_input = noise_scheduler.step(
            model_output, t, noisy_input
        ).prev_sample

        # batch.traj_noise = torch.flatten(noisy_input, start_dim=1, end_dim=2)

In [None]:
noisy_input[batch.mask==1]

In [None]:
batch.delta

In [None]:
batch.delta[batch.mask==1].shape