In [None]:
import numpy as np
from tqdm import tqdm
import os
import h5py
import sklearn.metrics as metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.init as init
import random



In [None]:

class TransformNet(nn.Module):
    def __init__(self):
        super(TransformNet, self).__init__()
        self.k = 3
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), self.bn1, nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), self.bn2, nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False), self.bn3, nn.LeakyReLU(negative_slope=0.2))
        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn3 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn4 = nn.BatchNorm1d(256)
        self.transform = nn.Linear(256, 3 * 3)
        init.constant_(self.transform.weight, 0)
        init.eye_(self.transform.bias.view(3, 3))

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)  # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)  # (batch_size, 64, num_points, k) -> (batch_size, 128, num_points, k)
        x = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)
        x = self.conv3(x)  # (batch_size, 128, num_points) -> (batch_size, 1024, num_points)
        x = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 1024, num_points) -> (batch_size, 1024)
        x = F.leaky_relu(self.bn3(self.linear1(x)), negative_slope=0.2)  # (batch_size, 1024) -> (batch_size, 512)
        x = F.leaky_relu(self.bn4(self.linear2(x)), negative_slope=0.2)  # (batch_size, 512) -> (batch_size, 256)
        x = self.transform(x)  # (batch_size, 256) -> (batch_size, 3*3)
        x = x.view(batch_size, 3, 3)  # (batch_size, 3*3) -> (batch_size, 3, 3)
        return x


def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx


def get_graph_feature(x, k=20, idx=None):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        idx = knn(x, k=k)   # (batch_size, num_points, k)
    idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1)*num_points
    idx = idx + idx_base
    idx = idx.view(-1)
    _, num_dims, _ = x.size()
    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
    return feature



class DGCNN(nn.Module):
    def __init__(self, k=25, output_channels=40):
        super(DGCNN, self).__init__()
        self.k = k
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm1d(1024)
        self.transform_net = TransformNet()

        self.linearsel1 = nn.Conv1d(24, 512, kernel_size=1)
        self.bnsel1 = nn.BatchNorm1d(512)
        self.linearsel2 = nn.Conv1d(512, 256, kernel_size=1)
        self.bnsel2 = nn.BatchNorm1d(256)
        self.linearsel3 = nn.Linear(256, 24)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), self.bn1, nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False), self.bn2, nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False), self.bn3, nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False), self.bn4, nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), self.bn5, nn.LeakyReLU(negative_slope=0.2))
        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(256)
        self.bn5 = nn.BatchNorm1d(1024)
        
        self.linear1 = nn.Linear(1024*2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout()
        self.linear2 = nn.Linear(512, 512)
        self.bn7 = nn.BatchNorm1d(512)
        self.dp2 = nn.Dropout()
        self.linear3 = nn.Linear(512, output_channels)


    def forward(self, x):
        bs = x.size(0)
        
        # pose selector
        xf = x.permute(0, 2, 3, 1).view(bs, 1024*3, 24)
        s = F.leaky_relu(self.bnsel1(self.linearsel1(xf.transpose(2, 1))), negative_slope=0.2)
        s = F.leaky_relu(self.bnsel2(self.linearsel2(s)), negative_slope=0.2)
        s = F.adaptive_max_pool1d(s, 1).view(bs, -1)
        s = F.softmax(self.linearsel3(s), dim=1).unsqueeze(-1)
        x_int = torch.bmm(xf, s).view(bs, 1024, 3).permute(0 ,2, 1)  

        x0 = get_graph_feature(x_int, k=self.k)  # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        t = self.transform_net(x0)  # (batch_size, 3, 3)
        x = x_int.transpose(2, 1)  # (batch_size, 3, num_points) -> (batch_size, num_points, 3)
        x = torch.bmm(x, t)  # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3)
        x = x.transpose(2, 1)
        
        x = get_graph_feature(x, k=self.k)
        x = self.conv1(x)
        x1 = x.max(dim=-1, keepdim=False)[0]
        x = get_graph_feature(x1, k=self.k)
        x = self.conv2(x)
        x2 = x.max(dim=-1, keepdim=False)[0]
        x = get_graph_feature(x2, k=self.k)
        x = self.conv3(x)
        x3 = x.max(dim=-1, keepdim=False)[0]
        x = get_graph_feature(x3, k=self.k)
        x = self.conv4(x)
        x4 = x.max(dim=-1, keepdim=False)[0]
        x = torch.cat((x1, x2, x3, x4), dim=1)
        x = self.conv5(x)
       
        x1 = F.adaptive_max_pool1d(x, 1).view(bs, -1)
        x2 = F.adaptive_avg_pool1d(x, 1).view(bs, -1)
        
        x = torch.cat((x1, x2), 1)
        
        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)
        x = self.dp2(x)
        x = self.linear3(x)
        return x
    
    

In [None]:



"""
dataloader 
"""

def generate_24_rotations():
    res = []
    for id in [[0, 1, 2], [1, 2, 0], [2, 0, 1]]:
        R = np.identity(3)[:, id].astype(int)
        R1= np.asarray([R[:, 0], R[:, 1], R[:, 2]]).T
        R2 = np.asarray([-R[:, 0], -R[:, 1], R[:, 2]]).T
        R3 = np.asarray([-R[:, 0], R[:, 1], -R[:, 2]]).T
        R4 = np.asarray([R[:, 0], -R[:, 1], -R[:, 2]]).T
        res += [R1, R2, R3, R4]
    for id in [[0, 2, 1], [1, 0, 2], [2, 1, 0]]:
        R = np.identity(3)[:, id].astype(int)
        R1 = np.asarray([-R[:, 0], -R[:, 1], -R[:, 2]]).T
        R2 = np.asarray([-R[:, 0], R[:, 1], R[:, 2]]).T
        R3 = np.asarray([R[:, 0], -R[:, 1], R[:, 2]]).T
        R4 = np.asarray([R[:, 0], R[:, 1], -R[:, 2]]).T
        res += [R1, R2, R3, R4]
    return res



