In [1]:
import numpy as np
from tqdm import tqdm
import os
import h5py
import sklearn.metrics as metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.parallel
from torch.utils.data import Dataset, DataLoader
import torch.nn.init as init
import random


In [2]:

class TransformNet(nn.Module):
    def __init__(self):
        super(TransformNet, self).__init__()
        self.k = 3
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), self.bn1, nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), self.bn2, nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False), self.bn3, nn.LeakyReLU(negative_slope=0.2))
        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn3 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn4 = nn.BatchNorm1d(256)
        self.transform = nn.Linear(256, 3 * 3)
        init.constant_(self.transform.weight, 0)
        init.eye_(self.transform.bias.view(3, 3))

    def forward(self, x):
        bs = x.size(0)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.max(dim=-1, keepdim=False)[0]
        x = self.conv3(x)
        x = x.max(dim=-1, keepdim=False)[0]
        x = F.leaky_relu(self.bn3(self.linear1(x)), negative_slope=0.2)
        x = F.leaky_relu(self.bn4(self.linear2(x)), negative_slope=0.2)
        x = self.transform(x).view(bs, 3, 3)
        return x



def knn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (bs, n_points, k)
    return idx


def get_graph_feature(x, k=20):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    idx = knn(x, k=k)  # (bs, n_points, k)
    idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
    idx = idx + idx_base
    idx = idx.view(-1)
    _, num_dims, _ = x.size()
    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous()
    return feature  # (bs, 2*n_dims, n_points, k)


class DGCNN_partseg(nn.Module):
    def __init__(self, seg_num_all, k=20):
        super(DGCNN_partseg, self).__init__()
        self.seg_num_all = seg_num_all
        self.k = k
        self.transform_net = TransformNet()

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(64)
        self.bn5 = nn.BatchNorm2d(64)
        self.bn6 = nn.BatchNorm1d(1024)
        self.bn7 = nn.BatchNorm1d(64)
        self.bn8 = nn.BatchNorm1d(256)
        self.bn9 = nn.BatchNorm1d(256)
        self.bn10 = nn.BatchNorm1d(128)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1, nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn2, nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False),
                                   self.bn3, nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn4, nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False),
                                   self.bn5, nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, 1024, kernel_size=1, bias=False),
                                   self.bn6, nn.LeakyReLU(negative_slope=0.2))
        self.conv7 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
                                   self.bn7, nn.LeakyReLU(negative_slope=0.2))
        self.conv8 = nn.Sequential(nn.Conv1d(1280, 256, kernel_size=1, bias=False),
                                   self.bn8, nn.LeakyReLU(negative_slope=0.2))
        self.dp1 = nn.Dropout(p=0.5)
        self.conv9 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
                                   self.bn9, nn.LeakyReLU(negative_slope=0.2))
        self.dp2 = nn.Dropout(p=0.5)
        self.conv10 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1, bias=False),
                                    self.bn10, nn.LeakyReLU(negative_slope=0.2))
        self.conv11 = nn.Conv1d(128, self.seg_num_all, kernel_size=1, bias=False)

        self.linearfil1 = nn.Conv1d(24, 512, kernel_size=1)
        self.bnfil1 = nn.BatchNorm1d(512)
        self.linearfil2 = nn.Conv1d(512, 256, kernel_size=1)
        self.bnfil2 = nn.BatchNorm1d(256)
        self.linearfil3 = nn.Linear(256, 24)

    def forward(self, x, l):
        bs, n_points = x.size(0), x.size(2)

        # pose selectot
        xf = x.permute(0, 2, 3, 1).view(bs,n_points*3, 24)
        c = F.leaky_relu(self.bnfil1(self.linearfil1(xf.transpose(2, 1))), negative_slope=0.2)
        c = F.leaky_relu(self.bnfil2(self.linearfil2(c)),negative_slope=0.2)
        c = F.adaptive_max_pool1d(c, 1).view(bs, -1)
        c = F.softmax(self.linearfil3(c), dim=1).unsqueeze(-1)
        x = torch.bmm(xf, c).view(bs,n_points, 3).permute(0, 2, 1).contiguous()

        x0 = get_graph_feature(x, k=self.k)  # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        t = self.transform_net(x0)  # (batch_size, 3, 3)
        x = x.transpose(2, 1)  # (batch_size, 3, num_points) -> (batch_size, num_points, 3)
        x = torch.bmm(x, t)  # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3)
        x = x.transpose(2, 1)  # (batch_size, num_points, 3) -> (batch_size, 3, num_points)

        x = get_graph_feature(x, k=self.k)  # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = self.conv1(x)  # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)  # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x1 = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x1, k=self.k)  # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv3(x)  # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv4(x)  # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x2, k=self.k)  # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv5(x)  # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = torch.cat((x1, x2, x3), dim=1)  # (batch_size, 64*3, num_points)

        x = self.conv6(x)  # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = x.max(dim=-1, keepdim=True)[0]  # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1)

        l = l.view(bs, -1, 1)  # (batch_size, num_categoties, 1)
        l = self.conv7(l)  # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)

        x = torch.cat((x, l), dim=1)  # (batch_size, 1088, 1)
        x = x.repeat(1, 1, n_points)  # (batch_size, 1088, num_points)

        x = torch.cat((x, x1, x2, x3), dim=1)  # (batch_size, 1088+64*3, num_points)

        x = self.conv8(x)  # (batch_size, 1088+64*3, num_points) -> (batch_size, 256, num_points)
        x = self.dp1(x)
        x = self.conv9(x)  # (batch_size, 256, num_points) -> (batch_size, 256, num_points)
        x = self.dp2(x)
        x = self.conv10(x)  # (batch_size, 256, num_points) -> (batch_size, 128, num_points)
        x = self.conv11(x)  # (batch_size, 256, num_points) -> (batch_size, seg_num_all, num_points)

        return x


