In [128]:
import torch
from tqdm import tqdm
from torch import optim
from torch.utils.data import DataLoader

from dataset.SDFDataset import SDFDataset
from model import DeepSDF
from loss import DeepSDFLoss

In [125]:
Z_DIM = 256
N_SAMPLES = 2000
BATCH_SIZE = 8
LR = 1e-5 * BATCH_SIZE
DEV = torch.device("mps")
DELTA = 1.0
LATENT_SD = 0.01
LATENT_MEAN = 0.0
EPOCHS = 100

In [121]:
dataset = SDFDataset(
    dataset_path="/Users/gursi/Desktop/tight_models",
    num_samples=N_SAMPLES,
    z_dim=Z_DIM,
    latent_mean=LATENT_MEAN,
    latent_sd=LATENT_SD
)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=dataset.collate_fn)

model = DeepSDF(Z_DIM + 3).to(DEV)
crit = DeepSDFLoss(DELTA, LATENT_SD)
opt = optim.Adam([dataset.latent_vectors] + list(model.parameters()), lr=LR)

In [123]:
for e in range(EPOCHS):
    loop = tqdm(loader, total=len(loader))
    loop.set_description(f"Epoch: {e}")
    for coords, latents, sdfs in loop:
        opt.zero_grad()
        coords, latents, sdfs = coords.to(DEV), latents.to(DEV), sdfs.to(DEV)
        predicted_sdf = model.forward(latents, coords)
        loss = crit(predicted_sdf, sdfs, latents)
        loss.backward()
        opt.step()
        loop.set_postfix(loss = loss.item())

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Epoch: 0:   2%|▏         | 7/385 [00:47<42:40,  6.77s/it, loss=257]  


KeyboardInterrupt: 