Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

How to use a different amount of classes? (tried to look at other issues) #273

Closed
jbitton opened this issue Dec 14, 2018 · 10 comments
Closed

Comments

@jbitton
Copy link

jbitton commented Dec 14, 2018

❓ Questions and Help

Hi there, I've been trying to get the repo to work with a new dataset (DDSM - mammography data), and I believe I'm close, but the final step is to actually use the correct amount of classes. I've modified the dataset to resemble the structure of COCO.

In the DDSM dataset, there are three classes (background, benign, and malignant). In order to try to get it to work, I followed the example in #166 (changed ROI_BOX_HEAD.NUM_CLASSES to 3 and modified the Checkpointer class). However, I'm still getting the following error:

2018-12-14 03:36:35,444 maskrcnn_benchmark.trainer INFO: Start training
start_iter 0
getting item 2491
classes: tensor([3])
self.json_category_id_to_contiguous_id: {0: 1, 1: 2, 2: 3}
/opt/conda/conda-bld/pytorch-nightly_1544606458595/work/aten/src/THCUNN/ClassNLLCriterion.cu:105: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [31,0,0] Assertion `t >= 0 && t < n_classes` failed.
THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch-nightly_1544606458595/work/aten/src/THCUNN/generic/ClassNLLCriterion.cu line=111 error=59 : device-side assert triggered
getting item 2767
classes: tensor([3])
self.json_category_id_to_contiguous_id: {0: 1, 1: 2, 2: 3}
Traceback (most recent call last):
  File "tools/train_net.py", line 169, in <module>
    main()
  File "tools/train_net.py", line 162, in main
    model = train(cfg, args.local_rank, args.distributed)
  File "tools/train_net.py", line 71, in train
    arguments,
  File "/scratch/jtb470/fb-mrcnn/maskrcnn-benchmark/maskrcnn_benchmark/engine/trainer.py", line 82, in do_train
    loss_dict = model(images, targets)
  File "/home/jtb470/.conda/envs/cv-fb-mrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/scratch/jtb470/fb-mrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py", line 52, in forward
    x, result, detector_losses = self.roi_heads(features, proposals, targets)
  File "/home/jtb470/.conda/envs/cv-fb-mrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/scratch/jtb470/fb-mrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/roi_heads.py", line 23, in forward
    x, detections, loss_box = self.box(features, proposals, targets)
  File "/home/jtb470/.conda/envs/cv-fb-mrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/scratch/jtb470/fb-mrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py", line 55, in forward
    [class_logits], [box_regression]
  File "/scratch/jtb470/fb-mrcnn/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py", line 139, in __call__
    classification_loss = F.cross_entropy(class_logits, labels)
  File "/home/jtb470/.conda/envs/cv-fb-mrcnn/lib/python3.7/site-packages/torch/nn/functional.py", line 1970, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/home/jtb470/.conda/envs/cv-fb-mrcnn/lib/python3.7/site-packages/torch/nn/functional.py", line 1790, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch-nightly_1544606458595/work/aten/src/THCUNN/generic/ClassNLLCriterion.cu:111

I've tried looking at #15 and other issues and quite frankly I'm still lost as to what's the right procedure for having a different amount of classes. What am I missing? What else do I need to do?

If it's any help, this is my config file:

MODEL:
  META_ARCHITECTURE: "GeneralizedRCNN"
  WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
  BACKBONE:
    CONV_BODY: "R-50-FPN"
    OUT_CHANNELS: 256
  RPN:
    USE_FPN: True
    ANCHOR_STRIDE: (4, 8, 16, 32, 64)
    PRE_NMS_TOP_N_TRAIN: 2000
    PRE_NMS_TOP_N_TEST: 1000
    POST_NMS_TOP_N_TEST: 1000
    FPN_POST_NMS_TOP_N_TEST: 1000
  ROI_HEADS:
    USE_FPN: True
  ROI_BOX_HEAD:
    POOLER_RESOLUTION: 7
    POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
    POOLER_SAMPLING_RATIO: 2
    FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
    PREDICTOR: "FPNPredictor"
    NUM_CLASSES: 3
  ROI_MASK_HEAD:
    POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
    FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor"
    PREDICTOR: "MaskRCNNC4Predictor"
    POOLER_RESOLUTION: 14
    POOLER_SAMPLING_RATIO: 2
    RESOLUTION: 28
    SHARE_BOX_FEATURE_EXTRACTOR: False
  MASK_ON: True
DATASETS:
  TRAIN: ("ddsm_train",)
  TEST: ("ddsm_val",)
DATALOADER:
  NUM_WORKERS: 0
  SIZE_DIVISIBILITY: 32
SOLVER:
  BASE_LR: 0.0025
  WEIGHT_DECAY: 0.0001
  STEPS: (60000, 80000)
  MAX_ITER: 90000
  IMS_PER_BATCH: 2
TEST:
  IMS_PER_BATCH: 2

Thank you so much in advance.

@fmassa
Copy link
Contributor

fmassa commented Dec 14, 2018

Hi,

Are you training your model from one of COCO pre-trained models or just from ImageNet?

The error you are having usually happens because your dataset is outputting a larger index than what you specified in NUM_CLASSES.

From a quick look at your class mapping, it looks like you assign index 2 to number 3, is that right?

{0: 1, 1: 2, 2: 3}

