In [None]:
%matplotlib inline
from mmdet3d.datasets import build_dataset
from tools.misc.browse_dataset import build_data_cfg
from mmdet3d.models import apply_3d_transformation
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, FixedLocator
import copy
import torch
import numpy as np
from mmcv import Config, DictAction
from mmdet3d.models import build_model
from mmdet3d.ops.voxel.voxelize import voxelization
from mmdet3d.ops import DynamicScatter
from mmdet3d.ops import (
    flat2window_v2,
    window2flat_v2,
    get_inner_win_inds,
    make_continuous_inds,
    get_flat2win_inds_v2,
    get_window_coors,
)
from mmdet3d.models.detectors.shared_fusion_net import SharedFusionNet
import time
import pickle
import random

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

with open("/shared-sst/forward_train_input_batch_size_2.pkl", "rb") as f:
    forward_train_input = pickle.load(f)

points, img, img_metas, gt_bboxes_3d, gt_labels_3d = forward_train_input.values()
img = img.to(device).float()
points = [p.float() for p in points]

cfg = Config.fromfile("configs/shared_sst/shared_fusion_lidar_detection_debug_config.py")
model = build_model(cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg"))
#model.init_weights()
model.cuda()

In [None]:
def get_patch_coors(unflattened_patches, patch_size):
    device = unflattened_patches.device
    batch_size, height, width = unflattened_patches.shape[0], unflattened_patches.shape[1], unflattened_patches.shape[2]
    patch_coors = torch.zeros((height * width * batch_size, 4), device=device)
    
    # Width indices
    patch_coors[:, 3] = torch.arange(width).repeat(height * batch_size)

    # Height and batch indices
    height_indices = np.repeat(np.arange(height), width)
    for batch_index in range(batch_size):
        patch_coors[batch_index * height * width : (batch_index + 1) * height * width, 0] = batch_index
        patch_coors[batch_index * height * width : (batch_index + 1) * height * width, 2] = torch.from_numpy(height_indices)

    # Scale to image size
    patch_coors[:, 2] = patch_coors[:, 2] * patch_size + patch_size // 2
    patch_coors[:, 3] = patch_coors[:, 3] * patch_size + patch_size // 2
    return patch_coors

In [None]:
# Voxelize point cloud
voxels, coors = model.voxelize(points)  # [Batch, Z, Y, X]
batch_size = coors[-1, 0].item() + 1
voxel_features, voxel_feature_coors = model.voxel_encoder(voxels, coors)
voxel_mean, _ = model.voxel_encoder.cluster_scatter(voxels, coors)

# Patchify wide image
img_wide = torch.cat([img[:, i] for i in model.middle_encoder.camera_order], dim=3)
patches = model.patch_embedder(img_wide)

# Convert patches to same format as voxels
unflattened_patches = patches[0].unflatten(1, patches[1])
patch_features = patches[0].flatten(0, 1)
patch_coors = get_patch_coors(unflattened_patches, model.patch_embedder.projection.kernel_size[0])


sst_info = model.middle_encoder(
    voxel_features,
    voxel_feature_coors,
    voxel_mean,
    patch_features,
    patch_coors,
    img_metas,
    batch_size,
)

[batch_canvas] = model.backbone(sst_info)

### Plot input

In [None]:
model.middle_encoder.camera_order

In [None]:
# Plot the 6 views
for view_idx in range(6):
    plt.figure(figsize=(20, 20))
    im = plt.imread(img_metas[0]["filename"][view_idx])
    plt.imshow(im)
    plt.savefig(f"view{view_idx+1}.png", bbox_inches='tight')
    plt.show()

In [None]:
# Plot points
plt.figure(figsize=(20, 20))
plt.scatter(points[0][:, 0].cpu(), points[0][:, 1].cpu(), s=0.1)
plt.axis("equal")
plt.title("Birds eye view of 2 aggregated point cloud sweeps")
plt.savefig("points.png")

In [None]:
plt.figure(figsize=(100, 10))
views = [plt.imread(img_metas[0]["filename"][i]) for i in model.middle_encoder.camera_order]
im = np.concatenate(views, axis=1)
plt.imshow(im)
plt.savefig("wide_image.png", bbox_inches='tight')


In [None]:
# Plot points and grid lines, forming a 200 x 200 grid
plt.figure(figsize=(20, 20))
plt.scatter(points[0][:, 0].cpu(), points[0][:, 1].cpu(), s=0.1)
plt.axis("equal")
plt.title("Birds eye view of 2 aggregated point cloud sweeps")
plt.xticks(np.arange(-200/2, 200/2 + 1, 1) - 0.5)
plt.yticks(np.arange(-200/2, 200/2 + 1, 1) - 0.5)
plt.grid(True, which='major')
#plt.grid(which='major', axis='both', linestyle='-', color='k', linewidth=1)
plt.savefig("points_grid.png")


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 20))
ax.scatter(points[0][:, 0].cpu(), points[0][:, 1].cpu(), s=0.1)
ax.set_title("Birds eye view of 2 aggregated point cloud sweeps")
ax.xaxis.set_major_locator(MultipleLocator(0.5 * 16))
ax.xaxis.set_major_formatter(FormatStrFormatter('%d'))
ax.xaxis.set_minor_locator(MultipleLocator(0.5))
ax.yaxis.set_major_locator(MultipleLocator(0.5 * 16))
ax.yaxis.set_minor_locator(MultipleLocator(0.5))
ax.xaxis.grid(True,'minor')
ax.yaxis.grid(True,'minor')
ax.xaxis.grid(True,'major',linewidth=2, color='r')
ax.yaxis.grid(True,'major',linewidth=2, color='r')
ax.axis("equal")
plt.savefig("points_grid_windows.png")

