In [None]:
%matplotlib inline
from mmdet3d.datasets import build_dataset
from tools.misc.browse_dataset import build_data_cfg
from mmdet3d.models import apply_3d_transformation
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, FixedLocator
import copy
import torch
import numpy as np
from mmcv import Config, DictAction
from mmdet3d.models import build_model
from mmdet3d.ops.voxel.voxelize import voxelization
from mmdet3d.ops import DynamicScatter
from mmdet3d.ops import (
    flat2window_v2,
    window2flat_v2,
    get_inner_win_inds,
    make_continuous_inds,
    get_flat2win_inds_v2,
    get_window_coors,
)
from mmdet3d.models.detectors.shared_fusion_net import SharedFusionNet
import time
import pickle
import random

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

with open("/shared-sst/forward_train_input_batch_size_2.pkl", "rb") as f:
    forward_train_input = pickle.load(f)

points, img, img_metas, gt_bboxes_3d, gt_labels_3d = forward_train_input.values()
img = img.to(device).float()
points = [p.float() for p in points]

cfg = Config.fromfile("configs/shared_sst/shared_fusion_lidar_detection_debug_config.py")
model = build_model(cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg"))
#model.init_weights()
model.cuda()

In [None]:
def get_patch_coors(unflattened_patches, patch_size):
    device = unflattened_patches.device
    batch_size, height, width = unflattened_patches.shape[0], unflattened_patches.shape[1], unflattened_patches.shape[2]
    patch_coors = torch.zeros((height * width * batch_size, 4), device=device)
    
    # Width indices
    patch_coors[:, 3] = torch.arange(width).repeat(height * batch_size)

    # Height and batch indices
    height_indices = np.repeat(np.arange(height), width)
    for batch_index in range(batch_size):
        patch_coors[batch_index * height * width : (batch_index + 1) * height * width, 0] = batch_index
        patch_coors[batch_index * height * width : (batch_index + 1) * height * width, 2] = torch.from_numpy(height_indices)

    # Scale to image size
    patch_coors[:, 2] = patch_coors[:, 2] * patch_size + patch_size // 2
    patch_coors[:, 3] = patch_coors[:, 3] * patch_size + patch_size // 2
    return patch_coors

In [None]:
# Voxelize point cloud
voxels, coors = model.voxelize(points)  # [Batch, Z, Y, X]
batch_size = coors[-1, 0].item() + 1
voxel_features, voxel_feature_coors = model.voxel_encoder(voxels, coors)
voxel_mean, _ = model.voxel_encoder.cluster_scatter(voxels, coors)

# Patchify wide image
img_wide = torch.cat([img[:, i] for i in model.middle_encoder.camera_order], dim=3)
patches = model.patch_embedder(img_wide)

# Convert patches to same format as voxels
unflattened_patches = patches[0].unflatten(1, patches[1])
patch_features = patches[0].flatten(0, 1)
patch_coors = get_patch_coors(unflattened_patches, model.patch_embedder.projection.kernel_size[0])


sst_info = model.middle_encoder(
    voxel_features,
    voxel_feature_coors,
    voxel_mean,
    patch_features,
    patch_coors,
    img_metas,
    batch_size,
)

[batch_canvas] = model.backbone(sst_info)