In [3]:


def load_data_partseg(data_dir, partition):
    all_data, all_label, all_seg = [], [], []
    for file_name in open(os.path.join(data_dir, '{}/{}_list.txt'.format(partition, partition))):
        f = h5py.File(os.path.join(data_dir, '{}/{}'.format(partition, file_name)).rstrip(), 'r', swmr=True)
        data, label, seg = f['data'][:].astype('float32'), f['label'][:].astype('int64'), f['pid'][:].astype('int64')
        f.close()
        all_data.append(data)
        all_label.append(label)
        all_seg.append(seg)
    return np.asarray(all_data), np.asarray(all_label), np.asarray(all_seg)


def generate_24_rotations():
    res = []
    for id in [[0, 1, 2], [1, 2, 0], [2, 0, 1]]:
        R = np.identity(3)[:, id].astype(int)
        R1= np.asarray([R[:, 0], R[:, 1], R[:, 2]]).T
        R2 = np.asarray([-R[:, 0], -R[:, 1], R[:, 2]]).T
        R3 = np.asarray([-R[:, 0], R[:, 1], -R[:, 2]]).T
        R4 = np.asarray([R[:, 0], -R[:, 1], -R[:, 2]]).T
        res += [R1, R2, R3, R4]
    for id in [[0, 2, 1], [1, 0, 2], [2, 1, 0]]:
        R = np.identity(3)[:, id].astype(int)
        R1 = np.asarray([-R[:, 0], -R[:, 1], -R[:, 2]]).T
        R2 = np.asarray([-R[:, 0], R[:, 1], R[:, 2]]).T
        R3 = np.asarray([R[:, 0], -R[:, 1], R[:, 2]]).T
        R4 = np.asarray([R[:, 0], R[:, 1], -R[:, 2]]).T
        res += [R1, R2, R3, R4]
    return res


