In [1]:
from datasets import PartDataset
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [3]:
d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True)

{'Earphone': 0, 'Motorbike': 1, 'Rocket': 2, 'Car': 3, 'Laptop': 4, 'Cap': 5, 'Skateboard': 6, 'Lamp': 10, 'Guitar': 8, 'Bag': 9, 'Mug': 7, 'Table': 11, 'Airplane': 12, 'Pistol': 13, 'Chair': 14, 'Knife': 15}


In [4]:
len(d.classes)

16

In [5]:
levels = (np.log(2048)/np.log(2)).astype(int)

In [6]:
cutdim = torch.zeros((levels)).long()
#branch = torch.zeros((2048, levels)).long()

In [7]:
def split_ps(point_set):
    #print point_set.size()
    num_points = point_set.size()[0]/2
    diff = point_set.max(dim=0)[0] - point_set.min(dim=0)[0] 
    dim = torch.max(diff, dim = 1)[1][0,0]
    cut = torch.median(point_set[:,dim])[0][0]  
    left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
    right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
    middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))
    
    #if torch.numel(left_idx) > 0:
    #    left_idx = left_idx[:,0]
    #if torch.numel(right_idx) > 0:
    #    right_idx = right_idx[:,0]
    #if torch.numel(middle_idx) > 0:
    #    middle_idx = middle_idx[:,0] 
    
    if torch.numel(left_idx) < num_points:
        left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
    if torch.numel(right_idx) < num_points:
        right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)
    
    left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
    right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
    return left_ps, right_ps, dim 

In [29]:
%%time
points_batch = []
cutdim_batch = []
for i in range(15):
    #print i
    point_set, class_label = d[i]
    tree = [[] for i in range(levels + 1)]
    cutdim = [[] for i in range(levels)]
    tree[0].append(point_set)
    for level in range(levels):
        for item in tree[level]:
            left_ps, right_ps, dim = split_ps(item)
            tree[level+1].append(left_ps)
            tree[level+1].append(right_ps)
            cutdim[level].append(dim)  
            cutdim[level].append(dim)  
    cutdim = [(torch.from_numpy(np.array(item).astype(np.int64))) for item in cutdim]
    points = torch.stack(tree[-1])
    points_batch.append(torch.unsqueeze(torch.squeeze(points), 0).transpose(2,1))
    cutdim_batch.append(cutdim)
#points_v = Variable(torch.unsqueeze(torch.squeeze(points), 0)).transpose(2,1)

CPU times: user 3.62 s, sys: 14 ms, total: 3.64 s
Wall time: 3.65 s


In [50]:
points_v = Variable(torch.cat(points_batch, 0))
cutdim_processed = []
for i in range(len(cutdim_batch[0])):
    cutdim_processed.append(torch.stack([item[i] for item in cutdim_batch], 0))

In [136]:
class KDNet(nn.Module):
    def __init__(self, k = 16):
        super(KDNet, self).__init__()
        self.conv1 = nn.Conv1d(3,8 * 3,1,1)
        self.conv2 = nn.Conv1d(8,32 * 3,1,1)
        self.conv3 = nn.Conv1d(32,64 * 3,1,1)
        self.conv4 = nn.Conv1d(64,64 * 3,1,1)
        self.conv5 = nn.Conv1d(64,64 * 3,1,1)
        self.conv6 = nn.Conv1d(64,128 * 3,1,1)
        self.conv7 = nn.Conv1d(128,256 * 3,1,1)
        self.conv8 = nn.Conv1d(256,512 * 3,1,1)
        self.conv9 = nn.Conv1d(512,512 * 3,1,1)
        self.conv10 = nn.Conv1d(512,512 * 3,1,1)
        self.conv11 = nn.Conv1d(512,1024 * 3,1,1)      
        self.fc = nn.Linear(1024, k)

    def forward(self, x, c):
        def kdconv(x, dim, featdim, sel, conv):
            batchsize = x.size(0)
            x =  F.relu(conv(x))
            x = x.view(-1, featdim, 3, dim)
            x = x.view(-1, featdim, 3 * dim)
            x = x.transpose(1,0).contiguous()
            x = x.view(featdim, 3 * dim * batchsize)
            #print x.size()
            sel = Variable(sel + (torch.arange(0,dim) * 3).repeat(batchsize,1).long()).view(-1,1)
            #print sel.size()
            offset = Variable((torch.arange(0,batchsize) * dim * 3).repeat(dim,1).transpose(1,0).contiguous().long().view(-1,1))
            sel = sel+offset
            
            if x.is_cuda:
                sel = sel.cuda()     
            sel = sel.squeeze()
            
            x = torch.index_select(x, dim = 1, index = sel)   
            x = x.view(featdim, batchsize, dim)
            x = x.transpose(1,0).contiguous()
            x = x.view(-1, featdim, dim/2, 2)
            x = torch.squeeze(torch.max(x, dim = -1)[0], 3)
            return x      
        
        x1 = kdconv(x, 2048, 8, c[-1], self.conv1)
        x2 = kdconv(x1, 1024, 32, c[-2], self.conv2)
        x3 = kdconv(x2, 512, 64, c[-3], self.conv3)
        x4 = kdconv(x3, 256, 64, c[-4], self.conv4)
        x5 = kdconv(x4, 128, 64, c[-5], self.conv5)
        x6 = kdconv(x5, 64, 128, c[-6], self.conv6)
        x7 = kdconv(x6, 32, 256, c[-7], self.conv7)
        x8 = kdconv(x7, 16, 512, c[-8], self.conv8)
        x9 = kdconv(x8, 8, 512, c[-9], self.conv9)
        x10 = kdconv(x9, 4, 512, c[-10], self.conv10)
        x11 = kdconv(x10, 2, 1024, c[-11], self.conv11)
        x11 = x11.view(-1,1024)
        out = F.log_softmax(self.fc(x11))
        return out
        
net = KDNet()

In [137]:
x = net(points_v, cutdim_processed)

In [138]:
points_v.size()

torch.Size([15, 3, 2048])

In [139]:
torch.sum(x).backward()

In [140]:
class_label


 0
[torch.LongTensor of size 1]

In [135]:
x

Variable containing:

Columns 0 to 9 
-2.7780 -2.7592 -2.7585 -2.7613 -2.7265 -2.7407 -2.7838 -2.7462 -2.7619 -2.8141
-2.7769 -2.7602 -2.7597 -2.7613 -2.7270 -2.7394 -2.7838 -2.7458 -2.7614 -2.8118
-2.7770 -2.7599 -2.7598 -2.7613 -2.7266 -2.7392 -2.7840 -2.7456 -2.7613 -2.8123
-2.7778 -2.7603 -2.7586 -2.7621 -2.7264 -2.7403 -2.7828 -2.7451 -2.7613 -2.8128
-2.7629 -2.7719 -2.7881 -2.7567 -2.7186 -2.7342 -2.7802 -2.7661 -2.7773 -2.8205
-2.7630 -2.7726 -2.7879 -2.7574 -2.7176 -2.7360 -2.7803 -2.7648 -2.7785 -2.8198
-2.7772 -2.7599 -2.7598 -2.7613 -2.7266 -2.7392 -2.7840 -2.7455 -2.7613 -2.8123
-2.7772 -2.7596 -2.7599 -2.7602 -2.7266 -2.7393 -2.7854 -2.7465 -2.7623 -2.8118
-2.7775 -2.7594 -2.7576 -2.7616 -2.7268 -2.7406 -2.7856 -2.7473 -2.7626 -2.8133
-2.7644 -2.7734 -2.7887 -2.7560 -2.7164 -2.7352 -2.7774 -2.7660 -2.7785 -2.8201
-2.7775 -2.7595 -2.7582 -2.7616 -2.7261 -2.7408 -2.7848 -2.7466 -2.7621 -2.8136
-2.7766 -2.7607 -2.7601 -2.7613 -2.7271 -2.7392 -2.7846 -2.7454 -2.7610 -2.8119
-2