-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
67 changed files
with
2,149 additions
and
459 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
bin/ | ||
obj/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
using MetadataManager; | ||
using PageManager; | ||
using QueryProcessing; | ||
using QueryProcessing.Exceptions; | ||
using QueryProcessing.Functions; | ||
using QueryProcessing.Utilities; | ||
|
||
namespace ImageProcessing | ||
{ | ||
public class ImageObjectClassificationFuncMappingHandler : IFunctionMappingHandler | ||
{ | ||
public MetadataColumn GetMetadataInfoForOutput(Sql.valueOrFunc.FuncCall func, MetadataColumn[] metadataColumns) | ||
{ | ||
const int MaxReturnTypeLength = 256; | ||
ColumnType[] columnTypes = FuncCallMapper.ExtractCallTypes(func, metadataColumns); | ||
|
||
if (columnTypes.Length != 1) | ||
{ | ||
throw new InvalidFunctionArgument("Object classification requires 1 argument (image path)"); | ||
} | ||
|
||
if (columnTypes[0] != ColumnType.String && columnTypes[0] != ColumnType.StringPointer) | ||
{ | ||
throw new InvalidFunctionArgument("invalid argument type for object classification"); | ||
} | ||
|
||
return new MetadataColumn(0, 0, "Object_Classification_Result", new ColumnInfo(ColumnType.String, MaxReturnTypeLength)); | ||
} | ||
|
||
public IFunctionCall MapToFunctor(ColumnType[] args) | ||
{ | ||
if (args.Length != 1) | ||
{ | ||
throw new InvalidFunctionArgument("Object classification requires 1 argument (image path)"); | ||
} | ||
|
||
if (args[0] != ColumnType.String && args[0] != ColumnType.StringPointer) | ||
{ | ||
throw new InvalidFunctionArgument("invalid argument type for object classification"); | ||
} | ||
|
||
return new ImageObjectClassificationFunctor(); | ||
} | ||
} | ||
|
||
public class ImageObjectClassificationFunctor : IFunctionCall | ||
{ | ||
public void ExecCompute(RowHolder inputRowHolder, RowHolder outputRowHolder, Union2Type<MetadataColumn, Sql.value>[] sourceArguments, int outputPosition) | ||
{ | ||
FunctorArgChecks.CheckInputArguments(sourceArguments, new[] { ColumnType.String }); | ||
FunctorArgExtractString arg = new FunctorArgExtractString(inputRowHolder, sourceArguments); | ||
const string inceptionPb = "tensorflow_inception_graph.pb"; | ||
const string labelsTxt = "imagenet_comp_graph_label_strings.txt"; | ||
|
||
TFModelImageLabelScorer scorer = new TFModelImageLabelScorer(inceptionPb, labelsTxt); | ||
ImageLabelPredictionProbability score = scorer.ScoreSingle(arg.ArgOne); | ||
|
||
outputRowHolder.SetField(outputPosition, score.PredictedLabels[0].ToCharArray()); | ||
} | ||
|
||
public IComparable ExecCompute(RowHolder inputRowHolder, Union2Type<MetadataColumn, Sql.value>[] sourceArguments) | ||
{ | ||
FunctorArgChecks.CheckInputArguments(sourceArguments, new[] { ColumnType.String }); | ||
FunctorArgExtractString arg = new FunctorArgExtractString(inputRowHolder, sourceArguments); | ||
throw new NotImplementedException(); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
using Microsoft.ML.Data; | ||
|
||
namespace ImageProcessing | ||
{ | ||
public class ImageDataSource | ||
{ | ||
[LoadColumn(0)] | ||
public string ImagePath; | ||
} | ||
|
||
public class ImageLabelPredictionProbability : ImageDataSource | ||
{ | ||
public string[] PredictedLabels; | ||
public float[] Probabilities { get; set; } | ||
} | ||
|
||
public class ImageLabelPrediction | ||
{ | ||
[ColumnName(TFModelImageLabelScorer.InceptionSettings.outputTensorName)] | ||
public float[] PredictedLabels; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFramework>netcoreapp3.1</TargetFramework> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="Microsoft.ML" Version="1.5.5" /> | ||
<PackageReference Include="Microsoft.ML.ImageAnalytics" Version="1.5.5" /> | ||
<PackageReference Include="Microsoft.ML.TensorFlow" Version="1.5.5" /> | ||
<PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.14.0" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<Content Include="assets_image_processing\*.*"> | ||
<CopyToOutputDirectory>Always</CopyToOutputDirectory> | ||
</Content> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\QueryProcessing\QueryProcessing.csproj" /> | ||
</ItemGroup> | ||
|
||
</Project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFramework>netcoreapp3.1</TargetFramework> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="Microsoft.ML" Version="1.5.5" /> | ||
<PackageReference Include="Microsoft.ML.ImageAnalytics" Version="1.5.5" /> | ||
<PackageReference Include="Microsoft.ML.TensorFlow" Version="1.5.5" /> | ||
<PackageReference Include="Microsoft.ML.TensorFlow.Redist" Version="0.14.0" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<Content Include="assets_image_processing\*.*"> | ||
<CopyToOutputDirectory>Always</CopyToOutputDirectory> | ||
</Content> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\QueryProcessing\QueryProcessing.csproj" /> | ||
</ItemGroup> | ||
|
||
</Project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
using Microsoft.ML; | ||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Linq; | ||
|
||
namespace ImageProcessing | ||
{ | ||
public class TFModelImageLabelScorer | ||
{ | ||
private readonly string modelLocation; | ||
private readonly string labelsLocation; | ||
private readonly MLContext mlContext; | ||
|
||
private static string ImageReal = nameof(ImageReal); | ||
|
||
private readonly PredictionEngine<ImageDataSource, ImageLabelPrediction> model; | ||
|
||
private static string GetAssetsPath() | ||
{ | ||
FileInfo dataRoot = new FileInfo(typeof(TFModelImageLabelScorer).Assembly.Location); | ||
string assemblyFolderPath = dataRoot.Directory.FullName; | ||
return Path.Combine(assemblyFolderPath, "assets_image_processing"); | ||
} | ||
|
||
public TFModelImageLabelScorer(string modelFileName, string labelsFileName) | ||
{ | ||
string assetsPath = GetAssetsPath(); | ||
this.modelLocation = Path.Combine(assetsPath, modelFileName); | ||
this.labelsLocation = Path.Combine(assetsPath, labelsFileName); | ||
this.mlContext = new MLContext(); | ||
this.model = LoadModel(modelLocation); | ||
|
||
} | ||
|
||
public struct ImageNetSettings | ||
{ | ||
public const int imageHeight = 224; | ||
public const int imageWidth = 224; | ||
public const float mean = 117; | ||
public const bool channelsLast = true; | ||
public const int returnTopNLabels = 10; | ||
public const float probabilityThreshold = 0.1f; | ||
} | ||
|
||
public struct InceptionSettings | ||
{ | ||
// input tensor name | ||
public const string inputTensorName = "input"; | ||
|
||
// output tensor name | ||
public const string outputTensorName = "softmax2"; | ||
} | ||
|
||
public ImageLabelPredictionProbability ScoreSingle(string path) | ||
{ | ||
return PredictDataUsingModelSinge(labelsLocation, path); | ||
} | ||
|
||
private IEnumerable<ImageDataSource> LoadSource() | ||
{ | ||
return Enumerable.Empty<ImageDataSource>(); | ||
} | ||
|
||
private PredictionEngine<ImageDataSource, ImageLabelPrediction> LoadModel(string modelLocation) | ||
{ | ||
// Don't train anything. | ||
// Keep the source empty and feed with when scorer is invoked. | ||
// TODO: This is ugly. Not sure if there is a nicer way to do this. | ||
var data = mlContext.Data.LoadFromEnumerable(LoadSource()); | ||
|
||
var pipeline = mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: ".", inputColumnName: nameof(ImageDataSource.ImagePath)) | ||
.Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: ImageNetSettings.imageWidth, imageHeight: ImageNetSettings.imageHeight, inputColumnName: "input")) | ||
.Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: ImageNetSettings.channelsLast, offsetImage: ImageNetSettings.mean)) | ||
.Append(mlContext.Model.LoadTensorFlowModel(modelLocation). | ||
ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2" }, | ||
inputColumnNames: new[] { "input" }, addBatchDimensionInput: true)); | ||
|
||
ITransformer model = pipeline.Fit(data); | ||
|
||
var predictionEngine = mlContext.Model.CreatePredictionEngine<ImageDataSource, ImageLabelPrediction>(model); | ||
|
||
return predictionEngine; | ||
} | ||
|
||
protected ImageLabelPredictionProbability PredictDataUsingModelSinge(string labelsLocation, string path) | ||
{ | ||
string[] labels = File.ReadAllLines(labelsLocation); | ||
|
||
ImageDataSource dataSource = new ImageDataSource() { ImagePath = path }; | ||
|
||
var probs = this.model.Predict(dataSource).PredictedLabels; | ||
var bestLabels = GetBestLabels(labels, probs, ImageNetSettings.returnTopNLabels); | ||
|
||
return new ImageLabelPredictionProbability() | ||
{ | ||
ImagePath = dataSource.ImagePath, | ||
PredictedLabels = bestLabels.Item1, | ||
Probabilities = bestLabels.Item2, | ||
}; | ||
} | ||
|
||
private static (string[], float[]) GetBestLabels(string[] labels, float[] probs, int topN) | ||
{ | ||
// TODO: This is naive slow implementation. | ||
List<string> bestLabels = new List<string>(); | ||
List<float> bestProbabilities = new List<float>(); | ||
var lblsList = labels.ToList(); | ||
var probsList = probs.ToList(); | ||
|
||
for (int i = 0; i < topN; i++) | ||
{ | ||
var max = probsList.Max(); | ||
var index = probsList.IndexOf(max); | ||
|
||
if (max >= ImageNetSettings.probabilityThreshold) | ||
{ | ||
bestLabels.Add(lblsList[index]); | ||
bestProbabilities.Add(max); | ||
} | ||
|
||
lblsList.RemoveAt(index); | ||
probsList.RemoveAt(index); | ||
} | ||
|
||
return (bestLabels.ToArray(), bestProbabilities.ToArray()); | ||
} | ||
} | ||
} |
Oops, something went wrong.