In [1]:
import os
import time
import math
import argparse
import torch
from tqdm.auto import tqdm

from utils.dataset import *
from utils.misc import *
from utils.data import *
from models.vae_gaussian import *
from models.vae_flow import *
from models.flow import add_spectral_norm, spectral_norm_power_iteration
from evaluation import *

def normalize_point_clouds(pcs, mode, logger):
    if mode is None:
        logger.info('Will not normalize point clouds.')
        return pcs
    logger.info('Normalization mode: %s' % mode)
    for i in tqdm(range(pcs.size(0)), desc='Normalize'):
        pc = pcs[i]
        if mode == 'shape_unit':
            shift = pc.mean(dim=0).reshape(1, 3)
            scale = pc.flatten().std().reshape(1, 1)
        elif mode == 'shape_bbox':
            pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
            pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
            shift = ((pc_min + pc_max) / 2).view(1, 3)
            scale = (pc_max - pc_min).max().reshape(1, 1) / 2
        pc = (pc - shift) / scale
        pcs[i] = pc
    return pcs

In [103]:

# Arguments
class Addict(dict):
    def __init__(self, *args, **kwargs):
        super(Addict, self).__init__(*args, **kwargs)
        self.__dict__ = self

args = Addict({
    "ckpt": "./pretrained/GEN_chair.pt",
    "categories": ["chair"],
    "save_dir": "./results",
    "device": "cuda",
    "dataset_path": "./data/shapenet.hdf5",
    "batch_size": 128,
    "sample_num_points": 5*2048,
    "normalize": "shape_bbox",
    "seed": 9988
})

# Logging
save_dir = os.path.join(args.save_dir, 'GEN_Ours_%s_%d' % ('_'.join(args.categories), int(time.time())) )
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
logger = get_logger('test', save_dir)
for k, v in vars(args).items():
    logger.info('[ARGS::%s] %s' % (k, repr(v)))



