In [None]:
from builtins import range
import numpy as np
from numpy.linalg import norm
import matplotlib.pylab as plt
from mpl_toolkits.mplot3d import Axes3D
import warnings
warnings.filterwarnings('ignore')

# plotting function
def plot_3d_point_cloud(
    pc,
    show=True,
    show_axis=True,
    in_u_sphere=True,
    marker=".",
    c="b",
    s=4,
    alpha=0.8,
    figsize=(5, 5),
    elev=10,
    azim=10,
    miv=None,
    mav=None,
    squeeze=0.8,
    axis=None,
    title=None,
    *args,
    **kwargs
):
    x, y, z = (pc.squeeze(0)[:, 2], pc.squeeze(0)[:, 0], pc.squeeze(0)[:, 1])

    if axis is None:
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, projection="3d")
    else:
        ax = axis
        fig = axis

    if title is not None:
        plt.title(title)

    sc = ax.scatter(x, y, z, marker=marker, c=c, s=s, alpha=alpha, *args, **kwargs)
    ax.view_init(elev=elev, azim=azim)

    if in_u_sphere:
        ax.set_xlim3d(-0.5, 0.5)
        ax.set_ylim3d(-0.5, 0.5)
        ax.set_zlim3d(-0.5, 0.5)
        miv = -0.5
        mav = 0.5
        plt.tight_layout()

    else:
        if miv is None:
            miv = squeeze * np.min(
                [np.min(x), np.min(y), np.min(z)]
            )  # Multiply with 'squeeze' to squeeze free-space.
        if mav is None:
            mav = squeeze * np.max([np.max(x), np.max(y), np.max(z)])
        ax.set_xlim(miv, mav)
        ax.set_ylim(miv, mav)
        ax.set_zlim(miv, mav)
        plt.tight_layout()

    if not show_axis:
        plt.axis("off")

    if "c" in kwargs:
        plt.colorbar(sc)

    if show:
        plt.show()

    return fig, miv, mav

In [None]:
import os
os.getcwd()

In [None]:
# import
import torch
import torchvision
import sys
sys.path.append("..")
from models import pointnet_cls, pointnet_utils
from src import FPSSampler, SampleNet
from src.pctransforms import OnUnitCube, PointcloudRandomInputDropout, PointcloudToTensor
from src.chamfer_distance import ChamferDistance
from data.facescape_loader import FaceScape

# device
device = torch.device("cuda")


# dataset
transforms = torchvision.transforms.Compose([PointcloudToTensor(), OnUnitCube()])
testset = FaceScape(
    1024,
    transforms=transforms,
    train=False,
    annotations="all_annotations1024.npy",
    contrastive=False
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=1,
    shuffle=True,
    num_workers=1,
    pin_memory=False,
)

# add noise function
def add_noise(data, model,
              learn_noise=False,
              resample=False,
              pointwise=False,
              sigma=0.):
    
    if learn_noise:
        if resample:
            scale = model.noise_std[0].abs()
            loc = model.noise_mean[0]
            
        elif pointwise:
            scale, loc = model.get_std(data[0])
        else:
            new_data = data[0] + model.noise[0]
        
        if resample or pointwise:
            noise = tod.Normal(loc=loc, scale=scale.abs() + 1e-10)
            resampled = noise.rsample()
            new_data = data[0] + resampled

            return new_data, data[1]
    else:
        new_data = data[0] + torch.normal(
                mean=0.0, std=sigma, size=data[0].shape
            ).to(data.device)
    
    return new_data, data[1]
    

In [None]:
# generate random sample
inputs, target, _ = next(iter(testloader))

# plot input
_= plot_3d_point_cloud(inputs, title="Complete input point cloud")


## Plot Exp Gender

In [None]:
pointnet_exp = pointnet_cls.get_model(20, use_enc_stn=True, num_in_points=1024).eval().to(device)

pointnet_gender = pointnet_cls.get_model(2, use_enc_stn=True, num_in_points=1024).eval().to(device)

sampler_s = SampleNet(
    num_out_points=64,
    bottleneck_size=128,
    group_size=7,
    initial_temperature=1.0,
    input_shape="bnc",
    output_shape="bnc",
    use_stn=True,
    device="cpu",
    learn_noise=False,
    pointwise_dist=False,
).eval().to(device)

sampler_f = FPSSampler(
                num_out_points=64,
                permute=True,
                input_shape="bnc",
                output_shape="bnc",
                learn_noise=False,
                pointwise_dist=False,
            )

