Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot plot batch with ObjectDetectionVisualizer #1923

Closed
lcoandrade opened this issue Sep 19, 2023 · 4 comments
Closed

Cannot plot batch with ObjectDetectionVisualizer #1923

lcoandrade opened this issue Sep 19, 2023 · 4 comments

Comments

@lcoandrade
Copy link

lcoandrade commented Sep 19, 2023

🐛 Bug

I'm using a Kaggle dataset to make an object detection task. I've converted it from YOLO to COCO to be able to use ObjectDetectionImageDataset. When trying to plot the batch using ObjectDetectionVisualizer, I get an error.

My ObjectDetectionImageDataset seems ok, as we can see below:

ds = ObjectDetectionImageDataset(
    img_dir=IMAGES,
    annotation_uri=os.path.join(OUTPUT_DIR, 'annotations.json'),
)

The result of len(ds) returns:

19820

To check the dataset, I'm using this:

x, y = ds[500]
print(x.shape)
print(y.boxes.shape)
print(y)

Which gives me this:

torch.Size([3, 640, 640])
torch.Size([326, 4])
{'boxes': tensor([[  53.,    0.,  225.,    6.],
        [   0.,    0.,  238.,   15.],
        [  47.,  130.,  108.,  265.],
        ...,
        [ 301.,  150.,  607.,  310.],
        [ 592.,  164., 1199.,  336.],
        [ 122.,  406.,  285.,  832.]]),
 'class_ids': tensor([53, 53,  5, 53, 56,  5,  5,  5, 53, 53, 53,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5, 13,  5, 48, 48, 48,  9, 53, 53, 53,
         5,  5,  5, 53, 11,  5,  5,  5,  5,  5, 48,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5, 48,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5, 48, 48,  5,
         5, 11])}

To Reproduce

To reproduce, I need to try to run this (as showed in the tutorial)

viz = ObjectDetectionVisualizer(
    class_names=CLASS_NAMES,
    #class_colors=['red', 'green'],
)
viz.scale = 8
viz.plot_batch(x.unsqueeze(0), [y], show=True)

Which gives me:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[12], line 6
      1 viz = ObjectDetectionVisualizer(
      2     class_names=CLASS_NAMES,
      3     #class_colors=['red', 'green'],
      4 )
      5 viz.scale = 8
----> 6 viz.plot_batch(x.unsqueeze(0), [y], show=True)

File /opt/conda/lib/python3.10/site-packages/rastervision/pytorch_learner/dataset/visualizer/visualizer.py:111, in Visualizer.plot_batch(self, x, y, output_path, z, batch_limit, show)
    108     fig, axs = plt.subplots(**params['fig_args'],
    109                             **params['subplot_args'])
    110     plot_xyz_args = params['plot_xyz_args']
--> 111     self._plot_batch(fig, axs, plot_xyz_args, x, y=y, z=z)
    112 elif x.ndim == 5:
    113     # If a temporal dimension is present, we divide the figure into
    114     # multiple subfigures--one for each batch. Then, in each subfigure,
   (...)
    117     # of only displaying subplot titles once per batch (above the first
    118     # row in each batch).
    119     batch_sz, T, *_ = x.shape

File /opt/conda/lib/python3.10/site-packages/rastervision/pytorch_learner/dataset/visualizer/visualizer.py:171, in Visualizer._plot_batch(self, fig, axs, plot_xyz_args, x, y, z)
    169 for i, row_axs in enumerate(axs):
    170     _z = None if z is None else z[i]
--> 171     self.plot_xyz(row_axs, x[i], y[i], z=_z, **plot_xyz_args[i])

File /opt/conda/lib/python3.10/site-packages/rastervision/pytorch_learner/dataset/visualizer/object_detection_visualizer.py:31, in ObjectDetectionVisualizer.plot_xyz(self, axs, x, y, z, plot_title)
     28 class_colors = self.class_colors
     30 imgs = channel_groups_to_imgs(x, channel_groups)
---> 31 imgs = [draw_boxes(img, y, class_names, class_colors) for img in imgs]
     32 plot_channel_groups(axs, imgs, channel_groups, plot_title=plot_title)

File /opt/conda/lib/python3.10/site-packages/rastervision/pytorch_learner/dataset/visualizer/object_detection_visualizer.py:31, in <listcomp>(.0)
     28 class_colors = self.class_colors
     30 imgs = channel_groups_to_imgs(x, channel_groups)
---> 31 imgs = [draw_boxes(img, y, class_names, class_colors) for img in imgs]
     32 plot_channel_groups(axs, imgs, channel_groups, plot_title=plot_title)

File /opt/conda/lib/python3.10/site-packages/rastervision/pytorch_learner/object_detection_utils.py:280, in draw_boxes(x, y, class_names, class_colors)
    277 scores: Optional[torch.Tensor] = y.get_field('scores')
    279 if len(boxes) > 0:
--> 280     box_annotations: List[str] = np.array(class_names)[class_ids].tolist()
    281     if scores is not None:
    282         box_annotations = [
    283             f'{ann} | {score:.2f}'
    284             for ann, score in zip(box_annotations, scores)
    285         ]

IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed

Expected behavior

Plot the batch using ObjectDetectionVisualizer.

Environment

Running Raster Vision directly in Windows is not supported, and we recommend that you run it from within a Docker container.

  • How you installed and are running Raster Vision (pip install on local vs. inside Docker image): pip install
  • Raster Vision version or commit: 0.21.1
  • OS (e.g., Linux): Kaggle
  • Python version:
  • CUDA/cuDNN version if running on GPU:
  • Any other relevant information:

Additional context

@AdeelH
Copy link
Collaborator

AdeelH commented Sep 19, 2023

Is your CLASS_NAMES defined? What about viz.class_names? I can only reproduce that numpy error if I do

np.array(None)[[0]]

@lcoandrade
Copy link
Author

lcoandrade commented Sep 19, 2023

My class names is this one:

print(CLASS_NAMES)

This also works correctly:

print(viz.class_names)

Which gives me:

dict_values(['Fixed-wing Aircraft', 'Small Aircraft', 'Passenger/Cargo Plane', 'Helicopter', 'Passenger Vehicle', 'Small Car', 'Bus', 'Pickup Truck', 'Utility Truck', 'Truck', 'Cargo Truck', 'Truck Tractor w/ Box Trailer', 'Truck Tractor', 'Trailer', 'Truck Tractor w/ Flatbed Trailer', 'Truck Tractor w/ Liquid Tank', 'Crane Truck', 'Railway Vehicle', 'Passenger Car', 'Cargo/Container Car', 'Flat Car', 'Tank car', 'Locomotive', 'Maritime Vessel', 'Motorboat', 'Sailboat', 'Tugboat', 'Barge', 'Fishing Vessel', 'Ferry', 'Yacht', 'Container Ship', 'Oil Tanker', 'Engineering Vehicle', 'Tower crane', 'Container Crane', 'Reach Stacker', 'Straddle Carrier', 'Mobile Crane', 'Dump Truck', 'Haul Truck', 'Scraper/Tractor', 'Front loader/Bulldozer', 'Excavator', 'Cement Mixer', 'Ground Grader', 'Hut/Tent', 'Shed', 'Building', 'Aircraft Hangar', 'Damaged Building', 'Facility', 'Construction Site', 'Vehicle Lot', 'Helipad', 'Storage Tank', 'Shipping container lot', 'Shipping Container', 'Pylon', 'Tower'])

@AdeelH
Copy link
Collaborator

AdeelH commented Sep 19, 2023

Make it a list instead of dict_values.

@lcoandrade
Copy link
Author

Thank you very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants