# Task 2: Graphs in 3D
## 2.3 Meshes

#### 2.3.0 Install and import libraries

In [None]:
import torch
import os
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html

2.1.0+cu118
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch_geometric.nn as pyg_nn
import matplotlib.pyplot as plt
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_max_pool

#### 2.3.1 GNN

In [None]:
from torch_geometric.datasets import ModelNet
import torch_geometric.transforms as T

train_dataset = ModelNet(
    root="ModelNet10",
    train=True,
    pre_transform=T.NormalizeScale()
)

test_dataset = ModelNet(
    root="ModelNet10",
    train=False,
    pre_transform=T.NormalizeScale()
)

Downloading http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip
Extracting ModelNet10/ModelNet10.zip
Processing...
Done!


In [None]:
transform = T.FaceToEdge(remove_faces=False)

train_dataset_mesh = []
for data in train_dataset:
    data.x = torch.arange(data.pos.shape[0], dtype=torch.float).view(-1,1)
    train_dataset_mesh.append(transform(data).to("cuda"))

test_dataset_mesh = []
for data in test_dataset:
    data.x = torch.arange(data.pos.shape[0], dtype=torch.float).view(-1,1)
    test_dataset_mesh.append(transform(data).to("cuda"))

In [None]:
class MeshGNN(torch.nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super(MeshGNN, self).__init__()
        self.conv1 = pyg_nn.GCNConv(1, hidden_dim)
        self.conv2 = pyg_nn.GCNConv(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, train_data):
        x, pos, edge_index, batch = train_data.x, train_data.pos, train_data.edge_index, train_data.batch
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        pooled = global_max_pool(x, batch)
        return self.classifier(pooled)

    @torch.no_grad()
    def eval_accuracy(self, evalloader):
        correct = 0
        total = 0
        for idx, eval_data in enumerate(evalloader):
            pred = self(eval_data).max(dim=1)[1]
            correct += pred.eq(eval_data.y).sum().item()
            total += eval_data.y.shape[0]
        return correct / total

def use_gnn(seed, hidden_dim, train_set, test_set):
    # seed
    torch.manual_seed(seed)

    # model and optimizer
    gnn_model = MeshGNN(hidden_dim=hidden_dim, num_classes=train_dataset.num_classes).to("cuda")
    optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.005, weight_decay=0)

    # data
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

    # train
    best_test_acc = 0
    for epoch in range(20):
        gnn_model.train()

        for idx, train_data in enumerate(train_loader):
            optimizer.zero_grad()
            pred = gnn_model(train_data)
            loss = F.cross_entropy(pred, train_data.y)
            loss.backward()
            optimizer.step()

        # valid
        if (epoch + 1) % 1 == 0:
            gnn_model.eval()
            valid_acc = gnn_model.eval_accuracy(test_loader)
            print("Epoch: {}\tValidation accuracy: {}".format(epoch + 1, valid_acc))

            if valid_acc > best_test_acc:
                best_test_acc = valid_acc

    # test
    return best_test_acc

seed = 114514
test_acc = use_gnn(seed=seed,
                   hidden_dim=16,
                   train_set=train_dataset_mesh,
                   test_set=test_dataset_mesh)
print("Test accuracy: {}\n".format(test_acc))



Epoch: 1	Validation accuracy: 0.1222466960352423
Epoch: 2	Validation accuracy: 0.11013215859030837
Epoch: 3	Validation accuracy: 0.11123348017621146
Epoch: 4	Validation accuracy: 0.11013215859030837
Epoch: 5	Validation accuracy: 0.11013215859030837
Epoch: 6	Validation accuracy: 0.11013215859030837
Epoch: 7	Validation accuracy: 0.11013215859030837
Epoch: 8	Validation accuracy: 0.11013215859030837
Epoch: 9	Validation accuracy: 0.11013215859030837
Epoch: 10	Validation accuracy: 0.11013215859030837
Epoch: 11	Validation accuracy: 0.11013215859030837
Epoch: 12	Validation accuracy: 0.11013215859030837
Epoch: 13	Validation accuracy: 0.11013215859030837
Epoch: 14	Validation accuracy: 0.11013215859030837
Epoch: 15	Validation accuracy: 0.11013215859030837
Epoch: 16	Validation accuracy: 0.11013215859030837
Epoch: 17	Validation accuracy: 0.11013215859030837
Epoch: 18	Validation accuracy: 0.11013215859030837
Epoch: 19	Validation accuracy: 0.11013215859030837
Epoch: 20	Validation accuracy: 0.11013215

#### 2.3.2 GNN with coordinates

Use original edges:

In [None]:
class MeshOriginalEdge(torch.nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super(MeshOriginalEdge, self).__init__()
        self.conv1 = pyg_nn.GCNConv(3, hidden_dim)
        self.conv2 = pyg_nn.GCNConv(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, train_data):
        pos, edge_index, batch = train_data.pos, train_data.edge_index, train_data.batch
        x = self.conv1(pos, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        pooled = global_max_pool(x, batch)
        return self.classifier(pooled)

    @torch.no_grad()
    def eval_accuracy(self, evalloader):
        correct = 0
        total = 0
        for idx, eval_data in enumerate(evalloader):
            pred = self(eval_data).max(dim=1)[1]
            correct += pred.eq(eval_data.y).sum().item()
            total += eval_data.y.shape[0]
        return correct / total

def use_gnn(seed, hidden_dim, train_set, test_set):
    # seed
    torch.manual_seed(seed)

    # model and optimizer
    gnn_model = MeshOriginalEdge(hidden_dim=hidden_dim, num_classes=train_dataset.num_classes).to("cuda")
    optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.005, weight_decay=0)

    # data
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

    # train
    best_test_acc = 0
    for epoch in range(50):
        gnn_model.train()

        for idx, train_data in enumerate(train_loader):
            optimizer.zero_grad()
            pred = gnn_model(train_data)
            loss = F.cross_entropy(pred, train_data.y)
            loss.backward()
            optimizer.step()

        # valid
        if (epoch + 1) % 1 == 0:
            gnn_model.eval()
            valid_acc = gnn_model.eval_accuracy(test_loader)
            print("Epoch: {}\tValidation accuracy: {}".format(epoch + 1, valid_acc))

            if valid_acc > best_test_acc:
                best_test_acc = valid_acc

    # test
    return best_test_acc

seed = 114514
test_acc = use_gnn(seed=seed,
                   hidden_dim=32,
                   train_set=train_dataset_mesh,
                   test_set=test_dataset_mesh)
print("Test accuracy: {}\n".format(test_acc))

Epoch: 1	Validation accuracy: 0.41409691629955947
Epoch: 2	Validation accuracy: 0.6398678414096917
Epoch: 3	Validation accuracy: 0.6519823788546255
Epoch: 4	Validation accuracy: 0.710352422907489
Epoch: 5	Validation accuracy: 0.7125550660792952
Epoch: 6	Validation accuracy: 0.7555066079295154
Epoch: 7	Validation accuracy: 0.7202643171806168
Epoch: 8	Validation accuracy: 0.7257709251101322
Epoch: 9	Validation accuracy: 0.7444933920704846
Epoch: 10	Validation accuracy: 0.762114537444934
Epoch: 11	Validation accuracy: 0.7654185022026432
Epoch: 12	Validation accuracy: 0.7709251101321586
Epoch: 13	Validation accuracy: 0.7544052863436124
Epoch: 14	Validation accuracy: 0.7588105726872246
Epoch: 15	Validation accuracy: 0.75
Epoch: 16	Validation accuracy: 0.7863436123348018
Epoch: 17	Validation accuracy: 0.7775330396475771
Epoch: 18	Validation accuracy: 0.7720264317180616
Epoch: 19	Validation accuracy: 0.789647577092511
Epoch: 20	Validation accuracy: 0.7940528634361234
Epoch: 21	Validation accu

Use dynamically generated edges:    
https://github.com/pyg-team/pytorch_geometric/blob/b0053ce1c193ed3c25ce0adb105558000489dacb/examples/dgcnn_classification.py

In [None]:
transform = T.SamplePoints(1024)
train_dataset_dyn_mesh = [transform(data).to("cuda") for data in train_dataset]
test_dataset_dyn_mesh = [transform(data).to("cuda") for data in test_dataset]

In [None]:
from torch_geometric.nn import MLP, DynamicEdgeConv

class MeshEdgeConv(torch.nn.Module):
    def __init__(self, out_channels, k=20, aggr='max'):
        super().__init__()
        self.conv1 = DynamicEdgeConv(MLP([2 * 3, 64, 64, 64]), k, aggr)
        self.conv2 = DynamicEdgeConv(MLP([2 * 64, 128]), k, aggr)
        self.lin1 = nn.Linear(128 + 64, 1024)
        self.mlp = MLP([1024, 512, 256, out_channels], dropout=0.5, norm=None)

    def forward(self, train_data):
        pos, batch = train_data.pos, train_data.batch
        x1 = self.conv1(pos, batch)
        x2 = self.conv2(x1, batch)
        out = self.lin1(torch.cat([x1, x2], dim=1))
        out = global_max_pool(out, batch)
        out = self.mlp(out)
        return F.log_softmax(out, dim=1)

    @torch.no_grad()
    def eval_accuracy(self, evalloader):
        correct = 0
        total = 0
        for idx, eval_data in enumerate(evalloader):
            pred = self(eval_data).max(dim=1)[1]
            correct += pred.eq(eval_data.y).sum().item()
            total += eval_data.y.shape[0]
        return correct / total

def use_gnn(seed, train_set, test_set):
    # seed
    torch.manual_seed(seed)

    # model and optimizer
    gnn_model = MeshEdgeConv(train_dataset.num_classes, k=20).to("cuda")
    optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.001, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    # data
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

    # train
    best_test_acc = 0
    for epoch in range(50):
        gnn_model.train()

        for idx, train_data in enumerate(train_loader):
            optimizer.zero_grad()
            pred = gnn_model(train_data)
            loss = F.nll_loss(pred, train_data.y)
            loss.backward()
            optimizer.step()
        scheduler.step()

        # valid
        if (epoch + 1) % 1 == 0:
            gnn_model.eval()
            valid_acc = gnn_model.eval_accuracy(test_loader)
            print("Epoch: {}\tValidation accuracy: {}".format(epoch + 1, valid_acc))

            if valid_acc > best_test_acc:
                best_test_acc = valid_acc

    # test
    return best_test_acc

seed = 42
test_acc = use_gnn(seed=seed,
                   train_set=train_dataset_dyn_mesh,
                   test_set=test_dataset_dyn_mesh)
print("Test accuracy: {}\n".format(test_acc))



Epoch: 1	Validation accuracy: 0.7191629955947136
Epoch: 2	Validation accuracy: 0.7731277533039648
Epoch: 3	Validation accuracy: 0.816079295154185
Epoch: 4	Validation accuracy: 0.8050660792951542
Epoch: 5	Validation accuracy: 0.8149779735682819
Epoch: 6	Validation accuracy: 0.8535242290748899
Epoch: 7	Validation accuracy: 0.8667400881057269
Epoch: 8	Validation accuracy: 0.8535242290748899
Epoch: 9	Validation accuracy: 0.8832599118942731
Epoch: 10	Validation accuracy: 0.8513215859030837
Epoch: 11	Validation accuracy: 0.8777533039647577
Epoch: 12	Validation accuracy: 0.8854625550660793
Epoch: 13	Validation accuracy: 0.8810572687224669
Epoch: 14	Validation accuracy: 0.8711453744493393
Epoch: 15	Validation accuracy: 0.8601321585903083
Epoch: 16	Validation accuracy: 0.8843612334801763
Epoch: 17	Validation accuracy: 0.8876651982378855
Epoch: 18	Validation accuracy: 0.8953744493392071
Epoch: 19	Validation accuracy: 0.8821585903083701
Epoch: 20	Validation accuracy: 0.9107929515418502
Epoch: 21	

#### 2.3.3 Rotation-invariant GNNs

*a. LGR-Net:*    
https://github.com/sailor-z/LGR-Net/tree/main

In [None]:
transform = T.SamplePoints(1024, include_normals=True)
rotate_transform = T.Compose([T.RandomRotate(degrees=180, axis=0),
                              T.RandomRotate(degrees=180, axis=1),
                              T.RandomRotate(degrees=180, axis=2),
                              T.SamplePoints(1024, include_normals=True)])

train_ori_lgrnet = [transform(data).to("cuda") for data in train_dataset]
test_ori_lgrnet = [transform(data).to("cuda") for data in test_dataset]
train_rot_lgrnet = [rotate_transform(data).to("cuda") for data in train_dataset]
test_rot_lgrnet = [rotate_transform(data).to("cuda") for data in test_dataset]

In [None]:
def index_points(points, idx):
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

def farthest_point_sample(xyz, npoint):
    device = xyz.device
    B, N, C = xyz.size()
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10

    centroid = torch.mean(xyz, dim=1, keepdim=True) #[B, 1, C]
    dist = torch.sum((xyz - centroid) ** 2, -1)
    farthest = torch.max(dist, -1)[1]

    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    idx = pairwise_distance.topk(k=k+1, dim=-1)[1][:, :, 1:]
    return idx

def grouping(x, k=20, idx=None):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        idx = knn(x, k=k)
    device = torch.device('cuda')
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points
    idx = idx + idx_base
    idx = idx.view(-1)
    _, num_dims, _ = x.size()
    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)

    return feature.permute(0, 3, 1, 2)

