In [93]:
import nn_utils
import torch
import torch.nn as nn
def get_edge_features(x, k):
    """
    Args:
        x: point cloud [B, dims, N]
        k: kNN neighbours
    Return:
        [B, 2dims, N, k]    
    """
    B, dims, N = x.shape

    # batched pair-wise distance
    xt = x.permute(0, 2, 1)
    xi = -2 * torch.bmm(xt, x)
    xs = torch.sum(xt**2, dim=2, keepdim=True)
    xst = xs.permute(0, 2, 1)
    dist = xi + xs + xst # [B, N, N]

    # get k NN id    
    _, idx = torch.sort(dist, dim=2)
    idx = idx[: ,: ,1:k+1] # [B, N, k]
    idx = idx.contiguous().view(B, N*k)

    # gather
    neighbors = []
    for b in range(B):
        tmp = torch.index_select(x[b], 1, idx[b]) # [d, N*k] <- [d, N], 0, [N*k]
        tmp = tmp.view(dims, N, k)
        neighbors.append(tmp)
    neighbors = torch.stack(neighbors) # [B, d, N, k]

    # centralize
    central = x.unsqueeze(3) # [B, d, N, 1]
    central = central.repeat(1, 1, 1, k) # [B, d, N, k]

    ee = torch.cat([central, neighbors-central], dim=1)
    assert ee.shape == (B, 2*dims, N, k)
    return ee

class conv2dbr(nn.Module):
    """ Conv2d-bn-relu
    [B, Fin, H, W] -> [B, Fout, H, W]
    """
    def __init__(self, Fin, Fout, kernel_size, stride=1):
        super(conv2dbr, self).__init__()
        self.fout = Fout
        self.kernel_size = kernel_size[1]
        self.conv = nn.Conv2d(Fin, Fout, kernel_size, stride)
        self.conv2 = nn.Conv2d(Fout,Fin,1,1)
        self.bn = nn.BatchNorm2d(Fout)
        self.ac = nn.ReLU(True)

    def forward(self, x):
        
        batch_size,Fin,H,W = x.shape
        print('batch_size,Fin,H,W:',x.shape)
        x1 = self.conv(x) # [B, Fout, H, W]
        print(x.shape)
        x1 = x1.view(batch_size,-1,self.fout)   #[B,HW, Fout] 
        print('x1:',x1.shape)
        x2 = torch.transpose(x1, 2, 1)      #[B,Fout,HW]     
        print('x2:',x2.shape)
        x3 = torch.bmm(x1,x2)    #[B,HW,HW] 
        print('x3:',x3.shape)
        x4 = x1    #[B,HW, Fout] 
        print('x4:',x4.shape)
        x5 = torch.bmm(x3,x4)
        print('x5:',x5.shape)
        x6 = x5.view(batch_size,self.fout,H,W//self.kernel_size)
        print('x6:',x6.shape)
        x = self.conv(x6)
        x = self.bn(x)
        x = self.ac(x)
        return x
    
# class conv2dbr(nn.Module):
#     """ Conv2d-bn-relu
#     [B, Fin, H, W] -> [B, Fout, H, W]
#     """
#     def __init__(self, Fin, Fout, kernel_size, stride=1):
#         super(conv2dbr, self).__init__()
#         self.kernel_size = kernel_size[1]
#         self.conv = nn.Conv2d(Fin, Fout, kernel_size, stride)
#         self.bn = nn.BatchNorm2d(Fout)
#         self.ac = nn.ReLU(True)

#     def forward(self, x):
#         print(x.shape)
#         print(self.kernel_size)
#         x = self.conv(x) # [B, Fout, H, W]
#         print('x:',x.shape)
#         x = self.bn(x)
#         x = self.ac(x)
#         return x
    
class upsample_edgeConv(nn.Module):
    """ Edge Convolution using 1x1 Conv h
    [B, Fin, N] -> [B, Fout, N]
    """
    def __init__(self, Fin, Fout, k):
        super(upsample_edgeConv, self).__init__()
        self.k = k
        self.Fin = Fin
        self.Fout = Fout
        # print(2*Fin, Fout)
        self.conv = conv2dbr(2*Fin, Fout, [1, 20], [1, 20])
        # self.conv = conv2dbr(2*Fin, Fin//2, [1, 20], [1, 20])

    def forward(self, x):
        B, Fin, N = x.shape
        x = get_edge_features(x, self.k); # [B, 2Fin, N, k]
        # print('econv: {}'.format(x.shape))
        # x = self.conv(x) # [B, Fout, N, k]
        # print('econv: {}'.format(x.shape))

        # --------------interpolate---------------------
        # print(input_org[0, :, 0, 0:4])
        BB, CC, NN, KK = x.size()
        inp = x.view(BB*CC*NN, 1, KK)
        oup = torch.nn.functional.interpolate(inp, scale_factor=2, mode='nearest')
        # print(oup.shape)
        oup = oup.view(BB, CC, NN, KK*2)
        # print(oup.shape)

        x = self.conv(oup) # [B, Fout, N, k]
        # print(x.shape)

        # exit()
        # if Fin == 12:
        #     x = x.view(BB, 3, 2*NN)
        # else:
        #     x = x.view(BB, Fin//2, 2*NN)

        x = x.view(BB, self.Fout, 2*NN)
        print(x.shape)
        # print(x.shape)

        # x, _ = torch.max(x, 3) # [B, Fout, N]
        # print('econv: {}'.format(x.shape))
        # assert x.shape == (B, self.Fout, 2*N)
        # exit()
        return x


In [94]:
x = torch.randn(50,128,32)

In [95]:
k = upsample_edgeConv(128,512,20)

In [96]:
y = k(x)

batch_size,Fin,H,W: torch.Size([50, 256, 32, 40])
torch.Size([50, 256, 32, 40])
x1: torch.Size([50, 64, 512])
x2: torch.Size([50, 512, 64])
x3: torch.Size([50, 64, 64])
x4: torch.Size([50, 64, 512])
x5: torch.Size([50, 64, 512])
x: torch.Size([50, 512, 32, 2])
torch.Size([50, 512, 64])


In [10]:
y.shape

torch.Size([50, 512, 64])