The goal of this TD is to code and play with [PointNet](https://arxiv.org/pdf/1612.00593.pdf), a standard architecture for deep learning for pointcloud.

In this lab, we will ask you to provide both:
- notebook with logs (training logs in particular)
- the weights of the trained networks (classifier and segmentation). We provide guidance for saving the logs directly in your google drive, you can also download them via the colab interface

It should be helpful to use the slides on the side of this TD to obtain additionnal illustrations.

The idea of PointNet is to learn functions which are by definition fit to process unordered set of elements, that is functions of the shape
$$
f\big(\left\{x_1,\dots, x_n\right\}\big) = g\big(h(x_1),\dots, h(x_n)\big)
$$

where $g$ is a *symmetric* function, such as the $\max$ function.

In this TD, we will build the network as presented in the following image (slide 34):
![](https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD4/PointNet_architecture.PNG)

It consists in:
1. An input / feature transformation module
3. Per-vertex MLP
4. Feature aggregation for classification

In [1]:
#from google.colab import output
#output.enable_custom_widget_manager()

import os
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch
import torch.nn.functional as F

from torch.utils.data import DataLoader

# 1. PointNet Feature Extractor

## 1.1 Input and feature Transformation

The feature transformation module takes as input a $d$ dimensional point-cloud and regresses a $d\times d$ transformation matrix $T$.

If the input consists in $[x,y,z]$ coordinates, we can expect the network to learn a $3\times 3$ matrix which for instance rotates the point cloud in a *canonical* orientation.

If the input consists in $d$-dimensional features, the transformation matrix is harder to visualize.

![](https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD4/feature_transform.PNG)

This is done in two steps:
1. Apply per-vertex MLP to obtain a $d_{out}$ dimensional embedding for each vertex
2. use MaxPooling over the first dimension to to obtain a simple $d_{out}$ dimensional vector **for the whole point cloud**
3. Use MLP to obtain a $d^2$ dimensional vector, which defines T (flattened).
---


### Question 1

Implement the `FeatTransform` module.

**NOTE :**

Except for the **last** linear layer of the second MLP, all layers consist in of the combination of `layer`, `Batchnorm` and `ReLu` activations. The last layer is a simple linear layer with **no activation** or batchnorm (we don't want to constraint the values of matrix T).

An important (and annoying) feature in [`nn.BatchNorm1d`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html) is that is expects the features to lie on the **second** dimension (the one after the batch). On the contrary fully connected layers expect features to lie on the **last** dimension. Unless the input is 2d, we would need to reshape values for each layer. Hopefully there is a small trick to handle this.

The trick consist in using [`nn.Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) with kernel size of $1$ to define fully connected layers. This is exactly equivalent to fully connected layers although a bit surprising at first sight. The adventage is that torch convolutions expect features to be on the **second** dimension.

Follow the following guide (or reshape at every layer, as you wish):

1. For the first MLP, the input is given with shape `(B,d,n)`. The features are in the second dimension, so use [`nn.Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html), which allows to use [`nn.BatchNorm1d`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html) directly without reshaping.
2. In the second layer you can give as input a tensor of shape `(B,d_out)` (instead of `(B,1,d_out)`). Both  [`nn.Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) and [`nn.BatchNorm1d`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html) will work since the **second** and **last** dimensions are the same.
3. Check the documentation of [`torch.max`](https://pytorch.org/docs/stable/generated/torch.max.html), as it does not return the same thing as numpy (use `torch.max(a,dim=dim).values`).
4. We add an argument to potentially remove batchnorm (for future experiments using batch size of 1)


To stabilize training, we actually output $I + T$ so that at initialization when $T\simeq 0$ this module has no effect.


In [2]:
class FeatTransform(nn.Module):
    # TODO
    def __init__(self, inp_dim=3, hidden_dims1=[64,128,1024], hidden_dims2=[512,256], use_bn=True):
      """
      inp_dim : int - dimension d of the input
      hidden_dims1 : list - hidden layers (inluding d_out) of the first MLP block (defined with nn.Conv1d...)
      hidden_dims2 : list - hidden layers (without d_out) of the second MLP block
      use_bn       : bool - whether to use batchnorm
      """
      super().__init__()

      ## TODO
      self.inp_dim = inp_dim
      self.use_bn = use_bn
      self.first_mlp = nn.ModuleList()
      self.second_mlp = nn.ModuleList()
      #Contruct first MLP
      self.first_mlp.append(nn.Conv1d(inp_dim, hidden_dims1[0], 1))
      if use_bn:
        self.first_mlp.append(nn.BatchNorm1d(hidden_dims1[0]))
      self.first_mlp.append(nn.ReLU())
      for i in range(len(hidden_dims1)-1):
        self.first_mlp.append(nn.Conv1d(hidden_dims1[i], hidden_dims1[i+1], 1))
        if use_bn:
          self.first_mlp.append(nn.BatchNorm1d(hidden_dims1[i+1]))
        self.first_mlp.append(nn.ReLU())
      #Construct second MLP
      self.second_mlp.append(nn.Linear(hidden_dims1[-1], hidden_dims2[0]))
      if use_bn:
        self.second_mlp.append(nn.BatchNorm1d(hidden_dims2[0]))
      self.second_mlp.append(nn.ReLU())
      for i in range(len(hidden_dims2)-1):
        self.second_mlp.append(nn.Linear(hidden_dims2[i], hidden_dims2[i+1]))
        if use_bn:
          self.second_mlp.append(nn.BatchNorm1d(hidden_dims2[i+1]))
        self.second_mlp.append(nn.ReLU())
      #add output dim d^2
      self.second_mlp.append(nn.Linear(hidden_dims2[-1], inp_dim*inp_dim))

     
    def forward(self, x):
        """
        x : (B, d, n) - This is standard shape for convolution inputs

        Output
        ------------
        (B, d , d) : output T defined as I + NET(x) for stability
        """
        T = x
        for layer in self.first_mlp:
          T = layer(T)
        T = torch.max(T, dim=-1)[0]
        for layer in self.second_mlp:
          T = layer(T)
        T = T.reshape(-1, self.inp_dim, self.inp_dim)
        return T + torch.eye(self.inp_dim, device=T.device).unsqueeze(0)

In [3]:
# Try a forward pass
FeatTransform()(torch.rand(2,3,10)).shape

torch.Size([2, 3, 3])

## 1.2 PointNet Feature extractor

We now build the PointNet feature extractor network (output is `global_feature`)
![](https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD4/PointNet_architecture.PNG)


It can be described as follows:
1. Predict first feature transform and apply to the input
2. Apply a first 2-layer MLP with BatchNorm + ReLu.
3. Preduct second feature transform and apply to the features
4. Apply a second MLP with Batchnorm + ReLu, **except** on last layer where there is no activation (`maxpool`will serve as non-linearity).
5. Apply Maxpooling to obtain the global feature vector

Note that for future segmentation task, we will also need to outpu the second transformed features (or the input to the second MLP).

---

### Question 2
Implement the `PointNetfeat` module with the following convention

1. Use again [`nn.Conv1d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) for both MLP blocks.
2. The **first** MLP will be a 1-layer MLP with output size being the input size of the second MLP (the picture has a typo somehow)
3. The module returns the transformed features before the are fed to the second MLP, the transformation matrix of the second block, and the global feature.
4. Both Feature Transform modules use the same hidden dimensions for their MLPs.
5. You can hardcode to default parameters, we will never change them


In [4]:
class PointNetfeat(nn.Module):

    ## TODO
    def __init__(self, inp_dim=3, mlp2_hidden_dims=[64,128,1024], transf_hidden_dims1=[64,128,1024], transf_hidden_dims2=[512,256], use_bn=True):
      """
      inp_dim   : int - dimension d of the input
      mlp2_hidden_dims : list - hidden layers (inluding d_out) of the second MLP block (defined with nn.Conv1d...)
      transf_hidden_dims1 : list - hidden layers (inluding d_out) of the first mlp in each FeatureTransform block
      transf_hidden_dims2 : list - hidden layers (without d_out) of the second mlp in each FeatureTransform block
      use_bn : bool - whether to use batchnorm
      """
      super().__init__()

      # USEFULE INFO
      self.inp_dim = inp_dim
      self.latent_dim1 = mlp2_hidden_dims[0] # size of the second features
      self.latent_dim = mlp2_hidden_dims[-1] # is also the out_dim
      self.use_bn = use_bn
      self.feature_transform1 = FeatTransform(inp_dim, transf_hidden_dims1, transf_hidden_dims2, use_bn)
      self.feature_transform2 = FeatTransform(self.latent_dim1, transf_hidden_dims1, transf_hidden_dims2, use_bn)
      self.mlp1 = nn.Sequential()
      self.mlp2 = nn.Sequential()
      #Construct first MLP
      self.mlp1.append(nn.Conv1d(inp_dim, mlp2_hidden_dims[0], 1))
      if use_bn:
        self.mlp1.append(nn.BatchNorm1d(mlp2_hidden_dims[0]))
      self.mlp1.append(nn.ReLU())
      #Construct second MLP
      for i in range(len(mlp2_hidden_dims)-1):
        self.mlp2.append(nn.Conv1d(mlp2_hidden_dims[i], mlp2_hidden_dims[i+1], 1))
        if i < len(mlp2_hidden_dims)-2:
          if use_bn:
            self.mlp2.append(nn.BatchNorm1d(mlp2_hidden_dims[i+1]))
          self.mlp2.append(nn.ReLU())



    def forward(self, x):
      """
      Parameters
      --------------
      x : (B, in_dim, n) input batch in convolution format.

      Output
      -------------
      global_feature : (B, d_out) output global feature (d_out=1024)
      T_feat : (B, latent1, latent1) transformation matrix of the second FeatureTransform module
      x_feat : (B, latent1, n) per-point features after being transformed by the second
                FeatureTransform module
      """
      # First Feature Transform
      T_feat = self.feature_transform1(x)
      x_feat = torch.bmm(T_feat, x)
      # First MLP
      x_feat = self.mlp1(x_feat)
      # Second Feature Transform
      T_feat = self.feature_transform2(x_feat)
      x_feat = torch.bmm(T_feat, x_feat)
      # Second MLP
      global_feature = self.mlp2(x_feat)
      # Global Feature with maxpool
      global_feature = torch.max(global_feature, dim=-1)[0]

      return global_feature, T_feat, x_feat

In [5]:
# Try a forward pass
global_, T_, x_ = PointNetfeat()(torch.rand(2,3,2048))
print(global_.shape, T_.shape, x_.shape) #check the shapes

torch.Size([2, 1024]) torch.Size([2, 64, 64]) torch.Size([2, 64, 2048])


# 2 PointNet Classifier

## 2.1 Network Design

The PointNet Classifier is the complete network seen as follow
![](https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD4/PointNet_architecture.PNG)


### Question 3
Implement the `PointNetCls` module.
1. The last MLP also uses BatchNorm + ReLu, **except** for the last layer, which uses either `log_softmax` directly or `softmax`, or even nothing. Note that the training loss you'll chose will be adapted to this choice !
2. You can hardcode to default parameters, we will not change them
3. We don't need the transformed feature, only the transformation matrix for classification

In [6]:
class PointNetCls(nn.Module):
    ## TODO
    def __init__(self, n_cls, inp_dim=3, cls_hidden_dims=[512,256], mlp2_hidden_dims=[64,128,1024], transf_hidden_dims1=[64,128,1024], transf_hidden_dims2=[512,256], use_bn=True, final_activation=None):
        """
        PointNet classification network

        Parameters
        ------------------
            n_cls (int): Number of classes for classification
            inp_dim (int): Dimension of the input points
            cls_hidden_dims (list): List of hidden dimensions for the MLP classifier
            mlp2_hidden_dims (list): List of hidden dimensions for the second MLP block in the feature extractor
            transf_hidden_dims1 (list): List of hidden dimensions for the first MLP in each FeatureTransform block
            transf_hidden_dims2 (list): List of hidden dimensions for the second MLP in each FeatureTransform block
            use_bn (bool): Whether to use batchnorm
        """

        super(PointNetCls, self).__init__()

        self.use_bn = use_bn
        self.PointNetBase = PointNetfeat(inp_dim, mlp2_hidden_dims, transf_hidden_dims1, transf_hidden_dims2, use_bn)
        self.mlp_cls = nn.Sequential()
        #Construct MLP
        self.mlp_cls.append(nn.Linear(mlp2_hidden_dims[-1], cls_hidden_dims[0]))
        if use_bn:
          self.mlp_cls.append(nn.BatchNorm1d(cls_hidden_dims[0]))
        self.mlp_cls.append(nn.ReLU())
        for i in range(len(cls_hidden_dims)-1):
          self.mlp_cls.append(nn.Linear(cls_hidden_dims[i], cls_hidden_dims[i+1]))
        if use_bn:
            self.mlp_cls.append(nn.BatchNorm1d(cls_hidden_dims[i+1]))
        self.mlp_cls.append(nn.ReLU())
        self.mlp_cls.append(nn.Linear(cls_hidden_dims[-1], n_cls))
        #add last activation
        if final_activation is not None:
          self.mlp_cls.append(final_activation)

    def forward(self, x):
      """
      Parameters
      --------------
      inp : (B, n, d) input batch in standard format.

      Output
      -------------
      out : (B, n_cls) output score for each class
      T_feat : (B, latent1, latent1) transformation matrix of the second FeatureTransform module
      """
      x = x.transpose(2,1) # Put input in convolution format
      global_feature, T_feat, _ = self.PointNetBase(x)
      out = self.mlp_cls(global_feature)
      
      return out, T_feat

## 2.2 - Data Augmentation

Since PointNet takes $(X,Y,Z)$ coordinates as input, it is not invariant to rigid transformations of the point cloud. One possibility is to make it learn to be invariant, by rigidly modifying shapes during training.

### Question 4
Fill the random scaling and translation function using numpy.
1. Scale and translate each coordinate independantly
2. Scaling factor lies between $\frac{2}{3}$ and $\frac{3}{2}$
2. Translation lies between $-0.2$ and $0.2$

In [7]:
import numpy as np

In [8]:
def random_scale_transl(X):
        """Apply random scaling to the point cloud.
        Scaling is applied for

        Args:
            X (np.ndarray): Point cloud data, (N, 3).

        Returns:
            X_transformed (np.ndarray): Scaled (and translated) point cloud data (N, 3).
        """
        # Random scaling
        scale = np.random.uniform(2/3, 3/2, X.shape[1])
        X_transformed = X * scale
        # Random translation
        translation = np.random.uniform(-0.2, 0.2, X.shape[1])
        X_transformed = X_transformed + translation

        return X_transformed

### Question 5
For this specfific dataset, all shapes pre-aligned, and only variability is rotation along the $y$ axis.

Therefore simply generate a random rotation matrix along the y axis.

In [9]:
def random_rotation(X):
        # TODO
        """Apply random rotation to the point cloud (along the y axis)

        Args:
            X (np.ndarray): Point cloud data, (N, 3).

        Returns:
            X_rot (np.ndarray): Rotated point cloud data (N, 3).
        """
        # Random rotation
        theta = np.random.uniform(0, 2*np.pi)
        #The transpose of the rotation matrix theta along the y axis :
        rot_matrix = np.array([[np.cos(theta), 0, np.sin(theta)],
                               [0, 1, 0],
                               [-np.sin(theta), 0, np.cos(theta)]])
        X_rot = np.dot(X, rot_matrix)
        return X_rot

## 2.3 Data Loading

We will train on the [ModelNet40](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Wu_3D_ShapeNets_A_2015_CVPR_paper.pdf) dataset.

This dataset consist in pointclouds of 12000 objects of 40 different categories such as airplane, car, plant, lamp. Because we only need to provide a global feature vector for each shape, we can subsample each surface with a fixed number of parameters, which allows us to use large batch sizes.

![](https://production-media.paperswithcode.com/datasets/modelnet.jpeg)


We provide below code for dataloading.

Let's first download the data and define the loaders.

In [None]:
!mkdir 'data'
!wget -O './data/modelnet40_ply_hdf5_2048.zip' https://huggingface.co/datasets/Msun/modelnet40/resolve/main/modelnet40_ply_hdf5_2048.zip?download=true --no-check-certificate
!unzip data/modelnet40_ply_hdf5_2048.zip -d ./data/

'wget' n'est pas reconnu en tant que commande interne
ou externe, un programme ex�cutable ou un fichier de commandes.
'unzip' n'est pas reconnu en tant que commande interne
ou externe, un programme ex�cutable ou un fichier de commandes.


In [13]:
import h5py
import os.path as osp

from pathlib import Path
from glob import glob
from tqdm.auto import tqdm
from torch.utils.data import Dataset

def load_h5(h5_filename):
    f = h5py.File(h5_filename)
    print(h5_filename)
    data = f['data'][:]
    label = f['label'][:]
    f.close()
    return (data, label)

def load_h5_files(data_path, files_list_path):
    """Load h5 into memory
    """
    files_list = [Path(line.rstrip()).name for line in open(osp.join(data_path, files_list_path))]
    data = []
    labels = []
    for i in range(len(files_list)):
        data_, labels_ = load_h5(os.path.join(data_path, files_list[i]))
        data.append(data_)
        labels.append(labels_)
    data = np.concatenate(data, axis=0)
    labels = np.concatenate(labels, axis=0)
    return data, labels


class ModelNet40Generator(Dataset):

    def __init__(self, mode, data_dir, files_list, num_classes=40, num_points=1024, rot_aug=False):
        """Initialize the ModelNet40Generator class.

        Args:
            mode (str): 'train' or 'test'.
            data_dir (str): Path to the data directory.
            files_list (str): Path to the files list.
            num_classes (int, optional): ModelNet class. Defaults to 40.
            num_points (int, optional): Input PC. Defaults to 1024.
            rot_aug (bool, optional): Augment data. Defaults to False.
        """

        assert mode.lower() in ('train', 'val', 'test')
        assert files_list in ('train_files.txt', 'test_files.txt')
        self.data, labels = load_h5_files(data_dir, files_list)

        self.mode = mode
        self.num_points = num_points
        self.num_classes = num_classes
        self.num_samples = self.data.shape[0]
        self.rot_aug = rot_aug
        self.labels = np.reshape(labels, (-1,))


    def __len__(self) -> int:
        return self.num_samples


    def __getitem__(self, idx):
        indexes = np.random.permutation(np.arange(self.num_points))[:self.num_points]
        X = self.data[idx, indexes, ...]
        y = self.labels[idx, ...]
        y_categorical = torch.from_numpy(np.eye(self.num_classes, dtype='uint8')[y]).long()

        if self.mode == 'train' and self.rot_aug:
            # X = self.random_scaling(X)
            X = self.random_rotation_3d(X)

        return torch.from_numpy(X).float(), y_categorical.float()


    def random_scaling(self, X):
        """Apply random scaling to the point cloud.

        Args:
            X (np.ndarray): Point cloud data, (N, 3).

        Returns:
            X (np.ndarray): Scaled point cloud data (N, 3).
        """
        return random_scale_transl(X)


    def random_rotation_3d(self, X):
        """Apply random rotation to the point cloud (axis agnostic).

        Args:
            X (np.ndarray): Point cloud data, (N, 3).

        Returns:
            rotated_data (np.ndarray): Rotated point cloud data (N, 3).
        """
        return random_rotation(X)



## 2.4. Classification Loss

PointNet classification loss is defined as the weighted sum of two losses: the classification loss $L_{cls}$, such as cross-entropy, and a regularization loss $L_{reg}$ which regularizes the FeatureTransform output.

In particular, the the regularization loss forces the (second) predicted transformation matrix $T$ to be near-orthogonal (so close to a "rotation"):

$$
L_{reg}(T) = \|T^\top T - I\|_2^2
$$

This loss is averaged along the batch.

This regularization is only applied to the second transformation matrix, that is the one applied to the feature space.

### Question 6
Code the classification loss. Adapt $L_{cls}$ to whatever activation you used in you architecture.

In [14]:
def classification_loss(preds, gt_labels, T_feat, w_reg, final_activation=None):
        """
        Classification loss for PointNet

        Parameters
        --------------
        preds : (B, n_cls) - output of the classifier
        gt_labels : (B,) - ground truth labels
        T_feat : (B, latent1, latent1) - transformation matrix of the second FeatureTransform module
        w_reg : weight of the regularization term

        Output
        -------------
        loss : classification loss L_{cls} + w_reg * L_{reg}
        """
        #classification loss
        if isinstance(final_activation, nn.LogSoftmax):
            criterion = nn.NLLLoss()
        else:
            criterion = nn.CrossEntropyLoss()
        loss_cls = criterion(preds, gt_labels)

        #regularization loss
        B, _, _ = T_feat.shape
        I = torch.eye(T_feat.shape[1], device=T_feat.device).unsqueeze(0)
        loss_reg = torch.norm(torch.bmm(T_feat, T_feat.transpose(2,1)) - I, p='fro', dim=(-1,-2)).mean()

        return loss_cls + w_reg * loss_reg

## 2.5 Training

This cell defines the trainer class, using your classification loss, simply run it.

In [23]:
class Trainer(object):

    def __init__(self, train_loader, valid_loader, device='cuda',
                 lr=1e-3, weight_decay=1e-4, num_epochs=200,
                 lr_decay_every = 50, lr_decay_rate = 0.5,
                 log_interval=10, save_dir=None, **model_cfg):

        """
        train_loader: (torch.utils.DataLoader) DataLoader for training set
        valid_loader: (torch.utils.DataLoader) DataLoader for validation set
        device: (str) 'cuda' or 'cpu'
        lr: (float) learning rate
        weight_decay: (float) weight decay for optimiser
        num_epochs: (int) number of epochs
        lr_decay_every: (int) decay learning rate every this many epochs
        lr_decay_rate: (float) decay learning rate by this factor
        log_interval: (int) print training stats every this many iterations
        save_dir: (str) directory to save model checkpoints
        model_cfg: (dict) keyword arguments for model
        """
        self.model = self.get_model(model_cfg).to(device)
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.device = device
        self.lr = lr
        self.weight_decay = weight_decay
        self.num_epochs = num_epochs

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        self.lr_decay_every = lr_decay_every
        self.lr_decay_rate = lr_decay_rate
        self.log_interval = log_interval
        self.save_dir = save_dir
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []

        self.inp_feat = model_cfg.get('inp_feat', 'xyz')
        self.num_eig = model_cfg.get('num_eig', 128)
        if not self.inp_feat in ['xyz', 'xyzn', 'hks', 'wks']:
            raise ValueError('inp_feat must be one of xyz, xyzn, hks, wks')

        self.model.to(self.device)
        self.final_activation = model_cfg.get('final_activation', None)

    def get_model(self, model_cfg):
        if model_cfg['name'].lower() == 'pointnet':
            return PointNetCls(n_cls=model_cfg['n_cls'],
                              inp_dim=model_cfg['inp_dim'],
                              cls_hidden_dims=model_cfg['cls_hidden_dims'],
                              mlp2_hidden_dims=model_cfg['mlp2_hidden_dims'],
                              transf_hidden_dims1=model_cfg['transf_hidden_dims1'],
                              transf_hidden_dims2=model_cfg['transf_hidden_dims2'],
                              final_activation=model_cfg['final_activation'])
        else:
            raise ValueError('%s must be one of PointNet, PointNet++'%model_cfg['name'])

    def adjust_lr(self):
        lr = self.lr * self.lr_decay_rate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def forward_step(self, inp):
        """
        Perform a forward step of the model.

        Args:
            inp (torch.Tensor): (N, D) tensor of input features

        Returns:
            pred (torch.Tensor): (N, p_out) tensor of predicted labels.
        """
        inp = inp.to(self.device)
        preds, trans_feat = self.model(inp)
        return preds, trans_feat


    def get_loss(self, preds, gt_labels, trans_feat):
        loss = classification_loss(preds, gt_labels.argmax(dim=1), trans_feat, 1e-3, self.final_activation)
        return loss 

    def train_epoch(self):
        train_loss = 0
        train_acc = 0
        for i, (inp_pts, gt_labels) in enumerate(tqdm(self.train_loader)):
            self.optimizer.zero_grad()
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            loss = self.get_loss(preds, gt_labels, trans_feat)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            pred_labels = torch.max(preds, dim=1).indices
            this_correct = pred_labels.eq(gt_labels.argmax(dim=-1)).sum().item()
            train_acc += this_correct/gt_labels.shape[0]

        return train_loss/len(self.train_loader), train_acc/len(self.train_loader)

    def valid_epoch(self):
        val_loss = 0
        val_acc = 0
        for i, (inp_pts, gt_labels) in enumerate(tqdm(self.valid_loader, leave=False)):
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            loss = self.get_loss(preds, gt_labels, trans_feat)
            val_loss += loss.item()
            pred_labels = torch.max(preds, dim=1).indices
            this_correct = pred_labels.eq(gt_labels.argmax(dim=-1)).sum().item()
            val_acc += this_correct/gt_labels.shape[0]

        return val_loss/len(self.valid_loader), val_acc/len(self.valid_loader)

    def run(self):
        for epoch in tqdm(range(self.num_epochs)):
            self.model.train()

            if epoch % self.lr_decay_every == 0:
                self.adjust_lr()

            train_ep_loss, train_ep_acc = self.train_epoch()
            self.train_losses.append(train_ep_loss)
            self.train_accs.append(train_ep_acc)

            if epoch % self.log_interval == 0:
                val_loss, val_acc = self.valid_epoch()
                self.val_losses.append(val_loss)
                self.val_accs.append(val_acc)
                torch.save(self.model.state_dict(), os.path.join(self.save_dir, 'model_latest.pth'))
                print('Epoch: {:03d}, Train Loss: {:.4f}, Train Acc: {:.2f}%, Val Loss: {:.4f}, Val Acc: {:.2f}%'.format(epoch,
                                                                                                                       train_ep_loss,
                                                                                                                       1e2*train_ep_acc,
                                                                                                                       val_loss,
                                                                                                                       1e2*val_acc))
        torch.save(self.model.state_dict(), os.path.join(self.save_dir, 'model_final.pth'))

    def test(self):
        file_final = os.path.join(self.save_dir, 'model_final.pth')
        if not os.path.exists(file_final):
            print('-------------------------------------------------------')
            print('No final weights, switching to last weights')
            print('-------------------------------------------------------')
            file_final = os.path.join(self.save_dir, 'model_latest.pth')
        weights = torch.load(file_final)
        self.model.load_state_dict(weights)
        _, score = self.valid_epoch()
        print('Final Valid Accuracy : {:.2f}%'.format(1e2*score))

Run the following cells to train the network.

In [16]:
N_CLS = 40
data_dir = 'data/modelnet40_ply_hdf5_2048'
train_data = ModelNet40Generator(mode='train', data_dir=data_dir, files_list='train_files.txt', rot_aug=True)
val_data = ModelNet40Generator(mode='val', data_dir=data_dir, files_list='test_files.txt')
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=0)

data/modelnet40_ply_hdf5_2048\ply_data_train0.h5
data/modelnet40_ply_hdf5_2048\ply_data_train1.h5
data/modelnet40_ply_hdf5_2048\ply_data_train2.h5
data/modelnet40_ply_hdf5_2048\ply_data_train3.h5
data/modelnet40_ply_hdf5_2048\ply_data_train4.h5
data/modelnet40_ply_hdf5_2048\ply_data_test0.h5
data/modelnet40_ply_hdf5_2048\ply_data_test1.h5


You will need to modify this cell if you are not using Google Colab + Drive

In [27]:
import os
# Comment these if you are going local
#from google.colab import drive
#drive.mount('/content/drive/')
## Change this if your are going local, otherwise keep it (checkpoints will be in checkpoint folder of your drive)
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)

In [28]:
model_cfg = dict(name="pointnet", n_cls=N_CLS, inp_dim=3, cls_hidden_dims=[512,256], mlp2_hidden_dims=[64,128,1024], 
                 transf_hidden_dims1=[64,128,1024], transf_hidden_dims2=[512,256], final_activation=nn.LogSoftmax(dim=-1))

#model_cfg = {'name': 'pointnet', 'n_cls': N_CLS, 'conv_dims': [64, 128, 1024], 'fc_dims': [512, 256]}
trainer = Trainer(train_loader, val_loader, lr=0.001, weight_decay=0.0, num_epochs=100, save_dir=os.path.join(save_dir, 'classifier'), **model_cfg)
trainer.run()

100%|██████████| 308/308 [00:13<00:00, 22.15it/s]
  1%|          | 1/100 [00:15<25:43, 15.59s/it]

Epoch: 000, Train Loss: 2.2530, Train Acc: 40.32%, Val Loss: 2.0115, Val Acc: 43.19%


100%|██████████| 308/308 [00:13<00:00, 23.33it/s]
100%|██████████| 308/308 [00:13<00:00, 22.18it/s]
100%|██████████| 308/308 [00:14<00:00, 21.64it/s]
100%|██████████| 308/308 [00:14<00:00, 21.81it/s]
100%|██████████| 308/308 [00:15<00:00, 19.97it/s]
100%|██████████| 308/308 [00:14<00:00, 21.25it/s]
100%|██████████| 308/308 [00:14<00:00, 21.32it/s]
100%|██████████| 308/308 [00:14<00:00, 21.29it/s]
100%|██████████| 308/308 [00:14<00:00, 21.38it/s]
100%|██████████| 308/308 [00:14<00:00, 21.32it/s]
 11%|█         | 11/100 [02:40<22:18, 15.04s/it]

Epoch: 010, Train Loss: 0.7525, Train Acc: 77.21%, Val Loss: 1.0801, Val Acc: 68.99%


100%|██████████| 308/308 [00:14<00:00, 21.39it/s]
100%|██████████| 308/308 [00:14<00:00, 21.35it/s]
100%|██████████| 308/308 [00:14<00:00, 21.34it/s]
100%|██████████| 308/308 [00:14<00:00, 21.26it/s]
100%|██████████| 308/308 [00:14<00:00, 21.23it/s]
100%|██████████| 308/308 [00:14<00:00, 21.24it/s]
100%|██████████| 308/308 [00:14<00:00, 21.03it/s]
100%|██████████| 308/308 [00:14<00:00, 21.24it/s]
100%|██████████| 308/308 [00:14<00:00, 21.21it/s]
100%|██████████| 308/308 [00:14<00:00, 21.11it/s]
 21%|██        | 21/100 [05:07<19:53, 15.10s/it]

Epoch: 020, Train Loss: 0.5759, Train Acc: 82.33%, Val Loss: 0.8691, Val Acc: 75.64%


100%|██████████| 308/308 [00:14<00:00, 21.25it/s]
100%|██████████| 308/308 [00:15<00:00, 19.40it/s]
100%|██████████| 308/308 [00:16<00:00, 18.57it/s]
100%|██████████| 308/308 [00:16<00:00, 18.82it/s]
100%|██████████| 308/308 [00:17<00:00, 17.71it/s]
100%|██████████| 308/308 [00:29<00:00, 10.37it/s]
100%|██████████| 308/308 [00:16<00:00, 18.45it/s]
100%|██████████| 308/308 [00:16<00:00, 18.62it/s]
100%|██████████| 308/308 [00:16<00:00, 18.19it/s]
100%|██████████| 308/308 [00:16<00:00, 18.54it/s]
 31%|███       | 31/100 [08:06<20:56, 18.21s/it]

Epoch: 030, Train Loss: 0.4598, Train Acc: 85.41%, Val Loss: 0.7724, Val Acc: 78.77%


100%|██████████| 308/308 [00:14<00:00, 20.76it/s]
100%|██████████| 308/308 [00:14<00:00, 20.76it/s]
100%|██████████| 308/308 [00:14<00:00, 20.98it/s]
100%|██████████| 308/308 [00:14<00:00, 20.94it/s]
100%|██████████| 308/308 [00:14<00:00, 20.94it/s]
100%|██████████| 308/308 [00:14<00:00, 21.05it/s]
100%|██████████| 308/308 [00:14<00:00, 21.08it/s]
100%|██████████| 308/308 [00:14<00:00, 21.12it/s]
100%|██████████| 308/308 [00:14<00:00, 20.91it/s]
100%|██████████| 308/308 [00:14<00:00, 20.93it/s]
 41%|████      | 41/100 [10:35<15:05, 15.35s/it]

Epoch: 040, Train Loss: 0.4083, Train Acc: 86.44%, Val Loss: 0.7230, Val Acc: 79.69%


100%|██████████| 308/308 [00:14<00:00, 20.66it/s]
100%|██████████| 308/308 [00:14<00:00, 20.68it/s]
100%|██████████| 308/308 [00:14<00:00, 20.75it/s]
100%|██████████| 308/308 [00:14<00:00, 20.83it/s]
100%|██████████| 308/308 [00:14<00:00, 21.21it/s]
100%|██████████| 308/308 [00:14<00:00, 21.10it/s]
100%|██████████| 308/308 [00:14<00:00, 20.54it/s]
100%|██████████| 308/308 [00:14<00:00, 20.99it/s]
100%|██████████| 308/308 [00:14<00:00, 21.23it/s]
100%|██████████| 308/308 [00:14<00:00, 21.10it/s]
 51%|█████     | 51/100 [13:05<12:29, 15.30s/it]

Epoch: 050, Train Loss: 0.3466, Train Acc: 88.38%, Val Loss: 0.7781, Val Acc: 79.61%


100%|██████████| 308/308 [00:14<00:00, 21.18it/s]
100%|██████████| 308/308 [00:14<00:00, 20.80it/s]
100%|██████████| 308/308 [00:14<00:00, 21.23it/s]
100%|██████████| 308/308 [00:14<00:00, 21.02it/s]
100%|██████████| 308/308 [00:14<00:00, 21.05it/s]
100%|██████████| 308/308 [00:14<00:00, 20.94it/s]
100%|██████████| 308/308 [00:14<00:00, 20.83it/s]
100%|██████████| 308/308 [00:14<00:00, 21.02it/s]
100%|██████████| 308/308 [00:14<00:00, 21.39it/s]
100%|██████████| 308/308 [00:14<00:00, 21.18it/s]
 61%|██████    | 61/100 [15:33<09:51, 15.16s/it]

Epoch: 060, Train Loss: 0.3083, Train Acc: 89.56%, Val Loss: 0.7165, Val Acc: 81.69%


100%|██████████| 308/308 [00:14<00:00, 21.16it/s]
100%|██████████| 308/308 [00:14<00:00, 21.04it/s]
100%|██████████| 308/308 [00:14<00:00, 21.31it/s]
100%|██████████| 308/308 [00:14<00:00, 21.05it/s]
100%|██████████| 308/308 [00:15<00:00, 20.48it/s]
100%|██████████| 308/308 [00:14<00:00, 20.66it/s]
100%|██████████| 308/308 [00:14<00:00, 20.87it/s]
100%|██████████| 308/308 [00:14<00:00, 20.64it/s]
100%|██████████| 308/308 [00:14<00:00, 20.74it/s]
100%|██████████| 308/308 [00:14<00:00, 20.72it/s]
 71%|███████   | 71/100 [18:03<07:27, 15.43s/it]

Epoch: 070, Train Loss: 0.2728, Train Acc: 90.65%, Val Loss: 0.7353, Val Acc: 80.97%


100%|██████████| 308/308 [00:14<00:00, 21.09it/s]
100%|██████████| 308/308 [00:14<00:00, 20.94it/s]
100%|██████████| 308/308 [00:14<00:00, 20.76it/s]
100%|██████████| 308/308 [00:14<00:00, 21.69it/s]
100%|██████████| 308/308 [00:14<00:00, 21.46it/s]
100%|██████████| 308/308 [00:14<00:00, 21.57it/s]
100%|██████████| 308/308 [00:14<00:00, 21.45it/s]
100%|██████████| 308/308 [00:14<00:00, 21.42it/s]
100%|██████████| 308/308 [00:14<00:00, 21.55it/s]
100%|██████████| 308/308 [00:14<00:00, 21.48it/s]
 81%|████████  | 81/100 [20:29<04:43, 14.91s/it]

Epoch: 080, Train Loss: 0.2412, Train Acc: 91.61%, Val Loss: 0.7207, Val Acc: 82.25%


100%|██████████| 308/308 [00:14<00:00, 21.48it/s]
100%|██████████| 308/308 [00:14<00:00, 21.35it/s]
100%|██████████| 308/308 [00:14<00:00, 21.42it/s]
100%|██████████| 308/308 [00:14<00:00, 21.56it/s]
100%|██████████| 308/308 [00:14<00:00, 21.57it/s]
100%|██████████| 308/308 [00:14<00:00, 21.47it/s]
100%|██████████| 308/308 [00:14<00:00, 21.56it/s]
100%|██████████| 308/308 [00:14<00:00, 21.57it/s]
100%|██████████| 308/308 [00:14<00:00, 21.52it/s]
100%|██████████| 308/308 [00:14<00:00, 21.54it/s]
 91%|█████████ | 91/100 [22:54<02:13, 14.85s/it]

Epoch: 090, Train Loss: 0.2352, Train Acc: 92.20%, Val Loss: 0.7039, Val Acc: 82.49%


100%|██████████| 308/308 [00:14<00:00, 21.45it/s]
100%|██████████| 308/308 [00:14<00:00, 21.53it/s]
100%|██████████| 308/308 [00:14<00:00, 21.43it/s]
100%|██████████| 308/308 [00:14<00:00, 21.44it/s]
100%|██████████| 308/308 [00:14<00:00, 21.64it/s]
100%|██████████| 308/308 [00:14<00:00, 21.62it/s]
100%|██████████| 308/308 [00:14<00:00, 21.67it/s]
100%|██████████| 308/308 [00:14<00:00, 21.72it/s]
100%|██████████| 308/308 [00:14<00:00, 21.55it/s]
100%|██████████| 100/100 [25:02<00:00, 15.03s/it]


In [29]:
trainer.test()

  weights = torch.load(file_final)
                                               

Final Valid Accuracy : 82.89%




# 3. PointNet Segmentation

## 3.1 Network Design

We now build the PointNet feature extractor network (slide 34, output is `global_feature`)
![](https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD4/pointnet_segmentation.PNG)


### Question 7
Build the segmentation network in the case of a batch size of 1.
**Disable all batchnorms**

In [106]:
class PointNetSeg(nn.Module):
    ## TODO
    def __init__(self, n_cls, inp_dim=3, seg_hidden_dims=[512,256,128], mlp2_hidden_dims=[64,128,1024], transf_hidden_dims1=[64,128,1024], transf_hidden_dims2=[512,256], final_activation=None):
        """
        PointNet classification network

        Parameters
        ------------------
            n_cls (int): Number of classes for classification
            inp_dim (int): Dimension of the input points
            seg_hidden_dims (list): List of hidden dimensions for the segmentation MLP
            mlp2_hidden_dims (list): List of hidden dimensions for the second MLP block in the feature extractor
            transf_hidden_dims1 (list): List of hidden dimensions for the first MLP in each FeatureTransform block
            transf_hidden_dims2 (list): List of hidden dimensions for the second MLP in each FeatureTransform block
        """

        super().__init__()
        self.PointNetBase = PointNetfeat(inp_dim, mlp2_hidden_dims, transf_hidden_dims1, transf_hidden_dims2, use_bn=False)
        self.mlp_seg = nn.Sequential()
        #Construct MLP
        self.mlp_seg.append(nn.Conv1d(mlp2_hidden_dims[-1]+mlp2_hidden_dims[0], seg_hidden_dims[0], 1))
        self.mlp_seg.append(nn.ReLU())
        for i in range(len(seg_hidden_dims)-1):
          self.mlp_seg.append(nn.Conv1d(seg_hidden_dims[i], seg_hidden_dims[i+1], 1))
          self.mlp_seg.append(nn.ReLU())
        self.mlp_seg.append(nn.Conv1d(seg_hidden_dims[-1], n_cls, 1))
        if final_activation is not None:
          self.mlp_seg.append(final_activation)

    def forward(self, x):
      """
      Parameters
      --------------
      inp : (B, n, d) input batch in standard format.

      Output
      -------------
      out : (B, n, n_cls) output score for each class
      T_feat : (B, latent1, latent1) transformation matrix of the second FeatureTransform module
      """
      x = x.transpose(2,1) # Put input in convolution format
      global_feature, T_feat, x_feat = self.PointNetBase(x)
      x_feat = torch.cat([global_feature.unsqueeze(-1).repeat(1,1,x_feat.shape[-1]), x_feat], dim=1)
      out = self.mlp_seg(x_feat)
      
      return out.transpose(-1,-2), T_feat

In [101]:
out_seg, T_feat = PointNetSeg(260)(torch.rand(2, 2046, 3))
print(out_seg.shape, T_feat.shape) # check the shapes

torch.Size([2, 2046, 260]) torch.Size([2, 64, 64])


## 3.2 Data Loading

We will train on the RNA dataset from Lab3. Since all shapes have different number of vertices, we can't use batch size bigger than one.

Run the following cells to download and define code to load the data.

In [None]:
!pip install potpourri3d
#Uncomment for colab
# !pip install git+https://github.com/skoch9/meshplot.git
# !pip install pythreejs

!wget https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD3/diffusion_utils.py
!wget https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/TD3/RNADataset.zip
!unzip -o RNADataset.zip
!mkdir models
!wget https://www.lix.polytechnique.fr/Labo/Robin.Magnet/INF631/material_TD3.zip
!unzip -o material_TD3.zip

In [91]:
from diffusion_utils import normalize_positions, read_mesh
from mesh_utils.mesh import TriMesh
import plot_utils as plu # Comment if you are not visualizing the results

class RNAMeshDataset(Dataset):
    """RNA Mesh Dataset
    """

    def __init__(self, root_dir, train, n_classes=260):

        """
            root_dir (string): Directory with all the meshes.
            train (bool): If True, use the training set, else use the test set.
            num_eig (int): Number of eigenvalues to use.
            op_cache_dir (string): Directory to cache the operators.
            n_classes (int): Number of classes.
        """

        self.train = train  # bool
        self.root_dir = root_dir
        self.n_class = n_classes # (includes -1)

        # store in memory
        self.verts_list = []
        self.faces_list = []
        self.labels_list = []  # per-vertex

        # Load the meshes & labels
        if self.train:
            with open(os.path.join(self.root_dir, "train.txt")) as f:
                this_files = [line.rstrip() for line in f]
        else:
            with open(os.path.join(self.root_dir, "test.txt")) as f:
                this_files = [line.rstrip() for line in f]

        print("loading {} files: {}".format(len(this_files), this_files))

        # Load the actual files

        off_path = os.path.join(root_dir, "off")
        label_path = os.path.join(root_dir, "labels")
        for f in this_files:
            off_file = os.path.join(off_path, f)
            label_file = os.path.join(label_path, f[:-4] + ".txt")

            verts, faces = read_mesh(off_file)
            labels = np.loadtxt(label_file).astype(int) + 1 # shift -1 --> 0

            verts = torch.tensor(verts).float()
            faces = torch.tensor(faces)
            labels = torch.tensor(labels)

            # center and unit scale
            verts = normalize_positions(verts)

            self.verts_list.append(verts)
            self.faces_list.append(faces)
            self.labels_list.append(labels)


    def __len__(self):
        return len(self.verts_list)

    def __getitem__(self, idx):
        return self.verts_list[idx], self.faces_list[idx], self.labels_list[idx]

## 3.3 Segmentation Loss

PointNet segmentation loss is defined the same as the classification loss, except we know perform per-vertex classification using $L_{seg}$ instead of $L_{cls}$. Adapt again the loss to your network design


### Question 8
Code the segmentation loss. Adapt $L_{seg}$ to whatever activation you used in you architecture. You might be able to use exactly the same code.

In [102]:
def segmentation_loss(preds, gt_labels, T_feat, w_reg, final_activation=None):
        """
        Classification loss for PointNet

        Parameters
        --------------
        preds : (B, n_cls) - output of the segmentation network
        gt_labels : (B, n) - ground truth labels
        T_feat : (B, latent1, latent1) - transformation matrix of the second FeatureTransform module
        w_reg : weight of the regularization term

        Output
        -------------
        loss : segmentation loss L_{seg} + w_reg * L_{reg}
        """
        #segmentation loss
        if isinstance(final_activation, nn.LogSoftmax):
            criterion = nn.NLLLoss()
        else:
            criterion = nn.CrossEntropyLoss()
        loss_seg = criterion(preds, gt_labels)

        #regularization loss
        B, _, _ = T_feat.shape
        I = torch.eye(T_feat.shape[1], device=T_feat.device).unsqueeze(0)
        loss_reg = torch.norm(torch.bmm(T_feat, T_feat.transpose(2,1)) - I, p='fro', dim=(-1,-2)).mean()
        
        return loss_seg + w_reg * loss_reg

## 3.4 Training

This defines a trainer class using your loss. Simply run the cells.

In [107]:
class TrainerSeg(object):

    def __init__(self, train_loader, valid_loader, device='cuda',
                 lr=1e-3, weight_decay=1e-4, num_epochs=200,
                 lr_decay_every = 50, lr_decay_rate = 0.5,
                 log_interval=10, save_dir=None, **model_cfg):

        """
        pointNet_cls: (nn.Module) class of the PointNet (or ++) model.
        train_loader: (torch.utils.DataLoader) DataLoader for training set
        valid_loader: (torch.utils.DataLoader) DataLoader for validation set
        device: (str) 'cuda' or 'cpu'
        lr: (float) learning rate
        weight_decay: (float) weight decay for optimiser
        num_epochs: (int) number of epochs
        lr_decay_every: (int) decay learning rate every this many epochs
        lr_decay_rate: (float) decay learning rate by this factor
        log_interval: (int) print training stats every this many iterations
        save_dir: (str) directory to save model checkpoints
        model_cfg: (dict) keyword arguments for model
        """
        self.model = self.get_model(model_cfg).to(device)
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.device = device
        self.lr = lr
        self.weight_decay = weight_decay
        self.num_epochs = num_epochs

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        self.loss = torch.nn.CrossEntropyLoss()

        self.lr_decay_every = lr_decay_every
        self.lr_decay_rate = lr_decay_rate
        self.log_interval = log_interval
        self.save_dir = save_dir
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []

        self.inp_feat = model_cfg.get('inp_feat', 'xyz')
        self.num_eig = model_cfg.get('num_eig', 128)
        if not self.inp_feat in ['xyz', 'xyzn', 'hks', 'wks']:
            raise ValueError('inp_feat must be one of xyz, xyzn, hks, wks')

        self.model.to(self.device)
        self.final_activation = model_cfg.get('final_activation', None)

    def get_model(self, model_cfg):
        if model_cfg['name'].lower() == 'pointnet':
            return PointNetSeg(n_cls=model_cfg['n_cls'],
                              inp_dim=model_cfg['inp_dim'],
                              seg_hidden_dims=model_cfg['seg_hidden_dims'],
                              mlp2_hidden_dims=model_cfg['mlp2_hidden_dims'],
                              transf_hidden_dims1=model_cfg['transf_hidden_dims1'],
                              transf_hidden_dims2=model_cfg['transf_hidden_dims2'],
                              final_activation=model_cfg['final_activation'])
        else:
            raise ValueError('%s must be one of PointNet, PointNet++'%model_cfg['name'])

    def adjust_lr(self):
        lr = self.lr * self.lr_decay_rate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def forward_step(self, inp):
        """
        Perform a forward step of the model.

        Args:
            inp (torch.Tensor): (N, D) tensor of input features

        Returns:
            pred (torch.Tensor): (N, p_out) tensor of predicted labels.
        """
        inp = inp.to(self.device)
        verts = inp[..., :3].to(self.device)
        norms = inp[..., 3:].to(self.device)
        preds, trans_feat = self.model(inp)
        return preds, trans_feat



    def get_loss(self, preds, gt_labels, trans_feat):
        
        loss = segmentation_loss(preds, gt_labels, trans_feat, 1e-3, self.final_activation)
        return loss

    def train_epoch(self):
        train_loss = 0
        train_acc = 0
        for i, (inp_pts,_, gt_labels) in enumerate(tqdm(self.train_loader)):
            self.optimizer.zero_grad()
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            loss = self.get_loss(preds[0], gt_labels[0], trans_feat)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()

            pred_labels = torch.max(preds[0], dim=1).indices  # (N,)
            this_correct = pred_labels.eq(gt_labels).sum().item()
            train_acc += this_correct / inp_pts.shape[-2]

        return train_loss/len(self.train_loader), train_acc/len(self.train_loader)

    def valid_epoch(self):
        val_loss = 0
        val_acc = 0
        total = 0
        
        for i, (inp_pts, _, gt_labels) in enumerate(self.valid_loader):
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            loss = self.get_loss(preds[0], gt_labels[0], trans_feat)
            val_loss += loss.item()

            pred_labels = torch.max(preds[0], dim=1).indices
            this_correct = pred_labels.eq(gt_labels).sum().item()
            total += inp_pts.shape[-2]
            val_acc += this_correct

        return val_loss/total, val_acc/total

    def run(self):
        for epoch in tqdm(range(self.num_epochs)):
            self.model.train()

            if epoch % self.lr_decay_every == 0:
                self.adjust_lr()

            train_ep_loss, train_ep_acc = self.train_epoch()
            self.train_losses.append(train_ep_loss)
            self.train_accs.append(train_ep_acc)

            if epoch % self.log_interval == 0:
                val_loss, val_acc = self.valid_epoch()
                self.val_losses.append(val_loss)
                self.val_accs.append(val_acc)
                torch.save(self.model.state_dict(), os.path.join(self.save_dir, 'model_latest.pth'))
                print('Epoch: {:03d}, Train Loss: {:.4f}, Train Acc: {:.2f}, Val Loss: {:.4f}, Val Acc: {:.2f}%'.format(epoch,
                                                                                                                       train_ep_loss,
                                                                                                                       1e2*train_ep_acc,
                                                                                                                       val_loss,
                                                                                                                       1e2*val_acc))
        torch.save(self.model.state_dict(), os.path.join(self.save_dir, 'model_final.pth'))

    def visualize(self):
        """
        We only test the first two shapes of validation set.
        """
        self.model.eval()
        test_seg_meshes = []

        for i, (inp_pts, inp_faces, gt_labels) in enumerate(self.valid_loader):
            gt_labels = gt_labels.to(self.device)
            preds, trans_feat = self.forward_step(inp_pts)
            pred_labels = torch.max(preds[0], dim=1).indices
            test_seg_meshes.append([TriMesh(inp_pts.squeeze().cpu().numpy(), inp_faces.squeeze().cpu().numpy()),
                                  pred_labels.squeeze().cpu().numpy()])
            if i==1:
                break

        cmap1 = plt.get_cmap("jet")(test_seg_meshes[0][-1] / (146))[:,:3]
        cmap2 = plt.get_cmap("jet")(test_seg_meshes[1][-1] / (146))[:,:3]

        plu.double_plot(test_seg_meshes[0][0], test_seg_meshes[1][0], cmap1, cmap2)

    def test(self):
        file_final = os.path.join(self.save_dir, 'model_final.pth')
        if not os.path.exists(file_final):
            print('-------------------------------------------------------')
            print('No final weights, switching to last weights')
            print('-------------------------------------------------------')
            file_final = os.path.join(self.save_dir, 'model_latest.pth')
        weights = torch.load(file_final)
        self.model.load_state_dict(weights)
        _, score = self.valid_epoch()
        print('Final Valid Segmentation Accuracy : {:.2f}%'.format(1e2*score))

Run these cells to see

In [45]:
root_dir = './RNADataset/'

train_dataset = RNAMeshDataset(root_dir, train=True)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0, persistent_workers=False)

