-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
TensorflowUtils.cs
408 lines (368 loc) · 18 KB
/
TensorflowUtils.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Security.AccessControl;
using System.Security.Principal;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
namespace Microsoft.ML.Transforms.TensorFlow
{
public static class TensorFlowUtils
{
public const string OpType = "OpType";
public const string InputOps = "InputOps";
internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, string opType = null)
{
var res = new List<KeyValuePair<string, ColumnType>>();
var opTypeGetters = new List<MetadataUtils.MetadataGetter<ReadOnlyMemory<char>>>();
var inputOpsGetters = new List<MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>>();
var inputOpsLengths = new List<int>();
foreach (var op in graph)
{
if (opType != null && opType != op.OpType)
continue;
var tfType = op[0].OutputType;
var mlType = Tf2MlNetTypeOrNull(tfType);
// If the type is not supported in ML.NET then we cannot represent it as a column in an ISchema.
// We also cannot output it with a TensorFlowTransform, so we skip it.
if (mlType == null)
continue;
var shape = graph.GetTensorShape(op[0]);
var shapeArray = shape.ToIntArray();
inputOpsLengths.Add(op.NumInputs);
MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> inputOpsGetter = null;
if (op.NumInputs > 0)
{
var inputOps = new ReadOnlyMemory<char>[op.NumInputs];
for (int i = 0; i < op.NumInputs; i++)
{
var input = op.GetInput(i);
inputOps[i] = new ReadOnlyMemory<char>(input.Operation.Name.ToArray());
}
inputOpsGetter = (int col, ref VBuffer<ReadOnlyMemory<char>> dst) =>
dst = new VBuffer<ReadOnlyMemory<char>>(op.NumInputs, inputOps);
}
inputOpsGetters.Add(inputOpsGetter);
MetadataUtils.MetadataGetter<ReadOnlyMemory<char>> opTypeGetter =
(int col, ref ReadOnlyMemory<char> dst) => dst = new ReadOnlyMemory<char>(op.OpType.ToArray());
opTypeGetters.Add(opTypeGetter);
var columnType = Utils.Size(shapeArray) == 1 && shapeArray[0] <= 0 ? new VectorType(mlType) :
Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ?
new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray())
: new VectorType(mlType);
res.Add(new KeyValuePair<string, ColumnType>(op.Name, columnType));
}
return Schema.Create(new TensorFlowSchema(ectx, res.ToArray(), opTypeGetters.ToArray(), inputOpsGetters.ToArray(), inputOpsLengths.ToArray()));
}
/// <summary>
/// This method retrieves the information about the graph nodes of a TensorFlow model as an <see cref="ISchema"/>.
/// For every node in the graph that has an output type that is compatible with the types supported by
/// <see cref="TensorFlowTransform"/>, the output schema contains a column with the name of that node, and the
/// type of its output (including the item type and the shape, if it is known). Every column also contains metadata
/// of kind <see cref="OpType"/>, indicating the operation type of the node, and if that node has inputs in the graph,
/// it contains metadata of kind <see cref="InputOps"/>, indicating the names of the input nodes.
/// </summary>
/// <param name="ectx">An <see cref="IExceptionContext"/>.</param>
/// <param name="modelFile">The name of the file containing the TensorFlow model. Currently only frozen model
/// format is supported.</param>
public static Schema GetModelSchema(IExceptionContext ectx, string modelFile)
{
var bytes = File.ReadAllBytes(modelFile);
var session = LoadTFSession(ectx, bytes, modelFile);
return GetModelSchema(ectx, session.Graph);
}
/// <summary>
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
/// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IExceptionContext, string)"/>,
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
/// </summary>
/// <param name="modelFile"></param>
/// <returns></returns>
public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile)
{
var schema = GetModelSchema(null, modelFile);
for (int i = 0; i < schema.Count; i++)
{
var name = schema[i].Name;
var type = schema[i].Type;
var metadataType = schema[i].Metadata.Schema.GetColumnOrNull(TensorFlowUtils.OpType)?.Type;
Contracts.Assert(metadataType != null && metadataType is TextType);
ReadOnlyMemory<char> opType = default;
schema[i].Metadata.GetValue(TensorFlowUtils.OpType, ref opType);
metadataType = schema[i].Metadata.Schema.GetColumnOrNull(TensorFlowUtils.InputOps)?.Type;
VBuffer<ReadOnlyMemory<char>> inputOps = default;
if (metadataType != null)
{
Contracts.Assert(metadataType.IsKnownSizeVector && metadataType.ItemType is TextType);
schema[i].Metadata.GetValue(TensorFlowUtils.InputOps, ref inputOps);
}
string[] inputOpsResult = inputOps.DenseValues()
.Select(input => input.ToString())
.ToArray();
yield return (name, opType.ToString(), type, inputOpsResult);
}
}
internal static PrimitiveType Tf2MlNetType(TFDataType type)
{
var mlNetType = Tf2MlNetTypeOrNull(type);
if (mlNetType == null)
throw new NotSupportedException("TensorFlow type not supported.");
return mlNetType;
}
private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type)
{
switch (type)
{
case TFDataType.Float:
return NumberType.R4;
case TFDataType.Float_ref:
return NumberType.R4;
case TFDataType.Double:
return NumberType.R8;
case TFDataType.UInt16:
return NumberType.U2;
case TFDataType.UInt8:
return NumberType.U1;
case TFDataType.UInt32:
return NumberType.U4;
case TFDataType.UInt64:
return NumberType.U8;
case TFDataType.Int16:
return NumberType.I2;
case TFDataType.Int32:
return NumberType.I4;
case TFDataType.Int64:
return NumberType.I8;
default:
return null;
}
}
internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelBytes, string modelFile = null)
{
var graph = new TFGraph();
try
{
graph.Import(modelBytes, "");
}
catch (Exception ex)
{
if (!string.IsNullOrEmpty(modelFile))
throw ectx.Except($"TensorFlow exception triggered while loading model from '{modelFile}'");
#pragma warning disable MSML_NoMessagesForLoadContext
throw ectx.ExceptDecode(ex, "Tensorflow exception triggered while loading model.");
#pragma warning restore MSML_NoMessagesForLoadContext
}
return new TFSession(graph);
}
private static TFSession LoadTFSession(IHostEnvironment env, string exportDirSavedModel)
{
Contracts.Check(env != null, nameof(env));
env.CheckValue(exportDirSavedModel, nameof(exportDirSavedModel));
var sessionOptions = new TFSessionOptions();
var tags = new string[] { "serve" };
var graph = new TFGraph();
var metaGraphDef = new TFBuffer();
return TFSession.FromSavedModel(sessionOptions, null, exportDirSavedModel, tags, graph, metaGraphDef);
}
// A TensorFlow frozen model is a single file. An un-frozen (SavedModel) on the other hand has a well-defined folder structure.
// Given a modelPath, this utility method determines if we should treat it as a SavedModel or not
internal static bool IsSavedModel(IHostEnvironment env, string modelPath)
{
Contracts.Check(env != null, nameof(env));
env.CheckNonWhiteSpace(modelPath, nameof(modelPath));
FileAttributes attr = File.GetAttributes(modelPath);
return attr.HasFlag(FileAttributes.Directory);
}
// Currently used in TensorFlowTransform to protect temporary folders used when working with TensorFlow's SavedModel format.
// Models are considered executable code, so we need to ACL tthe temp folders for high-rights process (so low-rights process can’t access it).
/// <summary>
/// Given a folder path, create it with proper ACL if it doesn't exist.
/// Fails if the folder name is empty, or can't create the folder.
/// </summary>
internal static void CreateFolderWithAclIfNotExists(IHostEnvironment env, string folder)
{
Contracts.Check(env != null, nameof(env));
env.CheckNonWhiteSpace(folder, nameof(folder));
//if directory exists, do nothing.
if (Directory.Exists(folder))
return;
WindowsIdentity currentIdentity = null;
try
{
currentIdentity = WindowsIdentity.GetCurrent();
}
catch (PlatformNotSupportedException)
{ }
if (currentIdentity != null && new WindowsPrincipal(currentIdentity).IsInRole(WindowsBuiltInRole.Administrator))
{
// Create high integrity dir and set no delete policy for all files under the directory.
// In case of failure, throw exception.
CreateTempDirectoryWithAcl(folder, currentIdentity.User.ToString());
}
else
{
try
{
Directory.CreateDirectory(folder);
}
catch (Exception exc)
{
throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}");
}
}
}
internal static void DeleteFolderWithRetries(IHostEnvironment env, string folder)
{
Contracts.Check(env != null, nameof(env));
int currentRetry = 0;
int maxRetryCount = 10;
using (var ch = env.Start("Delete folder"))
{
for (; ; )
{
try
{
currentRetry++;
Directory.Delete(folder, true);
break;
}
catch (IOException e)
{
if (currentRetry > maxRetryCount)
throw;
ch.Info("Error deleting folder. {0}. Retry,", e.Message);
}
}
}
}
private static void CreateTempDirectoryWithAcl(string folder, string identity)
{
// Dacl Sddl string:
// D: Dacl type
// D; Deny access
// OI; Object inherit ace
// SD; Standard delete function
// wIdentity.User Sid of the given user.
// A; Allow access
// OICI; Object inherit, container inherit
// FA File access
// BA Built-in administrators
// S: Sacl type
// ML;; Mandatory Label
// NW;;; No write policy
// HI High integrity processes only
string sddl = "D:(D;OI;SD;;;" + identity + ")(A;OICI;FA;;;BA)S:(ML;OI;NW;;;HI)";
try
{
var dir = Directory.CreateDirectory(folder);
DirectorySecurity dirSec = new DirectorySecurity();
dirSec.SetSecurityDescriptorSddlForm(sddl);
dirSec.SetAccessRuleProtection(true, false); // disable inheritance
dir.SetAccessControl(dirSec);
// Cleaning out the directory, in case someone managed to sneak in between creation and setting ACL.
DirectoryInfo dirInfo = new DirectoryInfo(folder);
foreach (FileInfo file in dirInfo.GetFiles())
{
file.Delete();
}
foreach (DirectoryInfo subDirInfo in dirInfo.GetDirectories())
{
subDirInfo.Delete(true);
}
}
catch (Exception exc)
{
throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}");
}
}
public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
{
var session = GetSession(env, modelPath);
return new TensorFlowModelInfo(env, session, modelPath);
}
internal static TFSession GetSession(IHostEnvironment env, string modelPath)
{
Contracts.Check(env != null, nameof(env));
if (IsSavedModel(env, modelPath))
{
env.CheckUserArg(Directory.Exists(modelPath), nameof(modelPath));
return LoadTFSession(env, modelPath);
}
env.CheckUserArg(File.Exists(modelPath), nameof(modelPath));
var bytes = File.ReadAllBytes(modelPath);
return LoadTFSession(env, bytes, modelPath);
}
internal static unsafe void FetchData<T>(IntPtr data, Span<T> result)
{
var dataSpan = new Span<T>(data.ToPointer(), result.Length);
dataSpan.CopyTo(result);
}
internal static bool IsTypeSupported(TFDataType tfoutput)
{
switch (tfoutput)
{
case TFDataType.Float:
case TFDataType.Double:
case TFDataType.UInt8:
case TFDataType.UInt16:
case TFDataType.UInt32:
case TFDataType.UInt64:
case TFDataType.Int16:
case TFDataType.Int32:
case TFDataType.Int64:
return true;
default:
return false;
}
}
private sealed class TensorFlowSchema : SimpleSchemaBase
{
private readonly MetadataUtils.MetadataGetter<ReadOnlyMemory<char>>[] _opTypeGetters;
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] _inputOpsGetters;
private readonly int[] _inputOpsLengths;
public TensorFlowSchema(IExceptionContext ectx, KeyValuePair<string, ColumnType>[] columns,
MetadataUtils.MetadataGetter<ReadOnlyMemory<char>>[] opTypeGetters,
MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] inputOpsGetters, int[] inputOpsLengths)
: base(ectx, columns)
{
ectx.CheckParam(Utils.Size(opTypeGetters) == ColumnCount, nameof(opTypeGetters));
ectx.CheckParam(Utils.Size(inputOpsGetters) == ColumnCount, nameof(inputOpsGetters));
ectx.CheckParam(Utils.Size(inputOpsLengths) == ColumnCount, nameof(inputOpsLengths));
_opTypeGetters = opTypeGetters;
_inputOpsGetters = inputOpsGetters;
_inputOpsLengths = inputOpsLengths;
}
protected override void GetMetadataCore<TValue>(string kind, int col, ref TValue value)
{
Ectx.Assert(0 <= col && col < ColumnCount);
if (kind == OpType)
_opTypeGetters[col].Marshal(col, ref value);
else if (kind == InputOps && _inputOpsGetters[col] != null)
_inputOpsGetters[col].Marshal(col, ref value);
else
throw Ectx.ExceptGetMetadata();
}
protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col)
{
Ectx.Assert(0 <= col && col < ColumnCount);
if (kind == OpType)
return TextType.Instance;
if (kind == InputOps && _inputOpsGetters[col] != null)
return new VectorType(TextType.Instance, _inputOpsLengths[col]);
return null;
}
protected override IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypesCore(int col)
{
Ectx.Assert(0 <= col && col < ColumnCount);
yield return new KeyValuePair<string, ColumnType>(OpType, TextType.Instance);
if (_inputOpsGetters[col] != null)
yield return new KeyValuePair<string, ColumnType>(InputOps, new VectorType(TextType.Instance, _inputOpsLengths[col]));
}
}
}
}