Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use native-assets to vendor MediaPipe SDK #9

Merged
merged 8 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions packages/mediapipe-core/lib/src/task_options.dart
Original file line number Diff line number Diff line change
Expand Up @@ -21,51 +21,84 @@ import 'third_party/mediapipe/generated/mediapipe_common_bindings.dart'
/// classifier's desired behavior.
class BaseOptions extends Equatable {
/// Generative constructor that creates a [BaseOptions] instance.
const BaseOptions({this.modelAssetBuffer, this.modelAssetPath})
: assert(
const BaseOptions._({
this.modelAssetBuffer,
this.modelAssetPath,
this.modelAssetBufferCount,
required _BaseOptionsType type,
}) : assert(
!(modelAssetBuffer == null && modelAssetPath == null),
'You must supply either `modelAssetBuffer` or `modelAssetPath`',
),
assert(
!(modelAssetBuffer != null && modelAssetPath != null),
'You must only supply one of `modelAssetBuffer` and `modelAssetPath`',
);
),
assert(
(modelAssetBuffer == null) == (modelAssetBufferCount == null),
'modelAssetBuffer and modelAssetBufferCount must only be submitted '
'together',
),
_type = type;

/// Constructor for [BaseOptions] classes using a file system path.
///
/// In practice, this is unsupported, as assets in Flutter are bundled into
/// the build output and not available on disk. However, it can potentially
/// be helpful for testing / development purposes.
factory BaseOptions.path(String path) => BaseOptions._(
modelAssetPath: path,
type: _BaseOptionsType.path,
);

/// Constructor for [BaseOptions] classes using an in-memory pointer to the
/// MediaPipe SDK.
///
/// In practice, this is the only option supported for production builds.
factory BaseOptions.memory(Uint8List buffer) {
return BaseOptions._(
modelAssetBuffer: buffer,
modelAssetBufferCount: buffer.lengthInBytes,
type: _BaseOptionsType.memory,
);
}

/// The model asset file contents as bytes;
final Uint8List? modelAssetBuffer;

/// The size of the model assets buffer (or `0` if not set).
final int? modelAssetBufferCount;

/// Path to the model asset file.
final String? modelAssetPath;

final _BaseOptionsType _type;

/// Converts this pure-Dart representation into C-memory suitable for the
/// MediaPipe SDK to instantiate various classifiers.
Pointer<bindings.BaseOptions> toStruct() {
final struct = calloc<bindings.BaseOptions>();

if (modelAssetPath != null) {
if (_type == _BaseOptionsType.path) {
struct.ref.model_asset_path = prepareString(modelAssetPath!);
}
if (modelAssetBuffer != null) {
if (_type == _BaseOptionsType.memory) {
struct.ref.model_asset_buffer = prepareUint8List(modelAssetBuffer!);
struct.ref.model_asset_buffer_count = modelAssetBuffer!.lengthInBytes;
}
return struct;
}

@override
List<Object?> get props => [modelAssetBuffer, modelAssetPath];

/// Releases all C memory held by this [bindings.BaseOptions] struct.
static void freeStruct(bindings.BaseOptions struct) {
if (struct.model_asset_buffer.address != 0) {
calloc.free(struct.model_asset_buffer);
}
if (struct.model_asset_path.address != 0) {
calloc.free(struct.model_asset_path);
}
}
List<Object?> get props => [
modelAssetBuffer,
modelAssetPath,
modelAssetBufferCount,
];
}

enum _BaseOptionsType { path, memory }

/// Dart representation of MediaPipe's "ClassifierOptions" concept.
///
/// Classifier options shared across MediaPipe classification tasks.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ final class BaseOptions extends ffi.Struct {

external ffi.Pointer<ffi.Char> model_asset_path;

@ffi.UnsignedInt()
@ffi.Int()
external int model_asset_buffer_count;
}

Expand Down
24 changes: 3 additions & 21 deletions packages/mediapipe-core/test/task_options_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,16 @@ import 'package:test/test.dart';
import 'package:mediapipe_core/mediapipe_core.dart';

void main() {
group('BaseOptions constructor should', () {
test('enforce exactly one of modelPath and modelBuffer', () {
expect(
() => BaseOptions(
modelAssetPath: 'abc',
modelAssetBuffer: Uint8List.fromList([1, 2, 3]),
),
throwsA(TypeMatcher<AssertionError>()),
);

expect(BaseOptions.new, throwsA(TypeMatcher<AssertionError>()));
});
});

group('BaseOptions.toStruct/fromStruct should', () {
test('allocate memory in C for a modelAssetPath', () {
final options = BaseOptions(modelAssetPath: 'abc');
final options = BaseOptions.path('abc');
final struct = options.toStruct();
expect(toDartString(struct.ref.model_asset_path), 'abc');
expectNullPtr(struct.ref.model_asset_buffer);
});

test('allocate memory in C for a modelAssetBuffer', () {
final options = BaseOptions(
modelAssetBuffer: Uint8List.fromList([1, 2, 3]),
);
final options = BaseOptions.memory(Uint8List.fromList([1, 2, 3]));
final struct = options.toStruct();
expect(
toUint8List(struct.ref.model_asset_buffer),
Expand All @@ -44,9 +28,7 @@ void main() {
});

test('allocate memory in C for a modelAssetBuffer containing 0', () {
final options = BaseOptions(
modelAssetBuffer: Uint8List.fromList([1, 2, 0, 3]),
);
final options = BaseOptions.memory(Uint8List.fromList([1, 2, 0, 3]));
final struct = options.toStruct();
expect(
toUint8List(struct.ref.model_asset_buffer),
Expand Down
41 changes: 41 additions & 0 deletions packages/mediapipe-task-text/build.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import 'dart:io';
import 'package:native_assets_cli/native_assets_cli.dart';
import 'package:http/http.dart' as http;

const cloudAssetFilename = 'libtext_classifier-v0.0.3.dylib';
const localAssetFilename = 'libtext_classifier.dylib';
const assetLocation =
'https://storage.googleapis.com/random-storage-asdf/$cloudAssetFilename';

Future<void> main(List<String> args) async {
final buildConfig = await BuildConfig.fromArgs(args);
final buildOutput = BuildOutput();
final downloadFileLocation = buildConfig.outDir.resolve(localAssetFilename);
if (!buildConfig.dryRun) {
final downloadUri = Uri.parse(assetLocation);
final downloadResponse = await http.get(downloadUri);
final downloadedFile = File(downloadFileLocation.toFilePath());
if (downloadResponse.statusCode == 200) {
if (downloadedFile.existsSync()) {
downloadedFile.deleteSync();
}
downloadedFile.createSync();
downloadedFile.writeAsBytes(downloadResponse.bodyBytes);
} else {
throw Exception(
'${downloadResponse.statusCode} :: ${downloadResponse.body}');
}
}
buildOutput.dependencies.dependencies
.add(buildConfig.packageRoot.resolve('build.dart'));
buildOutput.assets.add(
Asset(
// What should this `id` be?
id: 'package:mediapipe_text/src/mediapipe_text_bindings.dart',
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
linkMode: LinkMode.dynamic,
target: Target.macOSArm64,
path: AssetAbsolutePath(downloadFileLocation),
),
);
await buildOutput.writeToFile(outDir: buildConfig.outDir);
}
1 change: 1 addition & 0 deletions packages/mediapipe-task-text/example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class _MainAppState extends State<MainApp> {
super.initState();
_controller.text = 'Hello, world!';
_initClassifier();
Future.delayed(const Duration(milliseconds: 500)).then((_) => _classify());
}

Future<void> _initClassifier() async {
Expand Down
1 change: 0 additions & 1 deletion packages/mediapipe-task-text/example/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,3 @@ flutter:

assets:
- assets/bert_classifier.tflite
- assets/libtext_classifier.dylib
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TextClassifierOptions {
}) {
assert(!kIsWeb, 'fromAssetPath cannot be used on the web');
return TextClassifierOptions(
baseOptions: BaseOptions(modelAssetPath: assetPath),
baseOptions: BaseOptions.path(assetPath),
classifierOptions: classifierOptions,
);
}
Expand All @@ -45,7 +45,7 @@ class TextClassifierOptions {
ClassifierOptions classifierOptions = const ClassifierOptions(),
}) =>
TextClassifierOptions(
baseOptions: BaseOptions(modelAssetBuffer: assetBuffer),
baseOptions: BaseOptions.memory(assetBuffer),
classifierOptions: classifierOptions,
);

Expand Down
3 changes: 3 additions & 0 deletions packages/mediapipe-task-text/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ dependencies:
sdk: flutter
http: ^1.1.0
logging: ^1.2.0
http: ^1.1.0
mediapipe_core:
path: ../mediapipe-core
native_assets_cli: ^0.3.0
native_toolchain_c: ^0.3.0

dev_dependencies:
ffigen: ^9.0.1
Expand Down