# Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import torch
from torch.utils.data import DataLoader

from datasets import FromNpDataset, ModelnetDataset
from models import VAE, cd, ENCODER_HIDDEN, DECODER_LAYERS
import transforms

# Prepare dataset

In [None]:
t = transforms.Compose([
    transforms.RandomRotation(0.01),
    transforms.GaussianNoise(0.01),
])

train_dataset = ModelnetDataset(transform=t)
test_dataset = ModelnetDataset(transform=None)

train_loader = DataLoader(train_dataset, batch_size=24,
                        shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=24,
                        shuffle=True, num_workers=1)

# Visualize some data

In [None]:
def subplot_num(m, i, j):
    return i*m + j

def plot_samples(samples, n, m):
    fig = plt.figure(figsize=(18,18))
    for i in range(n):
        for j in range(m):
            idx = subplot_num(m, i, j)
            X, Y, Z = np.split(samples[idx], 3)
            ax = fig.add_subplot(n, m, idx+1, projection='3d')
            ax.scatter(X, Y, Z)        
    plt.show()
    
def apply_m_times(f, samples, m):
    applied = []
    for sample in samples:
        applied.append(sample)
        for i in range(m):
            applied.append(f(sample))
    return applied

In [None]:
N, M = 4, 3
batch = next(iter(test_loader))
samples = apply_m_times(lambda x: t(x.numpy()), batch[:N], M)
plot_samples(samples, N, M+1)

# Experiments

http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_FoldingNet_Point_Cloud_CVPR_2018_paper.pdf

In [None]:
model = VAE.load_from_drive(ENCODER_HIDDEN, DECODER_LAYERS, 'ae')

In [None]:
N, M = 4, 3
batch = next(iter(test_loader))
with torch.no_grad():
    samples = apply_m_times(lambda x: model(x.unsqueeze(0))[0].squeeze(0), batch[:N], M)
plot_samples(samples, N, M+1)

In [None]:
def reconstruct_dataset(model, loader):
    model.eval()
    model.to('cuda')
    recs = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.cuda()

            rec, _, _ = model(batch)
            recs.append(rec)
    
    model.to('cpu')
    return torch.cat(recs, dim=0).cpu()

In [None]:
def nearest(sample, loader):
    best = np.inf
    sample_batch = x.unsqueeze(0).cuda()
    for batch in loader:
        batch = batch.cuda()
        if y.shape[0] != batch.shape[0]:
            y = x.expand(batch.shape[0], -1, -1).cuda()
        best = min(best, torch.min(cd(batch, sample_batch)).item().cpu())
    return best

def rec_nearest(data, loader):
    return [ nearest(data[idx], loader) for idx in range(data.shape[0]) ]

In [None]:
def cd_sample(data, idx, max_batch):
    sample = data[idx].unsqueeze(0)
    done = 0
    results = []
    while done < idx:
        batch_size = min(max_batch, idx-done)
        next_batch = data[done:done+batch_size].cuda()
        sample_batch = sample.expand(batch_size, -1, -1).cuda()
        results.append(cd(next_batch, sample_batch))
        done += batch_size
    return torch.cat(results, dim=0).cpu()

def cd_dataset(data, max_batch):
    N = len(data)
    d = torch.from_numpy(np.full((N, N), np.inf))
    for i in range(1, N):
        sample_d = cd_sample(data, i, max_batch)
        d[:i, i] = sample_d
        d[i, :i] = sample_d
    return d
            
def dataset_nearest(data, max_batch=64):
    d = cd_dataset(data, max_batch)
    return torch.min(d, dim=1)[0]

In [None]:
def cov(model, dataset, batch_size=64):
    with torch.no_grad():
        data_loader = DataLoader(dataset, batch_size=batch_size,
                            shuffle=False, num_workers=2)
        rec = reconstruct_dataset(model, data_loader).numpy()
        rec_dataset = FromNpDataset(rec)
        rec_loader = DataLoader(rec_dataset, batch_size=batch_size,
                               shuffle=False, num_workers=2)
        
        data = torch.from_numpy(dataset.data)
        
        d_dataset = dataset_nearest(data)
        d_rec = rec_nearest(data, rec_loader)
        
        return torch.sum(d_rec <= d_dataset) / len(dataset)

In [None]:
print(cov(model, test_dataset))