Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SubjectSegmentationPainter extends CustomPainter {
void paint(Canvas canvas, Size size) {
final int width = mask.width;
final int height = mask.height;
final List<Subject> subjects = mask.subjects;
final List<Subject> subjects = mask.subjects ?? [];

final paint = Paint()..style = PaintingStyle.fill;

Expand All @@ -30,7 +30,7 @@ class SubjectSegmentationPainter extends CustomPainter {
final int startY = subject.startY;
final int subjectWidth = subject.subjectWidth;
final int subjectHeight = subject.subjectHeight;
final List<double> confidences = subject.confidences;
final List<double> confidences = subject.confidences ?? [];

for (int y = 0; y < subjectHeight; y++) {
for (int x = 0; y < subjectWidth; x++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class SubjectSegmenterView extends StatefulWidget {
}

class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
final SubjectSegmenter _segmenter = SubjectSegmenter();
final SubjectSegmenter _segmenter = SubjectSegmenter(
options: SubjectSegmenterOptions(enableForegroundConfidenceMask: true));
bool _canProcess = true;
bool _isBusy = false;
CustomPaint? _customPaint;
Expand Down Expand Up @@ -56,8 +57,7 @@ class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
_customPaint = CustomPaint(painter: painter);
} else {
// TODO: set _customPaint to draw on top of image
_text = 'There is a mask with ${mask.subjects.length} subjects';

_text = 'There is a mask with ${mask.subjects?.length} subjects';
_customPaint = null;
}
_isBusy = false;
Expand Down
3 changes: 2 additions & 1 deletion packages/google_mlkit_subject_segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ final InputImage inputImage;
#### Create an instance of `SubjectSegmenter`

```dart
final segmenter = SubjectSegmenter();
final options = SubjectSegmenterOptions();
final segmenter = SubjectSegmenter(options: options);
```

#### Process image
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
package com.google_mlkit_subject_segmentation;

import android.content.Context;
import android.graphics.Bitmap;

import androidx.annotation.NonNull;

import com.google.mlkit.vision.common.InputImage;
import com.google.mlkit.vision.segmentation.subject.Subject;
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation;
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentationResult;
import com.google.mlkit.vision.segmentation.subject.SubjectSegmenter;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import io.flutter.Log;
import io.flutter.plugin.common.MethodCall;
Expand All @@ -27,8 +35,6 @@ public class SubjectSegmenterProcess implements MethodChannel.MethodCallHandler
private static final String CLOSE = "vision#closeSubjectSegmenter";

private final Context context;

private static final String TAG = "Logger";

private int imageWidth;
private int imageHeight;
Expand All @@ -55,55 +61,119 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result
}
}

private SubjectSegmenter initialize(MethodCall call) {
SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder()
.enableMultipleSubjects(new SubjectSegmenterOptions.SubjectResultOptions.Builder()
.enableConfidenceMask().build());
SubjectSegmenterOptions options = builder.build();
return SubjectSegmentation.getClient(options);
}

private void handleDetection(MethodCall call, MethodChannel.Result result){
Map<String, Object> imageData = (Map<String, Object>) call.argument("imageData");
InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result);
if (inputImage == null) return;
private void handleDetection(MethodCall call, MethodChannel.Result result) {
InputImage inputImage = InputImageConverter.getInputImageFromData(call.argument("imageData"), context, result);
if(inputImage == null) return;
imageHeight = inputImage.getHeight();
imageWidth = inputImage.getWidth();

String id = call.argument("id");
SubjectSegmenter subjectSegmenter = instances.get(id);
if (subjectSegmenter == null) {
subjectSegmenter = initialize(call);
instances.put(id, subjectSegmenter);
SubjectSegmenter subjectSegmenter = getOrCreateSegmenter(id, call);

subjectSegmenter.process(inputImage)
.addOnSuccessListener(subjectSegmentationResult -> processResult(subjectSegmentationResult, call, result))
.addOnFailureListener(e -> result.error("Subject segmentation failure!", e.getMessage(), e));

}

private SubjectSegmenter getOrCreateSegmenter(String id, MethodCall call) {
return instances.computeIfAbsent(id, k -> initialize(call));
}
private SubjectSegmenter initialize(MethodCall call) {
Map<String, Object> options = call.argument("options");
SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder();
assert options != null;
configureBuilder(builder, options);
return SubjectSegmentation.getClient(builder.build());
}

private void configureBuilder(SubjectSegmenterOptions.Builder builder, Map<String, Object> options) {
if(Boolean.TRUE.equals(options.get("enableForegroundBitmap"))){
builder.enableForegroundBitmap();
}
if(Boolean.TRUE.equals(options.get("enableForegroundConfidenceMask"))){
builder.enableForegroundConfidenceMask();
}
configureMultipleSubjects(builder, options);
}

private void configureMultipleSubjects(SubjectSegmenterOptions.Builder builder, Map<String, Object> options) {
boolean enableMultiConfidenceMask = Boolean.TRUE.equals(options.get("enableMultiConfidenceMask")) ;
boolean enableMultiSubjectBitmap = Boolean.TRUE.equals(options.get("enableMultiSubjectBitmap"));

if(enableMultiConfidenceMask || enableMultiSubjectBitmap) {
SubjectSegmenterOptions.SubjectResultOptions.Builder subjectBuilder = new SubjectSegmenterOptions.SubjectResultOptions.Builder();
if(enableMultiConfidenceMask) subjectBuilder.enableConfidenceMask();
if(enableMultiSubjectBitmap) subjectBuilder.enableSubjectBitmap();
builder.enableMultipleSubjects(subjectBuilder.build());
}
}

private void processResult(SubjectSegmentationResult subjectSegmentationResult, MethodCall call, MethodChannel.Result result) {
Map<String, Object> resultMap = new HashMap<>();
Map<String, Object> options = call.argument("options");

assert options != null;
if(Boolean.TRUE.equals(options.get("enableForegroundBitmap"))) {
addForegroundBitmap(resultMap, subjectSegmentationResult.getForegroundBitmap());
}

if(Boolean.TRUE.equals(options.get("enableForegroundConfidenceMask"))){
addConfidenceMask(resultMap, subjectSegmentationResult.getForegroundConfidenceMask());
}
if(Boolean.TRUE.equals(options.get("enableMultiConfidenceMask")) || Boolean.TRUE.equals(options.get("enableMultiSubjectBitmap"))) {

subjectSegmenter.process(inputImage)
.addOnSuccessListener( subjectSegmentationResult -> {
List<Map<String, Object>> subjectsData = new ArrayList<>();
for(Subject subject : subjectSegmentationResult.getSubjects()){
Map<String, Object> subjectData = getStringObjectMap(subject);
for(Subject subject: subjectSegmentationResult.getSubjects()){
Map<String, Object> subjectData = getStringObjectMap(subject, options);
subjectsData.add(subjectData);
}
Map<String, Object> map = new HashMap<>();
map.put("subjects", subjectsData);
map.put("width", imageWidth);
map.put("height", imageHeight);
result.success(map);
}).addOnFailureListener( e -> result.error("Subject segmentation failed!", e.getMessage(), e) );
resultMap.put("subjects", subjectsData);
}
resultMap.put("width", imageWidth);
resultMap.put("height", imageHeight);

result.success(resultMap);
}

private void addForegroundBitmap(Map<String, Object> map, Bitmap bitmap) {
if(bitmap != null) {
map.put("bitmap", getBitmapBytes(bitmap));
}
}

private void addConfidenceMask(Map<String, Object> map, FloatBuffer mask) {
if(mask != null) {
map.put("confidences", getConfidences(mask));
}
}

private static float[] getConfidences(FloatBuffer floatBuffer) {
float[] confidences = new float[floatBuffer.remaining()];
floatBuffer.get(confidences);
return confidences;
}

private static byte[] getBitmapBytes(Bitmap bitmap) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream);
return outputStream.toByteArray();
}


@NonNull
private static Map<String, Object> getStringObjectMap(Subject subject) {
private static Map<String, Object> getStringObjectMap(Subject subject, Map<String, Object> options) {
Map<String, Object> subjectData = new HashMap<>();
subjectData.put("startX", subject.getStartX());
subjectData.put("startY", subject.getStartY());
subjectData.put("width", subject.getWidth());
subjectData.put("height", subject.getHeight());

FloatBuffer confidenceMask = subject.getConfidenceMask();
assert confidenceMask != null;
float[] confidences = new float[confidenceMask.remaining()];
confidenceMask.get(confidences);
subjectData.put("confidences", confidences);
if(Boolean.TRUE.equals(options.get("enableMultiConfidenceMask"))){
subjectData.put("confidences", getConfidences(Objects.requireNonNull(subject.getConfidenceMask())));
}
if(Boolean.TRUE.equals(options.get("enableMultiSubjectBitmap"))) {
subjectData.put("bitmap", getBitmapBytes(Objects.requireNonNull(subject.getBitmap())));
}
return subjectData;
}

Expand Down
Loading