In [1]:
import torch

### FPS方法

In [74]:
import torch
from torch.autograd import Variable

def farthest_point_sample(xyz, npoint): 

    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        result: sampled pointcloud , [B, npoint, 3]
    """
    
    device = xyz.device
    B, N, C = xyz.shape
    result = torch.zeros(B, npoint, C).to(device)
    
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)     
    distance = torch.ones(B, N).to(device) * 1e10                       

    batch_indices = torch.arange(B, dtype=torch.long).to(device)        
    
    barycenter = torch.sum((xyz), 1)                                    
    barycenter = barycenter/xyz.shape[1]
    barycenter = barycenter.view(B, 1, 3)

    dist = torch.sum((xyz - barycenter) ** 2, -1)
    farthest = torch.max(dist,1)[1]                                     

    for i in range(npoint):
        centroids[:, i] = farthest                                      
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)        
        dist = torch.sum((xyz - centroid) ** 2, -1)                     
        mask = dist < distance
        distance[mask] = dist[mask]                                     
        farthest = torch.max(distance, -1)[1]                           
    
    for b in range(0, B):
        for index in range(0, npoint):
            # todo 后面是不是可以改成生成式的形式？
            result[b, index] = xyz[b, centroids[b, index]]
    
    # r = [xyz[b, centroids[b, i]].item() for b in range(0, B) for i in range(0, npoint)]
    
    return result

In [75]:
r = farthest_point_sample(sim_data, 4)
r

tensor([[[0.9679, 0.3245, 0.2039],
         [0.0267, 0.4265, 0.7912],
         [0.2179, 0.8170, 0.0824],
         [0.4956, 0.9449, 0.9740]],

        [[0.9840, 0.6193, 0.0824],
         [0.3369, 0.8668, 0.9643],
         [0.2591, 0.0509, 0.2863],
         [0.5662, 0.3188, 0.8705]]])

In [76]:
r.shape

torch.Size([2, 4, 3])

### Slice方法

In [None]:
def get_slice(point_set, xyz_dim, index_slice, npoints):
    device = point_set.device
    B, _, C = point_set.shape
    result = torch.zeros((B, npoints, C)).to(device)
    
    def get_slice_index(index_slice, ratio_per_slice, overlap_ratio, num_all_points):
        start_index = index_slice * (ratio_per_slice - overlap_ratio) * num_all_points
        end_index = start_index + ratio_per_slice * num_all_points
        return int(start_index), int(end_index)

    def get_1_slice(point_set, xyz_dim, index_slice, npoints):
        # xyz_dim: 0, 1, 2 for x, y, z
        start_index, end_index = get_slice_index(index_slice, 0.4, 0.1, len(point_set))
        patch_index = torch.argsort(point_set, dim=0)[start_index: end_index, xyz_dim]
        patch = point_set[patch_index]
        random.shuffle(patch)
        if len(patch_index) > npoints:
            patch = fps(patch, npoints)
        return patch

    for b in range(0, B):
        result[b] = get_1_slice(point_set[b], xyz_dim, index_slice, npoints)
    return result

### Slice 测试

### Cube方法

In [80]:
def get_cube(point_set, side_length, npoints):
    device = point_set.device
    B, _, C = point_set.shape
    result = torch.zeros((B, npoints, C)).to(device)

    def point_in_cube(point_xyz, side_length):
        flag = True
        for i in range(0, len(point_xyz)):
            if abs(point_xyz[i]) >= (side_length / 2):
                flag = False
                break
        return flag

    def get_1_cube(point_set, side_length, npoints):
        sample_index = []
        for i in range(0, len(point_set)):
            if point_in_cube(point_set[i], side_length):
                sample_index.append(i)
        if len(sample_index) >= npoints:
            r = fps(point_set[sample_index], npoints)
            return r
        else:
            return get_1_cube(point_set, side_length + 0.1, npoints)

    for i in range(point_set.shape[0]):
        result[i] = get_1_cube(point_set[i], side_length, npoints)
    return result

### Cube方法测试

### Sphere方法

In [None]:
def get_sphere(point_set, radius, npoints):
    device = point_set.device
    B, _, C = point_set.shape
    result = torch.zeros((B, npoints, C)).to(device)

    def point_in_ball(point_xyz, center_xyz, radius):
        flag = False
        dist = 0
        for i in range(3):
            dist += (point_xyz[i] - center_xyz[i]) ** 2
        if dist <= radius ** 2:
            flag = True
        return flag

    def get_1_sphere(point_set, radius, npoints):
        center_xyz = torch.zeros([1, 3])
        sample_index = []
        for i in range(0, len(point_set)):
            if point_in_ball(point_set[i], center_xyz, radius):
                sample_index.append(i)
        if len(sample_index) >= npoints:
            r = fps(point_set[sample_index], npoints)
            return r
        else:
            return get_1_sphere(point_set, radius + 0.1, npoints)

    for i in range(point_set.shape[0]):
        result[i] = get_1_sphere(point_set[i], radius, npoints)
    return result

In [82]:
import numpy as np

x = np.random.random()
type(x)

float

In [90]:
drop_idx = np.where(np.random.random(20) <= 0.875)[0]
type(drop_idx)

numpy.ndarray

In [85]:
s = np.random.uniform(-1,0,10)
type(s)

numpy.ndarray

In [86]:
s

array([-0.91935565, -0.10548497, -0.00923473, -0.33601561, -0.34504973,
       -0.99699527, -0.522479  , -0.23381304, -0.37536566, -0.00571747])

In [87]:
st = torch.from_numpy(s)
st

tensor([-0.9194, -0.1055, -0.0092, -0.3360, -0.3450, -0.9970, -0.5225, -0.2338,
        -0.3754, -0.0057], dtype=torch.float64)

In [91]:
rand_point_set = torch.ones((4, 1000, 3))
rand_point_set.shape

torch.Size([4, 1000, 3])

In [94]:
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
    device = batch_pc.device
    # batch_pc: BxNx3
    for b in range(batch_pc.shape[0]):
        dropout_ratio = np.random.random()*max_dropout_ratio  # 0~0.875
        drop_idx = np.where(np.random.random((batch_pc.shape[1])) <= dropout_ratio)[0]
        drop_idx = torch.from_numpy(drop_idx).to(device)
        r = batch_pc.clone()
        if len(drop_idx) > 0:
            r[b, drop_idx, :] = batch_pc[b, 0, :]  # set to the first point
    return r

In [95]:
o1 = random_point_dropout(rand_point_set)
o1.shape

torch.Size([4, 1000, 3])

In [96]:
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
    """ Randomly scale the point cloud. Scale is per point cloud.
        Input:
            BxNx3 array, original batch of point clouds
        Return:
            BxNx3 array, scaled batch of point clouds
    """
    B, N, C = batch_data.shape
    device = batch_data.device
    scales = np.random.uniform(scale_low, scale_high, B)
    scales = torch.from_numpy(scales).to(device)
    for batch_index in range(B):
        batch_data[batch_index, :, :] *= scales[batch_index]
    return batch_data

In [98]:
o2 = random_scale_point_cloud(rand_point_set)
o2.shape

torch.Size([4, 1000, 3])

In [99]:
def shift_point_cloud(batch_data, shift_range=0.1):
    """ Randomly shift point cloud. Shift is per point cloud.
        Input:
          BxNx3 array, original batch of point clouds
        Return:
          BxNx3 array, shifted batch of point clouds
    """
    B, N, C = batch_data.shape
    device = batch_data.device
    shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
    shifts = torch.from_numpy(shifts).to(device)
    for batch_index in range(B):
        batch_data[batch_index, :, :] += shifts[batch_index, :]
    return batch_data

In [100]:
o3 = shift_point_cloud(rand_point_set)
o3.shape

torch.Size([4, 1000, 3])

### Model 修改

In [None]:
class SimAttention_5(nn.Module):
    def __init__(self,
                 aug_function,
                 sub_function,
                 slice_function,
                 cube_function,
                 sphere_function,
                 online_encoder,
                 crossed_attention_method):
        super().__init__()
        self.aug_function = aug_function
        self.sub_function = sub_function
        self.slice_function = slice_function
        self.cube_function = cube_function
        self.sphere_function = sphere_function
        self.online_encoder = online_encoder
        self.target_encoder = None

        self.crossed_attention = crossed_attention_method

    def forward(self, x):
        aug1, aug2 = self.aug_function(x), self.aug_function(x)
        sub1, sub2 = self.sub_function(aug1, 1024), self.sub_function(aug2, 1024)
        slice1, slice2 = self.slice_function(aug1, 1, 1, 1024), self.slice_function(aug2, 1, 1, 1024)
        cube1, cube2 = self.cube_function(aug1, 0.2, 1024), self.cube_function(aug2, 0.2, 1024)
        sphere1, sphere2 = self.sphere_function(aug1, 0.2, 1024), self.sphere_function(aug2, 0.1, 1024)

        # [B, 1, N_f] N_f: output dimension of mlp: 512
        sub_feature_1 = self.online_encoder(sub1)
        sub_feature_3 = self.online_encoder(sub2)

        # with momentum encoder
        with torch.no_grad():
            if self.target_encoder is None:
                self.target_encoder = copy.deepcopy(self.online_encoder)
            else:
                for online_params, target_params in zip(self.online_encoder.parameters(),
                                                        self.target_encoder.parameters()):
                    target_weight, online_weight = target_params.data, online_params.data
                    # moving average decay is tao
                    tao = 0.99
                    target_params.data = target_weight * tao + (1 - tao) * online_weight
            for parameter in self.target_encoder.parameters():
                parameter.requires_grad = False
            sub_feature_2 = self.target_encoder(sub2)
            sub_feature_4 = self.target_encoder(sub1)

        # slice feature [B, 1, N_f]
        slice_feature_1 = self.online_encoder(slice1)
        slice_feature_2 = self.online_encoder(slice2)

        # cube feature  [B, 1, N_f]
        cube_feature_1 = self.online_encoder(cube1)
        cube_feature_2 = self.online_encoder(cube2)

        # sphere feature [B, 1, N_f]
        sphere_feature_1 = self.online_encoder(sphere1)
        sphere_feature_2 = self.online_encoder(sphere2)

        # crop feature concat [B, 3, N_f]
        crop_feature_1 = torch.cat((slice_feature_1, cube_feature_1, sphere_feature_1), dim=1)
        crop_feature_2 = torch.cat((slice_feature_2, cube_feature_2, sphere_feature_2), dim=1)
        # [B, 6, N_f]
        crop_feature = torch.cat((crop_feature_1, crop_feature_2), dim=1)

        # attention feature
        attn_feature_1 = self.crossed_attention(sub_feature_1, crop_feature)
        attn_feature_2 = self.crossed_attention(sub_feature_2, crop_feature)
        attn_feature_3 = self.crossed_attention(sub_feature_3, crop_feature)
        attn_feature_4 = self.crossed_attention(sub_feature_4, crop_feature)

        # loss
        loss_1 = loss_fn(attn_feature_1, attn_feature_2)
        loss_2 = loss_fn(attn_feature_3, attn_feature_4)
        loss = loss_1 + loss_2

        return loss.mean()

### PointWOLF 部分

In [None]:
class PointWOLF(object):
    # todo2: delete args
    def __init__(self, w_sigma):
        self.num_anchor = 4
        self.sample_type = 'fps'  # 'random'
        self.sigma = w_sigma

        self.R_range = (-abs(10), abs(10))
        self.S_range = (1., 3)
        self.T_range = (-abs(0.25), abs(0.25))

    def __call__(self, pos):
        """
        input :
            pos([N,3])

        output :
            pos([N,3]) : original pointcloud
            pos_new([N,3]) : Pointcloud augmneted by PointWOLF
        """
        device = pos.device
        pos = pos.cpu().numpy()
        M = self.num_anchor  # (Mx3)
        N, _ = pos.shape  # (N)

        if self.sample_type == 'random':
            idx = np.random.choice(N, M)  # (M)
        elif self.sample_type == 'fps':
            idx = self.fps(pos, M)  # (M)

        pos_anchor = pos[idx]  # (M,3), anchor point
        pos_repeat = np.expand_dims(pos, 0).repeat(M, axis=0)  # (M,N,3)
        pos_normalize = np.zeros_like(pos_repeat, dtype=pos.dtype)  # (M,N,3)
        pos_normalize = pos_repeat - pos_anchor.reshape(M, -1, 3)

        # Local transformation at anchor point
        pos_transformed = self.local_transformaton(pos_normalize)  # (M,N,3)

        # Move to origin space
        pos_transformed = pos_transformed + pos_anchor.reshape(M, -1, 3)  # (M,N,3)

        pos_new = self.kernel_regression(pos, pos_anchor, pos_transformed)
        pos_new = self.normalize(pos_new)
        result = torch.from_numpy(pos_new.astype('float32')).to(device)

        return result