class ShapeNetPartSeg(Dataset):
    def __init__(self, data_dir, partition, class_choice=None):
        self.all_R = generate_24_rotations()
        self.data, self.label, self.seg = load_data_partseg(data_dir, partition)
        self.cat2id = {'airplane': 0, 'bag': 1, 'cap': 2, 'car': 3, 'chair': 4,
                       'earphone': 5, 'guitar': 6, 'knife': 7, 'lamp': 8, 'laptop': 9,
                       'motor': 10, 'mug': 11, 'pistol': 12, 'rocket': 13, 'skateboard': 14, 'table': 15}
        self.seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3]
        self.index_start = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47]
        self.partition = partition
        self.class_choice = class_choice
        if self.class_choice is not None:
            id_choice = self.cat2id[self.class_choice]
            indices = (self.label == id_choice).squeeze()
            self.data, self.label, self.seg = self.data[indices], self.label[indices], self.seg[indices]
            self.seg_num_all, self.seg_start_index = self.seg_num[id_choice], self.index_start[id_choice]
        else:
            self.seg_num_all, self.seg_start_index = 50, 0

    def __getitem__(self, item):
        pt, label, seg = self.data[item].reshape(1, 2048, 3), self.label[item], self.seg[item]
        pclist = []
        nums = np.arange(24)
        #if training data, select pattern among 24 possible patterns
        if self.partition == 'train':
            random.shuffle(nums)
        for i in range(24):
            pointcloud = pt @ self.all_R[nums[i]]
            pclist.append(pointcloud)
        data = np.concatenate(pclist)
        return data.astype('float32'), label, seg

    def __len__(self):
        return self.data.shape[0]
    


In [4]:
def cal_loss(pred, label, smoothing_eps=0.2):
    label = label.contiguous().view(-1)
    n_class = pred.size(1)
    one_hot = torch.zeros_like(pred).scatter(1, label.view(-1, 1), 1)
    one_hot = one_hot * (1 - smoothing_eps) + (1 - one_hot) * smoothing_eps / (n_class - 1)
    log_prb = F.log_softmax(pred, dim=1)
    loss = -(one_hot * log_prb).sum(dim=1).mean()
    return loss



def calculate_shape_IoU(pred_np, seg_np, label, class_choice):
    seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3]
    index_start = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47]
    label = label.squeeze()
    shape_ious = []
    for shape_idx in range(seg_np.shape[0]):  # 2874 shapes in total
        if class_choice is None or label[shape_idx] == class_choice:
            start_index = index_start[label[shape_idx]]
            num = seg_num[label[shape_idx]]
            parts = range(start_index, start_index + num)
            part_ious = []
            for part in parts:
                I = np.sum(np.logical_and(pred_np[shape_idx] == part, seg_np[shape_idx] == part))
                U = np.sum(np.logical_or(pred_np[shape_idx] == part, seg_np[shape_idx] == part))
                iou = 1 if U == 0 else  I / float(U)
                part_ious.append(iou)
            shape_ious.append(np.mean(part_ious))
        else:
            continue
    return np.asarray(shape_ious)



In [5]:


