In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data

from module.dataset import ModelNet40
from module.utils import *

import os, sys
from collections import OrderedDict

import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

init 함수에서 feature_dim을 input으로 받아 Batch Normalization을 수행합니다.
BatchNorm(3), BatchNorm(64), BatchNorm(1024)와 같이 사용됩니다.

In [2]:
class BatchNorm(nn.Module):
    '''
        Perform batch normalization.
        Input: A tensor of size (N, M, feature_dim), or (N, feature_dim, M) 
                                               ( it would be the former case if feature_dim == M ), 
                or (N, feature_dim)
        Output: A tensor of the same size as input.
    '''
    def __init__(self, feature_dim):
        super(BatchNorm, self).__init__()
        self.feature_dim = feature_dim
        self.batchnorm = nn.BatchNorm1d(feature_dim)
        self.permute = Permute((0, 2, 1))
        
    def forward(self, x, _ = None):
        if (len(x.shape) == 3) and (x.shape[-1] == self.feature_dim):
            return self.permute(self.batchnorm(self.permute(x)))
        else:
            return self.batchnorm(x)

init 함수에서 받은 tuple형태의 param 파라미터를 통해 torch.permute_(param)을 실행합니다.
예를 들어, x가 (100, 200, 300)의 shape를 갖고 있을 때 x.permute((0, 1, 2))는 x와 동일합니다.

In [3]:
class Permute(nn.Module):
    def __init__(self, param):
        super(Permute, self).__init__()
        self.param = param
    def forward(self, x):
        return x.permute(self.param)

Fully connected layers로 이루어진 MLP를 구성합니다.
일반적으로 Fully connected layer, Batch normalization, Activation function이 set로 구성됩니다.

init 함수에서 hidden_size 파라미터는 
input dimension부터 hidden dimension들, output dimension까지의 tuple로 입력받습니다.

input dimension부터 output dimension까지를 한번에 입력받는다는 점에 주의해주세요.

batchnorm과 last_activation argument에 따라 옵션을 줄 수 있도록 작성하면 좋지만, 무시하셔도 괜찮습니다.

In [4]:
class MLP(nn.Module):
    def __init__(self, hidden_size, batchnorm = True, last_activation = True):
        super(MLP, self).__init__()
        q = []
        for i in range(len(hidden_size)-1):
            in_dim = hidden_size[i]
            out_dim = hidden_size[i+1]
            q.append(("Linear_%d" % i, nn.Linear(in_dim, out_dim)))
            if (i < len(hidden_size) - 2) or ((i == len(hidden_size) - 2) and (last_activation)):
                if (batchnorm):
                    q.append(("Batchnorm_%d" % i, BatchNorm(out_dim)))
                q.append(("ReLU_%d" % i, nn.ReLU(inplace=True)))
        self.mlp = nn.Sequential(OrderedDict(q))
    def forward(self, x):
        return self.mlp(x)

일반적인 max pooling이 아닌 global maxpooling입니다.
input의 shape가 (B, N, D)일 때, output의 shape가 (B,)

In [5]:
class MaxPooling(nn.Module):
    def __init__(self):
        super(MaxPooling, self).__init__()
    def forward(self, x, dim=1, keepdim = False):
        res, _ = torch.max(x, dim=dim, keepdim = keepdim)
        return res

Graph Convolutional Layer를 작성합니다.

Graph Convolution은 크게 A matrix(adjacency matrix)와 X input(features)를 입력값으로 갖습니다.

weight와 bias를 nn.Parameter로 직접 선언하고, 초기화를 위해 reset_parameters() 함수를 작성하여 사용합니다.

A를 adjacency matrix, X를 input features, X'를 output features, W를 weight, b를 bias라고 할 때,

[X' = A X W + b] 를 구현하시면 됩니다.

In [7]:
class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        # in_features와 out_features는 batch 형태로 들어오지만, weight와 bias는 network이라 batch_norm거치지 않고, in_feature와 곱할때는 torch내에서 broadcasting처리됨
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) 
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

GraphConvolution layers를 BatchNorm과 조합하여 만든 GraphConvolutional Network 모듈입니다.

3개의 세트(GraphConvolution + BatchNorm + ReLU)으로 조합되어 있으며,