valid_dataset = RNAMeshDataset(root_dir, train=False)
valid_loader = DataLoader(valid_dataset, batch_size=1, num_workers=0, persistent_workers=False)

loading 60 files: ['5JUO_D.off', '5T62_B.off', '4WCE_Y.off', '4V9K_BB.off', '5E81_16.off', '3J92_7.off', '4V7X_BB.off', '5T7V_C.off', '4TUE_RB.off', '5DGE_3.off', '4V7B_BB.off', '4U4O_3.off', '3JA1_LB.off', '3PIO_Y.off', '5EL7_16.off', '5MDY_3.off', '5H5U_B.off', '4U4O_7.off', '5TBW_AS.off', '5J30_RB.off', '4W4G_RB.off', '4V7J_AB.off', '4V8J_BB.off', '2OTJ_9.off', '4V7X_DB.off', '3J9Y_B.off', '1VQK_9.off', '1Q82_B.off', '4W2H_DB.off', '3G4S_9.off', '1NKW_9.off', '4V5J_BB.off', '3JBN_AB.off', '4V8E_CB.off', '4YZV_YB.off', '4V9P_CB.off', '5J5B_DB.off', '4U6F_7.off', '5OBM_7.off', '3JCD_B.off', '4V55_BA.off', '5IBB_16.off', '1VY7_BB.off', '5GAD_B.off', '4U4Q_3.off', '4WRA_1J.off', '5ADY_A.off', '4V70_BB.off', '4TUB_RB.off', '5AJ0_A4.off', '1KD1_B.off', '4V6N_AA.off', '4V8I_BB.off', '4W4G_YB.off', '4V8F_DB.off', '4V8P_E3.off', '4V5Y_BA.off', '4TUA_YB.off', '5DFE_RB.off', '4V5K_DB.off']
loading 13 files: ['4V83_BB.off', '1VQ6_9.off', '4V7K_BB.off', '4V9A_DB.off', '3JQ4_B.off', '5KCS_1B.off'

In [None]:
#verify shapes of elements in the dataset to correct an error in the previous code (line 98)
for i in range(2):
    verts, faces, labels = train_dataset[i]
    print(verts.shape, faces.shape, labels.shape)

torch.Size([15073, 3]) torch.Size([30142, 3]) torch.Size([15073])
torch.Size([15678, 3]) torch.Size([30906, 3]) torch.Size([15678])


In [108]:
#model_cfg = {'name': 'pointnet', 'n_cls': 260, 'conv_dims': [64, 128, 1024], 'fc_dims': [512, 256]}
model_cfg = dict(name="pointnet", n_cls=260, inp_dim=3, seg_hidden_dims=[512,256], mlp2_hidden_dims=[64,128,1024], 
                 transf_hidden_dims1=[64,128,1024], transf_hidden_dims2=[512,256], final_activation=None)

trainer = TrainerSeg(train_loader, valid_loader, lr=0.001, device='cuda', weight_decay=0.0, num_epochs=300, save_dir=os.path.join(save_dir, 'segmentation'), **model_cfg)
trainer.run()

100%|██████████| 60/60 [00:02<00:00, 28.06it/s]
  0%|          | 1/300 [00:02<11:34,  2.32s/it]

Epoch: 000, Train Loss: 4.9845, Train Acc: 1.16, Val Loss: 0.0003, Val Acc: 1.38%


100%|██████████| 60/60 [00:01<00:00, 34.84it/s]
100%|██████████| 60/60 [00:01<00:00, 34.96it/s]
100%|██████████| 60/60 [00:01<00:00, 35.21it/s]
100%|██████████| 60/60 [00:01<00:00, 34.91it/s]
100%|██████████| 60/60 [00:01<00:00, 35.18it/s]
100%|██████████| 60/60 [00:01<00:00, 34.88it/s]
100%|██████████| 60/60 [00:01<00:00, 35.45it/s]
100%|██████████| 60/60 [00:01<00:00, 35.31it/s]
100%|██████████| 60/60 [00:01<00:00, 34.77it/s]
100%|██████████| 60/60 [00:01<00:00, 34.84it/s]]
  4%|▎         | 11/300 [00:19<08:33,  1.78s/it]