def train(data_dir, log_dir, device, n_epoch=1000, lr=1e-3, bs=16):

    # dataloader
    train_set = ShapeNetPartSeg(data_dir=data_dir, partition='train', class_choice=None)
    test_set = ShapeNetPartSeg(data_dir=data_dir, partition='test', class_choice=None)
    train_generator = DataLoader(train_set, batch_size=bs*2, shuffle=True, num_workers=2, drop_last=True)
    test_generator = DataLoader(test_set, batch_size=bs, shuffle=True, num_workers=2, drop_last=False)
    print('training data size: {}, test data size: {}'.format(train_set.__len__(), test_set.__len__()))

    # initialization
    model = DGCNN_partseg(seg_num_all=50).to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr * 100, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epoch, eta_min=1e-3)
    logger = open(log_dir, 'w').close()
    best_test_iou = 0
    vcand = np.loadtxt('./all_id.txt').astype(np.int)
    print('init done')

    # training and validation
    for epoch in range(n_epoch):
        # training phase
        loss_val = 0.0
        train_true, train_pred, train_label = [], [], []
        model.train()
        for data, label, seg in tqdm(train_generator):
            label_one_hot = np.zeros((label.shape[0], 16))
            label_one_hot[np.arange(label.shape[0]), label[np.arange(label.shape[0])].flatten()] = 1
            label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
            data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device)
            optimizer.zero_grad()
            pred = model(data, label_one_hot).permute(0, 2, 1).contiguous()  # bs * n_points * n_class
            loss = cal_loss(pred.view(-1, 50), seg.view(-1, 1).squeeze())
            loss.backward()
            optimizer.step()
            pred = pred.max(dim=2)[1]
            loss_val += loss.item()
            train_true.append(seg.cpu().numpy() )
            train_pred.append(pred.detach().cpu().numpy())
            train_label.append(label.reshape(-1))
        scheduler.step()
        train_true, train_pred, train_label = np.concatenate(train_true), np.concatenate(train_pred), np.concatenate(train_label)
        curr_train_iou = np.mean(calculate_shape_IoU(train_pred, train_true, train_label, class_choice=None))
        print('Epoch {}, loss {:.2f}, train_iou: {:.2f}'.format(epoch, loss_val, curr_train_iou))

        # testing phase
        if epoch >= 2000 or curr_train_iou >= 0.9 or epoch % 100 == 0:
            model.eval()
            ids = torch.from_numpy(vcand).to(device).long()
            with torch.no_grad():
                test_true, test_pred, test_label = [], [], []
                for data, label, seg in tqdm(test_generator):
                    label_one_hot = np.zeros((label.shape[0], 16))
                    label_one_hot[np.arange(label.shape[0]), label[np.arange(label.shape[0])].flatten()] = 1
                    label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
                    data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device)
                    tmp_pred = []
                    # roll though all 24 combinations
                    for vw in range(24):
                        tmp_pred.append(model(data[:, ids[vw]], label_one_hot).permute(0, 2, 1).contiguous().unsqueeze(0).detach().cpu().numpy())
                    pred = np.sum(np.concatenate(tmp_pred), axis=0)
                    test_true.append(seg.cpu().numpy())
                    test_pred.append(np.argmax(pred,axis=2))
                    test_label.append(label.reshape(-1))
                test_true, test_pred, test_label = np.concatenate(test_true), np.concatenate(test_pred), np.concatenate(test_label)
                curr_test_iou = np.mean(calculate_shape_IoU(test_pred, test_true, test_label, class_choice=None))
                if best_test_iou < curr_test_iou:
                    best_test_iou = curr_test_iou
                    torch.save(model.state_dict(), './partseg_checkpoint.t7')
                print('Epoch {}, curr_test_iou {:.2%}, best_test_iou {:.2%}'.format(epoch, curr_test_iou, best_test_iou))




In [None]:

if __name__ == '__main__':
    data_dir = os.path.join(os.getcwd(), './dataset/partseg/pca/')
    log_dir = 'partseg_log.txt'
    device = torch.device('cuda:4')

    torch.manual_seed(1)
    torch.cuda.manual_seed(1)

    train(data_dir, log_dir, device=device, n_epoch=3000, lr=1e-4, bs=16)

     
    

training data size: 14007, test data size: 2874


  0%|                                             | 0/437 [00:00<?, ?it/s]

init done


100%|███████████████████████████████████| 437/437 [02:55<00:00,  2.50it/s]
  0%|                                             | 0/180 [00:00<?, ?it/s]

Epoch 0, loss 970.52, train_iou: 0.40


 73%|█████████████████████████▋         | 132/180 [04:45<01:43,  2.16s/it]

## 