각 Graph Convolution의 output features를 Graph Convolution의 input features와 concatenation하여

정보를 확장하도록 설계합니다.

Concatentation에 따른 dimension의 변화를 잘 고려하여 Graph Convolution의 in_dim과 out_dim을 적절히 설정하면 됩니다.

In [24]:
class GCN(nn.Module):
    def __init__(self, n_in, n_hid1, n_hid2, n_out):
        super(GCN, self).__init__()

        self.gc1 = GraphConvolution(n_in, n_hid1)
        self.batchnorm1 = BatchNorm(n_hid1)
        self.gc2 = GraphConvolution(n_in+n_hid1, n_hid2)
        self.batchnorm2 = BatchNorm(n_hid2)
        self.gc3 = GraphConvolution(n_in+n_hid1+n_hid2, n_out)
        self.batchnorm3 = BatchNorm(n_out)
    
    def forward(self, xs, adjs = None):
        if (adjs is None): # dataloader로 부터 xs와 adjs를 같이 받을 예정이지만 만약 adjs가 따로 주어지지 않았다면, xs에 포함되어 있다고 생각
            xs, adjs = xs 
        
        num_points = xs.shape[1]
        
        xs = torch.cat(tuple(xs), dim=0) # batch*(n,d) 로 변환, v1.8 이전의 pytorch에서는 batchwise sparse-multiplication이 제대로 지원되지 않아서 들어간 trick        xs = xs.to(device)               
        adjs = adjs.to(device)
        
        xs1 = torch.cat( (xs, F.relu(self.batchnorm1(self.gc1(xs, adjs)))), dim=1)
        del xs  #memory 절약하기 위해
        xs2 = torch.cat( (xs1, F.relu(self.batchnorm2(self.gc2(xs1, adjs)))), dim=1)
        del xs1 #memory 절약하기 위해
        xs3 = torch.cat( (xs2, F.relu(self.batchnorm3(self.gc3(xs2, adjs)))), dim=1)
        del xs2 #memory 절약하기 위해
        
        res = xs3
        ys = torch.stack(torch.split(res, num_points, dim=0)).to(device) # 다시 (bach,n,d)로 변환
        return ys

Input 또는 중간의 feature를 permutation과 rigid motion에 invariant하도록 만들기 위한 모듈입니다.
- nfeat, 64, 128, 1024로 mapping되는 MLP
- max pooling, batch normalization
- 다시 1024, 512, 256, nfeat*nfeat로 mapping되는 MLP로 구성됩니다.

최종적으로 (B, n_feat*n_feat)의 output을 (B, n_feat, n_feat)로 shape을 변경해 return해주면 됩니다.

TNet에 GCN 모듈을 추가해서 이웃한 포인트들과의 local information을 aggregation 할 수 있도록 설계합니다.

Maxpooling 이전의 전반부를 encoder라고하고, Maxpooling 이후의 후반부를 decoder라고 할 수 있습니다.

encoder와 decoder 사이에 GCN을 추가해서 embedded features끼리 local information을 공유할 수 있도록 합니다.

GCN 모듈은 128, 128, 256, 512로 dimension이 mapping 되도록 설정합니다.

In [25]:
class TNet(nn.Module):
    def __init__(self, nfeat, dropout = 0):
        super(TNet, self).__init__()
        self.nfeat = nfeat
        self.encoder = MLP((nfeat, 64, 128, 128))
        self.gcn = GCN(128, 128, 256, 512)
        self.decoder = nn.Sequential(MaxPooling(), BatchNorm(1024), 
                                     MLP((1024, 512, 256)), nn.Dropout(dropout), MLP((256, nfeat*nfeat)))
        
    def forward(self, x, adjs):
        batch_size = x.shape[0]
        x = self.decoder(self.gcn(self.encoder(x), adjs))
        return x.view(batch_size, self.nfeat, self.nfeat)

위의 모듈들을 모두 조합해 최종적으로 point cloud classification을 위한 pointNet을 구성해보겠습니다.
TNet의 output은 transform matrix이므로 TNet의 input과 matrix multiplication을 해야합니다.