In [None]:
voxel_mean_2d = model.middle_encoder.get_voxel_mean_2d_coords(voxel_feature_coors, voxel_mean, img_metas, 2)

In [None]:
plt.figure(figsize=(100, 10))
views = [plt.imread(img_metas[0]["filename"][i]) for i in model.middle_encoder.camera_order]
im = np.concatenate(views, axis=1)
plt.imshow(im)
plt.scatter(voxel_mean_2d[:, 3].cpu(), voxel_mean_2d[:, 2].cpu(), s=1, color="red")
plt.savefig("wide_image_with_points.png", bbox_inches='tight')

In [None]:
# Create list of colors for each unique value in batch_win_inds_shift0
unique_batch_win_inds = torch.unique(sst_info["batch_win_inds_shift0"])
color_map = plt.get_cmap("gist_rainbow")
plot_colors = [
    color_map(i / len(unique_batch_win_inds))
    for i in range(len(unique_batch_win_inds))
]
random.shuffle(plot_colors)

# Plot voxel coors painted by batch win ind
fig, axs = plt.subplots(batch_size, 1, figsize=(60, batch_size * 10))
voxel_mean_2d = voxel_mean_2d.cpu().int()

for batch_idx in range(batch_size):
    views = [plt.imread(img_metas[batch_idx]["filename"][i]) for i in model.middle_encoder.camera_order]
    im = np.concatenate(views, axis=1)
    axs[batch_idx].imshow(im)


for i in range(voxel_mean_2d.shape[0]):
    batch_index = voxel_mean_2d[i, 0]
    batch_win_ind = sst_info["batch_win_inds_shift0"][i]
    color_index = torch.where(unique_batch_win_inds == batch_win_ind)[0]
    axs[batch_index].plot(
        voxel_mean_2d[i, 3],
        voxel_mean_2d[i, 2],
        color=plot_colors[color_index],
        markersize=2,
    )

for i in range(batch_size):
    axs[i].set_title(f"Batch {i}")
    axs[i].set_xlabel("x")
    axs[i].set_ylabel("y")
    axs[i].xaxis.set_major_locator(MultipleLocator(window_shape[0]))
    #axs[i].xaxis.set_major_locator(FixedLocator([int(window_shape[0]//2 + i * window_shape[0]) for i in range(sparse_shape[0]//window_shape[0])]))
    axs[i].xaxis.set_major_formatter(FormatStrFormatter("%d"))
    axs[i].yaxis.set_major_locator(MultipleLocator(window_shape[1]))
    #axs[i].yaxis.set_major_locator(FixedLocator([int(window_shape[1]//2 + i * window_shape[1]) for i in range(sparse_shape[1]//window_shape[1])]))
    axs[i].xaxis.grid(True, "major", linewidth=2)
    axs[i].yaxis.grid(True, "major", linewidth=2)