Epoch: 010, Train Loss: 3.2117, Train Acc: 12.03, Val Loss: 0.0002, Val Acc: 13.96%


100%|██████████| 60/60 [00:01<00:00, 30.63it/s]
100%|██████████| 60/60 [00:02<00:00, 29.91it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.16it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.25it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.41it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.05it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.45it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.25it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.22it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.23it/s]]
  7%|▋         | 21/300 [00:39<09:29,  2.04s/it]

Epoch: 020, Train Loss: 2.5731, Train Acc: 22.77, Val Loss: 0.0002, Val Acc: 24.85%


100%|██████████| 60/60 [00:02<00:00, 29.82it/s]
100%|██████████| 60/60 [00:02<00:00, 29.69it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.55it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.16it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.20it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.26it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.27it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.13it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.14it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.02it/s]]
 10%|█         | 31/300 [00:59<09:13,  2.06s/it]

Epoch: 030, Train Loss: 1.9391, Train Acc: 35.74, Val Loss: 0.0001, Val Acc: 41.45%


100%|██████████| 60/60 [00:02<00:00, 29.69it/s]
100%|██████████| 60/60 [00:02<00:00, 29.67it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.68it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.68it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.77it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.50it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.55it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.53it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.67it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.45it/s]]
 14%|█▎        | 41/300 [01:20<09:03,  2.10s/it]