def darboux(points, normals, k):
    B, C, N = points.size()
    idx = knn(points, k)

    points_knn = grouping(points, k, idx)
    normals_knn = grouping(normals, k, idx)

    mid = points_knn - points.unsqueeze(-1)

    d = torch.norm(mid, p=2, dim=1)
    l1 = torch.norm(normals, p=2, dim=1).unsqueeze(-1)
    l2 = torch.norm(normals_knn, p=2, dim=1)

    a1 = torch.sum(mid * normals.unsqueeze(-1), dim=1) / (d * l1 + 1e-10)
    a2 = torch.sum(mid * normals_knn, dim=1) / (d * l2 + 1e-10)
    a3 = torch.sum(normals_knn * normals.unsqueeze(-1), dim=1) / (l2 * l1 + 1e-10)

    mid = mid.permute(0, 2, 3, 1).contiguous().view(-1, k, C)
    normals_knn = normals_knn.permute(0, 2, 3, 1).contiguous().view(-1, k, C)
    normals = normals.permute(0, 2, 1).contiguous().view(-1, 1, C)

    v1 = torch.cross(mid, normals.repeat(1, k, 1))
    v2 = torch.cross(v1, normals.repeat(1, k, 1))
    v3 = torch.cross(mid, normals_knn)
    v4 = torch.cross(v3, normals_knn)

    d1 = torch.norm(v1, p=2, dim=-1)
    d2 = torch.norm(v2, p=2, dim=-1)
    d3 = torch.norm(v3, p=2, dim=-1)
    d4 = torch.norm(v4, p=2, dim=-1)

    a4 = torch.sum(v1 * v3, dim=-1) / (d1 * d3 + 1e-10) #[BN, K]
    a4 = a4.view(B, N, k)
    a5 = torch.sum(v2 * v4, dim=-1) / (d2 * d4  + 1e-10) #[BN, K]
    a5 = a5.view(B, N, k)

    a6 = torch.sum(v1 * v4, dim=-1) / (d1 * d4 + 1e-10) #[BN, K]
    a6 = a6.view(B, N, k)
    a7 = torch.sum(v2 * v3, dim=-1) / (d2 * d3  + 1e-10) #[BN, K]
    a7 = a7.view(B, N, k)

    a1.unsqueeze_(1)
    a2.unsqueeze_(1)
    a3.unsqueeze_(1)
    a4.unsqueeze_(1)
    a5.unsqueeze_(1)
    a6.unsqueeze_(1)
    a7.unsqueeze_(1)
    d.unsqueeze_(1)

    return torch.cat([d, a1, a2, a3, a4, a5, a6, a7], dim=1)

