From 403d04bc2a02067dc29c5b936195ad5775ee448a Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 10 Jan 2019 11:47:34 -0800 Subject: [PATCH 1/3] Support unfrozen models for GetModelSchema. --- .../TensorFlow/TensorflowUtils.cs | 19 ++++++++++++------- .../TensorflowTests.cs | 3 ++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index 1617b17e20..2b7f7bcf46 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -75,19 +75,18 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str /// of kind , indicating the operation type of the node, and if that node has inputs in the graph, /// it contains metadata of kind , indicating the names of the input nodes. /// - /// An . + /// The environment to use. /// The name of the file containing the TensorFlow model. Currently only frozen model /// format is supported. - public static Schema GetModelSchema(IExceptionContext ectx, string modelFile) + public static Schema GetModelSchema(IHostEnvironment env, string modelFile) { - var bytes = File.ReadAllBytes(modelFile); - var session = LoadTFSession(ectx, bytes, modelFile); - return GetModelSchema(ectx, session.Graph); + var model = LoadTensorFlowModel(env, modelFile); + return GetModelSchema(env, model.Session.Graph); } /// /// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It - /// iterates over the columns of the returned by , + /// iterates over the columns of the returned by , /// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names. /// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type. /// @@ -95,7 +94,7 @@ public static Schema GetModelSchema(IExceptionContext ectx, string modelFile) /// public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile) { - var schema = GetModelSchema(null, modelFile); + var schema = GetModelSchema(new MLContext(), modelFile); for (int i = 0; i < schema.Count; i++) { @@ -310,6 +309,12 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity) } } + /// + /// Load tensor flow model into memory. + /// + /// The environment to use. + /// The model to load. + /// public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath) { var session = GetSession(env, modelPath); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 8bfae482c2..a883dc6868 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -606,8 +606,9 @@ public void TensorFlowTransformCifar() public void TensorFlowTransformCifarSavedModel() { var modelLocation = "cifar_saved_model"; - var mlContext = new MLContext(seed: 1, conc: 1); + var loadModelSchema = TensorFlowUtils.GetModelSchema(mlContext, modelLocation); + Assert.Equal(86, loadModelSchema.Count); var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(mlContext, modelLocation); var schema = tensorFlowModel.GetInputSchema(); Assert.True(schema.TryGetColumnIndex("Input", out int column)); From 1ebd60717aa2db122c95ad162c0edf1a5af1e18a Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 10 Jan 2019 11:47:51 -0800 Subject: [PATCH 2/3] correct number in test --- .../ScenariosWithDirectInstantiation/TensorflowTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index a883dc6868..4699131680 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -608,7 +608,7 @@ public void TensorFlowTransformCifarSavedModel() var modelLocation = "cifar_saved_model"; var mlContext = new MLContext(seed: 1, conc: 1); var loadModelSchema = TensorFlowUtils.GetModelSchema(mlContext, modelLocation); - Assert.Equal(86, loadModelSchema.Count); + Assert.Equal(335, loadModelSchema.Count); var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(mlContext, modelLocation); var schema = tensorFlowModel.GetInputSchema(); Assert.True(schema.TryGetColumnIndex("Input", out int column)); From 953cdff18e97122345160db539c183f0e5a4be14 Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 10 Jan 2019 13:55:19 -0800 Subject: [PATCH 3/3] address comments --- .../TensorFlow/TensorflowUtils.cs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index 2b7f7bcf46..f0bfb7e50a 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -76,11 +76,10 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str /// it contains metadata of kind , indicating the names of the input nodes. /// /// The environment to use. - /// The name of the file containing the TensorFlow model. Currently only frozen model - /// format is supported. - public static Schema GetModelSchema(IHostEnvironment env, string modelFile) + /// Model to load. + public static Schema GetModelSchema(IHostEnvironment env, string modelPath) { - var model = LoadTensorFlowModel(env, modelFile); + var model = LoadTensorFlowModel(env, modelPath); return GetModelSchema(env, model.Session.Graph); } @@ -90,11 +89,11 @@ public static Schema GetModelSchema(IHostEnvironment env, string modelFile) /// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names. /// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type. /// - /// + /// Model to load. /// - public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile) + public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelPath) { - var schema = GetModelSchema(new MLContext(), modelFile); + var schema = GetModelSchema(new MLContext(), modelPath); for (int i = 0; i < schema.Count; i++) { @@ -310,7 +309,7 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity) } /// - /// Load tensor flow model into memory. + /// Load TensorFlow model into memory. /// /// The environment to use. /// The model to load.