Epoch: 040, Train Loss: 1.5682, Train Acc: 43.82, Val Loss: 0.0001, Val Acc: 43.90%


100%|██████████| 60/60 [00:02<00:00, 29.91it/s]
100%|██████████| 60/60 [00:02<00:00, 29.68it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.79it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.54it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.65it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.41it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.23it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.36it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.42it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.38it/s]]
 17%|█▋        | 51/300 [01:40<08:29,  2.05s/it]

Epoch: 050, Train Loss: 1.3723, Train Acc: 50.08, Val Loss: 0.0001, Val Acc: 49.87%


100%|██████████| 60/60 [00:01<00:00, 30.42it/s]
100%|██████████| 60/60 [00:01<00:00, 30.40it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.51it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.60it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.38it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.46it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.37it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.31it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.40it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.18it/s]]
 20%|██        | 61/300 [02:00<08:08,  2.05s/it]

Epoch: 060, Train Loss: 1.1512, Train Acc: 56.47, Val Loss: 0.0001, Val Acc: 53.98%


100%|██████████| 60/60 [00:02<00:00, 29.76it/s]
100%|██████████| 60/60 [00:02<00:00, 29.86it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.55it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.49it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.40it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.16it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.20it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.15it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.36it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.28it/s]]
 24%|██▎       | 71/300 [02:20<07:48,  2.05s/it]