def get_graph_feature(x, k=20, idx=None):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        idx = knn(x, k=k)
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points
    idx = idx + idx_base
    idx = idx.view(-1)
    _, num_dims, _ = x.size()
    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)

    feature = torch.cat((x, feature - x), dim=3).permute(0, 3, 1, 2)
    return feature

def global_transform(points, npoints, train, knn):
    points = points.permute(0, 2, 1)
    idx = farthest_point_sample(points, npoints)
    centroids = index_points(points, idx)
    U, S, V = torch.svd(centroids)

    if train == True:
        index = torch.randint(2, (points.size(0), 1, 3)).type(torch.FloatTensor).cuda()
        V_ = V * index
        V -= 2 * V_
    else:
        key_p = centroids[:, 0, :].unsqueeze(1)
        angle = torch.matmul(key_p, V)
        index = torch.le(angle, 0).type(torch.FloatTensor).cuda()
        V_ = V * index
        V -= 2 * V_

    xyz = torch.matmul(points, V).permute(0, 2, 1)

    feature = get_graph_feature(xyz, k=knn)
    return feature

class feature_fusion(nn.Module):
    def __init__(self):
        super(feature_fusion, self).__init__()
        self.conv = nn.Sequential(
          nn.Conv2d(1024, 1024, (1, 1), bias=False),
          nn.BatchNorm2d(1024),
          nn.LeakyReLU(negative_slope=0.2),
        )

    def forward(self, x):
        out = self.conv(x)
        att = F.softmax(out, dim=-1)
        out = x * att
        out = torch.sum(out, dim=-1, keepdim=True)
        return out

