In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from glob import glob
import pickle
import sys
from typing import Dict, List
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from mmcv import Config
from torchpack.utils.config import configs
from mmdet3d.utils import recursive_eval
import random
from mmdet3d.datasets import build_dataloader, build_dataset
import torch
import mmcv
from mmdet.datasets.pipelines import Compose
from mmdet3d.datasets.pipelines import (
    LoadMultiViewImageFromFiles,
    LoadPointsFromFile,
    LoadAnnotations3D,
)
from mmcv.utils import build_from_cfg
from mmdet3d.datasets.builder import OBJECTSAMPLERS
from mmdet3d.core.bbox import box_np_ops
from tqdm import tqdm

In [3]:
%cd ..

/root/mmdet3d


In [4]:
tumtraf_i_dbinfos_path = "data/tumtraf-i-bevfusion/tumtraf_dbinfos_train.pkl"
default_config_path = "notebooks/resources/tumtraf-i_test_config.yaml"

filter_by_min_points = 5

assert os.path.exists(tumtraf_i_dbinfos_path) , "tumtraf-i dbinfos not found"
assert os.path.exists(default_config_path), "config not found"

# load pickle
with open(tumtraf_i_dbinfos_path, "rb") as f:
    tumtraf_i_dbinfos = pickle.load(f)

# load config
configs.load(default_config_path, recursive=True)
cfg = Config(recursive_eval(configs), filename=default_config_path)

In [5]:
for key, values in tumtraf_i_dbinfos.items():
    num_points = []
    distances = []
    difficulty_counts = defaultdict(int)
    filtered_num_sample = 0

    for value in values:
        num_points.append(value["num_points_in_gt"])
        distances.append(value["distance"])
        difficulty_counts[value["difficulty"]] += 1

        if value["num_points_in_gt"] < filter_by_min_points:
            filtered_num_sample += 1

    avg_num_points = np.mean(num_points)
    min_num_points = np.min(num_points)
    max_num_points = np.max(num_points)
    avg_distance  = np.mean(distances)
    min_distance = np.min(distances)
    max_distance = np.max(distances)

    difficulty_counts = sorted(difficulty_counts.items(), key=lambda x: x[0])

    # print("="*60)
    # print(key)
    # print("="*60)
    # print(f"{'samples':<12} - total: {len(values):<8} - {'filtered: ' + str(filtered_num_sample):<15} - {'kept: ' + str(len(values) - filtered_num_sample):<10}")
    # print(f"{'num points':<12} - avg: {avg_num_points:<10.3f} - min: {min_num_points:<10} - max: {max_num_points:<10}")
    # print(f"{'distance':<12} - avg: {avg_distance:<10.3f} - min: {min_distance:<10.3f} - max: {max_distance:<10.3f}")
    # diff_string = " - ".join(f"{x}: {y:<12}" for x, y in dict(difficulty_counts).items())
    # print(f"{'difficulty':<12} - {diff_string}")

In [6]:
# set random seeds
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
dataset = build_dataset(cfg.data.train)

In [8]:


CLASSES = [
    "CAR",
    "TRAILER",
    "TRUCK",
    "VAN",
    "PEDESTRIAN",
    "BUS",
    "MOTORCYCLE",
    "BICYCLE",
    "EMERGENCY_VEHICLE",
]


class CustomObjectPaste:
    def __init__(self, db_sampler, sample_2d=False, stop_epoch=None):
        self.sampler_cfg = db_sampler
        self.sample_2d = sample_2d
        if "type" not in db_sampler.keys():
            db_sampler["type"] = "DataBaseSampler"
        self.db_sampler = build_from_cfg(db_sampler, OBJECTSAMPLERS)
        self.epoch = -1
        self.stop_epoch = stop_epoch

        self.total_prev = 0
        self.total_after = 0

        self.prev_total_class_counts = np.asarray([0] * len(CLASSES))
        self.after_total_class_counts = np.asarray([0] * len(CLASSES))

    def set_epoch(self, epoch):
        self.epoch = epoch

    @staticmethod
    def remove_points_in_boxes(points, boxes):
        """Remove the points in the sampled bounding boxes.
        Args:
            points (:obj:`BasePoints`): Input point cloud array.
            boxes (np.ndarray): Sampled ground truth boxes.
        Returns:
            np.ndarray: Points with those in the boxes removed.
        """
        masks = box_np_ops.points_in_rbbox(points.coord.numpy(), boxes)
        points = points[np.logical_not(masks.any(-1))]
        return points

    def __call__(self, data):
        """Call function to sample ground truth objects to the data.
        Args:
            data (dict): Result dict from loading pipeline.
        Returns:
            dict: Results after object sampling augmentation, \
                'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated \
                in the result dict.
        """
        # if self.stop_epoch is not None and self.epoch >= self.stop_epoch:
        #     return data
        gt_bboxes_3d = data["gt_bboxes_3d"]
        gt_labels_3d = data["gt_labels_3d"]

        before_count = gt_bboxes_3d.tensor.numpy().shape[0]
        before_class_counts = [0] * len(CLASSES)
        for x in gt_labels_3d:
            before_class_counts[x] += 1

        # change to float for blending operation
        points = data["points"]
        if self.sample_2d:
            img = data["img"]
            gt_bboxes_2d = data["gt_bboxes"]
            # Assume for now 3D & 2D bboxes are the same
            sampled_dict = self.db_sampler.sample_all(
                gt_bboxes_3d.tensor.numpy(),
                gt_labels_3d,
                gt_bboxes_2d=gt_bboxes_2d,
                img=img,
            )
        else:
            sampled_dict = self.db_sampler.sample_all(
                gt_bboxes_3d.tensor.numpy(), gt_labels_3d, img=None
            )

        if sampled_dict is not None:
            sampled_gt_bboxes_3d = sampled_dict["gt_bboxes_3d"]
            sampled_points = sampled_dict["points"]
            sampled_gt_labels = sampled_dict["gt_labels_3d"]

            gt_labels_3d = np.concatenate([gt_labels_3d, sampled_gt_labels], axis=0)
            gt_bboxes_3d = gt_bboxes_3d.new_box(
                np.concatenate([gt_bboxes_3d.tensor.numpy(), sampled_gt_bboxes_3d])
            )

            after_count = gt_bboxes_3d.tensor.numpy().shape[0]
            after_class_counts = [0] * len(CLASSES)
            for x in gt_labels_3d:
                after_class_counts[x] += 1

            points = self.remove_points_in_boxes(points, sampled_gt_bboxes_3d)
            # check the points dimension
            points = points.cat([sampled_points, points])

            if self.sample_2d:
                sampled_gt_bboxes_2d = sampled_dict["gt_bboxes_2d"]
                gt_bboxes_2d = np.concatenate([gt_bboxes_2d, sampled_gt_bboxes_2d]).astype(
                    np.float32
                )

                data["gt_bboxes"] = gt_bboxes_2d
                data["img"] = sampled_dict["img"]

        data["gt_bboxes_3d"] = gt_bboxes_3d
        data["gt_labels_3d"] = gt_labels_3d.astype(np.long)
        data["points"] = points

        if before_count != after_count:
            # print(f"b: {before_count:<10} {before_class_counts}")
            # print(f"a: {after_count:<10} {after_class_counts}")
            self.total_prev += before_count
            self.total_after += after_count
            self.prev_total_class_counts += np.array(before_class_counts)
            self.after_total_class_counts += np.array(after_class_counts)

        return data


custom_sample_group = {
    "CAR": 3,
    "TRAILER": 2,
    "TRUCK": 4,
    "VAN": 5,
    "PEDESTRIAN": 7,
    "BUS": 2,
    "MOTORCYCLE": 5,
    "BICYCLE": 5,
    "EMERGENCY_VEHICLE": 3,
}
cfg.data.train.dataset.pipeline[3].db_sampler.sample_groups = custom_sample_group
print(cfg.data.train.dataset.pipeline[3].db_sampler.sample_groups)

pipeline = Compose(
    [
        LoadMultiViewImageFromFiles(),
        LoadPointsFromFile("LIDAR", 5, 4),
        LoadAnnotations3D(),
    ]
)
gtp = CustomObjectPaste(cfg.data.train.dataset.pipeline[3].db_sampler, stop_epoch=99)
total_len = len(dataset.dataset)

for idx in tqdm(range(total_len)):
    input_dict = dataset.dataset.get_data_info(idx)
    dataset.dataset.pre_pipeline(input_dict)
    input_dict = pipeline(input_dict)
    input_dict = gtp(input_dict)

{'CAR': 3, 'TRAILER': 2, 'TRUCK': 4, 'VAN': 5, 'PEDESTRIAN': 7, 'BUS': 2, 'MOTORCYCLE': 5, 'BICYCLE': 5, 'EMERGENCY_VEHICLE': 3}


100%|██████████| 1920/1920 [01:28<00:00, 21.58it/s]


In [55]:

print("inital total", gtp.total_prev, "avg", f"{gtp.total_prev / total_len:.3f}")
print("final total", gtp.total_after, "avg", f"{gtp.total_after / total_len:.3f}")

print()


print(f"{'total':<10}" + "".join(f"{x:<12}" for x in CLASSES))
print("="*110)
print(f"{'intial':<10}" +"".join(f"{x:<12}" for x in gtp.prev_total_class_counts))
print(f"{'augment':<10}" +"".join(f"{'+' + str(x):<12}" for x in gtp.after_total_class_counts - gtp.prev_total_class_counts))
print(f"{'final':<10}" +"".join(f"{x:<12}" for x in gtp.after_total_class_counts))

print()
# print avg
print(f"{'avg/frame':<10}" + "".join(f"{x:<12}" for x in CLASSES))
print("="*110)
print(f"{'intial':<10}" +"".join(f"{x / total_len:<12.3f}" for x in gtp.prev_total_class_counts))
print(f"{'final':<10}" +"".join(f"{x / total_len:<12.3f}" for x in gtp.after_total_class_counts))

inital total 30202 avg 15.730
final total 61378 avg 31.968

total     CAR         TRAILER     TRUCK       VAN         PEDESTRIAN  BUS         MOTORCYCLE  BICYCLE     EMERGENCY_VEHICLE
intial    18213       2454        2128        3508        2013        714         526         460         186         
augment   +120        +901        +3319       +3704       +7555       +2138       +3345       +6366       +3728       
final     18333       3355        5447        7212        9568        2852        3871        6826        3914        

avg/frame CAR         TRAILER     TRUCK       VAN         PEDESTRIAN  BUS         MOTORCYCLE  BICYCLE     EMERGENCY_VEHICLE
intial    9.486       1.278       1.108       1.827       1.048       0.372       0.274       0.240       0.097       
final     9.548       1.747       2.837       3.756       4.983       1.485       2.016       3.555       2.039       