Epoch: 070, Train Loss: 1.1567, Train Acc: 56.87, Val Loss: 0.0001, Val Acc: 60.89%


100%|██████████| 60/60 [00:01<00:00, 30.57it/s]
100%|██████████| 60/60 [00:01<00:00, 30.35it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.24it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.43it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.42it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.29it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.37it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.35it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.38it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.45it/s]]
 27%|██▋       | 81/300 [02:40<07:26,  2.04s/it]

Epoch: 080, Train Loss: 1.1456, Train Acc: 56.81, Val Loss: 0.0001, Val Acc: 59.61%


100%|██████████| 60/60 [00:01<00:00, 30.00it/s]
100%|██████████| 60/60 [00:02<00:00, 29.73it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.82it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.86it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.82it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.70it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.74it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.62it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.71it/s]]
100%|██████████| 60/60 [00:02<00:00, 29.77it/s]]
 30%|███       | 91/300 [03:01<07:14,  2.08s/it]

Epoch: 090, Train Loss: 1.0561, Train Acc: 59.74, Val Loss: 0.0001, Val Acc: 54.77%


100%|██████████| 60/60 [00:02<00:00, 29.96it/s]
100%|██████████| 60/60 [00:01<00:00, 30.48it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.56it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.51it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.29it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.16it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.43it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.26it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.46it/s]]
100%|██████████| 60/60 [00:01<00:00, 30.22it/s]t]
 34%|███▎      | 101/300 [03:21<06:46,  2.04s/it]

