In [3]:
# Install required packages.
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-geometric

[K     |████████████████████████████████| 2.6MB 4.3MB/s 
[K     |████████████████████████████████| 1.4MB 3.9MB/s 
[K     |████████████████████████████████| 931kB 4.0MB/s 
[K     |████████████████████████████████| 225kB 4.4MB/s 
[K     |████████████████████████████████| 235kB 36.5MB/s 
[K     |████████████████████████████████| 51kB 6.8MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [None]:
# !pip install pytorch3d>=0.4.0
!pip install "git+https://github.com/facebookresearch/pytorch3d.git"

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

from torchsummary import summary

from torch_geometric.datasets import ModelNet, ShapeNet
from torch_geometric.transforms import Compose, FixedPoints, SamplePoints, NormalizeScale
from torch_geometric.data import DataLoader

import pytorch3d
from pytorch3d.loss.chamfer import chamfer_distance

import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import time

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    
torch.__version__, device

('1.9.0+cu102', device(type='cuda'))

In [225]:
shape_name = 'Lamp'
dataset = ShapeNet(root='.', categories=shape_name, split='test')

In [226]:
torch.manual_seed(11)

n_points = 2048 # number of points to sample
dataset.transform = FixedPoints(num=n_points) # samples points from a point cloud

In [227]:
single_class_data = dataset
n_shapes = len(single_class_data)
n_shapes

286

## The Chamfer loss

(based on the Chamfer pseudo-distance)

In [228]:
class ChamferLoss(nn.Module):
    def __init__(self, point_reduction='sum', batch_reduction='mean'):
        super(ChamferLoss, self).__init__()
        self.point_reduction = point_reduction
        self.batch_reduction = batch_reduction

    def forward(self, x, y):
        # https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/loss/chamfer.html
        chamfer_dist, _ = chamfer_distance(x, y, point_reduction=self.point_reduction, batch_reduction=self.batch_reduction)
        return chamfer_dist

## The Autoencoder (AE)

In [229]:
class Autoencoder(nn.Module):
  def __init__(self, encoder_in_ch=3, encoder_out_ch=128, n_points=2048):        
        super(Autoencoder, self).__init__()
        self.n_points = n_points

        self.encoder = nn.Sequential(
            nn.Conv1d(encoder_in_ch, 64, 1, 1), nn.BatchNorm1d(64), nn.ReLU(True), 
            nn.Conv1d(64, 128, 1, 1), nn.BatchNorm1d(128), nn.ReLU(True), 
            nn.Conv1d(128, 128, 1, 1), nn.BatchNorm1d(128), nn.ReLU(True), 
            nn.Conv1d(128, 256, 1, 1), nn.BatchNorm1d(256), nn.ReLU(True), 
            nn.Conv1d(256, encoder_out_ch, 1, 1), nn.BatchNorm1d(encoder_out_ch), nn.ReLU(True), 
        )

        # outputs (batch_size x encoder_out_ch x n_points)
        # we then maxpool the output in the forward method, 
        # so the input to the decoder becomes (batch_size x encoder_out_ch)
        
        self.decoder = nn.Sequential(
            nn.Linear(encoder_out_ch, 256), nn.ReLU(True),
            nn.Linear(256, 256), nn.ReLU(True),
            nn.Linear(256, 3*n_points)
        )  
    
  def forward(self, x):
      x = self.encoder(x)
      x, _ = torch.max(x, dim=-1)
      x = self.decoder(x) 
      return x.view(x.size(0), 3, -1)

In [230]:
ae_model = Autoencoder(encoder_in_ch=3, encoder_out_ch=128, n_points=2048).to(device)
summary(ae_model, (3, 2048))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1             [-1, 64, 2048]             256
       BatchNorm1d-2             [-1, 64, 2048]             128
              ReLU-3             [-1, 64, 2048]               0
            Conv1d-4            [-1, 128, 2048]           8,320
       BatchNorm1d-5            [-1, 128, 2048]             256
              ReLU-6            [-1, 128, 2048]               0
            Conv1d-7            [-1, 128, 2048]          16,512
       BatchNorm1d-8            [-1, 128, 2048]             256
              ReLU-9            [-1, 128, 2048]               0
           Conv1d-10            [-1, 256, 2048]          33,024
      BatchNorm1d-11            [-1, 256, 2048]             512
             ReLU-12            [-1, 256, 2048]               0
           Conv1d-13            [-1, 128, 2048]          32,896
      BatchNorm1d-14            [-1, 12

## The Latent GAN (l-GAN)

In [231]:
class LatentGenerator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LatentGenerator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(True),
            nn.Linear(128, output_dim), nn.ReLU(True)
        )
        
    def forward(self, x):        
        return self.model(x)
    
    def init_params(self, layer):
        if type(layer) == nn.Linear:
            nn.init.xavier_uniform_(layer.weight.data)
            nn.init.zeros_(layer.bias.data)

### Instantiate the l-GAN model




In [232]:
# instantiate the l-GAN generator:
l_gan_gen_input_dim = 128
l_gan_gen_output_dim = 128 # equal to the output_dim of the encoder of the AE model

l_gan_gen = LatentGenerator(l_gan_gen_input_dim, l_gan_gen_output_dim).to(device)

# initialize the parameters:
l_gan_gen.apply(l_gan_gen.init_params)

mu = 0.
sigma = 0.2

z = torch.FloatTensor(1, l_gan_gen_input_dim).normal_(mu, sigma).to(device)

print(z.shape)
print(l_gan_gen(z).shape)

summary(l_gan_gen, (l_gan_gen_input_dim,))

torch.Size([1, 128])
torch.Size([1, 128])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 128]          16,512
              ReLU-2                  [-1, 128]               0
            Linear-3                  [-1, 128]          16,512
              ReLU-4                  [-1, 128]               0
Total params: 33,024
Trainable params: 33,024
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.13
Estimated Total Size (MB): 0.13
----------------------------------------------------------------


## Raw GAN (r-GAN)

(See https://arxiv.org/pdf/1707.02392.pdf)

In [233]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(True),
            nn.Linear(64, 128), nn.ReLU(True),
            nn.Linear(128, 512), nn.ReLU(True),
            nn.Linear(512, 1024), nn.ReLU(True),
            nn.Linear(1024, output_dim)
        )
        
    def forward(self, x):
        # reshape to batch_size x 3 x n_points
        samples = self.model(x).view(x.size(0), 3, -1)

        # TODO: try clipping the output to be in [-1, 1], e.g., by applying tanh
        # samples.clamp_(min=-1, max=1)

        return samples
    
    def init_params(self, layer):
        if type(layer) == nn.Linear:
            nn.init.xavier_uniform_(layer.weight.data)
            nn.init.zeros_(layer.bias.data)

RawGanGenerator = Generator

### Instantiate the r-GAN model


In [234]:
raw_gan_gen_input_dim = 128
raw_gan_gen = RawGanGenerator(raw_gan_gen_input_dim, n_points*3).to(device)

## Load models from disk

In [235]:
ae_model.load_state_dict(torch.load(shape_name+'_ae_model.tar'))

checkpoint = torch.load(shape_name+'_l_gan.tar')
l_gan_gen.load_state_dict(checkpoint['l_gan_generator'])

checkpoint = torch.load(shape_name+'_raw_gan.tar')
raw_gan_gen.load_state_dict(checkpoint['raw_gan_generator'])

<All keys matched successfully>

## Metrics

### Coverage and Fidelity

In [236]:
from tqdm.notebook import tqdm

def coverage(real_clouds, gen_clouds, distance_fn):    
    matches = [ torch.argmin(torch.as_tensor([distance_fn(a, b) for b in real_clouds])) for a in gen_clouds ]
    score = len(torch.unique(matches)) / len(real_clouds)
    return score


def coverage_2(real_clouds, gen_clouds, distance_fn):    
    
    matches = []
    for a in tqdm(gen_clouds):
        idx = torch.argmin(torch.as_tensor([distance_fn(a, b) for b in real_clouds]))
        matches.append(idx)

    score = len(torch.unique(torch.as_tensor(matches))) / len(real_clouds)

    return score

def coverage_3(real_clouds, gen_clouds, distance_fn, batch_size=50):
    matches = []
    for gen_sample in tqdm(gen_clouds): 
        all_distances = []

        for batch_idx in range(0, len(real_clouds), batch_size):
            real_batch = torch.cat(real_clouds[batch_idx:batch_idx+batch_size])
            gen_batch = gen_sample.repeat(len(real_batch), 1, 1)

            all_distances.append(distance_fn(gen_batch, real_batch))
        
        idx = torch.argmin(torch.cat(all_distances))
        matches.append(idx)

    score = len(torch.unique(torch.as_tensor(matches))) / len(real_clouds)

    return score


def fidelity(real_clouds, gen_clouds, distance_fn):
    distances = [ [distance_fn(a, b) for b in gen_clouds] for a in real_clouds ]
    min_distances = torch.as_tensor(distances).min(dim=-1)

    score = torch.mean(min_distances)
    
    return score

def fidelity_2(real_clouds, gen_clouds, distance_fn):
    min_distances = []

    for a in tqdm(real_clouds):
        min_dist = torch.min(torch.as_tensor([distance_fn(a, b) for b in gen_clouds]))
        min_distances.append(min_dist)

    score = torch.mean(torch.as_tensor(min_distances)).cpu().numpy()
    
    return score

def fidelity_3(real_clouds, gen_clouds, distance_fn, batch_size=50):
    min_distances = []

    for real_sample in tqdm(real_clouds): 
        all_distances = []

        for batch_idx in range(0, len(gen_clouds), batch_size):
            gen_batch = torch.cat(gen_clouds[batch_idx:batch_idx+batch_size])
            real_batch = real_sample.repeat(len(gen_batch), 1, 1)

            all_distances.append(distance_fn(real_batch, gen_batch))
        
        min_dist = torch.min(torch.cat(all_distances))
        min_distances.append(min_dist)

    score = torch.mean(torch.as_tensor(min_distances)).cpu().numpy()
    
    return score

### Collect all samples belonging to the target class

In [237]:
batch_size = 1
dataloader = DataLoader(dataset=single_class_data, 
                        batch_size=batch_size, 
                        shuffle=False,
                        drop_last=False)

target_samples = []

for idx, batch_data in enumerate(dataloader):
    if idx >= len(dataloader):
        break
    
    target = batch_data.pos.view(batch_size, n_points, 3).to(device)
    target_samples.append(target)

len(target_samples)

286

### Generate a bunch of point clouds

In [243]:
batch_size = 1
n_samples = len(single_class_data)

is_latent_gan = False

generated_samples = []

for idx in range(n_samples):
    # sample the noise vector:
    z = torch.Tensor(batch_size, l_gan_gen_input_dim).normal_(mu, sigma).to(device)

    if is_latent_gan:
        # generate bottleneck representation, and then decode point cloud
        gen_samples = ae_model.decoder(l_gan_gen(z))
    else:
        gen_samples = raw_gan_gen(z)
        
    gen_samples = gen_samples.view(batch_size, 3, n_points).transpose(2, 1).detach()
    generated_samples.append(gen_samples)

len(generated_samples)

286

In [244]:
generated_samples[0].shape, target_samples[0].shape

(torch.Size([1, 2048, 3]), torch.Size([1, 2048, 3]))

## Point cloud distance: Chamfer

In [245]:
normalize = True
pointcloud_distance = ChamferLoss(batch_reduction=None, 
                                  point_reduction='mean' if normalize else 'sum')

## Evaluate coverage

In [246]:
n_samples = len(single_class_data)

c_score = coverage_3(target_samples[:n_samples], generated_samples[:n_samples], distance_fn=pointcloud_distance, batch_size=256)

print(f"\ncoverage CD: {c_score:.4}")

HBox(children=(FloatProgress(value=0.0, max=286.0), HTML(value='')))



coverage CD: 0.01748


## Evaluate fidelity

In [247]:
n_samples = len(single_class_data)

f_score = fidelity_3(target_samples[:n_samples], generated_samples[:n_samples], distance_fn=pointcloud_distance, batch_size=256)

print(f"\nfidelity CD (MMD-CD): {f_score:.4}")

HBox(children=(FloatProgress(value=0.0, max=286.0), HTML(value='')))



fidelity CD (MMD-CD): 0.01961


### Sanity check

In [124]:
coverage_3(target_samples[:n_samples], target_samples[:n_samples], distance_fn=pointcloud_distance, batch_size=256)

HBox(children=(FloatProgress(value=0.0, max=704.0), HTML(value='')))




1.0

In [125]:
fidelity_3(target_samples[:n_samples], target_samples[:n_samples], distance_fn=pointcloud_distance, batch_size=256)

HBox(children=(FloatProgress(value=0.0, max=704.0), HTML(value='')))




array(0., dtype=float32)