In [1]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from tqdm import tqdm


import plotly as plt
import os
import pprint
import argparse

import wandb
import torch
import numpy as np

In [7]:
x = torch.Tensor([0.4304, 0.5146])
x.mean()

tensor(0.4725)

In [None]:
# metric

# Batched CD (CPU), borrowed from https://github.com/ThibaultGROUEIX/AtlasNet
def cd_cpu(sample, ref):
    x, y = sample, ref
    bs, num_points, points_dim = x.size()
    xx = torch.bmm(x, x.transpose(2, 1))
    yy = torch.bmm(y, y.transpose(2, 1))
    zz = torch.bmm(x, y.transpose(2, 1))
    diag_ind = torch.arange(0, num_points).to(sample).long()
    rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
    ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
    P = (rx.transpose(2, 1) + ry - 2 * zz)
    return P.min(1)[0], P.min(2)[0]

def compute_cd(x, y, reduce_func=torch.mean):
    d1, d2 = cd_cpu(x, y)
    return reduce_func(d1, dim=1) + reduce_func(d2, dim=1)


#def compute_emd(x, y):
#    return match_cost(x, y) / x.size(1)


def compute_pairwise_cd_emd(x, y, batch_size=32):
    NX, NY, cd, _ = x.size(0), y.size(0), [], []
    y = y.contiguous()
    for i in tqdm(range(NX)):
        cdx, _ , xi = [], [], x[i]
        for j in range(0, NY, batch_size):
            yb = y[j : j + batch_size]
            xb = xi.view(1, -1, 3).expand_as(yb).contiguous()
            cdx.append(compute_cd(xb, yb).view(1, -1))
            #emdx.append(compute_emd(xb, yb).view(1, -1))
        cd.append(torch.cat(cdx, dim=1))
        #emd.append(torch.cat(emdx, dim=1))
    cd = torch.cat(cd, dim=0) #, torch.cat(emd, dim=0)
    return cd


def compute_mmd_cov(dxy):
    _, min_idx = dxy.min(dim=1)
    min_val, _ = dxy.min(dim=0)
    mmd = min_val.mean()
    cov = min_idx.unique().numel() / dxy.size(1)
    cov = torch.tensor(cov).to(dxy)
    return mmd, cov


@torch.no_grad()
def compute_metrics(x, y, batch_size):
    cd_yx = compute_pairwise_cd_emd(y, x, batch_size)
    mmd_cd, cov_cd = compute_mmd_cov(cd_yx.t())
    #mmd_emd, _ = compute_mmd_cov(emd_yx.t())
    return {
        "COV-CD": cov_cd.cpu(),
        #"COV-EMD": cov_emd.cpu(),
        "MMD-CD": mmd_cd.cpu(),
        #"MMD-EMD": mmd_emd.cpu(),
    }, {
        "CD_YX": cd_yx.cpu(),
        #"EMD_YX": emd_yx.cpu(),
    }


In [None]:
# Configuration
root_dir = "/Users/kevin/projects/cs236g/default-project"
data_dir = os.path.join("/Users/kevin/CS236G", "data")
ckpt_dir = os.path.join(root_dir, "checkpoints")
# Name of current experiment. Checkpoints will be stored in '{ckpt_dir}/{name}/'. 
name = "exp1"
# Manual seed for reproducibility.
seed = 0 
# point cloud category
cate = "airplane"
# Resumes training using the last checkpoint in ckpt_dir.
resume = False
batch_size = 8
# Number of points sampled from each training sample.
tr_sample_size = 10
# Number of points sampled from each testing sample.
te_sample_size = 10
# Total training epoch.
max_epoch = 2000
# Number of discriminator updates before a generator update.
repeat_d = 5
log_every_n_step = 20
val_every_n_epoch = 20
ckpt_every_n_epoch = 100
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# utils.py

def plot_samples(samples, num=8, rows=2, cols=4):
    fig = plt.subplots.make_subplots(
        rows=rows,
        cols=cols,
        specs=[[{"type": "Scatter3d"} for _ in range(cols)] for _ in range(rows)],
    )
    indices = torch.randperm(samples.size(0))[:num]
    for i, sample in enumerate(samples[indices].cpu()):
        fig.add_trace(
            plt.graph_objects.Scatter3d(
                x=sample[:, 0],
                y=sample[:, 2],
                z=sample[:, 1],
                mode="markers",
                marker=dict(size=3, opacity=0.8),
            ),
            row=i // cols + 1,
            col=i % cols + 1,
        )
    fig.update_layout(showlegend=False)
    return fig

In [None]:
# Dataset.py

