The goal of this TD is to code and play with [PointNet](https://arxiv.org/pdf/1612.00593.pdf), a standard architecture for deep learning for pointcloud.

In this lab, we will ask you to provide both:
- notebook with logs (training logs in particular)
- the weights of the trained networks (classifier and segmentation). We provide guidance for saving the logs directly in your google drive, you can also download them via the colab interface

It should be helpful to use the slides on the side of this TD to obtain additionnal illustrations.

The idea of PointNet is to learn functions which are by definition fit to process unordered set of elements, that is functions of the shape
$$
f\big(\left\{x_1,\dots, x_n\right\}\big) = g\big(h(x_1),\dots, h(x_n)\big)
$$

where $g$ is a *symmetric* function, such as the $\max$ function.

In this TD, we will build the network as presented in the following image (slide 34):
![](https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD4/PointNet_architecture.PNG)

It consists in:
1. An input / feature transformation module
3. Per-vertex MLP
4. Feature aggregation for classification

In [2]:
!pip install numpy matplotlib

[0m

In [1]:
#from google.colab import output
#output.enable_custom_widget_manager()

import os
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch
import torch.nn.functional as F

from torch.utils.data import DataLoader

# 1. PointNet Feature Extractor

## 1.1 Input and feature Transformation

The feature transformation module takes as input a $d$ dimensional point-cloud and regresses a $d\times d$ transformation matrix $T$.

If the input consists in $[x,y,z]$ coordinates, we can expect the network to learn a $3\times 3$ matrix which for instance rotates the point cloud in a *canonical* orientation.

If the input consists in $d$-dimensional features, the transformation matrix is harder to visualize.

![](https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD4/feature_transform.PNG)

This is done in two steps:
1. Apply per-vertex MLP to obtain a $d_{out}$ dimensional embedding for each vertex
2. use MaxPooling over the first dimension to to obtain a simple $d_{out}$ dimensional vector **for the whole point cloud**
3. Use MLP to obtain a $d^2$ dimensional vector, which defines T (flattened).
---


### Question 1

Implement the `FeatTransform` module.

**NOTE :**

Except for the **last** linear layer of the second MLP, all layers consist in of the combination of `layer`, `Batchnorm` and `ReLu` activations. The last layer is a simple linear layer with **no activation** or batchnorm (we don't want to constraint the values of matrix T).

An important (and annoying) feature in [`nn.BatchNorm1d`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html) is that is expects the features to lie on the **second** dimension (the one after the batch). On the contrary fully connected layers expect features to lie on the **last** dimension. Unless the input is 2d, we would need to reshape values for each layer. Hopefully there is a small trick to handle this.

The trick consist in using [`nn.Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) with kernel size of $1$ to define fully connected layers. This is exactly equivalent to fully connected layers although a bit surprising at first sight. The adventage is that torch convolutions expect features to be on the **second** dimension.

Follow the following guide (or reshape at every layer, as you wish):

1. For the first MLP, the input is given with shape `(B,d,n)`. The features are in the second dimension, so use [`nn.Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html), which allows to use [`nn.BatchNorm1d`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html) directly without reshaping.
2. In the second layer you can give as input a tensor of shape `(B,d_out)` (instead of `(B,1,d_out)`). Both  [`nn.Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) and [`nn.BatchNorm1d`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html) will work since the **second** and **last** dimensions are the same.
3. Check the documentation of [`torch.max`](https://pytorch.org/docs/stable/generated/torch.max.html), as it does not return the same thing as numpy (use `torch.max(a,dim=dim).values`).
4. We add an argument to potentially remove batchnorm (for future experiments using batch size of 1)


To stabilize training, we actually output $I + T$ so that at initialization when $T\simeq 0$ this module has no effect.


In [2]:
class FeatTransform(nn.Module):
    # TODO
    def __init__(self, inp_dim=3, hidden_dims1=[64,128,1024], hidden_dims2=[512,256], use_bn=True):
      """
      inp_dim : int - dimension d of the input
      hidden_dims1 : list - hidden layers (including d_out) of the first MLP block (defined with nn.Conv1d...)
      hidden_dims2 : list - hidden layers (without d_out) of the second MLP block
      use_bn       : bool - whether to use batchnorm
      """
      super().__init__()

      ## TODO
      self.inp_dim = inp_dim
      self.use_bn = use_bn
      self.hidden_dims1 = hidden_dims1
      self.hidden_dims2 = hidden_dims2


      self.relu = nn.ReLU()


      MLP1_layer_list = []
      MLP2_layer_list = []


      ##MLP1
      for k in range(len(self.hidden_dims1)):
        if k == 0:
          MLP1_layer_list.append(nn.Conv1d(self.inp_dim, self.hidden_dims1[k], 1))
          if self.use_bn:
            MLP1_layer_list.append(nn.BatchNorm1d(self.hidden_dims1[k]))
          MLP1_layer_list.append(self.relu)

        else:
          MLP1_layer_list.append(nn.Conv1d(self.hidden_dims1[k-1], self.hidden_dims1[k], 1))
          if self.use_bn:
            MLP1_layer_list.append(nn.BatchNorm1d(self.hidden_dims1[k]))
          MLP1_layer_list.append(self.relu)

      self.MLP1 = nn.Sequential(*MLP1_layer_list)


      ## MLP2
      MLP2_layer_list.append(nn.Linear(self.hidden_dims1[-1], self.hidden_dims2[0]))                      #(B, self.hidden_dims2[0] )
      if self.use_bn:
        MLP2_layer_list.append(nn.BatchNorm1d(self.hidden_dims2[0]))                                      #(B, self.hidden_dims2[0] )
      MLP2_layer_list.append(self.relu)

      for j in range(len(self.hidden_dims2)):
        if j == len(self.hidden_dims2) - 1:
          MLP2_layer_list.append(nn.Linear(self.hidden_dims2[j], self.inp_dim**2))
        else:
          MLP2_layer_list.append(nn.Linear(self.hidden_dims2[j], self.hidden_dims2[j+1]))                     #(B, self.hidden_dims2[1] )
          if self.use_bn:
            MLP2_layer_list.append(nn.BatchNorm1d(self.hidden_dims2[j+1]))                                      #(B, self.hidden_dims2[1] )
          MLP2_layer_list.append(self.relu)

      self.MLP2 = nn.Sequential(*MLP2_layer_list)


    def forward(self, x):
        """
        x : (B, d, n) - This is standard shape for convolution inputs

        Output
        ------------
        (B, d , d) : output T defined as I + NET(x) for stability
        """
        device = x.device

        y = self.MLP1(x)

        z = torch.max(y, dim = 2).values.to(device)

        w = self.MLP2(z).to(device)

        T = w.view(x.shape[0], self.inp_dim, self.inp_dim).to(device)

        return T + torch.eye(self.inp_dim, device=T.device).unsqueeze(0)

In [3]:
# Try a forward pass
FeatTransform()(torch.rand(2,3,10))

tensor([[[ 1.3285, -0.5769,  0.6018],
         [ 0.5172,  1.4392,  0.1610],
         [-0.0623,  0.4242,  0.5250]],

        [[ 0.9192, -0.2623,  0.2022],
         [-0.4969,  0.3996,  0.5600],
         [ 0.1675,  0.5305,  1.4662]]], grad_fn=<AddBackward0>)

## 1.2 PointNet Feature extractor

We now build the PointNet feature extractor network (output is `global_feature`)
![](https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD4/PointNet_architecture.PNG)


It can be described as follows:
1. Predict first feature transform and apply to the input
2. Apply a first 1-layer MLP with BatchNorm + ReLu.
3. Preduct second feature transform and apply to the features
4. Apply a second MLP with Batchnorm + ReLu, **except** on last layer where there is no activation (`maxpool`will serve as non-linearity).
5. Apply Maxpooling to obtain the global feature vector

Note that for future segmentation task, we will also need to outpu the second transformed features (or the input to the second MLP).

---

### Question 2
Implement the `PointNetfeat` module with the following convention

1. Use again [`nn.Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) for both MLP blocks.
2. The **first** MLP will be a 1-layer MLP with output size being the input size of the second MLP (the picture has a typo somehow)
3. The module returns the transformed features before the are fed to the second MLP, the transformation matrix of the second block, and the global feature.
4. Both Feature Transform modules use the same hidden dimensions for their MLPs.
5. You can hardcode to default parameters, we will never change them


In [3]:
class PointNetfeat(nn.Module):

    ## TODO
    def __init__(self, inp_dim=3, mlp2_hidden_dims=[64,128,1024], transf_hidden_dims1=[64,128,1024], transf_hidden_dims2=[512,256], use_bn=True):
      """
      inp_dim   : int - dimension d of the input
      mlp2_hidden_dims : list - hidden layers (inluding d_out) of the second MLP block (defined with nn.Conv1d...)
      transf_hidden_dims1 : list - hidden layers (inluding d_out) of the first mlp in each FeatureTransform block
      transf_hidden_dims2 : list - hidden layers (without d_out) of the second mlp in each FeatureTransform block
      use_bn : bool - whether to use batchnorm
      """
      super().__init__()

      # USEFULE INFO
      self.inp_dim = inp_dim
      self.latent_dim1 = mlp2_hidden_dims[0] # size of the second features
      self.latent_dim2 = mlp2_hidden_dims[1]
      self.latent_dim3 = mlp2_hidden_dims[-1] # is also the out_dim
      self.use_bn = use_bn


      self.relu = nn.ReLU()
      self.FeatTransform1 = FeatTransform(inp_dim = self.inp_dim, hidden_dims1 = transf_hidden_dims1, hidden_dims2 = transf_hidden_dims2, use_bn = self.use_bn)
      self.FeatTransform2 = FeatTransform(inp_dim = self.latent_dim1, hidden_dims1 = transf_hidden_dims1, hidden_dims2 = transf_hidden_dims2, use_bn = self.use_bn)

      self.conv1 = nn.Conv1d(self.inp_dim, self.latent_dim1, 1)
      self.batchnorm1 = nn.BatchNorm1d(self.latent_dim1)

      self.conv2 = nn.Conv1d(self.latent_dim1, self.latent_dim2, 1)
      self.batchnorm2 = nn.BatchNorm1d(self.latent_dim2)

      self.conv3 = nn.Conv1d(self.latent_dim2, self.latent_dim3, 1)
      self.batchnorm3 = nn.BatchNorm1d(self.latent_dim3)

    def forward(self, x):
      """
      Parameters
      --------------
      x : (B, in_dim, n) input batch in convolution format.

      Output
      -------------
      global_feature : (B, d_out) output global feature (d_out=1024)
      T_feat : (B, latent1, latent1) transformation matrix of the second FeatureTransform module
      x_feat : (B, latent1, n) per-point features after being transformed by the second
                FeatureTransform module
      """
      T1 = self.FeatTransform1.forward(x)   # (B, in_dim, in_dim)
      x = T1 @ x                   # (B, in_dim, n)


      device = x.device


      x = self.conv1(x).to(device) # (B, self.latent_dim1, n)
      if self.use_bn:
        x = self.batchnorm1(x).to(device)            # (B, self.latent_dim1, n)
      x = self.relu(x).to(device)

      T_feat = self.FeatTransform2.forward(x).to(device)                 # (B, self.latent_dim1, self.latent_dim1)
      x_feat = T_feat @ x                             # (B, self.latent_dim1, n)

      x = self.conv2(x_feat).to(device)
      if self.use_bn:
        x = self.batchnorm2(x).to(device)
      x = self.relu(x).to(device)

      x = self.conv3(x).to(device)
      if self.use_bn:
        x = self.batchnorm3(x).to(device)

      global_feature = torch.max(x, dim = 2).values

      return global_feature, T_feat, x_feat

In [5]:
# Try a forward pass
global_, T_, x_ = PointNetfeat()(torch.rand(2,3,2048))
print(global_.shape, T_.shape, x_.shape) #check the shapes

torch.Size([2, 1024]) torch.Size([2, 64, 64]) torch.Size([2, 64, 2048])


# 2 PointNet Classifier

## 2.1 Network Design

The PointNet Classifier is the complete network seen as follow
![](https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD4/PointNet_architecture.PNG)


### Question 3
Implement the `PointNetCls` module.
1. The last MLP also uses BatchNorm + ReLu, **except** for the last layer, which uses either `log_softmax` directly or `softmax`, or even nothing. Note that the training loss you'll chose will be adapted to this choice !
2. You can hardcode to default parameters, we will not change them
3. We don't need the transformed feature, only the transformation matrix for classification

In [6]:
class PointNetCls(nn.Module):
    ## TODO
    def __init__(self, n_cls, inp_dim=3, cls_hidden_dims=[512,256], mlp2_hidden_dims=[64,128,1024], transf_hidden_dims1=[64,128,1024], transf_hidden_dims2=[512,256], use_bn=True):
        """
        PointNet classification network

        Parameters
        ------------------
            n_cls (int): Number of classes for classification
            inp_dim (int): Dimension of the input points
            cls_hidden_dims (list): List of hidden dimensions for the MLP classifier
            mlp2_hidden_dims (list): List of hidden dimensions for the second MLP block in the feature extractor
            transf_hidden_dims1 (list): List of hidden dimensions for the first MLP in each FeatureTransform block
            transf_hidden_dims2 (list): List of hidden dimensions for the second MLP in each FeatureTransform block
            use_bn (bool): Whether to use batchnorm
        """

        super(PointNetCls, self).__init__()

        self.use_bn = use_bn
        self.n_cls = n_cls
        self.cls_hidden_dims = cls_hidden_dims
        self.mlp2_hidden_dims = mlp2_hidden_dims


        self.relu = nn.ReLU()
        self.PointNetfeat = PointNetfeat(inp_dim, mlp2_hidden_dims, transf_hidden_dims1, transf_hidden_dims2, use_bn)

        self.linear1 = nn.Linear(self.mlp2_hidden_dims[-1], self.cls_hidden_dims[0])
        self.batchnorm1 = nn.BatchNorm1d(self.cls_hidden_dims[0])

        self.linear2 = nn.Linear(self.cls_hidden_dims[0], self.cls_hidden_dims[1])
        self.batchnorm2 = nn.BatchNorm1d(self.cls_hidden_dims[1])

        self.linear3 = nn.Linear(self.cls_hidden_dims[1], self.n_cls)


    def forward(self, x):
      """
      Parameters
      --------------
      inp : (B, n, d) input batch in standard format.

      Output
      -------------
      out : (B, n_cls) output score for each class
      T_feat : (B, latent1, latent1) transformation matrix of the second FeatureTransform module
      """
      x = x.transpose(2,1) # Put input in convolution format

      global_feature, T_feat, x_feat = self.PointNetfeat.forward(x)  # (B, d_out)


      device = x.device

      y = self.linear1.to(device)(global_feature)  #(B, self.cls_hidden_dims[0])
      if self.use_bn:
        y = self.batchnorm1.to(device)(y)
      y = self.relu.to(device)(y)

      y = self.linear2.to(device)(y)          #(B, self.cls_hidden_dims[1])
      if self.use_bn:
        y = self.batchnorm2.to(device)(y)
      y = self.relu.to(device)(y)

      y = self.linear3.to(device)(y)                     #(B, self.n_cls)

      return y, T_feat

In [7]:
out, T_feat = PointNetCls(40)(torch.rand(2,6,3))
print(out.shape, T_feat.shape)

torch.Size([2, 40]) torch.Size([2, 64, 64])


## 2.2 - Data Augmentation

Since PointNet takes $(X,Y,Z)$ coordinates as input, it is not invariant to rigid transformations of the point cloud. One possibility is to make it learn to be invariant, by rigidly modifying shapes during training.

### Question 4
Fill the random scaling and translation function using numpy.
1. Scale and translate each coordinate independantly
2. Scaling factor lies between $\frac{2}{3}$ and $\frac{3}{2}$
2. Translation lies between $-0.2$ and $0.2$

In [8]:
import numpy as np

In [4]:
def random_scale_transl(X):
        """Apply random scaling to the point cloud.
        Scaling is applied for

        Args:
            X (np.ndarray): Point cloud data, (N, 3).

        Returns:
            X_transformed (np.ndarray): Scaled (and translated) point cloud data (N, 3).
        """
        scaling_factor = np.random.uniform(2/3, 3/2)
        translation_factor = np.random.uniform(-0.2, 0.2)

        X_transformed = scaling_factor*X + translation_factor

        return X_transformed

### Question 5
For this specific dataset, all shapes pre-aligned, and only variability is rotation along the $y$ axis.

Therefore simply generate a random rotation matrix along the y axis.

In [5]:
def random_rotation(X):
        # TODO
        """Apply random rotation to the point cloud (along the y axis)

        Args:
            X (np.ndarray): Point cloud data, (N, 3).

        Returns:
            X_rot (np.ndarray): Rotated point cloud data (N, 3).
        """
        # Generate random rotation angle within specified range
        theta = np.random.uniform(-np.pi, np.pi)

        # Create rotation matrix for y-axis rotation
        rotation_matrix = np.array([
            [np.cos(theta),  0, np.sin(theta)],
            [0            ,  1, 0           ],
            [-np.sin(theta), 0, np.cos(theta)]
        ])

        # Apply rotation
        X_rot = X @ rotation_matrix.T

        return X_rot

## 2.3 Data Loading

We will train on the [ModelNet40](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Wu_3D_ShapeNets_A_2015_CVPR_paper.pdf) dataset.

This dataset consist in pointclouds of 12000 objects of 40 different categories such as airplane, car, plant, lamp. Because we only need to provide a global feature vector for each shape, we can subsample each surface with a fixed number of parameters, which allows us to use large batch sizes.

![](https://production-media.paperswithcode.com/datasets/modelnet.jpeg)


We provide below code for dataloading.

Let's first download the data and define the loaders.

In [13]:
!mkdir 'data'
!wget -O './data/modelnet40_ply_hdf5_2048.zip' https://huggingface.co/datasets/Msun/modelnet40/resolve/main/modelnet40_ply_hdf5_2048.zip?download=true --no-check-certificate
!unzip data/modelnet40_ply_hdf5_2048.zip -d ./data/

--2024-11-01 15:02:47--  https://huggingface.co/datasets/Msun/modelnet40/resolve/main/modelnet40_ply_hdf5_2048.zip?download=true
Resolving huggingface.co (huggingface.co)... 3.167.112.25, 3.167.112.38, 3.167.112.45, ...
Connecting to huggingface.co (huggingface.co)|3.167.112.25|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.hf.co/repos/36/90/3690071fa5fce49ab533e0e81bd5b7fd1fc7337b68386704b06e7c4dfe6eed96/f01b8189281fae5790e39deb9f3eca86e446b771bdc665c6ad05f28d039b20e7?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27modelnet40_ply_hdf5_2048.zip%3B+filename%3D%22modelnet40_ply_hdf5_2048.zip%22%3B&response-content-type=application%2Fzip&Expires=1730732567&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMDczMjU2N319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNi85MC8zNjkwMDcxZmE1ZmNlNDlhYjUzM2UwZTgxYmQ1YjdmZDFmYzczMzdiNjgzODY3MDRiMDZlN2M0ZGZlNmVlZDk2L2YwMWI4MTg5MjgxZ

In [15]:
!pip install tqdm h5py

Collecting h5py
  Downloading h5py-3.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Downloading h5py-3.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m100.7 MB/s[0m eta [36m0:00:00[0m
[0mInstalling collected packages: h5py
[0mSuccessfully installed h5py-3.12.1


In [16]:
import h5py
import os.path as osp

from pathlib import Path
from glob import glob
from tqdm.auto import tqdm
from torch.utils.data import Dataset

def load_h5(h5_filename):
    f = h5py.File(h5_filename)
    print(h5_filename)
    data = f['data'][:]
    label = f['label'][:]
    f.close()
    return (data, label)

def load_h5_files(data_path, files_list_path):
    """Load h5 into memory
    """
    files_list = [Path(line.rstrip()).name for line in open(osp.join(data_path, files_list_path))]
    data = []
    labels = []
    for i in range(len(files_list)):
        data_, labels_ = load_h5(os.path.join(data_path, files_list[i]))
        data.append(data_)
        labels.append(labels_)
    data = np.concatenate(data, axis=0)
    labels = np.concatenate(labels, axis=0)
    return data, labels


class ModelNet40Generator(Dataset):

    def __init__(self, mode, data_dir, files_list, num_classes=40, num_points=1024, rot_aug=False):
        """Initialize the ModelNet40Generator class.

        Args:
            mode (str): 'train' or 'test'.
            data_dir (str): Path to the data directory.
            files_list (str): Path to the files list.
            num_classes (int, optional): ModelNet class. Defaults to 40.
            num_points (int, optional): Input PC. Defaults to 1024.
            rot_aug (bool, optional): Augment data. Defaults to False.
        """

        assert mode.lower() in ('train', 'val', 'test')
        assert files_list in ('train_files.txt', 'test_files.txt')
        self.data, labels = load_h5_files(data_dir, files_list)

        self.mode = mode
        self.num_points = num_points
        self.num_classes = num_classes
        self.num_samples = self.data.shape[0]
        self.rot_aug = rot_aug
        self.labels = np.reshape(labels, (-1,))


    def __len__(self) -> int:
        return self.num_samples


    def __getitem__(self, idx):
        indexes = np.random.permutation(np.arange(self.num_points))[:self.num_points]
        X = self.data[idx, indexes, ...]
        y = self.labels[idx, ...]
        y_categorical = torch.from_numpy(np.eye(self.num_classes, dtype='uint8')[y]).long()

        if self.mode == 'train' and self.rot_aug:
            # X = self.random_scaling(X)
            X = self.random_rotation_3d(X)

        return torch.from_numpy(X).float(), y_categorical.float()


    def random_scaling(self, X):
        """Apply random scaling to the point cloud.

        Args:
            X (np.ndarray): Point cloud data, (N, 3).

        Returns:
            X (np.ndarray): Scaled point cloud data (N, 3).
        """
        return random_scale_transl(X)


    def random_rotation_3d(self, X):
        """Apply random rotation to the point cloud (axis agnostic).

        Args:
            X (np.ndarray): Point cloud data, (N, 3).

        Returns:
            rotated_data (np.ndarray): Rotated point cloud data (N, 3).
        """
        return random_rotation(X)



## 2.4. Classification Loss

PointNet classification loss is defined as the weighted sum of two losses: the classification loss $L_{cls}$, such as cross-entropy, and a regularization loss $L_{reg}$ which regularizes the FeatureTransform output.

In particular, the regularization loss forces the (second) predicted transformation matrix $T$ to be near-orthogonal (so close to a "rotation"):

$$
L_{reg}(T) = \|T^\top T - I\|_2^2
$$

This loss is averaged along the batch.

This regularization is only applied to the second transformation matrix, that is the one applied to the feature space.

### Question 6
Code the classification loss. Adapt $L_{cls}$ to whatever activation you used in you architecture.

In [17]:
def classification_loss(preds, gt_labels, T_feat, w_reg):
        """
        Classification loss for PointNet

        Parameters
        --------------
        preds : (B, n_cls) - output of the classifier
        gt_labels : (B,) - ground truth labels
        T_feat : (B, latent1, latent1) - transformation matrix of the second FeatureTransform module
        w_reg : weight of the regularization term

        Output
        -------------
        loss : classification loss L_{cls} + w_reg * L_{reg}
        """
        device = T_feat.device

        criterion = nn.CrossEntropyLoss()

        L_cls = criterion(preds, gt_labels).to(device)

        L_reg = torch.linalg.matrix_norm(torch.matmul(T_feat.transpose(1, 2), T_feat) - torch.eye(T_feat.shape[1]).to(device), ord=2)**2

        loss = L_cls + w_reg*torch.mean(L_reg)

        return loss

In [18]:
T_feat = torch.randint(0, 2, (4, 3, 3))
print(T_feat)
T_t = T_feat.transpose(1, 2)
print(T_t)
n = torch.linalg.matrix_norm(torch.matmul(T_feat.transpose(1, 2), T_feat) - torch.eye(T_feat.shape[1]),  ord=2)**2
print(n)

tensor([[[1, 0, 1],
         [1, 1, 1],
         [1, 0, 0]],

        [[0, 1, 1],
         [0, 1, 0],
         [1, 1, 1]],

        [[1, 1, 0],
         [0, 1, 1],
         [0, 1, 1]],

        [[1, 1, 0],
         [1, 1, 0],
         [1, 0, 1]]])
tensor([[[1, 1, 1],
         [0, 1, 0],
         [1, 1, 0]],

        [[0, 0, 1],
         [1, 1, 1],
         [1, 0, 1]],

        [[1, 0, 0],
         [1, 1, 1],
         [0, 1, 1]],

        [[1, 1, 1],
         [1, 1, 0],
         [0, 0, 1]]])
tensor([16.3937, 16.3937, 13.9282, 13.9282])


## 2.5 Training

This cell defines the trainer class, using your classification loss, simply run it.

In [41]:
class Trainer(object):

    def __init__(self, train_loader, valid_loader, device='cuda',
                 lr=1e-3, weight_decay=1e-4, num_epochs=200,
                 lr_decay_every = 50, lr_decay_rate = 0.5,
                 log_interval=10, save_dir=None, **model_cfg):

        """
        train_loader: (torch.utils.DataLoader) DataLoader for training set
        valid_loader: (torch.utils.DataLoader) DataLoader for validation set
        device: (str) 'cuda' or 'cpu'
        lr: (float) learning rate
        weight_decay: (float) weight decay for optimiser
        num_epochs: (int) number of epochs
        lr_decay_every: (int) decay learning rate every this many epochs
        lr_decay_rate: (float) decay learning rate by this factor
        log_interval: (int) print training stats every this many iterations
        save_dir: (str) directory to save model checkpoints
        model_cfg: (dict) keyword arguments for model
        """
        self.model = self.get_model(model_cfg).to(device)
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.device = device
        self.lr = lr
        self.weight_decay = weight_decay
        self.num_epochs = num_epochs

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        self.lr_decay_every = lr_decay_every
        self.lr_decay_rate = lr_decay_rate
        self.log_interval = log_interval
        self.save_dir = save_dir
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []

        self.inp_feat = model_cfg.get('inp_feat', 'xyz')
        self.num_eig = model_cfg.get('num_eig', 128)
        if not self.inp_feat in ['xyz', 'xyzn', 'hks', 'wks']:
            raise ValueError('inp_feat must be one of xyz, xyzn, hks, wks')

        self.model.to(self.device)

    def get_model(self, model_cfg):
        if model_cfg['name'].lower() == 'pointnet':
            return PointNetCls(n_cls=model_cfg['n_cls'],
                              inp_dim=model_cfg['inp_dim'],
                              cls_hidden_dims=model_cfg['cls_hidden_dims'],
                              mlp2_hidden_dims=model_cfg['mlp2_hidden_dims'],
                              transf_hidden_dims1=model_cfg['transf_hidden_dims1'],
                              transf_hidden_dims2=model_cfg['transf_hidden_dims2'])
        else:
            raise ValueError('%s must be one of PointNet, PointNet++'%model_cfg['name'])

    def adjust_lr(self):
        lr = self.lr * self.lr_decay_rate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def forward_step(self, inp):
        """
        Perform a forward step of the model.

        Args:
            inp (torch.Tensor): (N, D) tensor of input features

        Returns:
            pred (torch.Tensor): (N, p_out) tensor of predicted labels.
        """
        inp = inp.to(self.device)
        preds, trans_feat = self.model(inp)
        return preds, trans_feat


    def get_loss(self, preds, gt_labels, trans_feat):
        loss = classification_loss(preds, gt_labels.argmax(dim=1), trans_feat, 1e-3)
        return loss

    def train_epoch(self):
        train_loss = 0
        train_acc = 0
        for i, (inp_pts, gt_labels) in enumerate(tqdm(self.train_loader)):
            self.optimizer.zero_grad()
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            loss = self.get_loss(preds, gt_labels, trans_feat)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            pred_labels = torch.max(preds, dim=1).indices
            this_correct = pred_labels.eq(gt_labels.argmax(dim=-1)).sum().item()
            train_acc += this_correct/gt_labels.shape[0]

        return train_loss/len(self.train_loader), train_acc/len(self.train_loader)

    def valid_epoch(self):
        val_loss = 0
        val_acc = 0
        for i, (inp_pts, gt_labels) in enumerate(tqdm(self.valid_loader, leave=False)):
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            loss = self.get_loss(preds, gt_labels, trans_feat)
            val_loss += loss.item()
            pred_labels = torch.max(preds, dim=1).indices
            this_correct = pred_labels.eq(gt_labels.argmax(dim=-1)).sum().item()
            val_acc += this_correct/gt_labels.shape[0]

        return val_loss/len(self.valid_loader), val_acc/len(self.valid_loader)

    def run(self):
        for epoch in tqdm(range(self.num_epochs)):
            self.model.train()

            if epoch % self.lr_decay_every == 0:
                self.adjust_lr()

            train_ep_loss, train_ep_acc = self.train_epoch()
            self.train_losses.append(train_ep_loss)
            self.train_accs.append(train_ep_acc)

            if epoch % self.log_interval == 0:
                val_loss, val_acc = self.valid_epoch()
                self.val_losses.append(val_loss)
                self.val_accs.append(val_acc)
                torch.save(self.model.state_dict(), os.path.join(self.save_dir, 'model_latest.pth'))
                print('Epoch: {:03d}, Train Loss: {:.4f}, Train Acc: {:.2f}%, Val Loss: {:.4f}, Val Acc: {:.2f}%'.format(epoch,
                                                                                                                       train_ep_loss,
                                                                                                                       1e2*train_ep_acc,
                                                                                                                       val_loss,
                                                                                                                       1e2*val_acc))
        torch.save(self.model.state_dict(), os.path.join(self.save_dir, 'model_final.pth'))

    def test(self):
        file_final = os.path.join(self.save_dir, 'model_final.pth')
        if not os.path.exists(file_final):
            print('-------------------------------------------------------')
            print('No final weights, switching to last weights')
            print('-------------------------------------------------------')
            file_final = os.path.join(self.save_dir, 'model_latest.pth')
        weights = torch.load(file_final)
        self.model.load_state_dict(weights)
        _, score = self.valid_epoch()
        print('Final Valid Accuracy : {:.2f}%'.format(1e2*score))

Run the following cells to train the network.

In [42]:
N_CLS = 40
data_dir = 'data/modelnet40_ply_hdf5_2048'
train_data = ModelNet40Generator(mode='train', data_dir=data_dir, files_list='train_files.txt', rot_aug=True)
val_data = ModelNet40Generator(mode='val', data_dir=data_dir, files_list='test_files.txt')
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=0)

data/modelnet40_ply_hdf5_2048/ply_data_train0.h5
data/modelnet40_ply_hdf5_2048/ply_data_train1.h5
data/modelnet40_ply_hdf5_2048/ply_data_train2.h5
data/modelnet40_ply_hdf5_2048/ply_data_train3.h5
data/modelnet40_ply_hdf5_2048/ply_data_train4.h5
data/modelnet40_ply_hdf5_2048/ply_data_test0.h5
data/modelnet40_ply_hdf5_2048/ply_data_test1.h5


You will need to modify this cell if you are not using Google Colab + Drive

In [43]:
import os
# Comment these if you are going local
#from google.colab import drive
#drive.mount('/content/drive/')
## Change this if your are going local, otherwise keep it (checkpoints will be in checkpoint folder of your drive)
save_dir = './checkpoints/classification'
os.makedirs(save_dir, exist_ok=True)

In [44]:
model_cfg = dict(name="pointnet", n_cls=N_CLS, inp_dim=3, cls_hidden_dims=[512,256], mlp2_hidden_dims=[64,128,1024], transf_hidden_dims1=[64,128,1024], transf_hidden_dims2=[512,256])

#model_cfg = {'name': 'pointnet', 'n_cls': N_CLS, 'conv_dims': [64, 128, 1024], 'fc_dims': [512, 256]}
trainer = Trainer(train_loader, val_loader, lr=0.001, weight_decay=0.0, num_epochs=100, save_dir=os.path.join(save_dir, 'classifier'), **model_cfg)
trainer.run()

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

Epoch: 000, Train Loss: 2.2903, Train Acc: 40.44%, Val Loss: 1.8861, Val Acc: 44.07%


  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

Epoch: 010, Train Loss: 0.6522, Train Acc: 79.69%, Val Loss: 0.8176, Val Acc: 74.92%


  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

Epoch: 020, Train Loss: 0.4701, Train Acc: 84.34%, Val Loss: 0.6835, Val Acc: 80.25%


  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

Epoch: 030, Train Loss: 0.4040, Train Acc: 86.21%, Val Loss: 0.6793, Val Acc: 82.05%


  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

Epoch: 040, Train Loss: 0.3012, Train Acc: 89.58%, Val Loss: 0.9070, Val Acc: 82.65%


  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

Epoch: 050, Train Loss: 0.2508, Train Acc: 91.08%, Val Loss: 0.7655, Val Acc: 83.41%


  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

Epoch: 060, Train Loss: 0.2274, Train Acc: 91.76%, Val Loss: 0.6695, Val Acc: 82.45%


  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

Epoch: 070, Train Loss: 0.2152, Train Acc: 92.41%, Val Loss: 0.6366, Val Acc: 83.77%


  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

Epoch: 080, Train Loss: 0.1835, Train Acc: 93.35%, Val Loss: 0.6557, Val Acc: 83.61%


  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

Epoch: 090, Train Loss: 0.1645, Train Acc: 94.45%, Val Loss: 0.6735, Val Acc: 84.05%


  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

In [45]:
trainer.test()

  0%|          | 0/78 [00:00<?, ?it/s]

Final Valid Accuracy : 83.41%


# 3. PointNet Segmentation

## 3.1 Network Design

We now build the PointNet feature extractor network (slide 34, output is `global_feature`)
![](https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD4/pointnet_segmentation.PNG)


### Question 7
Build the segmentation network in the case of a batch size of 1.
**Disable all batchnorms**

In [6]:
class PointNetSeg(nn.Module):
    ## TODO
    def __init__(self, n_cls, inp_dim=3, seg_hidden_dims=[512,256,128], mlp2_hidden_dims=[64,128,1024], transf_hidden_dims1=[64,128,1024], transf_hidden_dims2=[512,256]):
        """
        PointNet classification network

        Parameters
        ------------------
            n_cls (int): Number of classes for classification
            inp_dim (int): Dimension of the input points
            seg_hidden_dims (list): List of hidden dimensions for the segmentation MLP
            mlp2_hidden_dims (list): List of hidden dimensions for the second MLP block in the feature extractor
            transf_hidden_dims1 (list): List of hidden dimensions for the first MLP in each FeatureTransform block
            transf_hidden_dims2 (list): List of hidden dimensions for the second MLP in each FeatureTransform block
        """

        super().__init__()

        self.input_dim = mlp2_hidden_dims[0] + mlp2_hidden_dims[-1]
        self.PointNetfeat = PointNetfeat(inp_dim, mlp2_hidden_dims, transf_hidden_dims1, transf_hidden_dims2, use_bn = False) #disable all batchnorms
        self.relu = nn.ReLU()


        self.seg_mlp_layers = []

        self.seg_mlp_layers.append(nn.Conv1d(self.input_dim, seg_hidden_dims[0], 1))
        self.seg_mlp_layers.append(self.relu)

        for i in range(len(seg_hidden_dims)-1):
            self.seg_mlp_layers.append(nn.Conv1d(seg_hidden_dims[i], seg_hidden_dims[i+1], 1))
            self.seg_mlp_layers.append(self.relu)

        self.seg_mlp_layers.append(nn.Conv1d(seg_hidden_dims[-1], n_cls, 1))


        self.seg_mlp = nn.Sequential(*self.seg_mlp_layers)



    def forward(self, x):
      """
      Parameters
      --------------
      inp : (B, n, d) input batch in standard format.

      Output
      -------------
      out : (B, n, n_cls) output score for each class
      T_feat : (B, latent1, latent1) transformation matrix of the second FeatureTransform module
      """
      x = x.transpose(2,1) # Put input in convolution format

      device = x.device

      global_feature, T_feat, x_feat = self.PointNetfeat.forward(x)

      x_feat = x_feat.transpose(2,1)          #(B, n , 64)
      new_global_feature = global_feature.unsqueeze(1).repeat(1, x.shape[2], 1)     # (B, n, 1024)

      x_input_mlp = torch.cat((x_feat, new_global_feature), dim = 2).transpose(2,1)            # (B, 1088, n)

      out = self.seg_mlp(x_input_mlp).transpose(2,1)                                                          # (B, n, n_cls)

      return out, T_feat

In [12]:
out_seg, T_feat = PointNetSeg(260)(torch.rand(2, 2046, 3))
print(out_seg.shape, T_feat.shape) # check the shapes

torch.Size([2, 2046, 260]) torch.Size([2, 64, 64])


## 3.2 Data Loading

We will train on the RNA dataset from Lab3. Since all shapes have different number of vertices, we can't use batch size bigger than one.

Run the following cells to download and define code to load the data.

In [13]:
!pip install potpourri3d
#Uncomment for colab
# !pip install git+https://github.com/skoch9/meshplot.git
# !pip install pythreejs

!wget https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD3/diffusion_utils.py
!wget https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD3/RNADataset.zip
!unzip -o RNADataset.zip
!mkdir models
!wget https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/material_TD3.zip
!unzip -o material_TD3.zip

Collecting potpourri3d
  Obtaining dependency information for potpourri3d from https://files.pythonhosted.org/packages/e2/da/3a578d70c99c77f13e94289ec60f952fa7c97fd27cb7506f088878fc930f/potpourri3d-1.1.0-cp311-cp311-macosx_11_0_arm64.whl.metadata
  Downloading potpourri3d-1.1.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (17 kB)
Downloading potpourri3d-1.1.0-cp311-cp311-macosx_11_0_arm64.whl (736 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m736.5/736.5 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hInstalling collected packages: potpourri3d
Successfully installed potpourri3d-1.1.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
--2024-11-01 20:20:09--  https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD3/diffusion_utils.py

In [28]:
!pip install scikit-learn scipy

[31mERROR: Could not find a version that satisfies the requirement meshplot (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for meshplot[0m[31m
[0m

In [29]:
!pip install pythreejs
!pip install git+https://github.com/skoch9/meshplot.git

[0mCollecting pythreejs
  Downloading pythreejs-2.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting ipydatawidgets>=1.1.1 (from pythreejs)
  Downloading ipydatawidgets-4.3.5-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting traittypes>=0.2.0 (from ipydatawidgets>=1.1.1->pythreejs)
  Downloading traittypes-0.2.1-py2.py3-none-any.whl.metadata (1.0 kB)
Downloading pythreejs-2.4.2-py3-none-any.whl (3.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m105.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ipydatawidgets-4.3.5-py2.py3-none-any.whl (271 kB)
Downloading traittypes-0.2.1-py2.py3-none-any.whl (8.6 kB)
[0mInstalling collected packages: traittypes, ipydatawidgets, pythreejs
[0mSuccessfully installed ipydatawidgets-4.3.5 pythreejs-2.4.2 traittypes-0.2.1
[0mCollecting git+https://github.com/skoch9/meshplot.git
  Cloning https://github.com/skoch9/meshplot.git to /tmp/pip-req-build-zt1kjg2c
  Running command git clone --filter=blob

In [7]:
from diffusion_utils import normalize_positions, read_mesh
from mesh_utils.mesh import TriMesh
import plot_utils as plu # Comment if you are not visualizing the results

class RNAMeshDataset():
    """RNA Mesh Dataset
    """

    def __init__(self, root_dir, train, n_classes=260):

        """
            root_dir (string): Directory with all the meshes.
            train (bool): If True, use the training set, else use the test set.
            num_eig (int): Number of eigenvalues to use.
            op_cache_dir (string): Directory to cache the operators.
            n_classes (int): Number of classes.
        """

        self.train = train  # bool
        self.root_dir = root_dir
        self.n_class = n_classes # (includes -1)

        # store in memory
        self.verts_list = []
        self.faces_list = []
        self.labels_list = []  # per-vertex

        # Load the meshes & labels
        if self.train:
            with open(os.path.join(self.root_dir, "train.txt")) as f:
                this_files = [line.rstrip() for line in f]
        else:
            with open(os.path.join(self.root_dir, "test.txt")) as f:
                this_files = [line.rstrip() for line in f]

        print("loading {} files: {}".format(len(this_files), this_files))

        # Load the actual files

        off_path = os.path.join(root_dir, "off")
        label_path = os.path.join(root_dir, "labels")
        for f in this_files:
            off_file = os.path.join(off_path, f)
            label_file = os.path.join(label_path, f[:-4] + ".txt")

            verts, faces = read_mesh(off_file)
            labels = np.loadtxt(label_file).astype(int) + 1 # shift -1 --> 0

            verts = torch.tensor(verts).float()
            faces = torch.tensor(faces)
            labels = torch.tensor(labels)

            # center and unit scale
            verts = normalize_positions(verts)

            self.verts_list.append(verts)
            self.faces_list.append(faces)
            self.labels_list.append(labels)


    def __len__(self):
        return len(self.verts_list)

    def __getitem__(self, idx):
        return self.verts_list[idx], self.faces_list[idx], self.labels_list[idx]

## 3.3 Segmentation Loss

PointNet segmentation loss is defined the same as the classification loss, except we know perform per-vertex classification using $L_{seg}$ instead of $L_{cls}$. Adapt again the loss to your network design


### Question 8
Code the segmentation loss. Adapt $L_{seg}$ to whatever activation you used in you architecture. You might be able to use exactly the same code.

In [8]:
def segmentation_loss(preds, gt_labels, T_feat, w_reg):
        """
        Classification loss for PointNet

        Parameters
        --------------
        preds : (B, n, n_cls) - output of the segmentation network
        gt_labels : (B, n) - ground truth labels
        T_feat : (B, latent1, latent1) - transformation matrix of the second FeatureTransform module
        w_reg : weight of the regularization term

        Output
        -------------
        loss : segmentation loss L_{seg} + w_reg * L_{reg}
        """

        device = T_feat.device

        criterion = nn.CrossEntropyLoss()

        L_cls = criterion(preds, gt_labels).to(device)

        L_reg = torch.linalg.matrix_norm(torch.matmul(T_feat.transpose(1, 2), T_feat) - torch.eye(T_feat.shape[1]).to(device), ord=2)**2

        loss = L_cls + w_reg*torch.mean(L_reg)

        return loss

## 3.4 Training

This defines a trainer class using your loss. Simply run the cells.

In [20]:
class TrainerSeg(object):

    def __init__(self, train_loader, valid_loader, device='cuda',
                 lr=1e-3, weight_decay=1e-4, num_epochs=200,
                 lr_decay_every = 50, lr_decay_rate = 0.5,
                 log_interval=10, save_dir=None, **model_cfg):

        """
        pointNet_cls: (nn.Module) class of the PointNet (or ++) model.
        train_loader: (torch.utils.DataLoader) DataLoader for training set
        valid_loader: (torch.utils.DataLoader) DataLoader for validation set
        device: (str) 'cuda' or 'cpu'
        lr: (float) learning rate
        weight_decay: (float) weight decay for optimiser
        num_epochs: (int) number of epochs
        lr_decay_every: (int) decay learning rate every this many epochs
        lr_decay_rate: (float) decay learning rate by this factor
        log_interval: (int) print training stats every this many iterations
        save_dir: (str) directory to save model checkpoints
        model_cfg: (dict) keyword arguments for model
        """
        self.model = self.get_model(model_cfg).to(device)
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.device = device
        self.lr = lr
        self.weight_decay = weight_decay
        self.num_epochs = num_epochs

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        self.loss = torch.nn.CrossEntropyLoss()

        self.lr_decay_every = lr_decay_every
        self.lr_decay_rate = lr_decay_rate
        self.log_interval = log_interval
        self.save_dir = save_dir
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []

        self.inp_feat = model_cfg.get('inp_feat', 'xyz')
        self.num_eig = model_cfg.get('num_eig', 128)
        if not self.inp_feat in ['xyz', 'xyzn', 'hks', 'wks']:
            raise ValueError('inp_feat must be one of xyz, xyzn, hks, wks')

        self.model.to(self.device)

    def get_model(self, model_cfg):
        if model_cfg['name'].lower() == 'pointnet':
            return PointNetSeg(n_cls=model_cfg['n_cls'],
                              inp_dim=model_cfg['inp_dim'],
                              seg_hidden_dims=model_cfg['seg_hidden_dims'],
                              mlp2_hidden_dims=model_cfg['mlp2_hidden_dims'],
                              transf_hidden_dims1=model_cfg['transf_hidden_dims1'],
                              transf_hidden_dims2=model_cfg['transf_hidden_dims2'])
        else:
            raise ValueError('%s must be one of PointNet, PointNet++'%model_cfg['name'])

    def adjust_lr(self):
        lr = self.lr * self.lr_decay_rate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def forward_step(self, inp):
        """
        Perform a forward step of the model.

        Args:
            inp (torch.Tensor): (N, D) tensor of input features

        Returns:
            pred (torch.Tensor): (N, p_out) tensor of predicted labels.
        """
        inp = inp.to(self.device)
        verts = inp[..., :3].to(self.device)
        norms = inp[..., 3:].to(self.device)
        preds, trans_feat = self.model(inp)
        return preds, trans_feat



    def get_loss(self, preds, gt_labels, trans_feat):

        loss = segmentation_loss(preds, gt_labels, trans_feat, 1e-3)
        return loss

    def train_epoch(self):
        train_loss = 0
        train_acc = 0
        for i, (inp_pts, _, gt_labels) in enumerate(tqdm(self.train_loader)):
            self.optimizer.zero_grad()
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            loss = self.get_loss(preds[0], gt_labels[0], trans_feat)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()

            pred_labels = torch.max(preds[0], dim=1).indices  # (N,)
            this_correct = pred_labels.eq(gt_labels).sum().item()
            train_acc += this_correct / inp_pts.shape[-2]

        return train_loss/len(self.train_loader), train_acc/len(self.train_loader)

    def valid_epoch(self):
        val_loss = 0
        val_acc = 0
        total = 0
        for i, (inp_pts, _, gt_labels) in enumerate(self.valid_loader):
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            loss = self.get_loss(preds[0], gt_labels[0], trans_feat)
            val_loss += loss.item()

            pred_labels = torch.max(preds[0], dim=1).indices
            this_correct = pred_labels.eq(gt_labels).sum().item()
            total += inp_pts.shape[-2]
            val_acc += this_correct

        return val_loss/total, val_acc/total

    def run(self):
        for epoch in tqdm(range(self.num_epochs)):
            self.model.train()

            if epoch % self.lr_decay_every == 0:
                self.adjust_lr()

            train_ep_loss, train_ep_acc = self.train_epoch()
            self.train_losses.append(train_ep_loss)
            self.train_accs.append(train_ep_acc)

            if epoch % self.log_interval == 0:
                val_loss, val_acc = self.valid_epoch()
                self.val_losses.append(val_loss)
                self.val_accs.append(val_acc)
                torch.save(self.model.state_dict(), os.path.join(self.save_dir, 'model_latest.pth'))
                print('Epoch: {:03d}, Train Loss: {:.4f}, Train Acc: {:.2f}, Val Loss: {:.4f}, Val Acc: {:.2f}%'.format(epoch,
                                                                                                                       train_ep_loss,
                                                                                                                       1e2*train_ep_acc,
                                                                                                                       val_loss,
                                                                                                                       1e2*val_acc))
        torch.save(self.model.state_dict(), os.path.join(self.save_dir, 'model_final.pth'))

    def visualize(self):
        """
        We only test the first two shapes of validation set.
        """
        self.model.eval()
        test_seg_meshes = []

        for i, (inp_pts, inp_faces, gt_labels) in enumerate(self.valid_loader):
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            pred_labels = torch.max(preds[0], dim=1).indices
            test_seg_meshes.append([TriMesh(inp_pts.squeeze().cpu().numpy(), inp_faces.squeeze().cpu().numpy()),
                                  pred_labels.squeeze().cpu().numpy()])
            if i==1:
                break

        cmap1 = plt.get_cmap("jet")(test_seg_meshes[0][-1] / (146))[:,:3]
        cmap2 = plt.get_cmap("jet")(test_seg_meshes[1][-1] / (146))[:,:3]

        plu.double_plot(test_seg_meshes[0][0], test_seg_meshes[1][0], cmap1, cmap2)

    def test(self):
        file_final = os.path.join(self.save_dir, 'model_final.pth')
        if not os.path.exists(file_final):
            print('-------------------------------------------------------')
            print('No final weights, switching to last weights')
            print('-------------------------------------------------------')
            file_final = os.path.join(self.save_dir, 'model_latest.pth')
        weights = torch.load(file_final)
        self.model.load_state_dict(weights)
        _, score = self.valid_epoch()
        print('Final Valid Segmentation Accuracy : {:.2f}%'.format(1e2*score))

Run these cells to see

In [9]:
root_dir = './RNADataset/'

train_dataset = RNAMeshDataset(root_dir, train=True)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0, persistent_workers=False)

valid_dataset = RNAMeshDataset(root_dir, train=False)
valid_loader = DataLoader(valid_dataset, batch_size=1, num_workers=0, persistent_workers=False)

loading 60 files: ['5JUO_D.off', '5T62_B.off', '4WCE_Y.off', '4V9K_BB.off', '5E81_16.off', '3J92_7.off', '4V7X_BB.off', '5T7V_C.off', '4TUE_RB.off', '5DGE_3.off', '4V7B_BB.off', '4U4O_3.off', '3JA1_LB.off', '3PIO_Y.off', '5EL7_16.off', '5MDY_3.off', '5H5U_B.off', '4U4O_7.off', '5TBW_AS.off', '5J30_RB.off', '4W4G_RB.off', '4V7J_AB.off', '4V8J_BB.off', '2OTJ_9.off', '4V7X_DB.off', '3J9Y_B.off', '1VQK_9.off', '1Q82_B.off', '4W2H_DB.off', '3G4S_9.off', '1NKW_9.off', '4V5J_BB.off', '3JBN_AB.off', '4V8E_CB.off', '4YZV_YB.off', '4V9P_CB.off', '5J5B_DB.off', '4U6F_7.off', '5OBM_7.off', '3JCD_B.off', '4V55_BA.off', '5IBB_16.off', '1VY7_BB.off', '5GAD_B.off', '4U4Q_3.off', '4WRA_1J.off', '5ADY_A.off', '4V70_BB.off', '4TUB_RB.off', '5AJ0_A4.off', '1KD1_B.off', '4V6N_AA.off', '4V8I_BB.off', '4W4G_YB.off', '4V8F_DB.off', '4V8P_E3.off', '4V5Y_BA.off', '4TUA_YB.off', '5DFE_RB.off', '4V5K_DB.off']
loading 13 files: ['4V83_BB.off', '1VQ6_9.off', '4V7K_BB.off', '4V9A_DB.off', '3JQ4_B.off', '5KCS_1B.off'

In [36]:
#model_cfg = {'name': 'pointnet', 'n_cls': 260, 'conv_dims': [64, 128, 1024], 'fc_dims': [512, 256]}
model_cfg = dict(name="pointnet", n_cls=260, inp_dim=3, seg_hidden_dims=[512,256], mlp2_hidden_dims=[64,128,1024], transf_hidden_dims1=[64,128,1024], transf_hidden_dims2=[512,256])

trainer = TrainerSeg(train_loader, valid_loader, lr=0.001, device='cuda', weight_decay=0.0, num_epochs=300, save_dir=os.path.join(save_dir, 'segmentation'), **model_cfg)
trainer.run()

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 000, Train Loss: 4.9819, Train Acc: 1.10, Val Loss: 0.0003, Val Acc: 1.38%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 010, Train Loss: 3.1557, Train Acc: 12.83, Val Loss: 0.0002, Val Acc: 17.07%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 020, Train Loss: 2.6876, Train Acc: 22.70, Val Loss: 0.0002, Val Acc: 23.07%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 030, Train Loss: 2.0130, Train Acc: 33.87, Val Loss: 0.0001, Val Acc: 33.06%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 040, Train Loss: 2.0148, Train Acc: 34.61, Val Loss: 0.0001, Val Acc: 37.94%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 050, Train Loss: 1.4187, Train Acc: 48.40, Val Loss: 0.0001, Val Acc: 49.81%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 060, Train Loss: 1.2951, Train Acc: 51.97, Val Loss: 0.0001, Val Acc: 55.20%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 070, Train Loss: 1.1387, Train Acc: 56.53, Val Loss: 0.0001, Val Acc: 55.72%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 080, Train Loss: 1.0102, Train Acc: 60.72, Val Loss: 0.0001, Val Acc: 57.17%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 090, Train Loss: 0.9903, Train Acc: 62.22, Val Loss: 0.0001, Val Acc: 63.17%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 100, Train Loss: 0.9089, Train Acc: 64.57, Val Loss: 0.0001, Val Acc: 63.49%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 110, Train Loss: 0.8742, Train Acc: 65.82, Val Loss: 0.0001, Val Acc: 62.29%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 120, Train Loss: 0.8481, Train Acc: 66.98, Val Loss: 0.0001, Val Acc: 56.47%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 130, Train Loss: 0.8902, Train Acc: 65.49, Val Loss: 0.0001, Val Acc: 64.97%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 140, Train Loss: 0.7503, Train Acc: 70.39, Val Loss: 0.0001, Val Acc: 57.40%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 150, Train Loss: 0.8416, Train Acc: 67.31, Val Loss: 0.0001, Val Acc: 58.36%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 160, Train Loss: 0.7325, Train Acc: 70.92, Val Loss: 0.0001, Val Acc: 64.81%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 170, Train Loss: 0.7559, Train Acc: 70.06, Val Loss: 0.0001, Val Acc: 67.05%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 180, Train Loss: 0.7259, Train Acc: 71.15, Val Loss: 0.0001, Val Acc: 70.61%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 190, Train Loss: 0.6989, Train Acc: 71.97, Val Loss: 0.0001, Val Acc: 68.51%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 200, Train Loss: 0.6630, Train Acc: 73.49, Val Loss: 0.0001, Val Acc: 69.85%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 210, Train Loss: 0.5870, Train Acc: 76.39, Val Loss: 0.0001, Val Acc: 67.88%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 220, Train Loss: 0.6325, Train Acc: 74.57, Val Loss: 0.0001, Val Acc: 69.15%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 230, Train Loss: 0.6217, Train Acc: 75.21, Val Loss: 0.0001, Val Acc: 69.38%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 240, Train Loss: 0.6564, Train Acc: 73.92, Val Loss: 0.0001, Val Acc: 70.37%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 250, Train Loss: 0.6831, Train Acc: 72.82, Val Loss: 0.0001, Val Acc: 65.05%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 260, Train Loss: 0.5088, Train Acc: 79.76, Val Loss: 0.0001, Val Acc: 73.75%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 270, Train Loss: 0.5675, Train Acc: 77.11, Val Loss: 0.0001, Val Acc: 70.75%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 280, Train Loss: 0.5148, Train Acc: 79.26, Val Loss: 0.0001, Val Acc: 72.44%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 290, Train Loss: 0.5515, Train Acc: 77.76, Val Loss: 0.0001, Val Acc: 71.62%


  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

Vizualize the results

The next cell is useful to just visualize the result from the trained model without training it an other time

In [10]:
#load the model only to visualize
model_cfg = dict(name="pointnet", 
                 n_cls=260, 
                 inp_dim=3, 
                 seg_hidden_dims=[512,256], 
                 mlp2_hidden_dims=[64,128,1024], 
                 transf_hidden_dims1=[64,128,1024], 
                 transf_hidden_dims2=[512,256])

#trainer function modified to load the trained model to plot the results
class TrainerSegVisualization(object):

    def __init__(self, train_loader, valid_loader, device='cuda',
                 lr=1e-3, weight_decay=1e-4, num_epochs=200,
                 lr_decay_every = 50, lr_decay_rate = 0.5,
                 log_interval=10, save_dir=None, **model_cfg):

        """
        pointNet_cls: (nn.Module) class of the PointNet (or ++) model.
        train_loader: (torch.utils.DataLoader) DataLoader for training set
        valid_loader: (torch.utils.DataLoader) DataLoader for validation set
        device: (str) 'cuda' or 'cpu'
        lr: (float) learning rate
        weight_decay: (float) weight decay for optimiser
        num_epochs: (int) number of epochs
        lr_decay_every: (int) decay learning rate every this many epochs
        lr_decay_rate: (float) decay learning rate by this factor
        log_interval: (int) print training stats every this many iterations
        save_dir: (str) directory to save model checkpoints
        model_cfg: (dict) keyword arguments for model
        """
        self.model = self.get_model(model_cfg).to(device)
        self.model.load_state_dict(torch.load('./checkpoints/segmentation/model_final.pth', weights_only=True, map_location=torch.device(device)))
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.device = device
        self.lr = lr
        self.weight_decay = weight_decay

        self.inp_feat = model_cfg.get('inp_feat', 'xyz')
        self.num_eig = model_cfg.get('num_eig', 128)
        if not self.inp_feat in ['xyz', 'xyzn', 'hks', 'wks']:
            raise ValueError('inp_feat must be one of xyz, xyzn, hks, wks')

        self.model.to(self.device)

    def get_model(self, model_cfg):
        if model_cfg['name'].lower() == 'pointnet':
            return PointNetSeg(n_cls=model_cfg['n_cls'],
                              inp_dim=model_cfg['inp_dim'],
                              seg_hidden_dims=model_cfg['seg_hidden_dims'],
                              mlp2_hidden_dims=model_cfg['mlp2_hidden_dims'],
                              transf_hidden_dims1=model_cfg['transf_hidden_dims1'],
                              transf_hidden_dims2=model_cfg['transf_hidden_dims2'])
        else:
            raise ValueError('%s must be one of PointNet, PointNet++'%model_cfg['name'])


    def forward_step(self, inp):
        """
        Perform a forward step of the model.

        Args:
            inp (torch.Tensor): (N, D) tensor of input features

        Returns:
            pred (torch.Tensor): (N, p_out) tensor of predicted labels.
        """
        inp = inp.to(self.device)
        verts = inp[..., :3].to(self.device)
        norms = inp[..., 3:].to(self.device)
        preds, trans_feat = self.model(inp)
        return preds, trans_feat

    def visualize(self):
        """
        We only test the first two shapes of validation set.
        """
        self.model.eval()
        test_seg_meshes = []

        for i, (inp_pts, inp_faces, gt_labels) in enumerate(self.valid_loader):
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            pred_labels = torch.max(preds[0], dim=1).indices
            test_seg_meshes.append([TriMesh(inp_pts.squeeze().cpu().numpy(), inp_faces.squeeze().cpu().numpy()),
                                  pred_labels.squeeze().cpu().numpy()])
            if i==1:
                break

        cmap1 = plt.get_cmap("jet")(test_seg_meshes[0][-1] / (146))[:,:3]
        cmap2 = plt.get_cmap("jet")(test_seg_meshes[1][-1] / (146))[:,:3]

        plu.double_plot(test_seg_meshes[0][0], test_seg_meshes[1][0], cmap1, cmap2)

trainer = TrainerSegVisualization(train_loader, valid_loader, lr=0.001, device='cpu', weight_decay=0.0, num_epochs=300, **model_cfg)


In [11]:
trainer.visualize()

HBox(children=(Output(), Output()))

HBox(children=(Output(), Output()))

In [38]:
trainer.test()

Final Valid Segmentation Accuracy : 71.05%


Final task: save the notebook .ipynb file **with logs**, and create a zip with the checkpoints folder. You can send your results to geometricdeeplearning@protonmail.com