class dar_feat(nn.Module):
    def __init__(self, global_feat=True, knn=16, train_idx=True, cv_bias=False):
        super(dar_feat, self).__init__()
        self.knn = knn
        self.train_idx = train_idx
        self.cv_bias = cv_bias

        self.gb_gconv_1 = nn.Sequential(
          nn.Conv2d(6, 64, (1, 1), bias=self.cv_bias),
          nn.BatchNorm2d(64),
          nn.LeakyReLU(negative_slope=0.2),
        )

        self.gb_gconv_2 = nn.Sequential(
          nn.Conv2d(64, 128, (1, 1), bias=self.cv_bias),
          nn.BatchNorm2d(128),
          nn.LeakyReLU(negative_slope=0.2),
        )

        self.gb_gconv_3 = nn.Sequential(
          nn.Conv2d(128, 512, (1, 1), bias=self.cv_bias),
          nn.BatchNorm2d(512),
          nn.LeakyReLU(negative_slope=0.2),
        )

        self.gb_gconv_4 = nn.Sequential(
          nn.Conv2d(512, 1024, (1, 1), bias=self.cv_bias),
          nn.BatchNorm2d(1024),
          nn.LeakyReLU(negative_slope=0.2),
        )

        self.lc_gconv_1 = nn.Sequential(
          nn.Conv2d(8, 64, (1, 1), bias=self.cv_bias),
          nn.BatchNorm2d(64),
          nn.LeakyReLU(negative_slope=0.2),
        )

        self.lc_gconv_2 = nn.Sequential(
          nn.Conv2d(64, 128, (1, 1), bias=self.cv_bias),
          nn.BatchNorm2d(128),
          nn.LeakyReLU(negative_slope=0.2),
        )

        self.lc_gconv_3 = nn.Sequential(
          nn.Conv2d(128, 512, (1, 1), bias=self.cv_bias),
          nn.BatchNorm2d(512),
          nn.LeakyReLU(negative_slope=0.2),
        )

        self.lc_gconv_4 = nn.Sequential(
          nn.Conv2d(512, 1024, (1, 1), bias=self.cv_bias),
          nn.BatchNorm2d(1024),
          nn.LeakyReLU(negative_slope=0.2),
        )

        self.conv = nn.Sequential(
          nn.Conv2d(1024, 2048, (1, 1), bias=self.cv_bias),
          nn.BatchNorm2d(2048),
          nn.LeakyReLU(negative_slope=0.2),
        )
        self.feature_fusion = feature_fusion()

        self.global_feat = global_feat

    def region_pooling(self, num, x):
        _, _, _, k_num = x.size()
        group = torch.chunk(x, num, dim=-1)
        feature = []
        for i in range(num):
            feature += [torch.max(group[i], dim=-1, keepdim=False)[0]]
        feature = torch.stack(feature).permute(1, 2, 3, 0)
        return feature

    def forward(self, points, normals):
        n_pts = points.size(2)
        global_f = global_transform(points, 32, self.train_idx, self.knn)
        local_f = darboux(points, normals, self.knn)

        l_out = self.lc_gconv_1(local_f)
        l_out = self.lc_gconv_2(l_out)
        l_out = F.max_pool2d(l_out, (1, self.knn))
        l_out = self.lc_gconv_3(l_out)
        l_out = self.lc_gconv_4(l_out)

        g_out = self.gb_gconv_1(global_f)
        g_out = self.gb_gconv_2(g_out)
        g_out = F.max_pool2d(g_out, (1, self.knn))
        g_out = self.gb_gconv_3(g_out)
        g_out = self.gb_gconv_4(g_out)

        out = torch.cat([g_out, l_out], dim=-1)
        out = self.feature_fusion(out)

        out = self.conv(out)
        out = F.max_pool2d(out, (n_pts, 1))
        out = out.view(-1, 2048)

        if self.global_feat:
            out = out.view(-1, 2048)
            return out
        else:
            out = out.view(-1, 2048, 1, 1).repeat(1, 1, n_pts, 1)
            return torch.cat([out, g_out, l_out], 1)

class LGRNet(nn.Module):
    def __init__(self, k=10, knn=16, train_idx=True, cv_bias=False):
        super(LGRNet, self).__init__()

        self.class_nums = k
        self.knn = knn
        self.cv_bias = cv_bias
        self.train_idx = train_idx

        self.feat = dar_feat(global_feat=True, knn=self.knn, train_idx=self.train_idx, cv_bias=self.cv_bias)

        self.classify = nn.Sequential(
          nn.Linear(2048, 512, bias=self.cv_bias),
          nn.BatchNorm1d(512),
          nn.Dropout(0.5),
          nn.LeakyReLU(negative_slope=0.2),
          nn.Linear(512, 256),
          nn.BatchNorm1d(256),
          nn.Dropout(0.5),
          nn.LeakyReLU(negative_slope=0.2),
          nn.Linear(256, self.class_nums)
        )
        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias.data, 0)
            elif isinstance(m, nn.Linear):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias.data, 0)
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, train_data):
        points, normals = train_data.pos, train_data.normal
        points = points.reshape(-1, 1024, 3).transpose(1, 2)
        normals = normals.reshape(-1, 1024, 3).transpose(1, 2)
        x = self.feat(points, normals)
        x = self.classify(x)
        return x

    @torch.no_grad()
    def eval_accuracy(self, evalloader):
        correct = 0
        total = 0
        for idx, eval_data in enumerate(evalloader):
            pred = self(eval_data).max(dim=1)[1]
            correct += pred.eq(eval_data.y).sum().item()
            total += eval_data.y.shape[0]
        return correct / total