# split is either "Train", "Val", "Test"
class Lidar(torch.utils.data.Dataset):
    def __init__(self, data_dir, split): #random_sample, sample_size
        self.data = []
        for fname in os.listdir(os.path.join(data_dir, split)):
            if fname.endswith(".npy"):
                path = os.path.join(data_dir, split, fname)
                # you add an extra dimension
                # same as torch.unsqueeze but for numpy
                sample = np.load(path)[np.newaxis, ...]
                self.data.append(torch.from_numpy(sample).float())

        # Normalize data
        # concat observations along first dim
        self.data = torch.cat(self.data, dim=0)
        
        # Comment out because our data is already min-max-scaled
        #self.mu = self.data.view(-1, 3).mean(dim=0).view(1, 3)
        #self.std = self.data.view(-1).std(dim=0).view(1, 1)
        #self.data = (self.data - self.mu) / self.std

        # Following lines are purely for reproducing results of
        # the official SetVAE implementation: github.com/jw9730/setvae
        #tr_data, te_data = self.data.split(10000, dim=1)
        #self.data = tr_data if split == "train" else te_data

        #self.random_sample = random_sample
        #self.sample_size = sample_size

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        """
        sample_idx = (
            torch.randperm(x.size(0))[: self.sample_size]
            if self.random_sample
            else torch.arange(self.sample_size)
        )
        x = x[sample_idx]
        """
        return x #,self.mu, self.std

In [None]:
# Model.py

class MaxBlock(nn.Module):
    # Just a linear layer
    # Order of the points does not matter. No matter the order of the points
    # the output should be the same
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        xm, _ = x.max(dim=1, keepdim=True)
        x = self.proj(x - xm)
        return x


class Encoder(nn.Module):
    # You take a point cloud, i.e. (2000, 3) and encode into a latent space, i.e. with e.g. 64 dimensions,
    # and then you add some noise to the 64 dimensions and decode it back into a point cloud
    # x_dim is the dimension of the point cloud, i.e. 3 (x,y,z)
    # d_dim
    # z1_dim
    
    def __init__(self, x_dim, d_dim, z1_dim):
        super().__init__()
        self.phi = nn.Sequential(
            MaxBlock(x_dim, d_dim),
            nn.Tanh(),
            MaxBlock(d_dim, d_dim),
            nn.Tanh(),
            MaxBlock(d_dim, d_dim),
            nn.Tanh(),
        )
        self.ro = nn.Sequential(
            nn.Linear(d_dim, d_dim),
            nn.Tanh(),
            nn.Linear(d_dim, z1_dim),
        )

    def forward(self, x):
        x = self.phi(x)
        x, _ = x.max(dim=1)
        z1 = self.ro(x)
        return z1


class Decoder(nn.Module):
    def __init__(self, x_dim, z1_dim, z2_dim, h_dim=512):
        super().__init__()
        self.fc = nn.Linear(z1_dim, h_dim)
        self.fu = nn.Linear(z2_dim, h_dim, bias=False)
        self.dec = nn.Sequential(
            nn.Softplus(),
            nn.Linear(h_dim, h_dim),
            nn.Softplus(),
            nn.Linear(h_dim, h_dim),
            nn.Softplus(),
            nn.Linear(h_dim, h_dim),
            nn.Softplus(),
            nn.Linear(h_dim, h_dim),
            nn.Softplus(),
            nn.Linear(h_dim, x_dim),
        )

    def forward(self, z1, z2):
        x = self.fc(z1) + self.fu(z2)
        o = self.dec(x)
        return o


class Generator(nn.Module):
    # The Generator generates one 3D point at a time given conditioned on some random normal noise
    def __init__(self, x_dim=3, d_dim=256, z1_dim=256, z2_dim=10):
        super().__init__()
        self.z2_dim = z2_dim
        self.enc = Encoder(x_dim, d_dim, z1_dim)
        self.dec = Decoder(x_dim, z1_dim, z2_dim)

    def encode(self, x):
        z1 = self.enc(x).unsqueeze(dim=1)
        return z1

    def decode(self, z1, B, N, device):
        # z1 this is a latent vector specifying the class, this is the output of the encoder which takes
        # in the point cloud and encodes it into a latent space
        # z2 is the random noise used to generate new points individually
        z2 = torch.randn((B, N, self.z2_dim)).to(device)
        # output is a batch of points, and the points are uniformly distributed on the surface of the object 
        # that you are trying to model
        o = self.dec(z1, z2)
        return o

    def forward(self, x):
        # x is the point cloud, basically [N, 3], with N being the individual points
        z1 = self.encode(x)
        # z1 is the latent vector
        # o is another point cloud, also [N, 3]
        o = self.decode(z1, x.size(0), x.size(1), x.device)
        return o, z1


