diff --git a/android/android-snippets/app/build.gradle b/android/android-snippets/app/build.gradle index 8f9dc44aaa..0063c4849d 100644 --- a/android/android-snippets/app/build.gradle +++ b/android/android-snippets/app/build.gradle @@ -41,8 +41,9 @@ dependencies { // Barcode model implementation 'com.google.mlkit:barcode-scanning:16.0.1' - // Object feature and model + // Object detection and tracking features implementation 'com.google.mlkit:object-detection:16.1.0' + implementation 'com.google.mlkit:object-detection-custom:16.1.0' // Face features implementation 'com.google.android.gms:play-services-mlkit-face-detection:16.1.0' diff --git a/android/android-snippets/app/src/main/java/com/google/example/mlkit/ObjectDetectionActivity.java b/android/android-snippets/app/src/main/java/com/google/example/mlkit/ObjectDetectionActivity.java new file mode 100644 index 0000000000..23d8d0d55d --- /dev/null +++ b/android/android-snippets/app/src/main/java/com/google/example/mlkit/ObjectDetectionActivity.java @@ -0,0 +1,166 @@ +/* + * Copyright 2020 Google LLC. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.example.mlkit; + +import android.graphics.Bitmap; +import android.graphics.Rect; +import android.os.Bundle; + +import androidx.annotation.NonNull; +import androidx.appcompat.app.AppCompatActivity; + +import com.google.android.gms.tasks.OnFailureListener; +import com.google.android.gms.tasks.OnSuccessListener; +import com.google.mlkit.common.model.LocalModel; +import com.google.mlkit.vision.common.InputImage; +import com.google.mlkit.vision.objects.DetectedObject; +import com.google.mlkit.vision.objects.ObjectDetection; +import com.google.mlkit.vision.objects.ObjectDetector; +import com.google.mlkit.vision.objects.custom.CustomObjectDetectorOptions; +import com.google.mlkit.vision.objects.defaults.ObjectDetectorOptions; +import com.google.mlkit.vision.objects.defaults.PredefinedCategory; + +import java.util.ArrayList; +import java.util.List; + +public class ObjectDetectionActivity extends AppCompatActivity { + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + } + + private void useDefaultObjectDetector() { + // [START create_default_options] + // Live detection and tracking + ObjectDetectorOptions options = + new ObjectDetectorOptions.Builder() + .setDetectorMode(ObjectDetectorOptions.STREAM_MODE) + .enableClassification() // Optional + .build(); + + // Multiple object detection in static images + options = + new ObjectDetectorOptions.Builder() + .setDetectorMode(ObjectDetectorOptions.SINGLE_IMAGE_MODE) + .enableMultipleObjects() + .enableClassification() // Optional + .build(); + // [END create_default_options] + + // [START create_detector] + ObjectDetector objectDetector = ObjectDetection.getClient(options); + // [END create_detector] + + InputImage image = + InputImage.fromBitmap( + Bitmap.createBitmap(new int[100 * 100], 100, 100, Bitmap.Config.ARGB_8888), + 0); + + // [START process_image] + objectDetector.process(image) + .addOnSuccessListener( + new OnSuccessListener>() { + @Override + public void onSuccess(List detectedObjects) { + // Task completed successfully + // ... + } + }) + .addOnFailureListener( + new OnFailureListener() { + @Override + public void onFailure(@NonNull Exception e) { + // Task failed with an exception + // ... + } + }); + // [END process_image] + + List results = new ArrayList<>(); + // [START read_results_default] + // The list of detected objects contains one item if multiple + // object detection wasn't enabled. + for (DetectedObject detectedObject : results) { + Rect boundingBox = detectedObject.getBoundingBox(); + Integer trackingId = detectedObject.getTrackingId(); + for (DetectedObject.Label label : detectedObject.getLabels()) { + String text = label.getText(); + if (PredefinedCategory.FOOD.equals(text)) { + // ... + } + int index = label.getIndex(); + if (PredefinedCategory.FOOD_INDEX == index) { + // ... + } + float confidence = label.getConfidence(); + } + } + // [END read_results_default] + } + + private void useCustomObjectDetector() { + InputImage image = + InputImage.fromBitmap( + Bitmap.createBitmap(new int[100 * 100], 100, 100, Bitmap.Config.ARGB_8888), + 0); + + // [START create_local_model] + LocalModel localModel = + new LocalModel.Builder() + .setAssetFilePath("asset_file_path_to_tflite_model") + // or .setAbsoluteFilePath("absolute_file_path_to_tflite_model") + .build(); + // [END create_local_model] + + // [START create_custom_options] + // Live detection and tracking + CustomObjectDetectorOptions options = + new CustomObjectDetectorOptions.Builder(localModel) + .setDetectorMode(CustomObjectDetectorOptions.STREAM_MODE) + .enableClassification() + .setClassificationConfidenceThreshold(0.5f) + .setMaxPerObjectLabelCount(3) + .build(); + + // Multiple object detection in static images + options = + new CustomObjectDetectorOptions.Builder(localModel) + .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE) + .enableMultipleObjects() + .enableClassification() + .setClassificationConfidenceThreshold(0.5f) + .setMaxPerObjectLabelCount(3) + .build(); + // [END create_custom_options] + + List results = new ArrayList<>(); + // [START read_results_custom] + // The list of detected objects contains one item if multiple + // object detection wasn't enabled. + for (DetectedObject detectedObject : results) { + Rect boundingBox = detectedObject.getBoundingBox(); + Integer trackingId = detectedObject.getTrackingId(); + for (DetectedObject.Label label : detectedObject.getLabels()) { + String text = label.getText(); + int index = label.getIndex(); + float confidence = label.getConfidence(); + } + } + // [END read_results_custom] + } +} diff --git a/android/android-snippets/app/src/main/java/com/google/example/mlkit/kotlin/ObjectDetectionActivity.kt b/android/android-snippets/app/src/main/java/com/google/example/mlkit/kotlin/ObjectDetectionActivity.kt new file mode 100644 index 0000000000..27308f5ea4 --- /dev/null +++ b/android/android-snippets/app/src/main/java/com/google/example/mlkit/kotlin/ObjectDetectionActivity.kt @@ -0,0 +1,134 @@ +/* + * Copyright 2020 Google LLC. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.example.mlkit.kotlin + +import android.graphics.Bitmap +import androidx.appcompat.app.AppCompatActivity +import com.google.mlkit.common.model.LocalModel +import com.google.mlkit.vision.common.InputImage +import com.google.mlkit.vision.objects.DetectedObject +import com.google.mlkit.vision.objects.ObjectDetection +import com.google.mlkit.vision.objects.custom.CustomObjectDetectorOptions +import com.google.mlkit.vision.objects.defaults.ObjectDetectorOptions +import com.google.mlkit.vision.objects.defaults.PredefinedCategory + +class ObjectDetectionActivity : AppCompatActivity() { + + private fun useDefaultObjectDetector() { + // [START create_default_options] + // Live detection and tracking + var options = ObjectDetectorOptions.Builder() + .setDetectorMode(ObjectDetectorOptions.STREAM_MODE) + .enableClassification() // Optional + .build() + + // Multiple object detection in static images + options = ObjectDetectorOptions.Builder() + .setDetectorMode(ObjectDetectorOptions.SINGLE_IMAGE_MODE) + .enableMultipleObjects() + .enableClassification() // Optional + .build() + // [END create_default_options] + + // [START create_detector] + val objectDetector = ObjectDetection.getClient(options) + // [END create_detector] + + val image = InputImage.fromBitmap( + Bitmap.createBitmap(IntArray(100 * 100), 100, 100, Bitmap.Config.ARGB_8888), + 0) + + // [START process_image] + objectDetector.process(image) + .addOnSuccessListener { results -> + // Task completed successfully + // ... + } + .addOnFailureListener { e -> + // Task failed with an exception + // ... + } + // [END process_image] + + val results = listOf() + // [START read_results_default] + for (detectedObject in results) { + val boundingBox = detectedObject.boundingBox + val trackingId = detectedObject.trackingId + for (label in detectedObject.labels) { + val text = label.text + if (PredefinedCategory.FOOD == text) { + // ... + } + val index = label.index + if (PredefinedCategory.FOOD_INDEX == index) { + // ... + } + val confidence = label.confidence + } + } + // [END read_results_default] + } + + private fun useCustomObjectDetector() { + val image = InputImage.fromBitmap( + Bitmap.createBitmap(IntArray(100 * 100), 100, 100, Bitmap.Config.ARGB_8888), + 0) + + // [START create_local_model] + val localModel = + LocalModel.Builder() + .setAssetFilePath("asset_file_path_to_tflite_model") + // or .setAbsoluteFilePath("absolute_file_path_to_tflite_model") + .build() + // [END create_local_model] + + // [START create_custom_options] + // Live detection and tracking + var options = + CustomObjectDetectorOptions.Builder(localModel) + .setDetectorMode(CustomObjectDetectorOptions.STREAM_MODE) + .enableClassification() + .setClassificationConfidenceThreshold(0.5f) + .setMaxPerObjectLabelCount(3) + .build() + + // Multiple object detection in static images + options = + CustomObjectDetectorOptions.Builder(localModel) + .setDetectorMode(CustomObjectDetectorOptions.SINGLE_IMAGE_MODE) + .enableMultipleObjects() + .enableClassification() + .setClassificationConfidenceThreshold(0.5f) + .setMaxPerObjectLabelCount(3) + .build() + // [END create_custom_options] + + val results = listOf() + // [START read_results_custom] + for (detectedObject in results) { + val boundingBox = detectedObject.boundingBox + val trackingId = detectedObject.trackingId + for (label in detectedObject.labels) { + val text = label.text + val index = label.index + val confidence = label.confidence + } + } + // [END read_results_custom] + } +} diff --git a/android/automl/app/src/main/java/com/google/mlkit/vision/automl/demo/StillImageActivity.java b/android/automl/app/src/main/java/com/google/mlkit/vision/automl/demo/StillImageActivity.java index 8b2d7d49a6..c44baacd5e 100755 --- a/android/automl/app/src/main/java/com/google/mlkit/vision/automl/demo/StillImageActivity.java +++ b/android/automl/app/src/main/java/com/google/mlkit/vision/automl/demo/StillImageActivity.java @@ -23,15 +23,12 @@ import android.net.Uri; import android.os.Bundle; import android.provider.MediaStore; - -import android.view.ViewTreeObserver.OnGlobalLayoutListener; -import androidx.appcompat.app.AppCompatActivity; - import android.util.Log; import android.util.Pair; import android.view.Menu; import android.view.MenuInflater; import android.view.View; +import android.view.ViewTreeObserver.OnGlobalLayoutListener; import android.widget.AdapterView; import android.widget.AdapterView.OnItemSelectedListener; import android.widget.ArrayAdapter; @@ -40,6 +37,8 @@ import android.widget.Spinner; import android.widget.Toast; +import androidx.appcompat.app.AppCompatActivity; + import com.google.android.gms.common.annotation.KeepName; import com.google.mlkit.vision.automl.demo.automl.AutoMLImageLabelerProcessor; import com.google.mlkit.vision.automl.demo.automl.AutoMLImageLabelerProcessor.Mode; diff --git a/android/automl/app/src/main/java/com/google/mlkit/vision/automl/demo/automl/AutoMLImageLabelerProcessor.java b/android/automl/app/src/main/java/com/google/mlkit/vision/automl/demo/automl/AutoMLImageLabelerProcessor.java index 84355cbb37..a9a52b13d7 100755 --- a/android/automl/app/src/main/java/com/google/mlkit/vision/automl/demo/automl/AutoMLImageLabelerProcessor.java +++ b/android/automl/app/src/main/java/com/google/mlkit/vision/automl/demo/automl/AutoMLImageLabelerProcessor.java @@ -61,28 +61,18 @@ public AutoMLImageLabelerProcessor(Context context, Mode mode) { new AutoMLImageLabelerRemoteModel.Builder(remoteModelName).build(); createDetector(remoteModel); - RemoteModelManager.getInstance() - .isModelDownloaded(remoteModel) - .addOnCompleteListener( - task -> { - if (!task.getResult()) { - Log.d(TAG, "Model needs to be downloaded"); - DownloadConditions downloadConditions = - new DownloadConditions.Builder().requireWifi().build(); - modelDownloadingTask = - RemoteModelManager.getInstance().download(remoteModel, downloadConditions); - modelDownloadingTask.addOnFailureListener( - ignored -> - Toast.makeText( - context, - "Model download failed for AutoMLImageLabelerImpl," - + " please check your connection.", - Toast.LENGTH_LONG) - .show()); - } else { - Log.d(TAG, "Model Exist Locally"); - } - }); + DownloadConditions downloadConditions = + new DownloadConditions.Builder().requireWifi().build(); + modelDownloadingTask = + RemoteModelManager.getInstance() + .download(remoteModel, downloadConditions) + .addOnFailureListener(ignored -> + Toast.makeText( + context, + "Model download failed for AutoMLImageLabelerImpl," + + " please check your connection.", + Toast.LENGTH_LONG) + .show()); } @Override @@ -97,10 +87,7 @@ public void stop() { @Override protected Task> detectInImage(InputImage image) { - if (modelDownloadingTask == null) { - // No download task means only the locally bundled model is used. Model can be used directly. - return imageLabeler.process(image); - } else if (!modelDownloadingTask.isComplete()) { + if (!modelDownloadingTask.isComplete()) { if (mode == Mode.LIVE_PREVIEW) { Log.i(TAG, "Model download is in progress. Skip detecting image."); return Tasks.forResult(new ArrayList<>());