In [1]:
import os.path as osp

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import ModelNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, radius
from torch_geometric.typing import WITH_TORCH_CLUSTER

if not WITH_TORCH_CLUSTER:
    quit("This example requires 'torch-cluster'")

In [2]:
# data set imports
import dataset_utils as du
from torch.utils.data import DataLoader, random_split

B = 1 

# load dataset
dataset = du.SDFDataset("./cars100")
train_percent = 0.95 

# split dataset into training and validation sets
train_size = int(train_percent * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# create data loaders for training and validation sets
train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=B, shuffle=False, num_workers=4, pin_memory=True)

print(f"Train dataset size: {len(train_dataset)}, Validation dataset size: {len(val_dataset)}")

Train dataset size: 95, Validation dataset size: 5


In [3]:
def process_batch(batch, np_in=1024, np_q=1024):
    B, N, _ = batch.shape
    assert B == 1, "Batch size must be 1 for now."
    
    x = None # ! no features for now
    pos = batch[:, :np_in, :3]
    
    idx = torch.randperm(N)[:np_q]
    query_pos = batch[:, idx, :3] # query positions
    query_sdf = batch[:, idx, 3] # SDF values
    batch_vec = torch.tensor([1]) # batch indices (only have batch_size of 1 for now)
    
    return x, pos.contiguous(), batch_vec.contiguous(), query_pos, query_sdf

Define Layers and Model

In [4]:
class SAModule(torch.nn.Module):
    def __init__(self, ratio, r, nn):
        super().__init__()
        self.ratio = ratio
        self.r = r
        self.conv = PointNetConv(nn, add_self_loops=False)

    def forward(self, x, pos, batch):
        idx = fps(pos, batch, ratio=self.ratio)
        row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
                          max_num_neighbors=64)
        edge_index = torch.stack([col, row], dim=0)
        x_dst = None if x is None else x[idx]
        x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
        pos, batch = pos[idx], batch[idx]
        return x, pos, batch


class GlobalSAModule(torch.nn.Module):
    def __init__(self, nn):
        super().__init__()
        self.nn = nn

    def forward(self, x, pos, batch):
        x = self.nn(torch.cat([x, pos], dim=1))
        x = global_max_pool(x, batch)
        pos = pos.new_zeros((x.size(0), 3))
        batch = torch.arange(x.size(0), device=batch.device)
        return x, pos, batch


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Input channels account for both `pos` and node features.
        self.sa1_module = SAModule(0.5, 0.2, MLP([3, 64, 64, 128]))
        self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
        self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))

        self.encode = MLP([1024, 512, 256], dropout=0.5, norm=None)
        self.sdf = MLP([256 + 3, 128, 64, 1], dropout=0.5, norm=None)

    def forward(self, x, pos, batch, query_pos):
        # encode shape
        sa0_out = (x, pos, batch)
        sa1_out = self.sa1_module(*sa0_out)
        sa2_out = self.sa2_module(*sa1_out)
        sa3_out = self.sa3_module(*sa2_out)
        x, pos, batch = sa3_out
        
        x = self.encode(x)
        x = torch.cat((x.unsqueeze(1).repeat(1, query_pos.shape[1], 1), query_pos), dim=-1) # concatenate encoded shape with query positions

        return self.sdf(x)

In [5]:
# # test
# # make random tensor of shape (B, 256, 3) 
# x = torch.randn(B, 256)
# queries = torch.randn(B, 1024, 3) # queries for SDF values

# # want to make a tensor of shape (B, 1024, 256 + 3) where the first 256 columns are the x values and the last 3 columns are the query positions
# x = torch.cat((x.unsqueeze(1).repeat(1, 1024, 1), queries), dim=-1)  # (B, 1024, 256 + 3)
# mlp = MLP([256 + 3, 128, 64, 1], dropout=0.5, norm=None)  # MLP for SDF prediction
# x = mlp(x)  # (B, 1024, 1)

# print(f"x shape: {x.shape}")
# print(f"queries shape: {queries.shape[1]}")

In [6]:
# test model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

batch = next(iter(train_loader))
print(f'Batch shape: {batch.shape}')
x, pos, batch, query_pos, query_sdf = process_batch(batch)
print(f'pos shape: {pos.shape}, query_pos shape: {query_pos.shape}, query_sdf shape: {query_sdf.shape}')

# print(f'pos shape: {pos.shape}')
# # forward pass
# output = model(x, pos, batch, query_pos)
# # print output shape
# print(f"Output shape: {output.shape}")

Batch shape: torch.Size([1, 50000, 4])
pos shape: torch.Size([1, 1024, 3]), query_pos shape: torch.Size([1, 1024, 3]), query_sdf shape: torch.Size([1, 1024])


In [None]:
from tqdm import tqdm

EPOCHS = 5

def train(epoch):
    model.train()
    for batch in train_loader:
        x, pos, batch, query_pos, query_sdf = process_batch(batch)
        optimizer.zero_grad()
        out = model(x, pos, batch, query_pos).squeeze(-1)
        loss = F.mse_loss(out, query_sdf)  # L1 loss for SDF prediction
        loss.backward()
        optimizer.step()
    
    return loss.item()




for epoch in tqdm(range(1, EPOCHS + 1)):
    print('Loss: {:.4f}'.format(train(epoch)))

NameError: name '__file__' is not defined