[2023-04-06 14:04:20,734::test::INFO] [ARGS::ckpt] './pretrained/GEN_chair.pt'
[2023-04-06 14:04:20,734::test::INFO] [ARGS::ckpt] './pretrained/GEN_chair.pt'
[2023-04-06 14:04:20,734::test::INFO] [ARGS::ckpt] './pretrained/GEN_chair.pt'
[2023-04-06 14:04:20,734::test::INFO] [ARGS::ckpt] './pretrained/GEN_chair.pt'
[2023-04-06 14:04:20,735::test::INFO] [ARGS::categories] ['chair']
[2023-04-06 14:04:20,735::test::INFO] [ARGS::categories] ['chair']
[2023-04-06 14:04:20,735::test::INFO] [ARGS::categories] ['chair']
[2023-04-06 14:04:20,735::test::INFO] [ARGS::categories] ['chair']
[2023-04-06 14:04:20,736::test::INFO] [ARGS::save_dir] './results'
[2023-04-06 14:04:20,736::test::INFO] [ARGS::save_dir] './results'
[2023-04-06 14:04:20,736::test::INFO] [ARGS::save_dir] './results'
[2023-04-06 14:04:20,736::test::INFO] [ARGS::save_dir] './results'
[2023-04-06 14:04:20,737::test::INFO] [ARGS::device] 'cuda'
[2023-04-06 14:04:20,737::test::INFO] [ARGS::device] 'cuda'
[2023-04-06 14:04:20,737::te

# Original main code

In [55]:
# Checkpoint
ckpt = torch.load(args.ckpt)
seed_all(args.seed)

# Datasets and loaders
logger.info('Loading datasets...')
test_dset = ShapeNetCore(
    path=args.dataset_path,
    cates=args.categories,
    split='test',
    scale_mode=args.normalize,
)
test_loader = DataLoader(test_dset, batch_size=args.batch_size, num_workers=0)

# Model
logger.info('Loading model...')
if ckpt['args'].model == 'gaussian':
    model = GaussianVAE(ckpt['args']).to(args.device)
elif ckpt['args'].model == 'flow':
    model = FlowVAE(ckpt['args']).to(args.device)
logger.info(repr(model))
# if ckpt['args'].spectral_norm:
#     add_spectral_norm(model, logger=logger)
model.load_state_dict(ckpt['state_dict'])

# Reference Point Clouds
ref_pcs = []
for i, data in enumerate(test_dset):
    ref_pcs.append(data['pointcloud'].unsqueeze(0))
ref_pcs = torch.cat(ref_pcs, dim=0)

# Generate Point Clouds
gen_pcs = []
for i in tqdm(range(0, math.ceil(len(test_dset) / args.batch_size)), 'Generate'):
    with torch.no_grad():
        z = torch.randn([args.batch_size, ckpt['args'].latent_dim]).to(args.device)
        x = model.sample(z, args.sample_num_points, flexibility=ckpt['args'].flexibility)
        gen_pcs.append(x.detach().cpu())
gen_pcs = torch.cat(gen_pcs, dim=0)[:len(test_dset)]

if args.normalize is not None:
    gen_pcs = normalize_point_clouds(gen_pcs, mode=args.normalize, logger=logger)


[2023-04-05 19:12:57,456::test::INFO] Loading datasets...
[2023-04-05 19:12:57,456::test::INFO] Loading datasets...
[2023-04-05 19:12:57,456::test::INFO] Loading datasets...
[2023-04-05 19:12:57,538::test::INFO] Loading model...
[2023-04-05 19:12:57,538::test::INFO] Loading model...
[2023-04-05 19:12:57,538::test::INFO] Loading model...
[2023-04-05 19:12:57,570::test::INFO] FlowVAE(
  (encoder): PointNetEncoder(
    (conv1): Conv1d(3, 128, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
    (conv4): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
    (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, 

Generate:   0%|          | 0/4 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Barebones Generation Code

In [104]:
# Checkpoint
ckpt = torch.load(args.ckpt)
seed_all(args.seed)

# Model
logger.info('Loading model...')
if ckpt['args'].model == 'gaussian':
    model = GaussianVAE(ckpt['args']).to(args.device)
elif ckpt['args'].model == 'flow':
    model = FlowVAE(ckpt['args']).to(args.device)
logger.info(repr(model))
# if ckpt['args'].spectral_norm:
#     add_spectral_norm(model, logger=logger)
model.load_state_dict(ckpt['state_dict'])

[2023-04-06 14:04:26,724::test::INFO] Loading model...
[2023-04-06 14:04:26,724::test::INFO] Loading model...
[2023-04-06 14:04:26,724::test::INFO] Loading model...
[2023-04-06 14:04:26,724::test::INFO] Loading model...
[2023-04-06 14:04:26,755::test::INFO] FlowVAE(
  (encoder): PointNetEncoder(
    (conv1): Conv1d(3, 128, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
    (conv4): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
    (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc1_m): Linear(in_features=512, out_features=256, bias=True)
    (fc2_m): Linear(in_feat

<All keys matched successfully>

In [105]:
# Generate Point Clouds
gen_pcs = []
with torch.no_grad():
    z = torch.randn([3, ckpt['args'].latent_dim]).to(args.device)
    x = model.sample(z, args.sample_num_points, flexibility=ckpt['args'].flexibility)
    gen_pcs.append(x.detach().cpu())
gen_pcs = torch.cat(gen_pcs, dim=0)

if args.normalize is not None:
    gen_pcs = normalize_point_clouds(gen_pcs, mode=args.normalize, logger=logger)


[2023-04-06 14:04:27,770::test::INFO] Normalization mode: shape_bbox
[2023-04-06 14:04:27,770::test::INFO] Normalization mode: shape_bbox
[2023-04-06 14:04:27,770::test::INFO] Normalization mode: shape_bbox
[2023-04-06 14:04:27,770::test::INFO] Normalization mode: shape_bbox


Normalize:   0%|          | 0/3 [00:00<?, ?it/s]

# Broken Down

In [132]:
gen_pcs = []
samples = []
with torch.no_grad():
    # z = torch.randn([1, ckpt['args'].latent_dim]).to(args.device)
    # x = model.sample(z, args.sample_num_points, flexibility=ckpt['args'].flexibility)
    
    x_T = x
    # x_T = torch.randn([z.size(0), args.sample_num_points, 3]).to(args.device)
    # x_T[x_T[:, :, 0]>0] *= -1
    # x_T[:, :, 0] += 3
    
    x = model.diffusion.sample_this(x_T, context=z, flexibility=ckpt['args'].flexibility)

    gen_pcs.append(x.detach().cpu())
    samples.append(x_T.detach().cpu())
gen_pcs = torch.cat(gen_pcs, dim=0)
samples = torch.cat(samples, dim=0)

if args.normalize is not None:
    gen_pcs = normalize_point_clouds(gen_pcs, mode=args.normalize, logger=logger)

[2023-04-06 14:39:38,490::test::INFO] Normalization mode: shape_bbox
[2023-04-06 14:39:38,490::test::INFO] Normalization mode: shape_bbox
[2023-04-06 14:39:38,490::test::INFO] Normalization mode: shape_bbox
[2023-04-06 14:39:38,490::test::INFO] Normalization mode: shape_bbox


Normalize:   0%|          | 0/3 [00:00<?, ?it/s]

# Visualize

In [133]:
%matplotlib qt

# Plot 3D Point Cloud using matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def plot_3d_point_cloud(x, y, z, ax=None, title='', elev=10, azim=240, axis_off=False):
    def _to_np(v):
        if isinstance(v, torch.Tensor):
            v = v.clone().detach().cpu().numpy()
        return v
    x, y, z = _to_np(x), _to_np(y), _to_np(z)
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
    ax.view_init(elev=elev, azim=azim)
    if axis_off: ax.axis('off')
    ax.set_title(title)
    
    ax.scatter(x, y, z, s=0.5, c=x/10)
    return ax

print(samples[0].shape)
i = 0

plot_3d_point_cloud(samples[i, :, 0], samples[i, :, 1], samples[i, :, 2]) # x, y, z
plot_3d_point_cloud(gen_pcs[i, :, 0], gen_pcs[i, :, 1], gen_pcs[i, :, 2]) # x, y, z

torch.Size([8604, 3])


<Axes3D: >

# Data manipulation

Below I mess around with data to run experiments on the diffusion process

In [131]:
xnew = [x0[x0[:, 0]> 0] for x0 in x]
num_pts = min([n.shape[0] for n in xnew])
xnew = torch.stack([n[:num_pts] for n in xnew])

# Add noise
xnew = xnew + torch.randn(xnew.shape).to(xnew.device) * 0.05

plot_3d_point_cloud(xnew[i, :, 0], xnew[i, :, 1], xnew[i, :, 2], title='Noisy partial points') # x, y, z

x = xnew

# Evaluate and save

In [None]:

# Save
logger.info('Saving point clouds...')
np.save(os.path.join(save_dir, 'out.npy'), gen_pcs.numpy())

# Compute metrics
with torch.no_grad():
    results = compute_all_metrics(gen_pcs.to(args.device), ref_pcs.to(args.device), args.batch_size)
    results = {k:v.item() for k, v in results.items()}
    jsd = jsd_between_point_cloud_sets(gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy())
    results['jsd'] = jsd

for k, v in results.items():
    logger.info('%s: %.12f' % (k, v))
