In [None]:
# Multi Layer Perceptron
class MLP(nn.Module):
   def __init__(self, input_size, output_size):
     super().__init__()
     self.input_size   = input_size
     self.output_size  = output_size
     self.conv  = nn.Conv1d(self.input_size, self.output_size, 1)
     self.bn    = nn.BatchNorm1d(self.output_size)

   def forward(self, input):
     return F.relu(self.bn(self.conv(input)))

# Fully Connected with Batch Normalization
class FC_BN(nn.Module):
   def __init__(self, input_size, output_size):
     super().__init__()
     self.input_size   = input_size
     self.output_size  = output_size
     self.lin  = nn.Linear(self.input_size, self.output_size)
     self.bn    = nn.BatchNorm1d(self.output_size)

   def forward(self, input):
     return F.relu(self.bn(self.lin(input)))

class TNet(nn.Module):
   def __init__(self, k=3):
      super().__init__()
      self.k=k

      self.mlp1 = MLP(self.k, 64)
      self.mlp2 = MLP(64, 128)
      self.mlp3 = MLP(128, 1024)

      self.fc_bn1 = FC_BN(1024, 512)
      self.fc_bn2 = FC_BN(512,256)

      self.fc3 = nn.Linear(256,k*k)
    

   def forward(self, input):
      # input.shape == (batch_size,n,3)
      
      bs = input.size(0)
      xb = self.mlp1(input)
      xb = self.mlp2(xb)
      xb = self.mlp3(xb)

      pool = nn.MaxPool1d(xb.size(-1))(xb)
      flat = nn.Flatten(1)(pool)

      xb = self.fc_bn1(flat)
      xb = self.fc_bn2(xb)
      
      # initialize as identity
      init = torch.eye(self.k, requires_grad=True).repeat(bs,1,1)
      if xb.is_cuda:
        init=init.cuda()
      matrix = self.fc3(xb).view(-1,self.k,self.k) + init
      return matrix

In [None]:
class PointNet(nn.Module):
   def __init__(self):
        super().__init__()
        self.input_transform = TNet(k=3)

        ###########################################################
        ################## INSERT YOUR CODE HERE ##################
        ###########################################################
        self.mlp_64_1 = MLP(3, 64)
        self.mlp_64_2 = MLP(64, 64)
        self.feature_transform = TNet(k=64)

        self.mlp_128 = MLP(64, 128)
        self.mlp_1024 = MLP(128, 1024)

   def forward(self, input):
        n_pts = input.size()[2]
        matrix3x3 = self.input_transform(input)
        input_transform_output = torch.bmm(torch.transpose(input,1,2), matrix3x3).transpose(1,2)

        ###########################################################
        ################## INSERT YOUR CODE HERE ##################
        ###########################################################
        mlp_1_res = self.mlp_64_1(input_transform_output)
        mlp_2_res = self.mlp_64_2(mlp_1_res)
        matrix64x64 = self.feature_transform(mlp_2_res)
        feature_transform_output = torch.bmm(torch.transpose(mlp_2_res,1,2), matrix64x64).transpose(1,2)
        
        feature_extraction = self.mlp_1024(self.mlp_128(feature_transform_output))

        global_feature = nn.MaxPool1d(feature_extraction.size(-1))(feature_extraction)

        global_feature_repeated = nn.Flatten(1)(global_feature).repeat(n_pts,1,1).transpose(0,2).transpose(0,1)

        return [feature_transform_output, global_feature_repeated], matrix3x3, matrix64x64


In [None]:
class PointNetSeg(nn.Module):
    def __init__(self, classes = 3):
        super().__init__()
        self.pointnet = PointNet()

        ###########################################################
        ################## INSERT YOUR CODE HERE ##################
        ###########################################################
        self.mlp_512_r = MLP(1088, 512)
        self.mlp_256_r = MLP(512, 256)
        self.mlp_128_r = MLP(256, 128)
        self.mlp_3_r = MLP(128, 3)

        self.logsoftmax = nn.LogSoftmax(dim=1)
        

    def forward(self, input):
        inputs, matrix3x3, matrix64x64 = self.pointnet(input)
        stack = torch.cat(inputs,1)


        
        ###########################################################
        ################## INSERT YOUR CODE HERE ##################
        ###########################################################
        output = self.mlp_3_r(self.mlp_128_r(self.mlp_256_r(self.mlp_512_r(stack))))


        return self.logsoftmax(output), matrix3x3, matrix64x64


In [None]:
def pointNetLoss(outputs, labels, m3x3, m64x64, alpha = 0.0001):
    criterion = torch.nn.NLLLoss()
    bs=outputs.size(0)
    id3x3 = torch.eye(3, requires_grad=True).repeat(bs,1,1)
    id64x64 = torch.eye(64, requires_grad=True).repeat(bs,1,1)
    if outputs.is_cuda:
        id3x3=id3x3.cuda()
        id64x64=id64x64.cuda()
    diff3x3 = id3x3-torch.bmm(m3x3,m3x3.transpose(1,2))
    diff64x64 = id64x64-torch.bmm(m64x64,m64x64.transpose(1,2))
    return criterion(outputs, labels) + alpha * (torch.norm(diff3x3)+torch.norm(diff64x64)) / float(bs)


In [None]:
pointnet = PointNetSeg()
pointnet.to(device);

optimizer = torch.optim.Adam(pointnet.parameters(), lr=0.005)

def train(model, train_loader, val_loader=None,  epochs=15, save=True):
    best_val_acc = -1.0
    for epoch in range(epochs): 
        pointnet.train()
        running_loss = 0.0

        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs = inputs.to(device).float()
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs, m3x3, m64x64 = pointnet(inputs.transpose(1,2))
            loss = pointNetLoss(outputs, labels, m3x3, m64x64)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 10 == 9 or True:    # print every 10 mini-batches
                    print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
                    running_loss = 0.0

        pointnet.eval()
        correct = total = 0

        # validation
        with torch.no_grad():
            for data in val_loader:
                inputs, labels = data
                inputs = inputs.to(device).float()
                labels = labels.to(device)
                outputs, __, __ = pointnet(inputs.transpose(1,2))
                _, predicted = torch.max(outputs.data, 1)
                
                total   += labels.size(0) * labels.size(1)
                correct += (predicted == labels).sum().item()

        print("correct", correct, "/", total)
        val_acc = 100.0 * correct / total
        print('Valid accuracy: %d %%' % val_acc)

        # save the model
        if save and val_acc > best_val_acc:
            best_val_acc = val_acc
            path = os.path.join(drive_path, "MyDrive", "pointnetmodel.yml")
            print("best_val_acc:", val_acc, "saving model at", path)
            torch.save(pointnet.state_dict(), path)

train(pointnet, train_loader, val_loader, save=True)