def use_lgr(seed, train_set, test_set):
    # seed
    torch.manual_seed(seed)

    # model and optimizer
    lgrnet_model = LGRNet().to("cuda")
    optimizer = torch.optim.Adam(lgrnet_model.parameters(), lr=0.001, weight_decay=5e-4)

    # data
    train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=16, shuffle=False)

    # train
    best_test_acc = 0
    for epoch in range(20):
        lgrnet_model.train()

        for idx, train_data in enumerate(train_loader):
            optimizer.zero_grad()
            pred = lgrnet_model(train_data)
            loss = F.cross_entropy(pred, train_data.y)
            loss.backward()
            optimizer.step()

        # valid
        if (epoch + 1) % 1 == 0:
            lgrnet_model.eval()
            valid_acc = lgrnet_model.eval_accuracy(test_loader)
            print("Epoch: {}\tValidation accuracy: {}".format(epoch + 1, valid_acc))

            if valid_acc > best_test_acc:
                best_test_acc = valid_acc

    # test
    return best_test_acc


seed = 42
ori_ori_acc = use_lgr(seed=seed,
                      train_set=train_ori_lgrnet,
                      test_set=test_ori_lgrnet)
print("Accuracy of original training & original test: {}\n".format(ori_ori_acc))

ori_rot_acc = use_lgr(seed=seed,
                      train_set=train_ori_lgrnet,
                      test_set=test_rot_lgrnet)
print("Accuracy of original training & rotated test: {}\n".format(ori_rot_acc))

rot_rot_acc = use_lgr(seed=seed,
                      train_set=train_rot_lgrnet,
                      test_set=test_rot_lgrnet)
print("Accuracy of rotated training & rotated test: {}\n".format(rot_rot_acc))

Epoch: 1	Validation accuracy: 0.6266519823788547
Epoch: 2	Validation accuracy: 0.7257709251101322
Epoch: 3	Validation accuracy: 0.7577092511013216
Epoch: 4	Validation accuracy: 0.7455947136563876
Epoch: 5	Validation accuracy: 0.7720264317180616
Epoch: 6	Validation accuracy: 0.7555066079295154
Epoch: 7	Validation accuracy: 0.7797356828193832
Epoch: 8	Validation accuracy: 0.7984581497797357
Epoch: 9	Validation accuracy: 0.7918502202643172
Epoch: 10	Validation accuracy: 0.7687224669603524
Epoch: 11	Validation accuracy: 0.789647577092511
Epoch: 12	Validation accuracy: 0.7169603524229075
Epoch: 13	Validation accuracy: 0.7819383259911894
Epoch: 14	Validation accuracy: 0.8171806167400881
Epoch: 15	Validation accuracy: 0.8105726872246696
Epoch: 16	Validation accuracy: 0.7951541850220264
Epoch: 17	Validation accuracy: 0.7852422907488987
Epoch: 18	Validation accuracy: 0.7544052863436124
Epoch: 19	Validation accuracy: 0.829295154185022
Epoch: 20	Validation accuracy: 0.8392070484581498
Accuracy of

*b. PointNet:*    
https://github.com/fxia22/pointnet.pytorch/tree/master

In [None]:
transform = T.SamplePoints(1024)
rotate_transform = T.Compose([T.RandomRotate(degrees=180, axis=0),
                              T.RandomRotate(degrees=180, axis=1),
                              T.RandomRotate(degrees=180, axis=2),
                              T.SamplePoints(1024)])

train_ori_pointnet = [transform(data).to("cuda") for data in train_dataset]
test_ori_pointnet = [transform(data).to("cuda") for data in test_dataset]
train_rot_pointnet = [rotate_transform(data).to("cuda") for data in train_dataset]
test_rot_pointnet = [rotate_transform(data).to("cuda") for data in test_dataset]

In [None]:
class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        batchsize = x.size()[0]
        # x = F.relu(self.bn1(self.conv1(x)))
        # x = F.relu(self.bn2(self.conv2(x)))
        # x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        # x = F.relu(self.bn4(self.fc1(x)))
        # x = F.relu(self.bn5(self.fc2(x)))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        iden = torch.tensor([[1,0,0,0,1,0,0,0,1]], dtype=torch.float32).repeat(batchsize, 1).to("cuda")
        x = x + iden
        x = x.view(-1, 3, 3)
        return x

