From c7f4a5a60a3b422eeb49335a93c3807fbea7a155 Mon Sep 17 00:00:00 2001 From: elb3k Date: Thu, 27 May 2021 15:24:56 +0900 Subject: [PATCH 1/2] Custom mean, std --- .../pytorch_mobile/PyTorchMobilePlugin.java | 11 ++++++- ios/Classes/Helpers/UIImageExtension.m | 8 ++--- ios/Classes/PytorchMobilePlugin.mm | 10 +++++- lib/model.dart | 31 ++++++++++++++++--- 4 files changed, 49 insertions(+), 11 deletions(-) diff --git a/android/src/main/java/io/fynn/pytorch_mobile/PyTorchMobilePlugin.java b/android/src/main/java/io/fynn/pytorch_mobile/PyTorchMobilePlugin.java index 1daddc5..1ce7eeb 100644 --- a/android/src/main/java/io/fynn/pytorch_mobile/PyTorchMobilePlugin.java +++ b/android/src/main/java/io/fynn/pytorch_mobile/PyTorchMobilePlugin.java @@ -107,6 +107,15 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull Result result) { byte[] imageData = call.argument("image"); int width = call.argument("width"); int height = call.argument("height"); + // Custom mean + ArrayList _mean = call.argument("mean"); + final float [] mean = Convert.toFloatPrimitives(_mean.toArray(new Double[0])); + + // Custom std + ArrayList _std = call.argument("std"); + final float[] std = Convert.toFloatPrimitives(_std.toArray(new Double[0])); + + imageModule = modules.get(index); @@ -119,7 +128,7 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull Result result) { } final Tensor imageInputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, - TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB); + mean, std); final Tensor imageOutputTensor = imageModule.forward(IValue.from(imageInputTensor)).toTensor(); diff --git a/ios/Classes/Helpers/UIImageExtension.m b/ios/Classes/Helpers/UIImageExtension.m index 5332043..e44b4cb 100644 --- a/ios/Classes/Helpers/UIImageExtension.m +++ b/ios/Classes/Helpers/UIImageExtension.m @@ -12,7 +12,7 @@ + (UIImage*)resize:(UIImage*)image toWidth:(int) width toHeight:(int)height { return newImage; } -+ (nullable float*)normalize:(UIImage*)image{ ++ (nullable float*)normalize:(UIImage*)image mean:(NSArray*) std:(NSArray*) { CGImageRef cgImage = [image CGImage]; NSUInteger w = CGImageGetWidth(cgImage); NSUInteger h = CGImageGetHeight(cgImage); @@ -37,9 +37,9 @@ + (nullable float*)normalize:(UIImage*)image{ float* normalizedBuffer = malloc(3*h*w * sizeof(float)); for(int i = 0; i < (w*h); i++) { - normalizedBuffer[i] = (rawBytes[i * 4 + 0] / 255.0 - 0.485) / 0.229; - normalizedBuffer[w * h + i] = (rawBytes[i * 4 + 1] / 255.0 - 0.456) / 0.224; - normalizedBuffer[w * h * 2 + i] = (rawBytes[i * 4 + 2] / 255.0 - 0.406) / 0.225; + normalizedBuffer[i] = (rawBytes[i * 4 + 0] / 255.0 - mean[0]) / std[0]; + normalizedBuffer[w * h + i] = (rawBytes[i * 4 + 1] / 255.0 - mean[1]) / std[1]; + normalizedBuffer[w * h * 2 + i] = (rawBytes[i * 4 + 2] / 255.0 - mean[2]) / std[2]; } return normalizedBuffer; diff --git a/ios/Classes/PytorchMobilePlugin.mm b/ios/Classes/PytorchMobilePlugin.mm index 31473e1..330f3ae 100644 --- a/ios/Classes/PytorchMobilePlugin.mm +++ b/ios/Classes/PytorchMobilePlugin.mm @@ -71,6 +71,9 @@ - (void)handleMethodCall:(FlutterMethodCall*)call result:(FlutterResult)result { float* input; int width; int height; + NSArray* mean; + NSArray* std; + try { int index = [call.arguments[@"index"] intValue]; imageModule = modules[index]; @@ -78,11 +81,16 @@ - (void)handleMethodCall:(FlutterMethodCall*)call result:(FlutterResult)result { FlutterStandardTypedData *imageData = call.arguments[@"image"]; width = [call.arguments[@"width"] intValue]; height = [call.arguments[@"height"] intValue]; + // Custom mean + mean = call.arguments[@"mean"]; + // Custom std + std = call.arguments[@"std"]; + UIImage *image = [UIImage imageWithData: imageData.data]; image = [UIImageExtension resize:image toWidth:width toHeight:height]; - input = [UIImageExtension normalize:image]; + input = [UIImageExtension normalize:image mean:mean std:std]; } catch (const std::exception& e) { NSLog(@"PyTorchMobile: error reading image!\n%s", e.what()); } diff --git a/lib/model.dart b/lib/model.dart index 35ebf3d..877ff31 100644 --- a/lib/model.dart +++ b/lib/model.dart @@ -3,6 +3,9 @@ import 'dart:io'; import 'package:flutter/services.dart'; import 'package:pytorch_mobile/enums/dtype.dart'; +const TORCHVISION_NORM_MEAN_RGB = [0.485, 0.456, 0.406]; +const TORCHVISION_NORM_STD_RGB = [0.229, 0.224, 0.225]; + class Model { static const MethodChannel _channel = const MethodChannel('pytorch_mobile'); @@ -24,14 +27,21 @@ class Model { ///predicts image and returns the supposed label belonging to it Future getImagePrediction( - File image, int width, int height, String labelPath) async { + File image, int width, int height, String labelPath, + {List mean = TORCHVISION_NORM_MEAN_RGB, + List std = TORCHVISION_NORM_STD_RGB}) async { + // Assert mean std + assert(mean.length == 3, "Mean should have size of 3"); + assert(std.length == 3, "STD should have size of 3"); List labels = await _getLabels(labelPath); List byteArray = image.readAsBytesSync(); final List prediction = await _channel.invokeListMethod("predictImage", { "index": _index, "image": byteArray, "width": width, - "height": height + "height": height, + "mean": mean, + "std": std }); double maxScore = double.negativeInfinity; int maxScoreIndex = -1; @@ -45,9 +55,20 @@ class Model { } ///predicts image but returns the raw net output - Future getImagePredictionList(File image, int width, int height) async { - final List prediction = await _channel.invokeListMethod("predictImage", - {"index": _index, "image": image.readAsBytesSync(), "width": width, "height": height}); + Future getImagePredictionList(File image, int width, int height, + {List mean = TORCHVISION_NORM_MEAN_RGB, + List std = TORCHVISION_NORM_STD_RGB}) async { + // Assert mean std + assert(mean.length == 3, "Mean should have size of 3"); + assert(std.length == 3, "STD should have size of 3"); + final List prediction = await _channel.invokeListMethod("predictImage", { + "index": _index, + "image": image.readAsBytesSync(), + "width": width, + "height": height, + "mean": mean, + "std": std + }); return prediction; } From 403b45081b14787dfa320c773aac04157abebbf4 Mon Sep 17 00:00:00 2001 From: elb3k Date: Thu, 27 May 2021 15:43:49 +0900 Subject: [PATCH 2/2] Java fix and example code --- README.md | 8 ++++++++ .../java/io/fynn/pytorch_mobile/PyTorchMobilePlugin.java | 6 ++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 98d5955..eb130ab 100644 --- a/README.md +++ b/README.md @@ -53,5 +53,13 @@ String prediction = await _imageModel .getImagePrediction(image, 224, 224, "assets/labels/labels.csv"); ``` +### Image prediction for an image with custom mean and std +```dart +final mean = [0.5, 0.5, 0.5]; +final std = [0.5, 0.5, 0.5]; +String prediction = await _imageModel + .getImagePrediction(image, 224, 224, "assets/labels/labels.csv", mean: mean, std: std); +``` + ## Contact fynnmaarten.business@gmail.com \ No newline at end of file diff --git a/android/src/main/java/io/fynn/pytorch_mobile/PyTorchMobilePlugin.java b/android/src/main/java/io/fynn/pytorch_mobile/PyTorchMobilePlugin.java index 1ce7eeb..1de158d 100644 --- a/android/src/main/java/io/fynn/pytorch_mobile/PyTorchMobilePlugin.java +++ b/android/src/main/java/io/fynn/pytorch_mobile/PyTorchMobilePlugin.java @@ -102,6 +102,8 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull Result result) { case "predictImage": Module imageModule = null; Bitmap bitmap = null; + float [] mean = null; + float [] std = null; try { int index = call.argument("index"); byte[] imageData = call.argument("image"); @@ -109,11 +111,11 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull Result result) { int height = call.argument("height"); // Custom mean ArrayList _mean = call.argument("mean"); - final float [] mean = Convert.toFloatPrimitives(_mean.toArray(new Double[0])); + mean = Convert.toFloatPrimitives(_mean.toArray(new Double[0])); // Custom std ArrayList _std = call.argument("std"); - final float[] std = Convert.toFloatPrimitives(_std.toArray(new Double[0])); + std = Convert.toFloatPrimitives(_std.toArray(new Double[0]));