class Discriminator(nn.Module):
    def __init__(self, x_dim=3, z1_dim=256, h_dim=1024, o_dim=1):
        # z1_dim: dimension of the latent vector
        # o_dim: dimension of the output, which is a scalar that the discriminator aims 
        # to maximize while the generator aims to minimize
         
        super().__init__()
        self.fc = nn.Linear(z1_dim, h_dim)
        self.fu = nn.Linear(x_dim, h_dim, bias=False)
        self.d1 = nn.Sequential(
            nn.Softplus(),
            nn.Linear(h_dim, h_dim),
            nn.Softplus(),
            nn.Linear(h_dim, h_dim - z1_dim),
        )
        self.sc = nn.Linear(z1_dim, h_dim)
        self.su = nn.Linear(h_dim - z1_dim, h_dim, bias=False)
        self.d2 = nn.Sequential(
            nn.Softplus(),
            nn.Linear(h_dim, h_dim),
            nn.Softplus(),
            nn.Linear(h_dim, h_dim - z1_dim),
        )
        self.tc = nn.Linear(z1_dim, h_dim)
        self.tu = nn.Linear(h_dim - z1_dim, h_dim, bias=False)
        self.d3 = nn.Sequential(
            nn.Softplus(),
            nn.Linear(h_dim, h_dim),
            nn.Softplus(),
            nn.Linear(h_dim, o_dim),
            # You can add a Softmax here and adjust the o_dim to be the class
            # This one is currently conditioned on z1
            # google point cloud classification.
        )

    def forward(self, x, z1):
        y = self.fc(z1) + self.fu(x)
        o = self.d1(y)
        y = self.sc(z1) + self.su(o)
        o = self.d2(y)
        y = self.tc(z1) + self.tu(o)
        o = self.d3(y)
        return o

In [None]:
# Trainer.py

class Trainer:
    def __init__(
        self,
        net_g,
        device,
        batch_size,
        net_d=None,
        opt_g=None,
        opt_d=None,
        sch_g=None,
        sch_d=None,
        max_epoch=None,
        repeat_d=None,
        log_every_n_step=None,
        val_every_n_epoch=None,
        ckpt_every_n_epoch=None,
        ckpt_dir=None,
    ):
        self.net_g = net_g.to(device)
        self.device = device
        self.batch_size = batch_size
        self.net_d = net_d and net_d.to(device)
        self.opt_g = opt_g
        self.opt_d = opt_d
        self.sch_g = sch_g
        self.sch_d = sch_d
        self.step = 0
        self.epoch = 0
        self.max_epoch = max_epoch
        self.repeat_d = repeat_d
        self.log_every_n_step = log_every_n_step
        self.val_every_n_epoch = val_every_n_epoch
        self.ckpt_every_n_epoch = ckpt_every_n_epoch
        self.ckpt_dir = ckpt_dir

    def _state_dict(self):
        return {
            "net_g": self.net_g.state_dict(),
            "net_d": self.net_d.state_dict(),
            "opt_g": self.opt_g.state_dict(),
            "opt_d": self.opt_d.state_dict(),
            "sch_g": self.sch_g.state_dict(),
            "sch_d": self.sch_d.state_dict(),
            "step": self.step,
            "epoch": self.epoch,
            "max_epoch": self.max_epoch,
        }

    def _load_state_dict(self, state_dict):
        for k, m in {
            "net_g": self.net_g,
            "net_d": self.net_d,
            "opt_g": self.opt_g,
            "opt_d": self.opt_d,
            "sch_g": self.sch_g,
            "sch_d": self.sch_d,
        }.items():
            m and m.load_state_dict(state_dict[k])
        self.step, self.epoch, self.max_epoch = map(
            state_dict.get,
            (
                "step",
                "epoch",
                "max_epoch",
            ),
        )

    def save_checkpoint(self):
        ckpt_path = os.path.join(self.ckpt_dir, f"{self.epoch}.pth")
        torch.save(self._state_dict(), ckpt_path)

    def load_checkpoint(self, ckpt_path=None):
        if not ckpt_path:  # Find last checkpoint in ckpt_dir
            ckpt_paths = [p for p in os.listdir(self.ckpt_dir) if p.endswith(".pth")]
            assert ckpt_paths, "No checkpoints found."
            ckpt_path = sorted(ckpt_paths, key=lambda f: int(f[:-4]))[-1]
            ckpt_path = os.path.join(self.ckpt_dir, ckpt_path)
        self._load_state_dict(torch.load(ckpt_path))

    def _train_step_g(self, x): #, mu, std
        o, z1 = self.net_g(x)
        op = self.net_d(o, z1.detach())
        # This is the GAN loss
        loss_op = -op.mean()
        # TODO: This is the point cloud loss. You can modify the chamfer loss here
        # o is model output, prediction, [B x Number of Points x 3]
        # x is ground truth, [B x Number of Points x 3]
        # TODO: HERE YOU CAN MODIFY THE LOSS
        loss_cd = compute_cd(o, x, reduce_func=torch.sum).mean()
        return loss_op + loss_cd

    def _train_step_d(self, x): #, mu, std
        o, z1 = self.net_g(x)
        xp = self.net_d(x, z1.detach())
        op = self.net_d(o.detach(), z1.detach())
        loss_d = F.relu(1.0 - xp).mean() + F.relu(1.0 + op).mean()
        return loss_d

    def train(self, train_loader, val_loader):
        while self.epoch < self.max_epoch:

            # Validation and checkpointing
            if self.epoch % self.val_every_n_epoch == 0:
                (metrics, _), samples = self.test(val_loader)
                wandb.log({**metrics, "samples": samples, "epoch": self.epoch})
            if self.epoch % self.ckpt_every_n_epoch == 0:
                self.save_checkpoint()

            with tqdm(train_loader) as t:
                self.net_g.train()
                self.net_d.train()
                for batch in t:

                    # Update step
                    loss_d = self._train_step_d(batch.to(self.device))
                    self.opt_d.zero_grad()
                    loss_d.backward()
                    self.opt_d.step()
                    if self.step % self.repeat_d == 0:
                        loss_g = self._train_step_g(batch.to(self.device))
                        self.opt_g.zero_grad()
                        loss_g.backward()
                        self.opt_g.step()

                    # Stepwise logging
                    t.set_description(
                        f"Epoch:{self.epoch}|L(G):{loss_g.item():.2f}|L(D):{loss_d.item():.2f}"
                    )
                    if self.step % self.log_every_n_step == 0:
                        wandb.log(
                            {
                                "loss_g": loss_g.cpu(),
                                "loss_d": loss_d.cpu(),
                                "step": self.step,
                                "epoch": self.epoch,
                            }
                        )

                    self.step += 1
                self.sch_g.step()
                self.sch_d.step()
            self.epoch += 1

    def _test_step(self, x): #, mu, std
        o, _ = self.net_g(x)
        #x, o = x * std + mu, o * std + mu  # denormalize
        return o, x

    def _test_end(self, o, x):
        # TODO: This is the point cloud loss. You can modify the chamfer loss here
        # o is model output, prediction, [B x Number of Points x 3]
        # x is ground truth, [B x Number of Points x 3]
        # TODO: HERE YOU CAN MODIFY THE LOSS
        metrics = compute_metrics(o, x, self.batch_size)
        samples = plot_samples(o)
        return metrics, samples

    @torch.no_grad()
    def test(self, test_loader):
        results = []
        self.net_g.eval()
        for batch in tqdm(test_loader):
            #batch = [t.to(self.device) for t in batch]
            results.append(self._test_step(batch.to(self.device)))
        return self._test_end(*(torch.cat(_, dim=0) for _ in zip(*results)))

