Skip to content

Commit

Permalink
Added a new Session constructor to load session from SavedModel.
Browse files Browse the repository at this point in the history
  • Loading branch information
Canming Huang committed Jun 23, 2020
1 parent 72a36d9 commit 476fe2a
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
84 changes: 84 additions & 0 deletions Emgu.TF/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,74 @@ namespace Emgu.TF
/// </summary>
public class Session : UnmanagedObject
{
private bool _graphNeedDispose = false;

private Graph _graph;
private Buffer _metaGraphDef;

/// <summary>
/// The Graph of this session
/// </summary>
public Graph Graph
{
get { return _graph; }
}

/// <summary>
/// Create a Session from a SavedModel. If successful, populates the internal graph with the contents of the Graph and
/// <paramref name="metaGraphDef"/> with the MetaGraphDef of the loaded model.
/// </summary>
/// <param name="exportDir">Must be set to the path of the exported SavedModel.</param>
/// <param name="tags">Must include the set of tags used to identify one MetaGraphDef in the SavedModel.</param>
/// <param name="sessionOptions">Session options</param>
/// <param name="runOptions"></param>
/// <param name="status">The status</param>
public Session(
String exportDir,
String[] tags,
SessionOptions sessionOptions = null,
Buffer runOptions = null,
Status status = null)
{
_graph = new Graph();
_graphNeedDispose = true;
_metaGraphDef = new Buffer();

IntPtr exportDirPtr = Marshal.StringToHGlobalAuto(exportDir);

IntPtr[] tagsNative = new IntPtr[tags.Length];
for (int i = 0; i < tags.Length; i++)
tagsNative[i] = Marshal.StringToHGlobalAuto(tags[i]);
GCHandle tagsNativeHandle = GCHandle.Alloc(tagsNative, GCHandleType.Pinned);
try
{
using (StatusChecker checker = new StatusChecker(status))
_ptr = TfInvoke.tfeLoadSessionFromSavedModel(
sessionOptions,
runOptions,
exportDirPtr,
tagsNativeHandle.AddrOfPinnedObject(),
tagsNative.Length,
_graph,
_metaGraphDef,
checker.Status
);
}
catch (Exception)
{
throw;
}
finally
{
Marshal.FreeHGlobal(exportDirPtr);
tagsNativeHandle.Free();
for (int i = 0; i < tags.Length; i++)
{
Marshal.FreeHGlobal(tagsNative[i]);
}
}

}

/// <summary>
/// Return a new execution session with the associated graph.
Expand All @@ -29,6 +96,7 @@ public class Session : UnmanagedObject
public Session(Graph graph, SessionOptions sessionOptions = null, Status status = null)
{
_graph = graph;
_graphNeedDispose = false;

using (StatusChecker checker = new StatusChecker(status))
_ptr = TfInvoke.tfeNewSession(graph, sessionOptions, checker.Status);
Expand Down Expand Up @@ -56,6 +124,11 @@ protected override void DisposeObject()
TfInvoke.tfeDeleteSession(ref _ptr, checker.Status);
}

if (_graphNeedDispose && _graph != null)
{
_graph.Dispose();
}

_graph = null;
}

Expand Down Expand Up @@ -220,6 +293,17 @@ public static partial class TfInvoke
[DllImport(ExternLibrary, CallingConvention = TfInvoke.TFCallingConvention)]
internal static extern IntPtr tfeNewSession(IntPtr graph, IntPtr opts, IntPtr status);

[DllImport(ExternLibrary, CallingConvention = TfInvoke.TFCallingConvention)]
internal static extern IntPtr tfeLoadSessionFromSavedModel(
IntPtr sessionOptions,
IntPtr runOptions,
IntPtr exportDir,
IntPtr tags,
int tagsLen,
IntPtr graph,
IntPtr metaGraphDef,
IntPtr status);

[DllImport(ExternLibrary, CallingConvention = TfInvoke.TFCallingConvention)]
internal static extern void tfeDeleteSession(ref IntPtr session, IntPtr status);

Expand Down
2 changes: 1 addition & 1 deletion tensorflow

0 comments on commit 476fe2a

Please sign in to comment.