Epoch: 100, Train Loss: 0.9631, Train Acc: 63.53, Val Loss: 0.0001, Val Acc: 58.95%


100%|██████████| 60/60 [00:02<00:00, 29.78it/s]
100%|██████████| 60/60 [00:02<00:00, 29.73it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.67it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.63it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.89it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.61it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.76it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.70it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.74it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.72it/s]t]
 37%|███▋      | 111/300 [03:41<06:34,  2.08s/it]

Epoch: 110, Train Loss: 0.8801, Train Acc: 65.89, Val Loss: 0.0001, Val Acc: 64.48%


100%|██████████| 60/60 [00:02<00:00, 29.95it/s]
100%|██████████| 60/60 [00:02<00:00, 29.81it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.07it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.53it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.43it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.52it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.31it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.37it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.24it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.40it/s]t]
 40%|████      | 121/300 [04:01<06:06,  2.05s/it]

Epoch: 120, Train Loss: 0.7960, Train Acc: 68.93, Val Loss: 0.0001, Val Acc: 67.94%


100%|██████████| 60/60 [00:01<00:00, 30.11it/s]
100%|██████████| 60/60 [00:01<00:00, 30.04it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.74it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.85it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.48it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.46it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.46it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.36it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.60it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.39it/s]t]
 44%|████▎     | 131/300 [04:21<05:44,  2.04s/it]

