## Check if dataset is identical

In [None]:
from flowbot3d.datasets.flow_dataset_pyg import Flowbot3DPyGDataset
from python_ml_project_template.datasets.flow_trajectory_dataset_pyg import FlowTrajectoryPyGDataset

#### Proof that seed_everything doesn't work

In [None]:
import lightning as L
L.seed_everything(24)
flowbot3d = Flowbot3DPyGDataset(
    root = '/home/yishu/datasets/partnet-mobility/raw/',
    split = "umpnet-train-train",
    randomize_camera = True
)
print(flowbot3d.get_data('7167').pos)

In [None]:
import lightning as L
L.seed_everything(24)
trajectory = FlowTrajectoryPyGDataset(
    root = '/home/yishu/datasets/partnet-mobility-trajectory/raw/',
    split = "umpnet-train-train",
    randomize_camera = True,
    trajectory_len=1,
)
print(trajectory.get_data('7167', seed=None).pos)

In [None]:
flowbot3d.get_data('100031', seed=10).pos

#### Proof that with get_data's seed set as the same, dataset is able to produce same results

In [None]:
print(flowbot3d.get_data('103303', seed=10).flow)
print(trajectory.get_data('103303', seed=10).delta)

## Check generated dataset with on-disk dataset

In [None]:
original_dataset = '/home/yishu/datasets/partnet-mobility/processed_rj_rc/'
generated_dataset = '/home/yishu/datasets/partnet-mobility-trajectory/processed_1_rj_rc/'

In [None]:
import os

objects = list(os.listdir(original_dataset))
gen_objects = list(os.listdir(generated_dataset))
print(objects)
assert objects == gen_objects, "The objects are not the same"

In [None]:
import torch
import tqdm
import torch_geometric.data as tgd
count = 0
for object in tqdm.tqdm(objects):
    # print(object)
    if object[0] > '9' or object[0] < '0':
        continue
    original_flow = torch.load(os.path.join(original_dataset, object))[0].flow
    generated_flow = torch.load(os.path.join(generated_dataset, object))[0].delta
    diff = original_flow - generated_flow.squeeze()
    print(diff[torch.where(diff!=0)])
    if not torch.equal(original_flow, generated_flow.squeeze()):
        count += 1
    # assert torch.equal(original_flow, generated_flow.squeeze()), f"{object} is not the same."
print(count)

In [None]:
count = 0
for object in tqdm.tqdm(objects):
    if object[0] > '9' or object[0] < '0':
        continue
    original_mask = torch.load(os.path.join(original_dataset, object))[0].mask
    generated_mask = torch.load(os.path.join(generated_dataset, object))[0].mask
    
    if torch.sum(original_mask) != torch.sum(generated_mask):
        count += 1
        print(f"{object} is not the same.")
        print(torch.load(os.path.join(original_dataset, object))[0].flow.shape)
        print(torch.load(os.path.join(generated_dataset, object))[0].delta.squeeze().shape)

print(count, len(objects))