In [1]:
import numpy as np
import torch
import torch_scatter
from torch_scatter import scatter_sum
from data import compile_data
from matplotlib import pyplot as plt
from pointpillar import PointPillar
%matplotlib inline

In [2]:
src = torch.tensor([[[2, 1], [0, 1], [1, 1], [1, 4], [2, 3]], [[0, 0], [0, 2], [1, 2], [1, 3], [2, 4]]]) # (B,5, 2)     (B, N, C)
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) # (B, 2, 5)  #(B, 5)

In [3]:
torch_scatter.scatter_sum(src, index, dim=1, dim_size=8)

tensor([[[0, 0],
         [0, 0],
         [1, 4],
         [2, 3],
         [3, 2],
         [0, 1],
         [0, 0],
         [0, 0]],

        [[0, 2],
         [2, 4],
         [2, 5],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0]]])

In [4]:
version = 'mini'
dataroot='/mnt/datasets/nuScenes'

xbound=[-30.0, 30.0, 0.15]
ybound=[-15.0, 15.0, 0.15]
zbound=[-10.0, 10.0, 20.0]
dbound=[4.0, 45.0, 1.0]

H=900
W=1600
resize_lim=(0.193, 0.225)
final_dim=(128, 352)
bot_pct_lim=(0.0, 0.22)
rot_lim=(-5.4, 5.4)
rand_flip=True
ncams=5
max_grad_norm=5.0

bsz=4
nworkers=10

grid_conf = {
    'xbound': xbound,
    'ybound': ybound,
    'zbound': zbound,
    'dbound': dbound,
}
data_aug_conf = {
    'resize_lim': resize_lim,
    'final_dim': final_dim,
    'rot_lim': rot_lim,
    'H': H, 'W': W,
    'rand_flip': rand_flip,
    'bot_pct_lim': bot_pct_lim,
    'cams': ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT',
             'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT'],
    'Ncams': ncams,
    'preprocess': False,
    'line_width': 5,

}
parser_name = 'segmentationdata'

[trainloader, valloader], [train_sampler, val_sampler] = compile_data(version, dataroot, data_aug_conf=data_aug_conf,
                                                              grid_conf=grid_conf, bsz=bsz, nworkers=nworkers,
                                                              parser_name=parser_name, distributed=False)


In [5]:
for d in trainloader:
    break

In [6]:
points_xyz, points_mask = d[0], d[1]

In [7]:
pp = PointPillar(3, xbound, ybound, zbound)

In [8]:
points_feature = pp(points_xyz, points_mask)

torch.Size([4, 128, 400, 200])
torch.Size([4, 256, 400, 200])


In [9]:
points_feature.shape

torch.Size([4, 3, 400, 200])

In [None]:
def raval_index(coords, dims):
    dims = torch.cat((dims, torch.ones(1, device=dims.device)), dim=0)[1:]
    dims = torch.flip(dims, dims=[0])
    dims = torch.cumprod(dims, dim=0) / dims[0]
    multiplier = torch.flip(dims, dims=[0])
    indices = torch.sum(coords * multiplier, dim=1)
    return indices

