Skip to content

Commit

Permalink
Bugfix object detection visualize function(#739)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulya-tkch committed Jun 8, 2023
1 parent bf3b56c commit 7d84aa4
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions cleanlab/object_detection/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def visualize(
Optional dictionary mapping one-hot-encoded class labels back to their original class names in the format ``{"integer-label": "original-class-name"}``.
save_path:
Path to save figure at. If a path is provided, the figure is saved instead of displayed.
Path to save figure at. If a path is provided, the figure is saved. To save in a specific image format, add desired file extension to the end of `save_path`. Allowed file extensions are: 'png', 'pdf', 'ps', 'eps', and 'svg'.
figsize:
Optional figure size for plotting the image.
Expand Down Expand Up @@ -115,7 +115,7 @@ def visualize(
ax.imshow(image)
if label is not None:
fig, ax = _draw_boxes(
fig, ax, abbox, alabels, edgecolor="r", linestyle="-.", linewidth=1
fig, ax, abbox, alabels, edgecolor="r", linestyle="-", linewidth=1
)
if prediction is not None:
_, _ = _draw_boxes(fig, ax, pbbox, plabels, edgecolor="b", linestyle="-.", linewidth=1)
Expand All @@ -135,27 +135,29 @@ def visualize(
_, _ = _draw_boxes(
fig, axes[1], pbbox, plabels, edgecolor="b", linestyle="-.", linewidth=1
)

bbox_extra_artists = None
if label or prediction is not None:
legend, plt = _plot_legend(class_names, label, prediction)
if save_path: # save with legend
plt.savefig(
save_path,
format="pdf",
bbox_extra_artists=(legend,),
bbox_inches="tight",
transparent=True,
pad_inches=0.5,
)
elif save_path:
bbox_extra_artists = (legend,)

if save_path:
allowed_image_formats = set(["png", "pdf", "ps", "eps", "svg"])
image_format = (
save_path[-3:]
if len(save_path) > 2
and save_path[-4] == "."
and save_path[-3:] in allowed_image_formats
else None
)

plt.savefig(
save_path,
format="pdf",
format=image_format,
bbox_extra_artists=bbox_extra_artists,
bbox_inches="tight",
transparent=True,
pad_inches=0.5,
)

plt.show()


Expand Down

0 comments on commit 7d84aa4

Please sign in to comment.