class H5Loader(Dataset):
    def __init__(self, list_dir, file_list, partition, vcand):
        self.data_list = []
        self.partition = partition
        self.all_R = generate_24_rotations()
        self.vcand = vcand
        for file_name in open(file_list):
            self.data_list.append(os.path.join(list_dir, file_name).rstrip())

    def __getitem__(self, ind):
        file = h5py.File(self.data_list[ind], 'r', swmr=True)
        data = file['data'][:]
        pt = data[:1024, :].reshape(1, 1024, 3)
        pclist = []
        nums =np.arange(24)
        #if training data, select pattern among 24 possible patterns
        if self.partition == 'train':
            nums = self.vcand[random.randint(0, 23)]
        for i in range(24):
            ptc = pt @ self.all_R[nums[i]]
            pclist.append(ptc)
            
        data = np.concatenate(pclist)
        data, label = torch.from_numpy(data), torch.from_numpy(file['label'][:])
        file.close()
        return data, label

    def __len__(self):
        return len(self.data_list)

    

In [None]:


"""
training function
"""

def cal_loss(pred, label, smoothing_eps=0.2):
    label = label.contiguous().view(-1)
    n_class = pred.size(1)
    one_hot = torch.zeros_like(pred).scatter(1, label.view(-1, 1), 1)
    one_hot = one_hot * (1 - smoothing_eps) + (1 - one_hot) * smoothing_eps / (n_class - 1)
    log_prb = F.log_softmax(pred, dim=1)
    loss = -(one_hot * log_prb).sum(dim=1).mean()
    return loss



def train(data_dir, log_dir, device, n_epoch=500, lr=1e-1, batch_size=16, n_class=40, smoothing=True):
    # dataloader
    vcand = np.loadtxt('./all_id.txt').astype(np.int)
    train_set = H5Loader(os.path.join(data_dir,'train'), os.path.join(data_dir, 'train/train_list.txt'),'train', vcand)
    test_set = H5Loader(os.path.join(data_dir,'test'), os.path.join(data_dir, 'test/test_list.txt'),'test', vcand)
    train_generator = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
    test_generator = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=False)
    train_size, valid_size = train_set.__len__(), test_set.__len__()
    print('training data batches: {}, validation data batches: {}'.format(train_size, valid_size))   
    
    # initialization
    model = DGCNN(output_channels=n_class).to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epoch, eta_min=1e-6)
    max_test_acc = 0.0
    logger = open(log_dir, 'w').close()
    print('init done')
    
    # training and validation
    for epoch in range(n_epoch):  
        # training phase
        loss_val = 0.0
        model.train() 
        correct_train = 0.0
        for inputs, labels in tqdm(train_generator):
            inputs, labels = inputs.to(device).float(), labels.to(device).long().flatten()
            optimizer.zero_grad()
            pred = model(inputs)
            loss = cal_loss(pred, labels)
            correct_train += (torch.max(pred.data, dim=1)[1] == labels).sum().item()
            loss.backward()
            optimizer.step()
            loss_val += loss.item()
        scheduler.step()
        curr_train_acc = correct_train / train_size
        print('Epoch {}: avg_loss is {}, curr_train_acc is {:.5%};'.format(epoch, loss_val, curr_train_acc))

        # testing phase
        if epoch >= 1000 or epoch % 100 == 0:
            model.eval()
            test_pred, test_true = [], []
            with torch.no_grad():
                for inputs, labels in tqdm(test_generator):
                    ids = torch.from_numpy(vcand).to(device).long()
                    inputs, labels = inputs.to(device).float(), labels.to(device).squeeze()
                    tmp_pred = []
                    for vw in range(24):
                        tmp_pred.append(model(inputs[:,ids[vw]]).detach().cpu().numpy().reshape(1,-1,n_class))
                    pred = np.sum(np.concatenate(tmp_pred), axis=0)
                    pred = np.argmax(pred,axis=1)
                    test_pred.append(pred)
                    test_true.append(labels.detach().cpu().numpy())
            
                curr_test_acc = metrics.accuracy_score(np.concatenate(test_true), np.concatenate(test_pred))
                if max_test_acc < curr_test_acc:
                    max_test_acc = curr_test_acc
                    torch.save(model.state_dict(), './modelnet40_checkpoint.t7')
                print('Epoch {}: curr_test_acc is {:.2%};  max_test_acc is {:.2%}'.format(epoch, curr_test_acc, max_test_acc))
          
            # logger to txt
            f = open(log_dir, 'a')
            f.write('Epoch {}: loss {}, curr_acc_train {:.2%}, curr_acc {:.2%},  max_acc is {:.2%}\n'.format(epoch, loss_val, curr_train_acc, curr_test_acc, max_test_acc))
            f.close()
   
    print('training finished')



In [None]:

if __name__ == '__main__':
    
    #directories need to be changed
    data_dir = './dataset/modelnet40/pca/'
    log_dir = 'modelnet40_log.txt'
    device = torch.device('cuda:2')
    
    torch.manual_seed(123)
    torch.cuda.manual_seed(123)
    
    
    train(data_dir, log_dir, device=device, n_epoch=4000, lr=1e-1, batch_size=32, n_class=40, smoothing=True)

