Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

label value in CustomDataset #3629

Closed
hubutui opened this issue Aug 27, 2020 · 11 comments · Fixed by #3221
Closed

label value in CustomDataset #3629

hubutui opened this issue Aug 27, 2020 · 11 comments · Fixed by #3221
Assignees

Comments

@hubutui
Copy link
Contributor

hubutui commented Aug 27, 2020

Describe the bug
Hi, I'm new in object detection. I have some experience in classification and segmentation. We usually regard label value 0 as background. However, it seems mmdetection regards label value 0 as the first foreground? I test with KittiTiny dataset, and here are my scripts.

Reproduction

  1. convert KittiTiny to mmdet's middle format
    script to convert KittiTiny to mmdet's middle format
#!/usr/bin/env python

import os
import os.path as osp
import mmcv
import numpy as np
import pickle

if __name__ == "__main__":
    CLASSES = ('Car', 'Pedestrian', 'Cyclist')
    cat2label = {k: i for i, k in enumerate(CLASSES)}
    image_list = mmcv.list_from_file("data/kitti_tiny/train.txt")
    data_infos = []
    for image_id in image_list:
            filename = osp.join("data/kitti_tiny/training/image_2", image_id+".jpeg")
            image = mmcv.imread(filename)
            height, width = image.shape[:2]

            data_info = dict(filename=f'{image_id}.jpeg', width=width, height=height)

            # load annotations
            lines = mmcv.list_from_file(osp.join("data/kitti_tiny/training/label_2", f'{image_id}.txt'))

            content = [line.strip().split(' ') for line in lines]
            bbox_names = [x[0] for x in content]
            bboxes = [[float(info) for info in x[4:8]] for x in content]

            gt_bboxes = []
            gt_labels = []
            gt_bboxes_ignore = []
            gt_labels_ignore = []

            # filter 'DontCare'
            for bbox_name, bbox in zip(bbox_names, bboxes):
                if bbox_name in cat2label:
                    gt_labels.append(cat2label[bbox_name])
                    gt_bboxes.append(bbox)
                else:
                    gt_labels_ignore.append(-1)
                    gt_bboxes_ignore.append(bbox)

            data_anno = dict(
                bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
                labels=np.array(gt_labels, dtype=np.long),
                bboxes_ignore=np.array(gt_bboxes_ignore,
                                       dtype=np.float32).reshape(-1, 4),
                labels_ignore=np.array(gt_labels_ignore, dtype=np.long))

            data_info.update(ann=data_anno)
            data_infos.append(data_info)
    with open("kitti.pkl", 'wb') as f:
        pickle.dump(data_infos, f)
  1. config for kitti_detection.py
# modify from configs/_base_/datasets/coco_detection.py
# we use the same dataset pipeline as coco_detection
# for dataset_type is CustomDataset by default, and ClassBalancedDataset for trainset
# evaluation metric is mAP for CustomDataset
dataset_type = 'CustomDataset'
data_root = 'data/kitti/'
classes = data_root + 'classes.txt'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    classes=classes,
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        classes=classes,
        type='ClassBalancedDataset',
        oversample_thr=1e-3,
        dataset=dict(
            classes=classes,
            type=dataset_type,
            ann_file=data_root + 'train.pickle',
            img_prefix=data_root + 'train',
            pipeline=train_pipeline
        )
    ),
    val=dict(
        classes=classes,
        type=dataset_type,
        ann_file=data_root + 'train.pickle',
        img_prefix=data_root + 'train',
        pipeline=test_pipeline
    ),
    test=dict(
        classes=classes,
        type=dataset_type,
        ann_file=data_root + 'train.pickle',
        img_prefix=data_root + 'train',
        pipeline=test_pipeline
    )
)
evaluation = dict(interval=1, metric='mAP')

and classes.txt

Car
Pedestrian
Cyclist
  1. the faster rcnn config faster_rcnn_r50_fpn_1x_kitti.py
_base_ = [
    '../_base_/models/faster_rcnn_r50_fpn.py',
    'kitti_detection.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
    roi_head=dict(
        bbox_head=dict(
            num_classes=3
        )
    )
)
  1. I add these lines to tool/train.py after dataset are build:
x = []
    for item in datasets[0].dataset.data_infos:
        for _ in item['ann']['labels']:
            x.append(_)
    x = set(x)
    print(x)