def points_to_voxels(
  points_xyz,
  points_mask,
  grid_range_x,
  grid_range_y,
  grid_range_z
):
    batch_size, num_points, _ = points_xyz.shape
    voxel_size_x = grid_range_x[2]
    voxel_size_y = grid_range_y[2]
    voxel_size_z = grid_range_z[2]
    grid_size = np.asarray([
        (grid_range_x[1]-grid_range_x[0]) / voxel_size_x,
        (grid_range_y[1]-grid_range_y[0]) / voxel_size_y,
        (grid_range_z[1]-grid_range_z[0]) / voxel_size_z
    ]).astype('int32')
    voxel_size = np.asarray([voxel_size_x, voxel_size_y, voxel_size_z])
    voxel_size = torch.Tensor(voxel_size).to(points_xyz.device)
    num_voxels = grid_size[0] * grid_size[1] * grid_size[2]
    grid_offset = torch.Tensor([grid_range_x[0], grid_range_y[0], grid_range_z[0]]).to(points_xyz.device)
    shifted_points_xyz = points_xyz - grid_offset
    voxel_xyz = shifted_points_xyz / voxel_size
    voxel_coords = voxel_xyz.int()
    grid_size = torch.from_numpy(grid_size).to(points_xyz.device)
    grid_size = grid_size.int()
    zeros = torch.zeros_like(grid_size)
    voxel_paddings = ((points_mask < 1.0) | 
                      torch.any((voxel_coords >= grid_size) | 
                                (voxel_coords < zeros), dim=-1))
    voxel_indices = raval_index(
      torch.reshape(voxel_coords, [batch_size * num_points, 3]), grid_size)
    voxel_indices = torch.reshape(voxel_indices, [batch_size, num_points])
    voxel_indices = torch.where(voxel_paddings,
                                torch.zeros_like(voxel_indices),
                                voxel_indices)
    voxel_centers = ((0.5 + voxel_coords.float()) * voxel_size + grid_offset)
    voxel_coords = torch.where(torch.unsqueeze(voxel_paddings, dim=-1),
                               torch.zeros_like(voxel_coords),
                               voxel_coords)
    voxel_xyz = torch.where(torch.unsqueeze(voxel_paddings, dim=-1),
                            torch.zeros_like(voxel_xyz),
                            voxel_xyz)
    voxel_paddings = voxel_paddings.float()
    
    voxel_indices = voxel_indices.long()
    points_per_voxel = torch_scatter.scatter_sum(
        torch.ones((batch_size, num_points), dtype=voxel_coords.dtype, device=voxel_coords.device) * (1-voxel_paddings),
        voxel_indices, 
        dim=1,
        dim_size=num_voxels
    )
    
    voxel_point_count = torch.gather(points_per_voxel,
                                     dim=1,
                                     index=voxel_indices)
    
    
    voxel_centroids = torch_scatter.scatter_mean(
        points_xyz,
        voxel_indices, 
        dim=1,
        dim_size=num_voxels)
    point_centroids = torch.gather(voxel_centroids, dim=1, index=torch.unsqueeze(voxel_indices, dim=-1).repeat(1, 1, 3))
    local_points_xyz = points_xyz - point_centroids
    
    
    result = {
        'local_points_xyz': local_points_xyz,
        'shifted_points_xyz': shifted_points_xyz,
        'point_centroids': point_centroids,
        'points_xyz': points_xyz,
        'grid_offset': grid_offset,
        'voxel_coords': voxel_coords,
        'voxel_centers': voxel_centers,
        'voxel_indices': voxel_indices,
        'voxel_paddings': voxel_paddings,
        'voxel_mask': 1 - voxel_paddings,
        'num_voxels': num_voxels,
        'grid_size': grid_size,
        'voxel_xyz': voxel_xyz,
        'voxel_size': voxel_size,
        'voxel_point_count': voxel_point_count,
        'points_per_voxel': points_per_voxel,
        'voxel_centroids': voxel_centroids,
    }
    
                
    return result

In [None]:
result = points_to_voxels(points_xyz[:, :, :3], points_mask, xbound, ybound, zbound)

In [None]:
result['voxel_centroids'].view(4, result['grid_size'][0],  result['grid_size'][1], 3).shape

In [None]:
plt.imshow(result['points_per_voxel'].cpu()[0].numpy().reshape(400, 200) > 0)

In [None]:
result['points_per_voxel'][0].sum()

In [None]:
points_mask[0].sum()

In [None]:
result['voxel_xyz'][0, 20000]

In [None]:
result['voxel_indices'][0, 0]

In [None]:
result['grid_size']

In [None]:
points_mask[0][0: 500]

In [None]:
np.asarray((7, 3, 2)) - a.shape

In [None]:
np.shape(a)

In [None]:
def pad_or_trim_to_np(x, shape, pad_val=0):
    shape = np.asarray(shape)
    pad = shape - np.minimum(np.shape(x), shape)
    zeros = np.zeros_like(pad)
    x = np.pad(x, np.stack([zeros, pad], axis=1), constant_values=pad_val)
    return x[:shape[0], :shape[1]]

In [None]:
pad_or_trim_to_np(a, [7, 3, 2])

In [None]:
def raval_index(coords, dims):
    dims = dims[::-1]
    dims = torch.cumprod(dims, dim=0) / dims[0]
    dims = dims[::-1]
    indices = torch.sum(coords * multiplier, dim=1)
    return indices