### 1. Dynamic Graph CNN base model.

##### Imports

In [2]:
import os
import sys
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

##### Model Implementation for Semantic Segmentation
Starting with <strong>PointNet</strong> implementation, as it will be our backbone, on top of which we will add the EdgeConv

In [None]:
class PointNet(nn.Module):
    def __init__(self, args, out_channels=40):
        super(PointNet, self).__init__()
        self.conv_1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
        self.conv_2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv_3 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
        self.conv_4 = nn.Conv1d(64, 128, kernel_size=1, bias=False)
        self.conv_5 = nn.Conv1d(128, args.embedded_dims, kernel_size=1, bias=False)
        self.batch_norm_1 = nn.BatchNorm1d(64)
        self.batch_norm_2 = nn.BatchNorm1d(64)
        self.batch_norm_3 = nn.BatchNorm1d(64)
        self.batch_norm_4 = nn.BatchNorm1d(128)
        self.batch_norm_5 = nn.BatchNorm1d(args.embedded_dims)
        self.fc_1 = nn.Linear(args.embedded_dims, 512, bias=False)
        self.batch_norm_6 = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout()
        self.fc_2 = nn.Linear(512, out_channels)
    
    def forward(self, x):
        x = F.relu(self.batch_norm_1(self.conv_1(x)))
        x = F.relu(self.batch_norm_2(self.conv_2(x)))
        x = F.relu(self.batch_norm_3(self.conv_3(x)))
        x = F.relu(self.batch_norm_4(self.conv_4(x)))
        x = F.relu(self.batch_norm_5(self.conv_5(x)))
        x = F.adaptive_max_pool1d(x, 1).squeeze()
        x = F.relu(self.batch_norm_6(self.fc_1(x)))
        x = self.dropout(x)
        x = self.fc_2(x)
        return x

Next, we implement the <strong>EdgeConv</strong> network, specifically tailored for semantic segmentation.

In [None]:
class DGCNN(nn.Module):
    def __init__(self, args, num_clases=13):
        super(DGCNN, self).__init__()
        self.args = args
        self.k = args.k
        self.batch_norm_1 = nn.BatchNorm2d(64)
        self.batch_norm_2 = nn.BatchNorm2d(64)
        self.batch_norm_3 = nn.BatchNorm2d(64)
        self.batch_norm_4 = nn.BatchNorm2d(64)
        self.batch_norm_5 = nn.BatchNorm2d(64)
        self.batch_norm_6 = nn.BatchNorm1d(args.embedded_dims)
        self.batch_norm_7 = nn.BatchNorm1d(512)
        self.batch_norm_8 = nn.BatchNorm1d(256)

        self.conv_1 = nn.Sequential(nn.Conv2d(18, 64, kernel_size=1, bias=False),
                                    self.batch_norm_1,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.conv_2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                    self.batch_norm_2,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.conv_3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                    self.batch_norm_3,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.conv_4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                    self.batch_norm_4,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.conv_5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                    self.batch_norm_5,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.conv_6 = nn.Sequential(nn.Conv1d(192, args.embedded_dims, kernel_size=1, bias=False),
                                    self.batch_norm_6,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.conv_7 = nn.Sequential(nn.Conv1d(1216, 512, kernel_size=1, bias=False),
                                    self.batch_norm_7,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.conv_8 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False),
                                    self.batch_norm_8,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.dropout = nn.Dropout(p=args.dropout)
        self.conv_9 = nn.Conv1d(256, num_clases, kernel_size=1, bias=False)
        
    def forward(self, x):
        batch_size = x.size(0)
        num_points = x.size(2)

        x = get_graph_feature(x, k=self.k, dim9=True)   # (batch_size, 9, num_points) -> (batch_size, 9*2, num_points, k)
        x = self.conv_1(x)                              # (batch_size, 9*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv_2(x)                              # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x1 = x.max(dim=-1, keepdim=False)[0]            # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x1, k=self.k)             # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv_3(x)                              # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv_4(x)                              # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]            # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x2, k=self.k)             # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv_5(x)                              # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]            # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = torch.cat((x1, x2, x3), dim=1)              # (batch_size, 64*3, num_points)

        x = self.conv_6(x)                              # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = x.max(dim=-1, keepdim=True)[0]              # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1)

        x = x.repeat(1, 1, num_points)                  # (batch_size, 1024, num_points)
        x = torch.cat((x, x1, x2, x3), dim=1)           # (batch_size, 1024+64*3, num_points)

        x = self.conv_7(x)                              # (batch_size, 1024+64*3, num_points) -> (batch_size, 512, num_points)
        x = self.conv_8(x)                              # (batch_size, 512, num_points) -> (batch_size, 256, num_points)
        x = self.dropout(x)
        x = self.conv_9(x)                              # (batch_size, 256, num_points) -> (batch_size, 13, num_points)
        
        return x