In [14]:
import torch
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import scipy
import os
import numpy as np
from pysdf import SDF

In [79]:
Z_DIM = 256
N_SAMPLES = 2000
BATCH_SIZE = 8
LR = 1e-5 * BATCH_SIZE
DEV = torch.device("mps")
DELTA = 1.0

In [8]:
class DeepSDF(nn.Module):

    def __init__(self, input_dim, layer_size = 512, dropout_p = 0.2):
        super(DeepSDF, self).__init__()
        self.dropout_p = dropout_p
        self.input_layer = self.create_layer_block(input_dim, layer_size)
        self.layer2 = self.create_layer_block(layer_size, layer_size)
        self.layer3 = self.create_layer_block(layer_size, layer_size)
        self.layer4 = self.create_layer_block(layer_size, layer_size - input_dim)
        self.layer5 = self.create_layer_block(layer_size, layer_size)
        self.layer6 = self.create_layer_block(layer_size, layer_size)
        self.layer7 = self.create_layer_block(layer_size, layer_size)
        self.layer8 = self.create_layer_block(layer_size, 1)

    def create_layer_block(self, input_size, output_size):
        return nn.Sequential(
            nn.Linear(input_size, output_size),
            nn.ReLU(),
            nn.Dropout(self.dropout_p)
        )

    def forward(self, latent_vec, coords):
        """
        latent_vec has shape [batch_size, z_dim]
        coords has shape [num_coords, 3]
        """
        batch_size, num_coords = latent_vec.shape[0], coords.shape[0]
        # latent_vec now has shape [batch_size, num_coords, z_dim], repeated on the middle axis
        latent_vec = latent_vec.unsqueeze(1).repeat(1, num_coords, 1)
        # coords now has shape [batch_size, num_coords, 3], repeated on this first axis
        coords = coords.unsqueeze(0).repeat(batch_size, 1, 1)

        x = torch.cat([latent_vec, coords], dim = -1)
        skip_x = x

        x = self.input_layer(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(torch.cat([x, skip_x], dim = -1)) # skip connection
        x = self.layer6(x)
        x = self.layer7(x)
        x = self.layer8(x)

        # return has shape [batch_size, num_coords], where each element is the SDF
        # at the given input coordinate
        return x.squeeze(-1)

test_model = DeepSDF(Z_DIM + 3)
random_latent = torch.randn(32, Z_DIM)
random_coords = torch.randn(1000, 3)
output = test_model(random_latent, random_coords)
print(output.shape)

torch.Size([32, 1000])


In [60]:
class SDFDataset(Dataset):

    def __init__(self, dataset_path, num_samples, z_dim=256, latent_mean=0.0, latent_sd=0.01):
        fnames = os.listdir(dataset_path)
        if ".DS_Store" in fnames:
            fnames.remove(".DS_Store")
        fnames = sorted(fnames)
        self.file_paths = [os.path.join(dataset_path, i) for i in fnames]
        self.num_samples = num_samples
        self.latent_vectors = torch.randn(len(fnames), z_dim, requires_grad=True)
        self.latent_vectors = (self.latent_vectors * latent_sd) + latent_mean

    def __getitem__(self, idx):
        pd_sampler = scipy.stats.qmc.PoissonDisk(d=3)
        arr = np.load(self.file_paths[idx])
        poisson_grid_points = pd_sampler.random(self.num_samples)
        sdf_values = SDF(arr["vertices"], arr["faces"])(poisson_grid_points)
        return torch.from_numpy(poisson_grid_points), self.latent_vectors[idx], torch.from_numpy(sdf_values)

    def collate_fn(self, x):
        min_sample_len = min([len(i[0]) for i in x])
        x_vals = [i[0][:min_sample_len].unsqueeze(0) for i in x]
        vec = [i[1].unsqueeze(0) for i in x]
        y_vals = [i[2][:min_sample_len].unsqueeze(0) for i in x]
        return torch.cat(x_vals, dim=0), torch.cat(vec, dim=0), torch.cat(y_vals, dim=0)

    def __len__(self):
        return len(self.file_paths)

In [83]:
class DeepSDFLoss:

    def __init__(self, delta, sd):
        self.mae = nn.L1Loss()
        self.delta = delta
        self.sd = sd

    def __call__(self, yhat, y, latent):
        l = self.mae(torch.clamp(yhat, -self.delta, self.delta), torch.clamp(y, -self.delta, self.delta))
        latent_norm = torch.pow(latent, 2).sum(dim=-1).mean()  * (1 / (self.sd ** 2))
        return l + latent_norm

tensor(266.0443, grad_fn=<AddBackward0>)

In [81]:
dataset = SDFDataset(
    dataset_path="/Users/gursi/Desktop/tight_models",
    num_samples=N_SAMPLES,
    z_dim=Z_DIM
)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=dataset.collate_fn)

In [None]:
model = DeepSDF(Z_DIM + 3).to(DEV)