Epoch: 130, Train Loss: 0.7748, Train Acc: 69.74, Val Loss: 0.0001, Val Acc: 62.81%


100%|██████████| 60/60 [00:02<00:00, 29.77it/s]
100%|██████████| 60/60 [00:02<00:00, 29.85it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.81it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.20it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.26it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.34it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.24it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.24it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.31it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.25it/s]t]
 47%|████▋     | 141/300 [04:42<05:25,  2.05s/it]

Epoch: 140, Train Loss: 0.7698, Train Acc: 70.00, Val Loss: 0.0001, Val Acc: 65.15%


100%|██████████| 60/60 [00:01<00:00, 30.71it/s]
100%|██████████| 60/60 [00:01<00:00, 30.58it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.38it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.68it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.28it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.28it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.23it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.27it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.79it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.55it/s]t]
 50%|█████     | 151/300 [05:02<05:08,  2.07s/it]

Epoch: 150, Train Loss: 0.7137, Train Acc: 71.81, Val Loss: 0.0001, Val Acc: 66.26%


100%|██████████| 60/60 [00:02<00:00, 29.87it/s]
100%|██████████| 60/60 [00:02<00:00, 29.67it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.70it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.69it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.71it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.10it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.56it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.55it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.67it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.02it/s]t]
 54%|█████▎    | 161/300 [05:22<04:48,  2.08s/it]