This might mean that your datasets are returning labels [0, 1, 3], which would explain the crash.

Could you check that and report back?

@jbitton
Copy link
Author

jbitton commented Dec 14, 2018

Hi @fmassa, thanks for the response. Perhaps I misunderstood the self.json_category_id_to_contiguous_id variable? This is how I have it set in my Dataset class:

self.json_category_id_to_contiguous_id = {
     v: i + 1 for i, v in enumerate([0] + self.ddsm.getCatIds())
}

The difference from COCODataset is that I appended [0]. Could this be the issue?

Also: I double-checked my preprocessing code: I only ever return the category ids 0, 1, or 2

@jbitton
Copy link
Author

jbitton commented Dec 14, 2018

Yup, that seems to be the issue. Once I removed the [0], the error went away. Dumb mistake on my part, apologies!

@jbitton jbitton closed this as completed Dec 14, 2018
@fmassa
Copy link
Contributor

fmassa commented Dec 14, 2018

Cool. great that you managed to fix the issue!

@adrifloresm
Copy link

adrifloresm commented Jan 4, 2019

@fmassa I am running into the same issue "RuntimeError: copy_if failed to synchronize: device-side assert triggered".

I have my dataset with 4 classes in COCO format and I have edited the Checkpointer class as in #166 as well I have set the ROI_BOX_HEAD.NUM_CLASSES to 4.

The output of my maskrcnn_benchmark/data/datasets/coco.py for:
print(self.coco.getCatIds()) is [0, 1, 2, 3]
print(self.json_category_id_to_contiguous_id.items()) is dict_items([(0, 1), (1, 2), (2, 3), (3, 4)])

What am I missing? Is it not needed to set the ROI_BOX_HEAD.NUM_CLASSES to 4? when I don't set that training works.

Thanks in advanced!

2019-01-04 18:15:47,495 maskrcnn_benchmark.trainer INFO: Start training

RuntimeErrorTraceback (most recent call last)
<ipython-input-8-ba0e1f55fb0d> in <module>
      7         device,
      8         checkpoint_period,
----> 9         arguments,
     10     )

/maskrcnn-benchmark/maskrcnn_benchmark/engine/trainer.py in do_train(model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments)
     64         targets = [target.to(device) for target in targets]
     65 
---> 66         loss_dict = model(images, targets)
     67 
     68         losses = sum(loss for loss in loss_dict.values())

/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

/maskrcnn-benchmark/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py in forward(self, images, targets)
     50         proposals, proposal_losses = self.rpn(images, features, targets)
     51         if self.roi_heads:
---> 52             x, result, detector_losses = self.roi_heads(features, proposals, targets)
     53         else:
     54             # RPN-only models don't have roi_heads

/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/roi_heads.py in forward(self, features, proposals, targets)
     21         losses = {}
     22         # TODO rename x to roi_box_features, if it doesn't increase memory consumption
---> 23         x, detections, loss_box = self.box(features, proposals, targets)
     24         losses.update(loss_box)
     25         if self.cfg.MODEL.MASK_ON:

/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py in forward(self, features, proposals, targets)
     53 
     54         loss_classifier, loss_box_reg = self.loss_evaluator(
---> 55             [class_logits], [box_regression]
     56         )
     57         return (

/maskrcnn-benchmark/maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py in __call__(self, class_logits, box_regression)
    142         # the corresponding ground truth labels, to be used with
    143         # advanced indexing
--> 144         sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
    145         labels_pos = labels[sampled_pos_inds_subset]
    146         map_inds = 4 * labels_pos[:, None] + torch.tensor([0, 1, 2, 3], device=device)

RuntimeError: copy_if failed to synchronize: device-side assert triggered

@fmassa
Copy link
Contributor

fmassa commented Jan 7, 2019

@adrifloresm looks like your indices goes up to 4 for the labels? I believe they should be from 0-3 if you have 4 classes (counting background).

@adrifloresm
Copy link

@fmassa thank you for your response. Indeed my issue was that I did not know I had to count the background class for the config setting, so "ROI_BOX_HEAD.NUM_CLASSES" had to be 5.
Issue #297 helped me realize that!

I also had the mistake of not deleting the previous checkpoint (deleting the output folder after testing with 81 classes), so it was loading that, instead of creating a new one.

Thanks for the help!

@BobZhangHT
Copy link

@fmassa Hi!I have a question about how to label the background annotation (or negative sample) when creating the COCO format data set. For example, suppose that I have a medical image without any mask, so it should belong to the background during the training, i.e. its 'segmentation' is [ ] (empty), 'area' is 0, and 'bbox' is [0,0,0,0]. When labeling the categories for the annotations of all images, I assign the postive samples to class 1 (i.e., the category id is 1), the negative samples to class 0, but during the training I need to modify the ROI_BOX_HEAD.NUM_CLASSES into 3 rather than 2 because the model automatically add the another 'background' class. I wonder if there is a way to avoid such a conflict. Thank you very much!

@fmassa
Copy link
Contributor

fmassa commented Jan 21, 2019

@BobZhangHT does this issue #169 addresses your question? if you don't have any label in an image, and if the patch from #169 works, then just just need to modify the COCODataset to support returning no classes at all.

@BobZhangHT
Copy link

Sincerely thanks for your suggestion! : )

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants