Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom mapper for normalizing bounding box #4824

Open
suryasid09 opened this issue Feb 27, 2023 · 1 comment
Open

Custom mapper for normalizing bounding box #4824

suryasid09 opened this issue Feb 27, 2023 · 1 comment

Comments

@suryasid09
Copy link

suryasid09 commented Feb 27, 2023

Instructions To Reproduce the Issue:

I am trying to code custom DatasetMapper as provided in the tutorial of detectron2. This mapper is responsible for normalizing the bounding box coordinate for my custom dataset. I am not sure where exactly detectron2 does normalization, if it does so. If you can point out this as well, it would be nice. Currently, I am writing my piece of code for this.

  1. Full runnable code or full changes you made:
<      dataset_dict = DatasetCatalog.get(cfg.DATASETS.TRAIN[0])
        def mapper(dataset_dict):
            dataset_dict = copy.deepcopy(dataset_dict) 
            image = cv2.imread(dataset_dict["file_name"], cv2.IMREAD_COLOR)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            height, width = image.shape[:2]
            auginput = T.AugInput(image)
            transform = T.Resize((800, 800))(auginput)
            image = torch.from_numpy(auginput.image.transpose(2, 0, 1))
            annotations = []
            for annotation in dataset_dict["annotations"]:
                bbox = annotation["bbox"]
                bbox = [bbox[0] / width, bbox[1] / height, bbox[2] / width, bbox[3] / height]
                bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]
                annotations.append({
                    "bbox": bbox,
                    "bbox_mode": BoxMode.XYXY_ABS,
                    "category_id": annotation["category_id"]
                })
            return {
                "image": image,
                "instances": utils.annotations_to_instances(annotations, image.shape[:2])
            } 
        return build_detection_train_loader(cfg, mapper=mapper)>
  1. Full logs or other relevant observations:
<Traceback (most recent call last):
  File "/home/jovyan/.local/lib/python3.8/site-packages/detectron2/engine/train_loop.py", line 134, in train
    self.run_step()
  File "/home/jovyan/.local/lib/python3.8/site-packages/detectron2/engine/train_loop.py", line 222, in run_step
    data = next(self._data_loader_iter)
  File "/home/jovyan/.local/lib/python3.8/site-packages/detectron2/data/common.py", line 180, in __iter__
    w, h = d["width"], d["height"]
KeyError: 'width'
[02/27 13:52:49 d2.engine.hooks]: Total training time: 0:00:00 (0:00:00 on hooks)
[02/27 13:52:49 d2.utils.events]:  iter: 1    lr: N/A  max_mem: 199M
Traceback (most recent call last):
  File "main.py", line 85, in <module>
    launch(
  File "/home/jovyan/.local/lib/python3.8/site-packages/detectron2/engine/launch.py", line 62, in launch
    main_func(*args)
  File "main.py", line 80, in main
    return trainer.train()
  File "/home/jovyan/thesis_s2577712/DeFRCN_voc_format/thesis-pascal_voc_format/defrcn/engine/defaults.py", line 387, in train
    super().train(self.start_iter, self.max_iter)
  File "/home/jovyan/.local/lib/python3.8/site-packages/detectron2/engine/train_loop.py", line 134, in train
    self.run_step()
  File "/home/jovyan/.local/lib/python3.8/site-packages/detectron2/engine/train_loop.py", line 222, in run_step
    data = next(self._data_loader_iter)
  File "/home/jovyan/.local/lib/python3.8/site-packages/detectron2/data/common.py", line 180, in __iter__
    w, h = d["width"], d["height"]
KeyError: 'width'>

Environment:

----------------------  ------------------------------------------------------------------------
sys.platform            linux
Python                  3.8.10 (default, Nov 14 2022, 12:59:47) [GCC 9.4.0]
numpy                   1.23.1
detectron2              0.3 @/home/jovyan/.local/lib/python3.8/site-packages/detectron2
Compiler                GCC 7.3
CUDA compiler           CUDA 10.1
detectron2 arch flags   3.7, 5.0, 5.2, 6.0, 6.1, 7.0, 7.5
DETECTRON2_ENV_MODULE   <not set>
PyTorch                 1.6.0+cu101 @/home/jovyan/.local/lib/python3.8/site-packages/torch
PyTorch debug build     False
GPU available           True
GPU 0,1                 Tesla T4 (arch=7.5)
CUDA_HOME               /usr/local/cuda
Pillow                  9.4.0
torchvision             0.7.0+cu101 @/home/jovyan/.local/lib/python3.8/site-packages/torchvision
torchvision arch flags  3.5, 5.0, 6.0, 7.0, 7.5
fvcore                  0.1.5.post20221221
cv2                     4.7.0
----------------------  ------------------------------------------------------------------------
PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2019.0.5 Product Build 20190808 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.5.0 (Git Hash e2ac1fac44c5078ca927cb9b90e1b3066a0b2ed0)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75
  - CuDNN 7.6.3
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF, 
@vkrishnamurthy11
Copy link

Usually, detectron2 will take care of rescaling the boxes when it needs to compute the loss and other metrics. It happens in this method here:

def _postprocess(instances, batched_inputs: List[Dict[str, torch.Tensor]], image_sizes):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants