In [1]:
"""
Evaluate P-FID between two batches of point clouds.

The point cloud batches should be saved to two npz files, where there
is an arr_0 key of shape [N x K x 3], where K is the dimensionality of
each point cloud and N is the number of clouds.
"""

import argparse

import torch
import numpy as np
from torchvision import transforms
from tqdm.auto import tqdm

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud

from point_e.evals.feature_extractor import PointNetClassifier, get_torch_devices
from point_e.evals.fid_is import compute_statistics
from point_e.evals.npz_stream import NpzStreamer

from PIL import Image

import os

In [2]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_name = 'base300M' # base40M, use base300M or base1B for better results

MODEL_CONFIGS[base_name]["n_views"] = 1
MODEL_CONFIGS['upsample']["n_views"] = 1

print('creating base model...')
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

print('creating upsample model...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

print('downloading base checkpoint...')
base_model.load_state_dict(load_checkpoint(base_name, device))

print('downloading upsampler checkpoint...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))

sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096-1024], # points in cloud and missing ones for upsampling
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 3.0],
)

creating base model...
[-] Low Res Diff Transformer set with 1024+256*1 input tokens.
creating upsample model...
[-] High Res Diff Transformer set with 3072+256*1 input tokens.
downloading base checkpoint...
downloading upsampler checkpoint...


In [3]:
# ---- Evaluation
print("creating classifier...")
clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=None)

creating classifier...


In [4]:
# --- Loading
base_path = os.path.join("..", "..", "..", "Datasets", "banner_pcs")
save_each = 2
step = 0

ground_clouds = None
gen_clouds = None

for root, dirs, files in os.walk(base_path):
    for f in files:
        
        # For each view => render etc...
        if f.endswith(".png"):
        
            filename = f[:-4]
            
            # Load view and ground point cloud
            view = Image.open(os.path.join(base_path, f"{filename}.png"))
            
            pointcloud = torch.Tensor(
                np.load(os.path.join(base_path, f"{filename}.npz"))["coords"]
            ).unsqueeze(0)
            
            if ground_clouds is None: ground_clouds = pointcloud
            else: ground_clouds = torch.cat((ground_clouds, pointcloud), 0)

            # Predict point cloud and append
            samples = None
            for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[view]))):
                samples = x

            samples = samples.permute(0, 2, 1)[:, :, :3]

            if gen_clouds is None: gen_clouds = samples
            else: gen_clouds = torch.cat((gen_clouds, samples), 0)
                
            step += 1
            if step % save_each == 0:
                ground = ground_clouds.numpy()
                np.savez("ground.npz", ground)

                gen = gen_clouds.cpu().numpy()[:, :, :3]
                np.savez("gen.npz", gen)

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

KeyboardInterrupt: 

In [None]:
# ---- Computing distances in batch
ground = ground_clouds.numpy()
np.savez("ground.npz", ground)

gen = gen_clouds.cpu().numpy()[:, :, :3]
np.savez("gen.npz", gen)

features_1, _ = clf.features_and_preds(NpzStreamer("ground.npz"))
stats_1 = compute_statistics(features_1)
del features_1

features_2, _ = clf.features_and_preds(NpzStreamer("gen.npz"))
stats_2 = compute_statistics(features_2)
del features_2

print(f"P-FID: {stats_1.frechet_distance(stats_2)}")