Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,13 @@ String prediction = await _imageModel
.getImagePrediction(image, 224, 224, "assets/labels/labels.csv");
```

### Image prediction for an image with custom mean and std
```dart
final mean = [0.5, 0.5, 0.5];
final std = [0.5, 0.5, 0.5];
String prediction = await _imageModel
.getImagePrediction(image, 224, 224, "assets/labels/labels.csv", mean: mean, std: std);
```

## Contact
fynnmaarten.business@gmail.com
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,22 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull Result result) {
case "predictImage":
Module imageModule = null;
Bitmap bitmap = null;
float [] mean = null;
float [] std = null;
try {
int index = call.argument("index");
byte[] imageData = call.argument("image");
int width = call.argument("width");
int height = call.argument("height");
// Custom mean
ArrayList<Double> _mean = call.argument("mean");
mean = Convert.toFloatPrimitives(_mean.toArray(new Double[0]));

// Custom std
ArrayList<Double> _std = call.argument("std");
std = Convert.toFloatPrimitives(_std.toArray(new Double[0]));



imageModule = modules.get(index);

Expand All @@ -119,7 +130,7 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull Result result) {
}

final Tensor imageInputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
mean, std);

final Tensor imageOutputTensor = imageModule.forward(IValue.from(imageInputTensor)).toTensor();

Expand Down
8 changes: 4 additions & 4 deletions ios/Classes/Helpers/UIImageExtension.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ + (UIImage*)resize:(UIImage*)image toWidth:(int) width toHeight:(int)height {
return newImage;
}

+ (nullable float*)normalize:(UIImage*)image{
+ (nullable float*)normalize:(UIImage*)image mean:(NSArray<NSNumber*>*) std:(NSArray<NSNumber*>*) {
CGImageRef cgImage = [image CGImage];
NSUInteger w = CGImageGetWidth(cgImage);
NSUInteger h = CGImageGetHeight(cgImage);
Expand All @@ -37,9 +37,9 @@ + (nullable float*)normalize:(UIImage*)image{

float* normalizedBuffer = malloc(3*h*w * sizeof(float));
for(int i = 0; i < (w*h); i++) {
normalizedBuffer[i] = (rawBytes[i * 4 + 0] / 255.0 - 0.485) / 0.229;
normalizedBuffer[w * h + i] = (rawBytes[i * 4 + 1] / 255.0 - 0.456) / 0.224;
normalizedBuffer[w * h * 2 + i] = (rawBytes[i * 4 + 2] / 255.0 - 0.406) / 0.225;
normalizedBuffer[i] = (rawBytes[i * 4 + 0] / 255.0 - mean[0]) / std[0];
normalizedBuffer[w * h + i] = (rawBytes[i * 4 + 1] / 255.0 - mean[1]) / std[1];
normalizedBuffer[w * h * 2 + i] = (rawBytes[i * 4 + 2] / 255.0 - mean[2]) / std[2];
}

return normalizedBuffer;
Expand Down
10 changes: 9 additions & 1 deletion ios/Classes/PytorchMobilePlugin.mm
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,26 @@ - (void)handleMethodCall:(FlutterMethodCall*)call result:(FlutterResult)result {
float* input;
int width;
int height;
NSArray<NSNumber*>* mean;
NSArray<NSNumber*>* std;

try {
int index = [call.arguments[@"index"] intValue];
imageModule = modules[index];

FlutterStandardTypedData *imageData = call.arguments[@"image"];
width = [call.arguments[@"width"] intValue];
height = [call.arguments[@"height"] intValue];
// Custom mean
mean = call.arguments[@"mean"];
// Custom std
std = call.arguments[@"std"];


UIImage *image = [UIImage imageWithData: imageData.data];
image = [UIImageExtension resize:image toWidth:width toHeight:height];

input = [UIImageExtension normalize:image];
input = [UIImageExtension normalize:image mean:mean std:std];
} catch (const std::exception& e) {
NSLog(@"PyTorchMobile: error reading image!\n%s", e.what());
}
Expand Down
26 changes: 21 additions & 5 deletions lib/model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import 'dart:io';
import 'package:flutter/services.dart';
import 'package:pytorch_mobile/enums/dtype.dart';

const TORCHVISION_NORM_MEAN_RGB = [0.485, 0.456, 0.406];
const TORCHVISION_NORM_STD_RGB = [0.229, 0.224, 0.225];

class Model {
static const MethodChannel _channel = const MethodChannel('pytorch_mobile');

Expand All @@ -24,14 +27,21 @@ class Model {

///predicts image and returns the supposed label belonging to it
Future<String> getImagePrediction(
File image, int width, int height, String labelPath) async {
File image, int width, int height, String labelPath,
{List<double> mean = TORCHVISION_NORM_MEAN_RGB,
List<double> std = TORCHVISION_NORM_STD_RGB}) async {
// Assert mean std
assert(mean.length == 3, "Mean should have size of 3");
assert(std.length == 3, "STD should have size of 3");
List<String> labels = await _getLabels(labelPath);
List byteArray = image.readAsBytesSync();
final List? prediction = await _channel.invokeListMethod("predictImage", {
"index": _index,
"image": byteArray,
"width": width,
"height": height
"height": height,
"mean": mean,
"std": std
});
double maxScore = double.negativeInfinity;
int maxScoreIndex = -1;
Expand All @@ -45,13 +55,19 @@ class Model {
}

///predicts image but returns the raw net output
Future<List?> getImagePredictionList(
File image, int width, int height) async {
Future<List?> getImagePredictionList(File image, int width, int height,
{List<double> mean = TORCHVISION_NORM_MEAN_RGB,
List<double> std = TORCHVISION_NORM_STD_RGB}) async {
// Assert mean std
assert(mean.length == 3, "Mean should have size of 3");
assert(std.length == 3, "STD should have size of 3");
final List? prediction = await _channel.invokeListMethod("predictImage", {
"index": _index,
"image": image.readAsBytesSync(),
"width": width,
"height": height
"height": height,
"mean": mean,
"std": std
});
return prediction;
}
Expand Down