diff --git a/src/Microsoft.ML.TensorFlow/doc.xml b/src/Microsoft.ML.TensorFlow/doc.xml
index 29d7ca2844..63410cda02 100644
--- a/src/Microsoft.ML.TensorFlow/doc.xml
+++ b/src/Microsoft.ML.TensorFlow/doc.xml
@@ -10,6 +10,8 @@
The TensorflowTransform extracts the specified output from the operation computed on the graph (given the input(s)) using a pre-trained Tensorflow model.
The transform takes as input the Tensorflow model together with the names of the inputs to the model and name of the operation for which output values will be extracted from the model.
+ This transform requires the Microsoft.ML.TensorFlow nuget to be installed.
+
The TensorflowTransform has following assumptions regarding the input, output and processing of data.
-
@@ -23,6 +25,9 @@
Upon success, the transform will introduce a new column in based on the name of the output column specified.
+
+ The inputs and outputs of a TensorFlow model can be obtained using the summarize_graph tool.
+
@@ -71,4 +76,4 @@
-
\ No newline at end of file
+
diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
index d7d2e9a2de..affdbea72c 100644
--- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
@@ -7,6 +7,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.ImageAnalytics;
using Microsoft.ML.Runtime.LightGBM;
+using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using System.Collections.Generic;
using System.IO;
@@ -16,7 +17,7 @@ namespace Microsoft.ML.Scenarios
{
public partial class ScenariosTests
{
- [Fact(Skip = "Disabled due to this bug https://github.com/dotnet/machinelearning/issues/770")]
+ [Fact]
public void TensorFlowTransformCifarLearningPipelineTest()
{
var imageHeight = 32;
@@ -52,23 +53,35 @@ public void TensorFlowTransformCifarLearningPipelineTest()
OutputColumn = "Output"
});
- using (var environment = new TlcEnvironment())
+ pipeline.Add(new ColumnConcatenator(outputColumn: "Features", "Output"));
+ pipeline.Add(new TextToKeyConverter("Label"));
+ pipeline.Add(new StochasticDualCoordinateAscentClassifier());
+
+ var model = pipeline.Train();
+ string[] scoreLabels;
+ model.TryGetScoreLabelNames(out scoreLabels);
+
+ Assert.NotNull(scoreLabels);
+ Assert.Equal(3, scoreLabels.Length);
+ Assert.Equal("banana", scoreLabels[0]);
+ Assert.Equal("hotdog", scoreLabels[1]);
+ Assert.Equal("tomato", scoreLabels[2]);
+
+ CifarPrediction prediction = model.Predict(new CifarData()
{
- IDataView trans = pipeline.Execute(environment);
- Assert.NotNull(trans);
+ ImagePath = GetDataPath("images/banana.jpg")
+ });
+ Assert.Equal(1, prediction.PredictedLabels[0], 2);
+ Assert.Equal(0, prediction.PredictedLabels[1], 2);
+ Assert.Equal(0, prediction.PredictedLabels[2], 2);
- trans.Schema.TryGetColumnIndex("Output", out int output);
- using (var cursor = trans.GetRowCursor(col => col == output))
- {
- var buffer = default(VBuffer);
- var getter = cursor.GetGetter>(output);
- while (cursor.MoveNext())
- {
- getter(ref buffer);
- Assert.Equal(10, buffer.Length);
- }
- }
- }
+ prediction = model.Predict(new CifarData()
+ {
+ ImagePath = GetDataPath("images/hotdog.jpg")
+ });
+ Assert.Equal(0, prediction.PredictedLabels[0], 2);
+ Assert.Equal(1, prediction.PredictedLabels[1], 2);
+ Assert.Equal(0, prediction.PredictedLabels[2], 2);
}
}
@@ -78,6 +91,12 @@ public class CifarData
public string ImagePath;
[Column("1")]
- public string Name;
+ public string Label;
+ }
+
+ public class CifarPrediction
+ {
+ [ColumnName("Score")]
+ public float[] PredictedLabels;
}
}
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index c04b35669c..0bffb5c4d0 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -119,7 +119,7 @@ public void TensorFlowTransformMNISTConvTest()
var metrics = Evaluate(env, testDataScorer);
Assert.Equal(0.99, metrics.AccuracyMicro, 2);
- Assert.Equal(0.99, metrics.AccuracyMicro, 2);
+ Assert.Equal(1.0, metrics.AccuracyMacro, 2);
// Create prediction engine and test predictions
var model = env.CreatePredictionEngine(testDataScorer);
@@ -215,12 +215,63 @@ public void TensorFlowTransformCifar()
{
var buffer = default(VBuffer);
var getter = cursor.GetGetter>(output);
+ var numRows = 0;
while (cursor.MoveNext())
{
getter(ref buffer);
Assert.Equal(10, buffer.Length);
+ numRows += 1;
+ }
+ Assert.Equal(3, numRows);
+ }
+ }
+ }
+
+ [Fact]
+ public void TensorFlowTransformCifarInvalidShape()
+ {
+ var model_location = "cifar_model/frozen_model.pb";
+
+ using (var env = new TlcEnvironment())
+ {
+ var imageHeight = 28;
+ var imageWidth = 28;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
+ {
+ new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
+ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
+ }
+ }, images);
+
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
+ {
+ Column = new ImagePixelExtractorTransform.Column[1]{
+ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true}
}
+ }, cropped);
+
+ var thrown = false;
+ try
+ {
+ IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, "Output", "Input");
+ }
+ catch
+ {
+ thrown = true;
}
+ Assert.True(thrown);
}
}
}