- TNet : (B,N,3) > (B,N,3)
- Encoder : BatchNorm & MLP (B,N,3) > (B,N,3) > (B,N,64) > (B,N,64)
- TNet : (B,N,64) > (B,N,64)

- BatchNorm : (B,N,64) > (B,N,64)
- MLP : (B,N,64) > (B,N,64) > (B,N,128) > (B,N,128)
- GCN : (B,N,128) > (B,N,128) > (B,N,256) > (B,N,512*2)

- Maxpooling : (B,N,1024) > (B,1024)
- Decoder : BatchNorm & MLP (B,1024) > (B,1024) > (B,512) > (B,256)
- Dropout : (B,256) > (B,256)
- FC-layer (Linear or Conv1d) for k classes : (B,256) > (B,nclass)

In [26]:
class PointNetGCN(nn.Module):
    def __init__(self, nfeat, nclass, dropout = 0):
        super(PointNetGCN, self).__init__()

        self.input_transform = TNet(nfeat, 0.1)
        self.encoder = nn.Sequential(BatchNorm(3), MLP((nfeat, 64, 64)))
        self.feature_transform = TNet(64, 0.1)
        self.batchnorm = BatchNorm(64)
        self.mlp = MLP((64, 64, 128, 128)) # tuple형태로 input을 받음
        self.gcn = GCN(128, 128, 256, 512)
        self.maxpooling = MaxPooling()
        self.decoder = nn.Sequential(BatchNorm(1024), MLP((1024, 512, 256)), nn.Dropout(dropout), nn.Linear(256, nclass))

        self.eye64 = torch.eye(64).to(device)
        
    def forward(self, xs, adjs):
        batch_size = xs.shape[0]
        
        transform = self.input_transform(xs, adjs)
        xs = torch.stack([torch.mm(xs[i],transform[i]) for i in range(batch_size)])
        xs = self.encoder(xs)
        
        transform = self.feature_transform(xs, adjs)
        xs = torch.stack([torch.mm(xs[i],transform[i]) for i in range(batch_size)])
        
        xs = self.gcn(self.mlp(self.batchnorm(xs)), adjs)
        xs = self.decoder(self.maxpooling(xs))
        
        if (self.training):
            transform_transpose = transform.transpose(1, 2)
            tmp = torch.stack([torch.mm(transform[i], transform_transpose[i]) for i in range(batch_size)])
            L_reg = ((tmp - self.eye64) ** 2).sum() / batch_size
            
        return (F.log_softmax(xs, dim=1), L_reg) if self.training else F.log_softmax(xs, dim=1)


In [27]:

lr = 0.001
num_points = 128
save_name = "PointNet.pt"

########### loading data ###########
train_data = ModelNet40(num_points)
test_data = ModelNet40(num_points, 'test')

train_size = int(0.9 * len(train_data))
valid_size = len(train_data) - train_size
train_data, valid_data = Data.random_split(train_data, [train_size, valid_size])
valid_data.partition = 'valid'
train_data.partition = 'train'

print("train data size: ", len(train_data))
print("valid data size: ", len(valid_data))
print("test data size: ", len(test_data))

def collate_fn(batch):
    Xs = torch.stack([X for X, _, _ in batch])
    #adjs = [adj for _, adj, _ in batch]
    
    global num_points
    batch_size = len(batch)
    edges = torch.cat( tuple(batch[i][1][0] + i*num_points for i in range(batch_size)), dim=0)
    values = torch.cat( tuple(batch[i][1][1] for i in range(batch_size)), dim=0)
    N = num_points * batch_size
    adjs = torch.sparse.FloatTensor(edges.t(), values, torch.Size([N,N]))
    
    Ys = torch.tensor([Y for _,_, Y in batch], dtype = torch.long)
    return Xs, adjs, Ys

train_iter  = Data.DataLoader(train_data, shuffle = True, batch_size = 32, collate_fn = collate_fn)
valid_iter = Data.DataLoader(valid_data, batch_size = 32, collate_fn = collate_fn)
test_iter = Data.DataLoader(test_data, batch_size = 32, collate_fn = collate_fn)

train data size:  8856
valid data size:  984
test data size:  2468


In [28]:
############### loading model ####################

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = PointNetGCN(nfeat=3, nclass=40, dropout=0.3)
net.to(device)
print(net)

