Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New annotation mode "ai_mask" that generates shapes as mask #1358

Merged
merged 9 commits into from
Dec 2, 2023
36 changes: 34 additions & 2 deletions labelme/ai/models/segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import onnxruntime
import PIL.Image
import skimage.measure
import skimage

from ...logger import logger

Expand Down Expand Up @@ -71,6 +71,18 @@ def predict_polygon_from_points(self, points, point_labels):
)
return polygon

def predict_mask_from_points(self, points, point_labels):
image_embedding = self._get_image_embedding()
mask = _compute_mask_from_points(
image_size=self._image_size,
decoder_session=self._decoder_session,
image=self._image,
image_embedding=image_embedding,
points=points,
point_labels=point_labels,
)
return mask


def _compute_scale_to_resize_image(image_size, image):
height, width = image.shape[:2]
Expand Down Expand Up @@ -127,7 +139,7 @@ def _get_contour_length(contour):
return np.linalg.norm(contour_end - contour_start, axis=1).sum()


def _compute_polygon_from_points(
def _compute_mask_from_points(
image_size, decoder_session, image, image_embedding, points, point_labels
):
input_point = np.array(points, dtype=np.float32)
Expand Down Expand Up @@ -163,10 +175,30 @@ def _compute_polygon_from_points(
masks, _, _ = decoder_session.run(None, decoder_inputs)
mask = masks[0, 0] # (1, 1, H, W) -> (H, W)
mask = mask > 0.0

MIN_SIZE_RATIO = 0.05
skimage.morphology.remove_small_objects(
mask, min_size=mask.sum() * MIN_SIZE_RATIO, out=mask
)

if 0:
imgviz.io.imsave(
"mask.jpg", imgviz.label2rgb(mask, imgviz.rgb2gray(image))
)
return mask


def _compute_polygon_from_points(
image_size, decoder_session, image, image_embedding, points, point_labels
):
mask = _compute_mask_from_points(
image_size=image_size,
decoder_session=decoder_session,
image=image,
image_embedding=image_embedding,
points=points,
point_labels=point_labels,
)

contours = skimage.measure.find_contours(np.pad(mask, pad_width=1))
contour = max(contours, key=_get_contour_length)
Expand Down
125 changes: 48 additions & 77 deletions labelme/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,28 @@ def __init__(
self.tr("Start drawing ai_polygon. Ctrl+LeftClick ends creation."),
enabled=False,
)
createAiPolygonMode.changed.connect(
lambda: self.canvas.initializeAiModel(
name=self._selectAiModelComboBox.currentText()
)
if self.canvas.createMode == "ai_polygon"
else None
)
createAiMaskMode = action(
self.tr("Create AI-Mask"),
lambda: self.toggleDrawMode(False, createMode="ai_mask"),
None,
"objects",
self.tr("Start drawing ai_mask. Ctrl+LeftClick ends creation."),
enabled=False,
)
createAiMaskMode.changed.connect(
lambda: self.canvas.initializeAiModel(
name=self._selectAiModelComboBox.currentText()
)
if self.canvas.createMode == "ai_mask"
else None
)
editMode = action(
self.tr("Edit Polygons"),
self.setEditMode,
Expand Down Expand Up @@ -620,6 +642,7 @@ def __init__(
createPointMode=createPointMode,
createLineStripMode=createLineStripMode,
createAiPolygonMode=createAiPolygonMode,
createAiMaskMode=createAiMaskMode,
zoom=zoom,
zoomIn=zoomIn,
zoomOut=zoomOut,
Expand Down Expand Up @@ -655,6 +678,7 @@ def __init__(
createPointMode,
createLineStripMode,
createAiPolygonMode,
createAiMaskMode,
editMode,
edit,
duplicate,
Expand All @@ -674,6 +698,7 @@ def __init__(
createPointMode,
createLineStripMode,
createAiPolygonMode,
createAiMaskMode,
editMode,
brightnessContrast,
),
Expand Down Expand Up @@ -762,11 +787,12 @@ def __init__(
)
self._selectAiModelComboBox.addItems([model.name for model in MODELS])
self._selectAiModelComboBox.setCurrentIndex(1)
self._selectAiModelComboBox.setEnabled(False)
self._selectAiModelComboBox.currentIndexChanged.connect(
lambda: self.canvas.initializeAiModel(
name=self._selectAiModelComboBox.currentText()
)
if self.canvas.createMode in ["ai_polygon", "ai_mask"]
else None
)

self.tools = self.toolbar("Tools")
Expand Down Expand Up @@ -892,6 +918,7 @@ def populateModeActions(self):
self.actions.createPointMode,
self.actions.createLineStripMode,
self.actions.createAiPolygonMode,
self.actions.createAiMaskMode,
self.actions.editMode,
)
utils.addActions(self.menus.edit, actions + self.actions.editMenu)
Expand Down Expand Up @@ -924,6 +951,7 @@ def setClean(self):
self.actions.createPointMode.setEnabled(True)
self.actions.createLineStripMode.setEnabled(True)
self.actions.createAiPolygonMode.setEnabled(True)
self.actions.createAiMaskMode.setEnabled(True)
title = __appname__
if self.filename is not None:
title = "{} - {}".format(title, self.filename)
Expand Down Expand Up @@ -992,86 +1020,25 @@ def toggleDrawingSensitive(self, drawing=True):
self.actions.delete.setEnabled(not drawing)

def toggleDrawMode(self, edit=True, createMode="polygon"):
draw_actions = {
"polygon": self.actions.createMode,
"rectangle": self.actions.createRectangleMode,
"circle": self.actions.createCircleMode,
"point": self.actions.createPointMode,
"line": self.actions.createLineMode,
"linestrip": self.actions.createLineStripMode,
"ai_polygon": self.actions.createAiPolygonMode,
"ai_mask": self.actions.createAiMaskMode,
}

self.canvas.setEditing(edit)
self.canvas.createMode = createMode
if edit:
self.actions.createMode.setEnabled(True)
self.actions.createRectangleMode.setEnabled(True)
self.actions.createCircleMode.setEnabled(True)
self.actions.createLineMode.setEnabled(True)
self.actions.createPointMode.setEnabled(True)
self.actions.createLineStripMode.setEnabled(True)
self.actions.createAiPolygonMode.setEnabled(True)
self._selectAiModelComboBox.setEnabled(False)
for draw_action in draw_actions.values():
draw_action.setEnabled(True)
else:
if createMode == "polygon":
self.actions.createMode.setEnabled(False)
self.actions.createRectangleMode.setEnabled(True)
self.actions.createCircleMode.setEnabled(True)
self.actions.createLineMode.setEnabled(True)
self.actions.createPointMode.setEnabled(True)
self.actions.createLineStripMode.setEnabled(True)
self.actions.createAiPolygonMode.setEnabled(True)
self._selectAiModelComboBox.setEnabled(False)
elif createMode == "rectangle":
self.actions.createMode.setEnabled(True)
self.actions.createRectangleMode.setEnabled(False)
self.actions.createCircleMode.setEnabled(True)
self.actions.createLineMode.setEnabled(True)
self.actions.createPointMode.setEnabled(True)
self.actions.createLineStripMode.setEnabled(True)
self.actions.createAiPolygonMode.setEnabled(True)
self._selectAiModelComboBox.setEnabled(False)
elif createMode == "line":
self.actions.createMode.setEnabled(True)
self.actions.createRectangleMode.setEnabled(True)
self.actions.createCircleMode.setEnabled(True)
self.actions.createLineMode.setEnabled(False)
self.actions.createPointMode.setEnabled(True)
self.actions.createLineStripMode.setEnabled(True)
self.actions.createAiPolygonMode.setEnabled(True)
self._selectAiModelComboBox.setEnabled(False)
elif createMode == "point":
self.actions.createMode.setEnabled(True)
self.actions.createRectangleMode.setEnabled(True)
self.actions.createCircleMode.setEnabled(True)
self.actions.createLineMode.setEnabled(True)
self.actions.createPointMode.setEnabled(False)
self.actions.createLineStripMode.setEnabled(True)
self.actions.createAiPolygonMode.setEnabled(True)
self._selectAiModelComboBox.setEnabled(False)
elif createMode == "circle":
self.actions.createMode.setEnabled(True)
self.actions.createRectangleMode.setEnabled(True)
self.actions.createCircleMode.setEnabled(False)
self.actions.createLineMode.setEnabled(True)
self.actions.createPointMode.setEnabled(True)
self.actions.createLineStripMode.setEnabled(True)
self.actions.createAiPolygonMode.setEnabled(True)
self._selectAiModelComboBox.setEnabled(False)
elif createMode == "linestrip":
self.actions.createMode.setEnabled(True)
self.actions.createRectangleMode.setEnabled(True)
self.actions.createCircleMode.setEnabled(True)
self.actions.createLineMode.setEnabled(True)
self.actions.createPointMode.setEnabled(True)
self.actions.createLineStripMode.setEnabled(False)
self.actions.createAiPolygonMode.setEnabled(True)
self._selectAiModelComboBox.setEnabled(False)
elif createMode == "ai_polygon":
self.actions.createMode.setEnabled(True)
self.actions.createRectangleMode.setEnabled(True)
self.actions.createCircleMode.setEnabled(True)
self.actions.createLineMode.setEnabled(True)
self.actions.createPointMode.setEnabled(True)
self.actions.createLineStripMode.setEnabled(True)
self.actions.createAiPolygonMode.setEnabled(False)
self.canvas.initializeAiModel(
name=self._selectAiModelComboBox.currentText()
)
self._selectAiModelComboBox.setEnabled(True)
else:
raise ValueError("Unsupported createMode: %s" % createMode)
for draw_mode, draw_action in draw_actions.items():
draw_action.setEnabled(createMode != draw_mode)
self.actions.editMode.setEnabled(not edit)

def setEditMode(self):
Expand Down Expand Up @@ -1286,6 +1253,7 @@ def loadLabels(self, shapes):
shape_type=shape_type,
group_id=group_id,
description=description,
mask=shape["mask"],
)
for x, y in points:
shape.addPoint(QtCore.QPointF(x, y))
Expand Down Expand Up @@ -1325,6 +1293,9 @@ def format_shape(s):
description=s.description,
shape_type=s.shape_type,
flags=s.flags,
mask=None
if s.mask is None
else utils.img_arr_to_b64(s.mask),
)
)
return data
Expand Down
3 changes: 2 additions & 1 deletion labelme/config/default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ shape:
vertex_fill_color: [0, 255, 0, 255]
# selecting / hovering
select_line_color: [255, 255, 255, 255]
select_fill_color: [0, 255, 0, 155]
select_fill_color: [0, 255, 0, 64]
hvertex_fill_color: [255, 255, 255, 255]
point_size: 8

Expand Down Expand Up @@ -77,6 +77,7 @@ canvas:
point: false
linestrip: false
ai_polygon: false
ai_mask: false

shortcuts:
close: Ctrl+W
Expand Down
4 changes: 4 additions & 0 deletions labelme/label_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def load(self, filename):
"shape_type",
"flags",
"description",
"mask",
]
try:
with open(filename, "r") as f:
Expand Down Expand Up @@ -112,6 +113,9 @@ def load(self, filename):
flags=s.get("flags", {}),
description=s.get("description"),
group_id=s.get("group_id"),
mask=utils.img_b64_to_arr(s["mask"])
if s.get("mask")
else None,
other_data={
k: v for k, v in s.items() if k not in shape_keys
},
Expand Down
Loading
Loading