In [None]:
# plot sampled point cloud exp: 64 points
EXP_SAMPLENET64 = "../facescape/log/exp/SAMPLENET64/train_sampler_best.pth"
EXP_POINTNET64 = "../facescape/log/exp/SAMPLENET64/train_model_best.pth"
EXP_FPS64 = "../facescape/log/exp/FPS64/train_sampler_best.pth"
EXP_POINTNET64_FPS = "../facescape/log/exp/FPS64/train_model_best.pth"


def run_exp_samplenet(inputs, target, sampler, model_path, sampler_path):
    pointnet_exp.load_state_dict(torch.load(model_path))
    sampler.load_state_dict(torch.load(sampler_path))

    # sample using samplenet
    if "FPS" not in sampler_path:
        p0_simplified, p0_projected, original = sampler(inputs.to(device))
        sampled_data = (p0_projected, target)
    else:
        sampled_data = sampler(inputs.to(device))
    
    return sampled_data

sampled_data = run_exp_samplenet(inputs, target, sampler_s, EXP_POINTNET64, EXP_SAMPLENET64)
p0_projected, target = sampled_data
# plot sampled data
_ = plot_3d_point_cloud(p0_projected.to("cpu"), title="Sampled point cloud (exp)")

# inference
x, trans_feat = pointnet_exp(p0_projected)
_, pred = torch.max(x.data, 1)
t = target[:, -1]
# loss = pointnet_exp.get_loss()(x, target, trans_feat)
correct = (pred.to("cpu") == t.to("cpu")).sum()
total = t.size(0)
correct = float(correct.item()) / total

print(correct, pred, t)
# print(loss)


In [None]:
GENDER_FPS64 = "../facescape/log/gender/FPS64/train_sampler_best.pth"
GENDER_SAMPLENET64 = "../facescape/log/gender/SAMPLENET64/train_sampler_best.pth"
GENDER_POINTNET64 = "../facescape/log/gender/SAMPLENET64/train_model_best.pth"
GENDER_POINTNET64_FPS = "../facescape/log/gender/FPS64/train_model_best.pth"

def run_gender_samplenet(inputs, target, sampler, model_path, sampler_path):
    pointnet_gender.load_state_dict(torch.load(model_path))
    sampler.load_state_dict(torch.load(sampler_path))

    if "FPS" not in sampler_path:
        p0_simplified, p0_projected, original = sampler(inputs.to(device))
        sampled_data = (p0_projected, target)
    else:
        sampled_data = sampler(inputs.to(device))
    
    return sampled_data


sampled_data = run_gender_samplenet(inputs, target, sampler_s, GENDER_POINTNET64, GENDER_SAMPLENET64)
p0_projected, target = sampled_data
# plot sampled data
_ = plot_3d_point_cloud(p0_projected.to("cpu"), title="Sampled point cloud (gender)")

# inference
x, trans_feat = pointnet_gender(p0_projected)
_, pred = torch.max(x.data, 1)
t = target[:, -2]
# loss = pointnet.get_loss()(x, target, trans_feat)
correct = (pred.to("cpu") == t.to("cpu")).sum()
total = t.size(0)
correct = float(correct.item()) / total

print(correct, pred, t)


In [None]:
# plot private point cloud exp-gender: 64 points
import torch.distributions as tod

# CSN + Pointwise
CSN_POINTWISE_SAMPLER = "../facescape/log/csn_pointwise_discrim/5/50/exp_gender_1/SAMPLENET64/train_sampler_best.pth"
CSN_POINTWISE_POINTNET = "../facescape/log/csn_pointwise_discrim/5/50/exp_gender_1/SAMPLENET64/train_model_best.pth"
CSN_POINTWISE_ATTACKER = "../facescape/log/csn_pointwise_discrim/5/50/exp_gender_1/finetune/train_model_best.pth"

sampler = SampleNet(
    num_out_points=64,
    bottleneck_size=128,
    group_size=7,
    initial_temperature=1.0,
    input_shape="bnc",
    output_shape="bnc",
    use_stn=True,
    device="cpu",
    learn_noise=True,
    pointwise_dist=True,
).eval().to(device)

attacker = pointnet_cls.get_model(2, use_enc_stn=True, num_in_points=1024).eval().to(device)

pointnet_exp.load_state_dict(torch.load(CSN_POINTWISE_POINTNET))
sampler.load_state_dict(torch.load(CSN_POINTWISE_SAMPLER))
attacker.load_state_dict(torch.load(CSN_POINTWISE_ATTACKER))

# sample using samplenet
p0_simplified, p0_projected, original = sampler(inputs.to(device))
sampled_data = (p0_projected, target)
sampled_data = add_noise(
    sampled_data,
    sampler,
    learn_noise=True,
    resample=False,
    pointwise=True
)

# plot sampled data
_ = plot_3d_point_cloud(sampled_data[0].detach().to("cpu"), title="CSN Pointwise (exp-gender)")