Environment

  1. Please run python mmdet/utils/collect_env.py to collect necessary environment information and paste it here.
sys.platform: linux
Python: 3.7.7 (default, May  7 2020, 21:25:33) [GCC 7.3.0]
CUDA available: True
CUDA_HOME: /cm/shared/apps/cuda10.1/toolkit/10.1.243
NVCC: Cuda compilation tools, release 10.1, V10.1.243
GPU 0: GeForce RTX 2080 Ti
GCC: gcc (GCC) 4.8.5 20150623 (Red Hat 4.8.5-39)
PyTorch: 1.6.0
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.1 Product Build 20200208 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.5.0 (Git Hash e2ac1fac44c5078ca927cb9b90e1b3066a0b2ed0)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  - CuDNN 7.6.3
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF, 

TorchVision: 0.7.0
OpenCV: 4.4.0
MMCV: 1.1.1
MMDetection: 2.3.0+unknown
MMDetection Compiler: GCC 7.3
MMDetection CUDA Compiler: 10.1
  1. You may add addition that may be helpful for locating the problem, such as
    • conda install -c pytorch pytorch cudatoolkit=10.1.243 torchvision

Outputs

{0, 1, 2}

So the label values are 0, 1, 2 for 3 classes, 0 is not background.

@hubutui
Copy link
Contributor Author

hubutui commented Aug 27, 2020

I check Codebase Conventions,

In MMDetection 2.0, label “K” means background, and labels [0, K-1] correspond to the K = num_categories object categories.

So, for CustomDataset, we need to regenerate the middle format and make sure label values [0, K-1], am I right?

@RyanXLi
Copy link
Contributor

RyanXLi commented Aug 31, 2020

Yes. As you have quoted, in MMDetection 2.0 the background class is put at the end and class 0 is the first foreground class.

@hellock hellock assigned ZwwWayne and unassigned RyanXLi Sep 12, 2020
@emergencyd
Copy link

emergencyd commented Sep 19, 2020

According to https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/rpn_head.py Line 22, background_label=0.
I'm quite confused now. When preparing my annotations, what value should I assign for my foreground items? (assuming that I only have 1 foreground)
@hellock

@hubutui
Copy link
Contributor Author

hubutui commented Sep 19, 2020 via email

@emergencyd
Copy link

For CustomDataset, you label value is 0.

On Sat, Sep 19, 2020, 22:09 Shujian Deng @.***> wrote: According to https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/rpn_head.py Line 22, background_label=0. I'm quite confused now. When preparing my annotations, what value should I assign for my foreground items? (assuming that I only have 1 foreground) — You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub <#3629 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAWP34LEBFXPPWPH4ZOCN73SGS3R5ANCNFSM4QMPEFGA .

did u notice the code? it states that background_label=0. So I'm not sure if there is something wrong.

@hubutui
Copy link
Contributor Author

hubutui commented Sep 21, 2020

@emergencyd Check https://mmdetection.readthedocs.io/en/latest/compatibility.html#codebase-conventions, for CustomDataset, label value [0, K-1], where 0 is for the first foreground class.

@emergencyd
Copy link

@emergencyd Check https://mmdetection.readthedocs.io/en/latest/compatibility.html#codebase-conventions, for CustomDataset, label value [0, K-1], where 0 is for the first foreground class.

I know it is stated that we should use K for the background. But the code shows that 0 is for the background.

@hubutui
Copy link
Contributor Author

hubutui commented Sep 21, 2020

@emergencyd you mean this line?

if self.use_sigmoid_cls:

But, according to
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),

use_sigmoid=true.

@emergencyd
Copy link

emergencyd commented Sep 21, 2020

No. As I mentioned before, please check

super(RPNHead, self).__init__(

1, in_channels, background_label=0, **kwargs)

and

assert (self.background_label == 0

or self.background_label == num_classes)

@hubutui

@hubutui
Copy link
Contributor Author

hubutui commented Sep 21, 2020

@RyanXLi @ZwwWayne Could you answer it? Maybe update the document to clarify.

@ZwwWayne
Copy link
Collaborator

The background_label=0 is only for RPN to use binary cross entropy loss. In other places, the background label is N. When generating your label, you do not need to take care about the label in RPN, just prepare the foreground class label into [0, N-1] will be fine. This PR #3221 will unify the notion, we will merge this PR in the near future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants