In [None]:
import numpy as np
import time
import utils
import matplotlib.pyplot as plt
%matplotlib inline
import torch_geometric
import torch.optim as optim
import tqdm
import model
from DeltaNetAE import DeltaNetAE

In [4]:
batch_size = 32
output_folder = "output/" # folder path to save the results
save_results = True # save the results to output_folder
use_GPU = True # use GPU, False to use CPU
latent_size = 128 # bottleneck size of the Autoencoder model

In [5]:
from Dataloaders import GetDataLoaders

pc_array = np.load("data/chair_set.npy")
print(pc_array.shape)

# load dataset from numpy array and divide 90%-10% randomly for train and test sets
train_loader, test_loader = GetDataLoaders(npArray=pc_array, batch_size=batch_size)

# Assuming all models have the same size, get the point size from the first model
point_size = len(train_loader.dataset[0])
print(point_size)

(3746, 1024, 3)
1024


In [6]:
import glob
import os
import os.path as osp
from typing import Callable, List, Optional

import torch

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.io import fs, read_off


class MNLoader(InMemoryDataset):
    r"""The ModelNet10/40 datasets from the `"3D ShapeNets: A Deep
    Representation for Volumetric Shapes"
    <https://people.csail.mit.edu/khosla/papers/cvpr2015_wu.pdf>`_ paper,
    containing CAD models of 10 and 40 categories, respectively.

    .. note::

        Data objects hold mesh faces instead of edge indices.
        To convert the mesh to a graph, use the
        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
        To convert the mesh to a point cloud, use the
        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
        sample a fixed number of points on the mesh faces according to their
        face area.

    Args:
        root (str): Root directory where the dataset should be saved.
        name (str, optional): The name of the dataset (:obj:`"10"` for
            ModelNet10, :obj:`"40"` for ModelNet40). (default: :obj:`"10"`)
        train (bool, optional): If :obj:`True`, loads the training dataset,
            otherwise the test dataset. (default: :obj:`True`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
        force_reload (bool, optional): Whether to re-process the dataset.
            (default: :obj:`False`)

    **STATS:**

    .. list-table::
        :widths: 20 10 10 10 10 10
        :header-rows: 1

        * - Name
          - #graphs
          - #nodes
          - #edges
          - #features
          - #classes
        * - ModelNet10
          - 4,899
          - ~9,508.2
          - ~37,450.5
          - 3
          - 10
        * - ModelNet40
          - 12,311
          - ~17,744.4
          - ~66,060.9
          - 3
          - 40
    """

    urls = {
        '10':
        'http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip',
        '40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip'
    }

    def __init__(
        self,
        root: str,
        name: str = '10',
        train: bool = True,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        force_reload: bool = False,
    ) -> None:
        assert name in ['10', '40']
        self.name = name
        super().__init__(root, transform, pre_transform, pre_filter,
                         force_reload=force_reload)
        path = self.processed_paths[0] if train else self.processed_paths[1]
        self.load(path)

    @property
    def raw_file_names(self) -> List[str]:
        return [
            'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor',
            'night_stand', 'sofa', 'table', 'toilet'
        ]

    @property
    def processed_file_names(self) -> List[str]:
        return ['training.pt', 'test.pt']

    def download(self) -> None:
        path = download_url(self.urls[self.name], self.root)
        extract_zip(path, self.root)
        os.unlink(path)
        folder = osp.join(self.root, f'ModelNet{self.name}')
        fs.rm(self.raw_dir)
        os.rename(folder, self.raw_dir)

        # Delete osx metadata generated during compression of ModelNet10
        metadata_folder = osp.join(self.root, '__MACOSX')
        if osp.exists(metadata_folder):
            fs.rm(metadata_folder)

    def process(self) -> None:
        self.save(self.process_set('train'), self.processed_paths[0])
        self.save(self.process_set('test'), self.processed_paths[1])

    def process_set(self, dataset: str) -> List[Data]:
        categories = ['chair']

        data_list = []
        for target, category in enumerate(categories):
            folder = osp.join(self.raw_dir, category, dataset)
            paths = glob.glob(f'{folder}/{category}_*.off')
            for path in paths:
                data = read_off(path)
                data.y = torch.tensor([target])
                data_list.append(data)

        if self.pre_filter is not None:
            data_list = [d for d in data_list if self.pre_filter(d)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        return data_list

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}{self.name}({len(self)})'


