diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index d43b933323..23d7eece24 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -66,7 +66,8 @@ namespace Microsoft.ML.Vision /// /// ### Training Algorithm Details /// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained model such as Resnet50 for the purpose - /// of classifying images. + /// of classifying images. The technique was inspired from [TensorFlow's retrain image classification tutorial] + /// (https://www.tensorflow.org/hub/tutorials/image_retraining) /// ]]> /// /// @@ -392,7 +393,7 @@ public sealed class Options : TrainerInputBaseWithLabel public Action MetricsCallback = null; /// - /// Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory + /// Indicates the path where the image bottleneck cache files and trained model are saved, default is a new temporary directory /// [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory.", SortOrder = 15)] public string WorkspacePath = null; @@ -591,6 +592,7 @@ private void InitializeTrainingGraph(IDataView input) _classCount = labelCount == 1 ? 2 : (int)labelCount; var imageSize = ImagePreprocessingSize[_options.Arch]; _session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch).Session; + _session.graph.as_default(); (_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3); _jpegDataTensorName = _jpegData.name; _resizedImageTensorName = _resizedImage.name; @@ -631,6 +633,14 @@ private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape private protected override ImageClassificationModelParameters TrainModelCore(TrainContext trainContext) { + // Workspace directory is cleaned after training run. However, the pipeline can be re-used by calling + // fit() again after transform(), in which case we must ensure workspace directory exists. This scenario + // is typical in the case of cross-validation. + if (!Directory.Exists(_options.WorkspacePath)) + { + Directory.CreateDirectory(_options.WorkspacePath); + } + InitializeTrainingGraph(trainContext.TrainingSet.Data); CheckTrainingParameters(_options); var validationSet = trainContext.ValidationSet?.Data ?? _options.ValidationSet; @@ -1301,7 +1311,7 @@ private void VariableSummaries(RefVariable var) var optimizer = useLearningRateScheduler ? tf.train.GradientDescentOptimizer(_learningRateInput) : tf.train.GradientDescentOptimizer(learningRate); - _trainStep = optimizer.minimize(crossEntropyMean); + _trainStep = optimizer.minimize(crossEntropyMean); }); return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor); @@ -1341,6 +1351,11 @@ private void Dispose(bool disposing) { _session.close(); } + + if (_session != null && _session.graph != IntPtr.Zero) + { + _session.graph.Dispose(); + } } /// @@ -1527,6 +1542,11 @@ private void Dispose(bool disposing) { _session.close(); } + + if (_session != null && _session.graph != IntPtr.Zero) + { + _session.graph.Dispose(); + } } } }