Explaining Lukas Point Net ++ Regression Model Code

In [1]:
import torch #Import torch package
#Get required DL packages from pytorch geometric
#https://pytorch-geometric.readthedocs.io/en/latest/
from torch_geometric.nn import MLP, knn_interpolate, PointConv, global_max_pool, fps, radius

In [None]:
#Create Set Abstraction Module for use in PointNet
class SAModule(torch.nn.Module): #torch.nn.Module is a base class for all neural network modules
    def __init__(self, ratio, r, nn): #Using __init__ to establish parameters for class instances
        #We define the layers of the network in the __init__ function  
         """
        In the constructor we instantiate parameters and assign them as
        member parameters.
        """
        super().__init__() #super() gives access to methods and properties of a parent or sibling class.
        self.ratio = ratio #Establish ratio as a parameter in __init__
        self.r = r #Add r as parameter in __init__
        self.conv = PointConv(nn, add_self_loops=False) #Add PointNet architecture 

    def forward(self, x, pos, batch):
         """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors. 
        Essentially, we specify how data will pass through the network.
        """
        idx = fps(pos, batch, ratio=self.ratio) # fps(): furthers point sampling function from PointNet++, iteratively samples most distant points with regard to rest of the points
        row, col = radius(pos, pos[idx], self.r, batch, batch[idx], # radius(): Finds for each element in y (pos[idx]) all points in x (pos) within distance r
                          max_num_neighbors=64) #reduce to 64 points as neighbors
        edge_index = torch.stack([col, row], dim=0) #QUESTION: what is edge_index?
        x_dst = None if x is None else x[idx] #Specify x distance if x exists (not sure exactly what is happening)
        x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
        pos, batch = pos[idx], batch[idx]
        return x, pos, batch

class GlobalSAModule(torch.nn.Module): #Creating a global model to predict a single value 
    def __init__(self, nn):
        super().__init__()
        self.nn = nn

    def forward(self, x, pos, batch):
        x = self.nn(torch.cat([x, pos], dim=1)) #torch.cat(): Concatenates input tensors for the given dimension
        x = global_max_pool(x, batch) #Max pooling function to aggregate tensors to a single prediction
        pos = pos.new_zeros((x.size(0), 3)) #torch.Tensor.new_zeroes(): returns a tensor of a given size containing zeroes
        batch = torch.arange(x.size(0), device=batch.device) #torch.arrange(): returns a 1D tensor of size (end-start)/step; 
        #size() provides the total number of elements in an array
        return x, pos, batch #Not sure what these returned objects describe

class FPModule(torch.nn.Module): #What does FP stand for
    def __init__(self, k, nn):
        super().__init__()
        self.k = k
        self.nn = nn

    def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
        x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        x = self.nn(x)
        return x, pos_skip, batch_skip


class Net(torch.nn.Module):
    def __init__(self, num_features):
        super().__init__()

        # Input channels account for both `pos` and node features.
        #The torch geometric MLP() function creates a multilayer perceptron model with various input channel, hidden channel, output channel, and layers sizes
        self.sa1_module = SAModule(0.2, 2, MLP([3 + num_features, 64, 64, 128]))
        self.sa2_module = SAModule(0.25, 8, MLP([128 + 3, 128, 128, 256]))
        self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))

        self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))
        self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))
        self.fp1_module = FPModule(3, MLP([128 + num_features, 128, 128, 128]))

        self.mlp = MLP([128, 128, 128, 1], dropout=0.5,
                       batch_norm=False)

    def forward(self, data):
        sa0_out = (data.x, data.pos, data.batch)
        sa1_out = self.sa1_module(*sa0_out) #*identifier is initialized to a tuple receiving any excess positional parameters, defaulting to the empty tuple
        sa2_out = self.sa2_module(*sa1_out)
        sa3_out = self.sa3_module(*sa2_out)

        fp3_out = self.fp3_module(*sa3_out, *sa2_out)
        fp2_out = self.fp2_module(*fp3_out, *sa1_out)
        x, _, _ = self.fp1_module(*fp2_out, *sa0_out)

        return self.mlp(x)
        