In [3]:
from torch_geometric.datasets import PCPNetDataset
from torch_geometric.transforms import ToSparseTensor, KNNGraph, Compose
from normal_diffusion.data.patches import PatchDataloader
from normal_diffusion.data.transforms import DistanceToEdgeWeight, KeepNormals

# Choose the root directory where you want to save the dataset
root = "../data/PCPNetDataset"
dataset = PCPNetDataset(
    root=root,
    category="NoNoise",
    split="train",
    transform=Compose([KeepNormals(), KNNGraph(k=6)]),
)
dataloader = PatchDataloader(dataset, batch_size=128, hops=10, transform=Compose([DistanceToEdgeWeight(), ToSparseTensor()])) # can add ToSparseTensor conversion here 
print(len(dataloader))
first_collection = next(iter(dataloader))
print(first_collection.x.shape)
print(first_collection.adj_t)
print(first_collection)

6250
torch.Size([36942, 3])
SparseTensor(row=tensor([    0,     0,     0,  ..., 36941, 36941, 36941]),
             col=tensor([  128,   129,   130,  ..., 36936, 36939, 36940]),
             val=tensor([0.1265, 0.1278, 0.1422,  ..., 0.3345, 0.3348, 0.3308]),
             size=(36942, 36942), nnz=208025, density=0.02%)
DataBatch(x=[36942, 3], pos=[36942, 3], test_idx=[40000], ptr=[9], n_id=[36942], e_id=[208025], input_id=[128], batch_size=128, adj_t=[36942, 36942, nnz=208025])


In [4]:
import torch
from normal_diffusion.models import GCNModel
model = GCNModel()
t = torch.ones(first_collection.x.shape[0])
predicted_normals = model(graph_data=first_collection, t=t)
print(predicted_normals)


tensor([[  2.8732,  -3.9300,   6.6652],
        [-10.8953,  -8.9923, -15.1326],
        [  0.9552,  -2.8138,  -3.9986],
        ...,
        [ -0.8777,   0.7203,  -1.2785],
        [ -0.8786,   0.7215,  -1.2775],
        [ -0.8776,   0.7219,  -1.2774]], grad_fn=<AddBackward0>)


  return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())


In [5]:
from diffusers import DDPMScheduler
from normal_diffusion.training.training import train_diffusion
scheduler = DDPMScheduler(num_train_timesteps=5, beta_schedule="squaredcos_cap_v2", clip_sample=False)
train_diffusion(model=model, dataloader=dataloader, scheduler=scheduler, n_epochs=100, lr=1e-3)

  from .autonotebook import tqdm as notebook_tqdm
  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


KeyboardInterrupt: 