In [8]:
from torch_geometric.transforms import Compose, SamplePoints
import deltaconv.transforms as T2

pre_transform = Compose((
    T2.NormalizeScale(),
    SamplePoints(1024, include_normals=True),
    T2.GeodesicFPS(1024)
))

train_ds = MNLoader('modelnet', '10', True, pre_transform=pre_transform)
train_loader = torch_geometric.loader.DataLoader(train_ds, batch_size=batch_size, shuffle=True)

test_ds = MNLoader('modelnet', '10', False, pre_transform=pre_transform)
test_loader = torch_geometric.loader.DataLoader(test_ds, batch_size=batch_size, shuffle=False)

In [None]:
net = model.PointCloudVAE(point_size,latent_size)



print(sum(p.numel() for p in net.parameters()))

930816


In [None]:
net = DeltaNetAE( 
    in_channels = 3, 
    point_size = point_size,
    latent_size=latent_size,
    conv_channels = [64, 128, 256], 
    mlp_depth = 2, 
    num_neighbors = 10, 
    grad_regularizer = .001, 
    grad_kernel_width = 1
)
print(sum(p.numel() for p in net.parameters()))

477504
3040704


In [None]:
if(use_GPU):
    device = torch.device("cuda:0")
    if torch.cuda.device_count() > 1: # if there are multiple GPUs use all
        net = torch.nn.DataParallel(net)
else:
    device = torch.device("cpu")

net = net.to(device)

In [10]:
def chamfer_distance(A, B):
    """
    Computes the chamfer distance between two sets of points A and B.
    """
    tree = torch.cdist(A, B)
    dist_A = torch.min(tree, dim=1)[0]
    dist_B = torch.min(tree, dim=2)[0]
    return torch.mean(torch.mean(dist_A, dim=1) + torch.mean(dist_B, dim=1))

In [11]:
optimizer = optim.Adam(net.parameters(), lr=0.0005)

In [12]:
def train_epoch():
    epoch_loss = 0
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()        
        data = data.to(device)
        output = net(data) # transpose data for NumberxChannelxSize format
        loss = chamfer_distance(data.pos.view(output.shape[0], -1, 3), output) 
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss/i

In [13]:
def test_batch(data): # test with a batch of inputs
    with torch.no_grad():
        data = data.to(device)
        output = net(data)
        loss = chamfer_distance(data.pos.view(output.shape[0], -1, 3), output)
        
    return loss.item(), output.cpu()

In [14]:
def test_epoch(): # test with all test set
    with torch.no_grad():
        epoch_loss = 0
        for i, data in enumerate(test_loader):
            loss, output = test_batch(data)
            epoch_loss += loss

    return epoch_loss/i

In [15]:
if(save_results):
    utils.clear_folder(output_folder)

In [None]:
train_loss_list = []  
test_loss_list = []  

for i in tqdm.tqdm(range(500)): # train for 1000 epochs
    startTime = time.time()
    
    train_loss = train_epoch() #train one epoch, get the average loss
    train_loss_list.append(train_loss)
    
    test_loss = test_epoch() # test with test set
    test_loss_list.append(test_loss)
    
    epoch_time = time.time() - startTime
    
    writeString = "epoch " + str(i) + " train loss : " + str(train_loss) + " test loss : " + str(test_loss) + " epoch time : " + str(epoch_time) + "\n"
    
    # plot train/test loss graph
    plt.plot(train_loss_list, label="Train")
    plt.plot(test_loss_list, label="Test")
    plt.legend()

    if(save_results): # save all outputs to the save folder

        # write the text output to file
        with open(output_folder + "prints.txt","a") as file: 
            file.write(writeString)

        # update the loss graph
        plt.savefig(output_folder + "loss.png")
        plt.close()

        # save input/output as image file
        if(i%50==0):
            test_samples = next(iter(test_loader))
            loss , test_output = test_batch(test_samples)
            utils.plotPCbatch(test_samples.pos.view(test_output.shape[0], -1, 3).cpu(), test_output, show=False, save=True, name = (output_folder  + "epoch_" + str(i)))

    else : # display all outputs
        
        test_samples = next(iter(test_loader))
        loss , test_output = test_batch(test_samples)
        utils.plotPCbatch(test_samples,test_output)

        print(writeString)

        plt.show()

        


 92%|█████████▏| 459/500 [04:37<00:25,  1.62it/s]