In [2]:
import os
import torch
from torchsummary import summary

from det3d import torchie
from det3d.models import build_detector
from det3d.torchie.parallel import MegDataParallel
from det3d.torchie.trainer import load_checkpoint

from torch.utils.data import DataLoader
from det3d.datasets import build_dataset
from det3d.torchie.parallel import collate_kitti
from det3d.torchie.trainer.trainer import example_to_device

In [3]:
cfg = torchie.Config.fromfile("../examples/second/configs/config.py")

In [4]:
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
checkpoint_path = os.path.join("../exp_se_ssd_v1_8", "se-ssd-model.pth")
checkpoint = load_checkpoint(model, checkpoint_path, map_location="cpu")

In [None]:
model = MegDataParallel(model, device_ids=[0])

In [26]:
# data part
dataset = build_dataset(cfg.data.val)
batch_size = cfg.data.samples_per_gpu
num_workers = cfg.data.workers_per_gpu
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=None, num_workers=num_workers, collate_fn=collate_kitti, shuffle=False,)

In [27]:
# data part
def get_dataset_ids(mode='val'):
    assert mode in ['test', 'val', 'trainval', 'val']
    id_file_path = "../det3d/datasets/ImageSets/{}.txt".format(mode)
    with open(id_file_path, 'r') as f:
        ids = f.readlines()
    ids = list(map(int, ids))
    return ids

In [40]:
# data part
kitti_dataset = data_loader.dataset         # det3d.datasets.kitti.kitti.KittiDataset
samples = []
valid_ids = get_dataset_ids('test')
for id in [6]:
    index = valid_ids.index(id)
    samples.append(kitti_dataset[index])
batch_samples = collate_kitti(samples)
example = example_to_device(batch_samples, device=torch.device('cuda'))

In [41]:
print([k for k, _ in example.items()])
annos = example['annos']
meta = example['metadata']

['metadata', 'points', 'voxels', 'shape', 'num_points', 'num_voxels', 'coordinates', 'anchors', 'calib', 'annos']


In [42]:
with torch.no_grad():
    # outputs: predicted results in lidar coord.
    outputs = model(example, return_loss=False, rescale=True)

In [43]:
meta, annos, outputs

([{'image_prefix': PosixPath('/home/liyue/workspace/datasets/KITTI'),
   'num_point_features': 4,
   'image_idx': 15,
   'image_shape': array([ 374, 1238], dtype=int32),
   'token': '15'}],
 array([{'boxes': array([[ 4.34868808,  2.78305102, -0.95996055,  1.66999996,  4.13999987,
          1.57000005,  1.72000003],
        [ 7.87614118, -4.70337004, -0.5614374 ,  1.05999994,  0.73000002,
          1.80999994,  1.25      ],
        [24.41451656, -2.33182605, -0.70579305,  0.83999997,  0.86000001,
          1.73000002, -1.48000002],
        [24.49862131, -3.17148418, -0.6651138 ,  0.89999998,  0.94999999,
          1.80999994, -1.46000004],
        [23.5540589 ,  1.91323242, -0.58234927,  0.92000002,  1.00999999,
          1.77999997, -1.53999996]]), 'names': array(['Car', 'Pedestrian', 'Pedestrian', 'Pedestrian', 'Pedestrian'],
       dtype='<U10')}], dtype=object),
 [{'box3d_lidar': tensor([[ 4.4854,  2.7864, -0.9088,  1.6959,  4.0749,  1.5271,  1.7115],
           [27.6373,  3.5021, -

In [44]:
results_dict = {}
for output in outputs:
    token = output["metadata"]["token"]
    for k, v in output.items():
        if k not in ["metadata", ]:
            output[k] = v.to(torch.device("cpu"))
    results_dict.update({token: output, })

In [45]:
results_dict

{'15': {'box3d_lidar': tensor([[ 4.4854,  2.7864, -0.9088,  1.6959,  4.0749,  1.5271,  1.7115],
          [27.6373,  3.5021, -0.6575,  1.5260,  3.1540,  1.4776,  4.7586],
          [34.0037, 26.0929, -1.3481,  1.7285,  3.9487,  1.5847, -0.4384],
          [51.1645, -1.3558, -0.7618,  1.6977,  3.8721,  1.7024,  2.5026],
          [47.5214, 35.7089, -2.2138,  1.6744,  3.5265,  1.5699,  0.6065]]),
  'scores': tensor([0.4420, 0.1753, 0.1482, 0.0935, 0.0630]),
  'label_preds': tensor([0, 0, 0, 0, 0]),
  'metadata': {'image_prefix': PosixPath('/home/liyue/workspace/datasets/KITTI'),
   'num_point_features': 4,
   'image_idx': 15,
   'image_shape': array([ 374, 1238], dtype=int32),
   'token': '15'}}}