PointNetGCN(
  (input_transform): TNet(
    (encoder): MLP(
      (mlp): Sequential(
        (Linear_0): Linear(in_features=3, out_features=64, bias=True)
        (Batchnorm_0): BatchNorm(
          (batchnorm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (permute): Permute()
        )
        (ReLU_0): ReLU(inplace=True)
        (Linear_1): Linear(in_features=64, out_features=128, bias=True)
        (Batchnorm_1): BatchNorm(
          (batchnorm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (permute): Permute()
        )
        (ReLU_1): ReLU(inplace=True)
        (Linear_2): Linear(in_features=128, out_features=128, bias=True)
        (Batchnorm_2): BatchNorm(
          (batchnorm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (permute): Permute()
        )
        (ReLU_2): ReLU(inplace=True)
      )
    )
    (gcn): GCN(
      (gc1): GraphConvo

In [29]:
############### training #########################
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay = 0.0001)
loss = nn.NLLLoss()

def adjust_lr(optimizer, decay_rate=0.95):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= decay_rate

retrain = True
if os.path.exists(save_name):
    print("Model parameters have already been trained before. Retrain (y) or tune (n) ?")
    ans = input()
    if not (ans == 'y'):
        checkpoint = torch.load(save_name, map_location = device)
        net.load_state_dict(checkpoint["net"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        for g in optimizer.param_groups:
            g['lr'] = lr
        
train_model(train_iter, valid_iter, net, loss, optimizer, device = device, max_epochs = 1000, adjust_lr = adjust_lr,
            early_stop = EarlyStop(patience = 20, save_name = save_name))
    

############### testing ##########################

loss, acc = evaluate_model(test_iter, net, loss)
print('test acc = %.6f' % (acc))

training on  cuda


100%|█| 277/277 [00:34<00:00, 


epoch 0, train loss 3.2410 (acc 0.155375), valid loss 3.0697 (acc 0.195122), time 36.1 sec


100%|█| 277/277 [00:34<00:00, 


epoch 1, train loss 3.0021 (acc 0.190831), valid loss 2.9773 (acc 0.195122), time 36.7 sec


100%|█| 277/277 [00:35<00:00, 


epoch 2, train loss 2.9260 (acc 0.200203), valid loss 2.8720 (acc 0.233740), time 37.8 sec


100%|█| 277/277 [00:36<00:00, 


epoch 3, train loss 2.7817 (acc 0.223916), valid loss 2.7212 (acc 0.270325), time 38.5 sec


100%|█| 277/277 [00:36<00:00, 


epoch 4, train loss 2.7029 (acc 0.246612), valid loss 2.6764 (acc 0.258130), time 38.7 sec


100%|█| 277/277 [00:36<00:00, 


epoch 5, train loss 2.6038 (acc 0.260163), valid loss 2.5283 (acc 0.290650), time 38.4 sec


100%|█| 277/277 [00:37<00:00, 


epoch 6, train loss 2.5312 (acc 0.281165), valid loss 2.4581 (acc 0.301829), time 39.4 sec


100%|█| 277/277 [00:37<00:00, 


epoch 7, train loss 2.4843 (acc 0.294490), valid loss 2.3645 (acc 0.355691), time 39.3 sec


100%|█| 277/277 [00:37<00:00, 


epoch 8, train loss 2.3834 (acc 0.324074), valid loss 2.2985 (acc 0.352642), time 39.1 sec


100%|█| 277/277 [00:37<00:00, 


epoch 9, train loss 2.3230 (acc 0.343270), valid loss 2.2093 (acc 0.385163), time 39.1 sec


100%|█| 277/277 [00:36<00:00, 


epoch 10, train loss 2.2795 (acc 0.356481), valid loss 2.1859 (acc 0.389228), time 38.9 sec


100%|█| 277/277 [00:37<00:00, 


epoch 11, train loss 2.2277 (acc 0.371387), valid loss 2.1956 (acc 0.397358), time 39.1 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:36<00:00, 


epoch 12, train loss 2.1370 (acc 0.390470), valid loss 1.9920 (acc 0.431911), time 38.9 sec


100%|█| 277/277 [00:37<00:00, 


epoch 13, train loss 2.2083 (acc 0.371725), valid loss 2.0918 (acc 0.422764), time 39.3 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:37<00:00, 


epoch 14, train loss 2.1102 (acc 0.398939), valid loss 2.0530 (acc 0.430894), time 39.3 sec
EarlyStopping counter: 2 out of 20


100%|█| 277/277 [00:36<00:00, 


epoch 15, train loss 2.0280 (acc 0.417118), valid loss 1.9303 (acc 0.477642), time 38.9 sec


100%|█| 277/277 [00:37<00:00, 


epoch 16, train loss 1.9340 (acc 0.443993), valid loss 1.8293 (acc 0.487805), time 39.8 sec


100%|█| 277/277 [00:38<00:00, 


epoch 17, train loss 1.8927 (acc 0.448058), valid loss 1.8079 (acc 0.512195), time 41.1 sec


100%|█| 277/277 [00:38<00:00, 


epoch 18, train loss 1.8156 (acc 0.475045), valid loss 1.6744 (acc 0.526423), time 40.5 sec


100%|█| 277/277 [00:38<00:00, 


epoch 19, train loss 1.7717 (acc 0.489612), valid loss 1.8843 (acc 0.493902), time 40.6 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:37<00:00, 


epoch 20, train loss 1.7680 (acc 0.484192), valid loss 1.7411 (acc 0.561992), time 39.5 sec
EarlyStopping counter: 2 out of 20


100%|█| 277/277 [00:37<00:00, 


epoch 21, train loss 1.6510 (acc 0.517728), valid loss 1.5198 (acc 0.568089), time 39.1 sec


100%|█| 277/277 [00:37<00:00, 


epoch 22, train loss 1.6228 (acc 0.524616), valid loss 1.5441 (acc 0.552846), time 39.1 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:37<00:00, 


epoch 23, train loss 1.5777 (acc 0.541667), valid loss 1.4884 (acc 0.588415), time 39.0 sec


100%|█| 277/277 [00:36<00:00, 


epoch 24, train loss 1.5021 (acc 0.560863), valid loss 1.5107 (acc 0.602642), time 38.1 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 25, train loss 1.4310 (acc 0.578591), valid loss 1.3779 (acc 0.630081), time 38.0 sec


100%|█| 277/277 [00:35<00:00, 


epoch 26, train loss 1.4086 (acc 0.579607), valid loss 1.4194 (acc 0.634146), time 37.6 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 27, train loss 1.3574 (acc 0.596996), valid loss 1.2504 (acc 0.635163), time 37.9 sec


100%|█| 277/277 [00:39<00:00, 


epoch 28, train loss 1.4631 (acc 0.571251), valid loss 1.6199 (acc 0.565041), time 41.8 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:39<00:00, 


epoch 29, train loss 1.4779 (acc 0.563234), valid loss 1.3079 (acc 0.628049), time 42.5 sec
EarlyStopping counter: 2 out of 20


100%|█| 277/277 [00:38<00:00, 


epoch 30, train loss 1.3891 (acc 0.589092), valid loss 1.2782 (acc 0.621951), time 40.9 sec
EarlyStopping counter: 3 out of 20


100%|█| 277/277 [00:37<00:00, 


epoch 31, train loss 1.3502 (acc 0.597561), valid loss 1.4205 (acc 0.635163), time 39.7 sec
EarlyStopping counter: 4 out of 20


100%|█| 277/277 [00:39<00:00, 


epoch 32, train loss 1.2868 (acc 0.615402), valid loss 1.5648 (acc 0.653455), time 41.4 sec
EarlyStopping counter: 5 out of 20


100%|█| 277/277 [00:38<00:00, 


epoch 33, train loss 1.2399 (acc 0.635501), valid loss 1.1023 (acc 0.659553), time 40.2 sec


100%|█| 277/277 [00:39<00:00, 


epoch 34, train loss 1.2119 (acc 0.639341), valid loss 1.1314 (acc 0.679878), time 41.6 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:36<00:00, 


epoch 35, train loss 1.1833 (acc 0.647922), valid loss 1.1347 (acc 0.689024), time 38.9 sec
EarlyStopping counter: 2 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 36, train loss 1.1599 (acc 0.651987), valid loss 1.0299 (acc 0.697154), time 37.7 sec


100%|█| 277/277 [00:37<00:00, 


epoch 37, train loss 1.1358 (acc 0.653907), valid loss 1.1188 (acc 0.701220), time 39.3 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:37<00:00, 


epoch 38, train loss 1.1378 (acc 0.657859), valid loss 1.0660 (acc 0.718496), time 39.8 sec
EarlyStopping counter: 2 out of 20


100%|█| 277/277 [00:37<00:00, 


epoch 39, train loss 1.1077 (acc 0.663731), valid loss 1.2017 (acc 0.701220), time 39.5 sec
EarlyStopping counter: 3 out of 20


100%|█| 277/277 [00:37<00:00, 


epoch 40, train loss 1.0961 (acc 0.665876), valid loss 0.9678 (acc 0.723577), time 39.7 sec


100%|█| 277/277 [00:38<00:00, 


epoch 41, train loss 1.0885 (acc 0.667231), valid loss 0.9593 (acc 0.723577), time 40.1 sec


100%|█| 277/277 [00:36<00:00, 


epoch 42, train loss 1.0577 (acc 0.675474), valid loss 1.2343 (acc 0.711382), time 38.7 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 43, train loss 1.0310 (acc 0.684395), valid loss 0.9714 (acc 0.731707), time 37.7 sec
EarlyStopping counter: 2 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 44, train loss 1.0117 (acc 0.689250), valid loss 0.9833 (acc 0.754065), time 37.7 sec
EarlyStopping counter: 3 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 45, train loss 0.9984 (acc 0.692977), valid loss 1.0014 (acc 0.730691), time 37.9 sec
EarlyStopping counter: 4 out of 20


100%|█| 277/277 [00:36<00:00, 


epoch 46, train loss 0.9791 (acc 0.705849), valid loss 0.9264 (acc 0.751016), time 38.0 sec


100%|█| 277/277 [00:35<00:00, 


epoch 47, train loss 0.9687 (acc 0.700429), valid loss 1.5428 (acc 0.744919), time 37.8 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:36<00:00, 


epoch 48, train loss 0.9411 (acc 0.712737), valid loss 1.1077 (acc 0.737805), time 38.2 sec
EarlyStopping counter: 2 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 49, train loss 0.9334 (acc 0.707430), valid loss 1.0888 (acc 0.768293), time 36.7 sec
EarlyStopping counter: 3 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 50, train loss 0.9348 (acc 0.710479), valid loss 0.8285 (acc 0.768293), time 36.4 sec


100%|█| 277/277 [00:34<00:00, 


epoch 51, train loss 0.9350 (acc 0.715108), valid loss 1.1449 (acc 0.760163), time 36.6 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 52, train loss 0.9260 (acc 0.714770), valid loss 0.9079 (acc 0.756098), time 36.6 sec
EarlyStopping counter: 2 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 53, train loss 0.9076 (acc 0.718722), valid loss 0.8786 (acc 0.767276), time 36.8 sec
EarlyStopping counter: 3 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 54, train loss 0.8846 (acc 0.726513), valid loss 0.8562 (acc 0.789634), time 37.3 sec
EarlyStopping counter: 4 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 55, train loss 0.8807 (acc 0.731594), valid loss 0.7615 (acc 0.781504), time 37.6 sec


100%|█| 277/277 [00:35<00:00, 


epoch 56, train loss 0.8667 (acc 0.727981), valid loss 0.8809 (acc 0.785569), time 37.6 sec
EarlyStopping counter: 1 out of 20


100%|█| 277/277 [00:36<00:00, 


epoch 57, train loss 0.8559 (acc 0.737805), valid loss 0.7822 (acc 0.781504), time 38.1 sec
EarlyStopping counter: 2 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 58, train loss 0.8594 (acc 0.736111), valid loss 0.7884 (acc 0.781504), time 37.7 sec
EarlyStopping counter: 3 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 59, train loss 0.8464 (acc 0.730691), valid loss 1.1100 (acc 0.782520), time 36.4 sec
EarlyStopping counter: 4 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 60, train loss 0.8464 (acc 0.740063), valid loss 0.9403 (acc 0.771341), time 36.4 sec
EarlyStopping counter: 5 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 61, train loss 0.8211 (acc 0.748419), valid loss 1.1093 (acc 0.774390), time 37.8 sec
EarlyStopping counter: 6 out of 20


100%|█| 277/277 [00:36<00:00, 


epoch 62, train loss 0.8314 (acc 0.743451), valid loss 0.9534 (acc 0.764228), time 38.3 sec
EarlyStopping counter: 7 out of 20


100%|█| 277/277 [00:36<00:00, 


epoch 63, train loss 0.8255 (acc 0.743338), valid loss 0.8625 (acc 0.780488), time 38.1 sec
EarlyStopping counter: 8 out of 20


100%|█| 277/277 [00:36<00:00, 


epoch 64, train loss 0.8125 (acc 0.746048), valid loss 0.9995 (acc 0.776423), time 38.1 sec
EarlyStopping counter: 9 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 65, train loss 0.8162 (acc 0.745483), valid loss 0.7698 (acc 0.778455), time 37.9 sec
EarlyStopping counter: 10 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 66, train loss 0.7913 (acc 0.750113), valid loss 0.8880 (acc 0.780488), time 37.7 sec
EarlyStopping counter: 11 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 67, train loss 0.7975 (acc 0.752597), valid loss 0.9810 (acc 0.790650), time 36.4 sec
EarlyStopping counter: 12 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 68, train loss 0.7733 (acc 0.760953), valid loss 0.8313 (acc 0.795732), time 36.3 sec
EarlyStopping counter: 13 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 69, train loss 0.7591 (acc 0.768631), valid loss 0.9235 (acc 0.778455), time 36.3 sec
EarlyStopping counter: 14 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 70, train loss 0.7655 (acc 0.761969), valid loss 0.7947 (acc 0.786585), time 36.1 sec
EarlyStopping counter: 15 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 71, train loss 0.7737 (acc 0.754855), valid loss 0.9962 (acc 0.787602), time 36.3 sec
EarlyStopping counter: 16 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 72, train loss 0.7584 (acc 0.765921), valid loss 1.0623 (acc 0.785569), time 36.1 sec
EarlyStopping counter: 17 out of 20


100%|█| 277/277 [00:35<00:00, 


epoch 73, train loss 0.7500 (acc 0.764905), valid loss 0.9480 (acc 0.771341), time 36.9 sec
EarlyStopping counter: 18 out of 20


100%|█| 277/277 [00:36<00:00, 


epoch 74, train loss 0.7490 (acc 0.763889), valid loss 0.8062 (acc 0.799797), time 38.1 sec
EarlyStopping counter: 19 out of 20


100%|█| 277/277 [00:34<00:00, 


epoch 75, train loss 0.7380 (acc 0.765808), valid loss 0.8180 (acc 0.785569), time 36.5 sec
EarlyStopping counter: 20 out of 20
test acc = 0.792950


In [33]:
def build_knn(self, input, k):
    batch_size, num_points, _ = input.size()

    inner = torch.matmul(input, input.transpose(1,2)) # (B, N, D) @ (B, D, N) > (B, N, N)
    xx = torch.sum(input**2, dim=2, keepdim=True) # (B, N, D) > (B, N, 1)
    pairwise_distance = -1 * (xx + xx.transpose(2,1) - 2*inner) # (B, N, 1) + (B, 1, N) -2*(B, N, N) > (B, N, N) , topk_indices는 제일 큰값 k개 return, 우리가 필요한 것은 제일 작은값 k개 이므로 -1

    _, topk_indices = pairwise_distance.topk(k=k, dim=-1) # (B, N, K)
    base_indices = torch.arange(num_points)[None, :, None].repeat(batch_size, 1, k).to(topk_indices.device) #(B, N, K)
    indices = torch.stack([base_indices.view(batch_size, -1), topk_indices.view(batch_size, -1)], dim=1) # (B, 2, N*K)

    adj = torch.stack([torch.sparse.FloatTensor(indices[b], (1/k) * torch.ones_like(indices[b][0], dtype=torch.float), torch.Size([num_points, num_points])) for b in range(batch_size)])

    return adj