Epoch: 160, Train Loss: 0.6756, Train Acc: 73.32, Val Loss: 0.0001, Val Acc: 69.22%


100%|██████████| 60/60 [00:02<00:00, 28.77it/s]
100%|██████████| 60/60 [00:02<00:00, 28.47it/s]t]
100%|██████████| 60/60 [00:02<00:00, 28.06it/s]t]
100%|██████████| 60/60 [00:02<00:00, 27.30it/s]t]
100%|██████████| 60/60 [00:02<00:00, 26.95it/s]t]
100%|██████████| 60/60 [00:02<00:00, 28.46it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.13it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.30it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.63it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.64it/s]t]
 57%|█████▋    | 171/300 [05:43<04:27,  2.07s/it]

Epoch: 170, Train Loss: 0.6712, Train Acc: 73.65, Val Loss: 0.0001, Val Acc: 68.39%


100%|██████████| 60/60 [00:01<00:00, 30.01it/s]
100%|██████████| 60/60 [00:02<00:00, 29.84it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.51it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.50it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.46it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.29it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.49it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.31it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.33it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.64it/s]t]
 60%|██████    | 181/300 [06:03<04:04,  2.06s/it]

Epoch: 180, Train Loss: 0.6330, Train Acc: 75.26, Val Loss: 0.0001, Val Acc: 70.02%


100%|██████████| 60/60 [00:02<00:00, 29.88it/s]
100%|██████████| 60/60 [00:02<00:00, 29.65it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.88it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.70it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.76it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.80it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.69it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.65it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.74it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.67it/s]t]
 64%|██████▎   | 191/300 [06:23<03:47,  2.09s/it]

Epoch: 190, Train Loss: 1.2673, Train Acc: 60.53, Val Loss: 0.0001, Val Acc: 44.35%


100%|██████████| 60/60 [00:02<00:00, 29.92it/s]
100%|██████████| 60/60 [00:02<00:00, 29.94it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.73it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.86it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.82it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.71it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.53it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.72it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.59it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.78it/s]t]
 67%|██████▋   | 201/300 [06:44<03:26,  2.09s/it]

Epoch: 200, Train Loss: 0.6045, Train Acc: 76.31, Val Loss: 0.0001, Val Acc: 72.56%


100%|██████████| 60/60 [00:02<00:00, 29.80it/s]
100%|██████████| 60/60 [00:02<00:00, 29.90it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.26it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.06it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.54it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.14it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.61it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.59it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.78it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.76it/s]t]
 70%|███████   | 211/300 [07:04<03:05,  2.08s/it]

Epoch: 210, Train Loss: 0.5967, Train Acc: 76.39, Val Loss: 0.0001, Val Acc: 69.29%


100%|██████████| 60/60 [00:01<00:00, 30.26it/s]
100%|██████████| 60/60 [00:01<00:00, 30.32it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.37it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.42it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.46it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.48it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.27it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.40it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.33it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.28it/s]t]
 74%|███████▎  | 221/300 [07:24<02:41,  2.04s/it]

Epoch: 220, Train Loss: 0.6350, Train Acc: 74.78, Val Loss: 0.0001, Val Acc: 70.77%


100%|██████████| 60/60 [00:02<00:00, 29.92it/s]
100%|██████████| 60/60 [00:02<00:00, 29.84it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.70it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.83it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.62it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.86it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.82it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.83it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.81it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.72it/s]t]
 77%|███████▋  | 231/300 [07:45<02:23,  2.08s/it]

Epoch: 230, Train Loss: 0.6112, Train Acc: 76.17, Val Loss: 0.0001, Val Acc: 72.27%


100%|██████████| 60/60 [00:02<00:00, 29.85it/s]
100%|██████████| 60/60 [00:02<00:00, 29.89it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.75it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.24it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.94it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.85it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.69it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.76it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.51it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.66it/s]t]
 80%|████████  | 241/300 [08:05<02:02,  2.08s/it]

Epoch: 240, Train Loss: 0.5399, Train Acc: 78.59, Val Loss: 0.0001, Val Acc: 72.84%


100%|██████████| 60/60 [00:02<00:00, 29.85it/s]
100%|██████████| 60/60 [00:02<00:00, 29.78it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.71it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.80it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.66it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.64it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.32it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.84it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.72it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.82it/s]t]
 84%|████████▎ | 251/300 [08:25<01:41,  2.08s/it]

Epoch: 250, Train Loss: 0.5534, Train Acc: 77.97, Val Loss: 0.0001, Val Acc: 73.51%


100%|██████████| 60/60 [00:02<00:00, 29.84it/s]
100%|██████████| 60/60 [00:02<00:00, 29.87it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.91it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.78it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.92it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.88it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.82it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.82it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.82it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.72it/s]t]
 87%|████████▋ | 261/300 [08:46<01:21,  2.08s/it]

Epoch: 260, Train Loss: 0.6046, Train Acc: 76.19, Val Loss: 0.0001, Val Acc: 69.90%


100%|██████████| 60/60 [00:02<00:00, 29.89it/s]
100%|██████████| 60/60 [00:01<00:00, 30.05it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.14it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.87it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.79it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.79it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.99it/s]t]
100%|██████████| 60/60 [00:02<00:00, 30.00it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.89it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.91it/s]t]
 90%|█████████ | 271/300 [09:06<01:00,  2.07s/it]

Epoch: 270, Train Loss: 0.5936, Train Acc: 76.60, Val Loss: 0.0001, Val Acc: 71.73%


100%|██████████| 60/60 [00:01<00:00, 30.18it/s]
100%|██████████| 60/60 [00:01<00:00, 30.03it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.89it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.87it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.91it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.54it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.37it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.49it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.36it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.42it/s]t]
 94%|█████████▎| 281/300 [09:26<00:38,  2.04s/it]

Epoch: 280, Train Loss: 0.6135, Train Acc: 76.02, Val Loss: 0.0001, Val Acc: 74.73%


100%|██████████| 60/60 [00:02<00:00, 29.90it/s]
100%|██████████| 60/60 [00:02<00:00, 29.88it/s]t]
100%|██████████| 60/60 [00:01<00:00, 30.03it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.80it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.90it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.91it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.65it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.71it/s]t]
100%|██████████| 60/60 [00:02<00:00, 28.64it/s]t]
100%|██████████| 60/60 [00:02<00:00, 27.70it/s]t]
 97%|█████████▋| 291/300 [09:47<00:19,  2.15s/it]

Epoch: 290, Train Loss: 0.5717, Train Acc: 77.29, Val Loss: 0.0001, Val Acc: 71.44%


100%|██████████| 60/60 [00:02<00:00, 27.62it/s]
100%|██████████| 60/60 [00:02<00:00, 27.67it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.16it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.27it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.04it/s]t]
100%|██████████| 60/60 [00:02<00:00, 28.34it/s]t]
100%|██████████| 60/60 [00:02<00:00, 28.53it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.15it/s]t]
100%|██████████| 60/60 [00:02<00:00, 29.25it/s]t]
100%|██████████| 300/300 [10:06<00:00,  2.02s/it]


Vizualize the results

In [109]:
trainer.visualize()

HBox(children=(Output(), Output()))

HBox(children=(Output(), Output()))

In [110]:
trainer.test()

  weights = torch.load(file_final)


Final Valid Segmentation Accuracy : 69.49%


Final task: save the notebook .ipynb file **with logs**, and create a zip with the checkpoints folder. You can send your results to geometricdeeplearning@protonmail.com