In [None]:
# Fix seed
np.random.seed(seed)
torch.manual_seed(seed)

# Setup checkpoint directory
if not os.path.exists(ckpt_dir):
    os.mkdir(ckpt_dir)
ckpt_subdir = os.path.join(ckpt_dir, name)
if not os.path.exists(ckpt_subdir):
    os.mkdir(ckpt_subdir)

# Setup logging
wandb.init(project="pcgan")

# Setup dataloaders
train_loader = torch.utils.data.DataLoader(
    dataset=Lidar(
        data_dir=data_dir,
        split="Train",
        #random_sample=True,
        #sample_size=tr_sample_size,
    ),
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True,
)
val_loader = torch.utils.data.DataLoader(
    dataset=Lidar(
        data_dir=data_dir,
        split="Val",
        #random_sample=False,
        #sample_size=te_sample_size,
    ),
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    drop_last=False,
)

# Setup model, optimizer and scheduler
net_g = Generator()
net_d = Discriminator()
opt_g = torch.optim.Adam(net_g.parameters(), lr=4e-4, betas=(0.9, 0.999))
opt_d = torch.optim.Adam(net_d.parameters(), lr=2e-4, betas=(0.9, 0.999))
sch_g = torch.optim.lr_scheduler.LambdaLR(opt_g, lr_lambda=lambda e: 1.0)
sch_d = torch.optim.lr_scheduler.LambdaLR(opt_d, lr_lambda=lambda e: 1.0)

# Setup trainer
trainer = Trainer(
    net_g=net_g,
    net_d=net_d,
    opt_g=opt_g,
    opt_d=opt_d,
    sch_g=sch_g,
    sch_d=sch_d,
    device=device,
    batch_size=batch_size,
    max_epoch=max_epoch,
    repeat_d=repeat_d,
    log_every_n_step=log_every_n_step,
    val_every_n_epoch=val_every_n_epoch,
    ckpt_every_n_epoch=ckpt_every_n_epoch,
    ckpt_dir=ckpt_subdir,
)

# Load checkpoint
if resume:
    trainer.load_checkpoint()

# Start training
trainer.train(train_loader, val_loader)
# Loss of Generator
# Loss of Discriminator
# Train Epoch