class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform=False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        # if self.feature_transform:
        #     self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        # x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.conv1(x))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        # x = F.relu(self.bn2(self.conv2(x)))
        # x = self.bn3(self.conv3(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat

class PointNetCls(nn.Module):
    def __init__(self, k, num_points, feature_transform=False):
        super(PointNetCls, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()
        self.num_points = num_points

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        # x = F.relu(self.bn1(self.fc1(x)))
        # x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = F.relu(self.fc1(x))
        x = F.relu(self.dropout(self.fc2(x)))
        x = self.fc3(x)
        # return F.log_softmax(x, dim=1), trans, trans_feat
        return F.log_softmax(x, dim=1)

    @torch.no_grad()
    def eval_accuracy(self, evalloader):
        correct = 0
        total = 0
        for idx, eval_data in enumerate(evalloader):
            pred = self(eval_data.pos.reshape(-1,self.num_points,3).transpose(1,2)).max(dim=1)[1]
            correct += pred.eq(eval_data.y).sum().item()
            total += eval_data.y.shape[0]
        return correct / total


def use_pointnet(seed, num_points, train_set, test_set):
    # seed
    torch.manual_seed(seed)

    # model and optimizer
    pointnet_model = PointNetCls(k=train_dataset.num_classes, num_points=num_points).to("cuda")
    optimizer = torch.optim.Adam(pointnet_model.parameters(), lr=0.001, weight_decay=5e-4)

    # data
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

    # train
    best_test_acc = 0
    for epoch in range(50):
        pointnet_model.train()

        for idx, train_data in enumerate(train_loader):
            optimizer.zero_grad()
            pred = pointnet_model(train_data.pos.reshape(-1,num_points,3).transpose(1,2))
            loss = F.nll_loss(pred, train_data.y)
            loss.backward()
            optimizer.step()

        # valid
        if (epoch + 1) % 1 == 0:
            pointnet_model.eval()
            valid_acc = pointnet_model.eval_accuracy(test_loader)
            print("Epoch: {}\tValidation accuracy: {}".format(epoch + 1, valid_acc))

            if valid_acc > best_test_acc:
                best_test_acc = valid_acc

    # test
    return best_test_acc

seed = 42
num_points = 1024
ori_ori_acc = use_pointnet(seed=seed,
                           num_points=num_points,
                           train_set=train_ori_pointnet,
                           test_set=test_ori_pointnet)
print("Accuracy of original training & original test: {}\n".format(ori_ori_acc))

ori_rot_acc = use_pointnet(seed=seed,
                           num_points=num_points,
                           train_set=train_ori_pointnet,
                           test_set=test_rot_pointnet)
print("Accuracy of original training & rotated test: {}\n".format(ori_rot_acc))

rot_rot_acc = use_pointnet(seed=seed,
                           num_points=num_points,
                           train_set=train_rot_pointnet,
                           test_set=test_rot_pointnet)
print("Accuracy of rotated training & rotated test: {}\n".format(rot_rot_acc))

Epoch: 1	Validation accuracy: 0.6288546255506607
Epoch: 2	Validation accuracy: 0.73568281938326
Epoch: 3	Validation accuracy: 0.789647577092511
Epoch: 4	Validation accuracy: 0.7918502202643172
Epoch: 5	Validation accuracy: 0.7863436123348018
Epoch: 6	Validation accuracy: 0.829295154185022
Epoch: 7	Validation accuracy: 0.8325991189427313
Epoch: 8	Validation accuracy: 0.8314977973568282
Epoch: 9	Validation accuracy: 0.8568281938325991
Epoch: 10	Validation accuracy: 0.8502202643171806
Epoch: 11	Validation accuracy: 0.8381057268722467
Epoch: 12	Validation accuracy: 0.8634361233480177
Epoch: 13	Validation accuracy: 0.8634361233480177
Epoch: 14	Validation accuracy: 0.8744493392070485
Epoch: 15	Validation accuracy: 0.8392070484581498
Epoch: 16	Validation accuracy: 0.8667400881057269
Epoch: 17	Validation accuracy: 0.8590308370044053
Epoch: 18	Validation accuracy: 0.8270925110132159
Epoch: 19	Validation accuracy: 0.8667400881057269
Epoch: 20	Validation accuracy: 0.8425110132158591
Epoch: 21	Val

*c. Use original edges:*

In [None]:
transform = T.FaceToEdge(remove_faces=False)
rotate_transform = T.Compose([T.RandomRotate(degrees=180, axis=0),
                              T.RandomRotate(degrees=180, axis=1),
                              T.RandomRotate(degrees=180, axis=2),
                              T.FaceToEdge(remove_faces=False)])

train_ori_edges = [transform(data).to("cuda") for data in train_dataset]
test_ori_edges = [transform(data).to("cuda") for data in test_dataset]
train_rot_edges = [rotate_transform(data).to("cuda") for data in train_dataset]
test_rot_edges = [rotate_transform(data).to("cuda") for data in test_dataset]

In [None]:
class MeshOriginalEdge(torch.nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super(MeshOriginalEdge, self).__init__()
        self.conv1 = pyg_nn.GCNConv(3, hidden_dim)
        self.conv2 = pyg_nn.GCNConv(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, train_data):
        pos, edge_index, batch = train_data.pos, train_data.edge_index, train_data.batch
        x = self.conv1(pos, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        pooled = global_max_pool(x, batch)
        return self.classifier(pooled)

    @torch.no_grad()
    def eval_accuracy(self, evalloader):
        correct = 0
        total = 0
        for idx, eval_data in enumerate(evalloader):
            pred = self(eval_data).max(dim=1)[1]
            correct += pred.eq(eval_data.y).sum().item()
            total += eval_data.y.shape[0]
        return correct / total


def use_edges(seed, hidden_dim, train_set, test_set):
    # seed
    torch.manual_seed(seed)

    # model and optimizer
    gnn_model = MeshOriginalEdge(hidden_dim=hidden_dim, num_classes=train_dataset.num_classes).to("cuda")
    optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.005, weight_decay=0)

    # data
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

    # train
    best_test_acc = 0
    for epoch in range(50):
        gnn_model.train()

        for idx, train_data in enumerate(train_loader):
            optimizer.zero_grad()
            pred = gnn_model(train_data)
            loss = F.cross_entropy(pred, train_data.y)
            loss.backward()
            optimizer.step()

        # valid
        if (epoch + 1) % 1 == 0:
            gnn_model.eval()
            valid_acc = gnn_model.eval_accuracy(test_loader)
            print("Epoch: {}\tValidation accuracy: {}".format(epoch + 1, valid_acc))

            if valid_acc > best_test_acc:
                best_test_acc = valid_acc

    # test
    return best_test_acc


seed = 114514
ori_ori_acc = use_edges(seed=seed,
                        hidden_dim=32,
                        train_set=train_ori_edges,
                        test_set=test_ori_edges)
print("Accuracy of original training & original test: {}\n".format(ori_ori_acc))

ori_rot_acc = use_edges(seed=seed,
                        hidden_dim=32,
                        train_set=train_ori_edges,
                        test_set=test_rot_edges)
print("Accuracy of original training & rotated test: {}\n".format(ori_rot_acc))

rot_rot_acc = use_edges(seed=seed,
                        hidden_dim=32,
                        train_set=train_rot_edges,
                        test_set=test_rot_edges)
print("Accuracy of rotated training & rotated test: {}\n".format(rot_rot_acc))



Epoch: 1	Validation accuracy: 0.41409691629955947
Epoch: 2	Validation accuracy: 0.6398678414096917
Epoch: 3	Validation accuracy: 0.6519823788546255
Epoch: 4	Validation accuracy: 0.710352422907489
Epoch: 5	Validation accuracy: 0.7136563876651982
Epoch: 6	Validation accuracy: 0.7555066079295154
Epoch: 7	Validation accuracy: 0.7180616740088106
Epoch: 8	Validation accuracy: 0.724669603524229
Epoch: 9	Validation accuracy: 0.7455947136563876
Epoch: 10	Validation accuracy: 0.763215859030837
Epoch: 11	Validation accuracy: 0.7687224669603524
Epoch: 12	Validation accuracy: 0.7709251101321586
Epoch: 13	Validation accuracy: 0.7533039647577092
Epoch: 14	Validation accuracy: 0.7588105726872246
Epoch: 15	Validation accuracy: 0.7533039647577092
Epoch: 16	Validation accuracy: 0.7852422907488987
Epoch: 17	Validation accuracy: 0.7775330396475771
Epoch: 18	Validation accuracy: 0.7687224669603524
Epoch: 19	Validation accuracy: 0.7918502202643172
Epoch: 20	Validation accuracy: 0.7951541850220264
Epoch: 21	V

Release GPU

In [None]:
import gc
gc.collect() # Python thing
# torch.cuda.empty_cache() # PyTorch thing
with torch.no_grad():
    torch.cuda.empty_cache()