diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
index 921309d7ef..e100f9af9f 100644
--- a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
+++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
@@ -11,6 +11,16 @@ namespace Microsoft.ML.Runtime.ImageAnalytics.EntryPoints
{
public static class ImageAnalytics
{
+ // This method is needed for the Pipeline API, since ModuleCatalog does not load entry points that are located
+ // in assemblies that aren't directly used in the code. Users who want to use ImageAnalytics components will have to call
+ // ImageAnalytics.Initialize() before creating the pipeline.
+ ///
+ /// Initialize the Image Analytics environment. Call this method before adding Image components to a learning pipeline.
+ ///
+ public static void Initialize()
+ {
+ }
+
[TlcModule.EntryPoint(Name = "Transforms.ImageLoader", Desc = ImageLoaderTransform.Summary,
UserName = ImageLoaderTransform.UserName, ShortName = ImageLoaderTransform.LoaderSignature)]
public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, ImageLoaderTransform.Arguments input)
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
index ac11c7fa8d..8b83bc6cd4 100644
--- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
+++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
@@ -17,10 +17,10 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
-[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(ImageResizerTransform), typeof(ImageResizerTransform.Arguments),
+[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(IDataTransform), typeof(ImageResizerTransform), typeof(ImageResizerTransform.Arguments),
typeof(SignatureDataTransform), ImageResizerTransform.UserName, "ImageResizerTransform", "ImageResizer")]
-[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(ImageResizerTransform), null, typeof(SignatureLoadDataTransform),
+[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(IDataTransform), typeof(ImageResizerTransform), null, typeof(SignatureLoadDataTransform),
ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)]
[assembly: LoadableClass(typeof(ImageResizerTransform), null, typeof(SignatureLoadModel),
diff --git a/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj b/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj
index c04bdde2da..a0ee42b93c 100644
--- a/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj
+++ b/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj
@@ -10,6 +10,7 @@
+
diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
index 980df5e641..4a2558f0b7 100644
--- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
@@ -5,11 +5,23 @@
using System;
using System.Runtime.InteropServices;
using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
namespace Microsoft.ML.Transforms.TensorFlow
{
- internal partial class TensorFlowUtils
+ public static class TensorFlowUtils
{
+ // This method is needed for the Pipeline API, since ModuleCatalog does not load entry points that are located
+ // in assemblies that aren't directly used in the code. Users who want to use TensorFlow components will have to call
+ // TensorFlowUtils.Initialize() before creating the pipeline.
+ ///
+ /// Initialize the TensorFlow environment. Call this method before adding TensorFlow components to a learning pipeline.
+ ///
+ public static void Initialize()
+ {
+ ImageAnalytics.Initialize();
+ }
+
internal static PrimitiveType Tf2MlNetType(TFDataType type)
{
switch (type)
@@ -27,7 +39,7 @@ internal static PrimitiveType Tf2MlNetType(TFDataType type)
}
}
- public static unsafe void FetchData(IntPtr data, T[] result)
+ internal static unsafe void FetchData(IntPtr data, T[] result)
{
var size = result.Length;
diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
index affdbea72c..46433b1aef 100644
--- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
@@ -9,6 +9,7 @@
using Microsoft.ML.Runtime.LightGBM;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
+using Microsoft.ML.Transforms.TensorFlow;
using System.Collections.Generic;
using System.IO;
using Xunit;
@@ -57,6 +58,7 @@ public void TensorFlowTransformCifarLearningPipelineTest()
pipeline.Add(new TextToKeyConverter("Label"));
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
+ TensorFlowUtils.Initialize();
var model = pipeline.Train();
string[] scoreLabels;
model.TryGetScoreLabelNames(out scoreLabels);