In [None]:
## Diffusion model
from dataclasses import dataclass

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

# import rpad.pyg.nets.pointnet2 as pnp
import tqdm
import wandb
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from flowbot3d.grasping.agents.flowbot3d import FlowNetAnimation

from open_anything_diffusion.datasets.flow_trajectory import FlowTrajectoryDataModule
from open_anything_diffusion.models.diffusion.model import PNDiffuser
from open_anything_diffusion.models.diffusion.diffuser import TrajDiffuser


In [None]:

@dataclass
class TrainingConfig:
    device = "cuda"

    image_size = 128  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 100000
    # num_epochs = 10
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 1000
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "no"  # `no` for float32, `fp16` for automatic mixed precision
    # output_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hub
    train_sample_number = 1

    traj_len = 1
    # Diffuser params
    num_train_timesteps = 100
    seed = 0
    sample_size = [1, 1200]
    in_channels = 3
    out_channels = 3
    cross_attention_dim = 3
    block_out_channels = [128, 256, 512, 512]
    attention_head_dim = 3

    # ckpt params
    read_ckpt_path = "./diffusion_best_ckpt.pth"
    save_ckpt_path = "./diffusion_overfit_best_5_ckpt.pth"

In [None]:
config = TrainingConfig()
datamodule = FlowTrajectoryDataModule(
        root="/home/yishu/datasets/partnet-mobility",
        batch_size=1,
        num_workers=30,
        n_proc=2,
        seed=42,
        trajectory_len=config.traj_len,  # Only used when training trajectory model
    )

train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.train_val_dataloader()

# # Overfit
samples = list(enumerate(train_dataloader))
# breakpoint()
sample = samples[0][1]

In [None]:
sample

In [None]:
model = PNDiffuser(
            in_channels=3 * config.traj_len,
            # sample_size=1200,
            traj_len=config.traj_len,
            time_embed_dim=64,
            # emb_dims=3
        )

In [None]:
model(torch.rand(1,3,1,1200), 0, sample)