Skip to content

Commit

Permalink
Add support for a collection of OrtValue as inputs and outputs to C# …
Browse files Browse the repository at this point in the history
…TrainingSession (#19048)
  • Loading branch information
baijumeswani committed Jan 26, 2024
1 parent 358650d commit fc44f96
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 0 deletions.
107 changes: 107 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,48 @@ public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> TrainStep(
}
}

/// <summary>
/// This function performs a training step that computes the outputs of the training model and the gradients
/// of the trainable parameters for the given OrtValue inputs. The train step is performed based on the training model
/// that was provided to the training session.
/// The TrainStep method is equivalent of running forward propagation and backward propagation in a single
/// step.
/// The gradients computed are stored inside the training session state so they can be later consumed
/// by the OptimizerStep function.
/// The gradients can be lazily reset by invoking the LazyResetGrad function.
/// Example usage:
/// <code>
/// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...);
/// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...);
/// List<OrtValue> inputValues = new List<OrtValue> { x, label };
/// using (var loss = trainingSession.TrainStep(inputValues))
/// {
/// // process output values
/// }
/// </code>
/// </summary>
/// <param name="inputValues">Specify a collection of <see cref="OrtValue"/> that indicates the input values to the training model.</param>
/// <returns>Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.</returns>
public IDisposableReadOnlyCollection<OrtValue> TrainStep(IReadOnlyCollection<OrtValue> inputValues)
{
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues);
IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount];

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray));


var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray);
try
{
return CreateDisposableResult(disposableHandles);
}
finally
{
disposableHandles.Dispose();
}
}

/// <summary>
/// Convert native OrtValue handles to OrtValue instances
/// in an exceptions safe manner.
Expand Down Expand Up @@ -370,6 +412,42 @@ public void EvalStep(
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
}

/// <summary>
/// This function performs an eval step that computes the outputs of the eval model for the given inputs.
/// Inputs are expected to be of type OrtValue. The eval step is performed based on the eval model that was
/// provided to the training session.
/// Example usage:
/// <code>
/// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...);
/// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...);
/// List<OrtValue> inputValues = new List<OrtValue> { x, label };
/// using (var loss = trainingSession.EvalSteps(inputValues))
/// {
/// // process output values
/// }
/// </code>
/// </summary>
/// <param name="inputValues">Specify a collection of <see cref="OrtValue"/> that indicates the input values to the eval model.</param>
public IDisposableReadOnlyCollection<OrtValue> EvalStep(IReadOnlyCollection<OrtValue> inputValues)
{
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues);
IntPtr[] outputValuesArray = new IntPtr[(int)_evalOutputCount];

NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)_evalOutputCount, outputValuesArray));


var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray);
try
{
return CreateDisposableResult(disposableHandles);
}
finally
{
disposableHandles.Dispose();
}
}


/// <summary>
/// Sets the learning rate for this training session.
Expand Down Expand Up @@ -702,6 +780,35 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection<FixedBufferOnnxValue> v
return valuesArray;
}

private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection<OrtValue> inputValues)
{
var valuesArray = new IntPtr[inputValues.Count];
for (int index = 0; index < inputValues.Count; ++index)
{
valuesArray[index] = inputValues.ElementAt(index).Handle;
}
return valuesArray;
}

private static IDisposableReadOnlyCollection<OrtValue> CreateDisposableResult(DisposableOrtValueHandleArray disposableHandles)
{
var outputValues = new DisposableList<OrtValue>(disposableHandles.Span.Length);
try
{
for (int i = 0; i < disposableHandles.Span.Length; i++)
{
outputValues.Add(new OrtValue(disposableHandles.Span[i]));
disposableHandles.Span[i] = IntPtr.Zero;
}
return outputValues;
}
catch (Exception)
{
outputValues.Dispose();
throw;
}
}

private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection<string> names, DisposableList<IDisposable> cleanupList)
{
cleanupList.Capacity += names.Count;
Expand Down
75 changes: 75 additions & 0 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,81 @@ public void TestUpdateParameter()
}
}

[Fact(DisplayName = "TestTrainingSessionTrainStepWithOrtValues")]
public void TestTrainingSessionTrainStepWithOrtValues()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");

var trainingSession = new TrainingSession(state, trainingPath, optimizerPath);
cleanUp.Add(trainingSession);

float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out");
var expectedOutputDimensions = new int[] { 1 };
float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in");
long[] inputShape = { 2, 784 };
Int32[] labelsData = { 1, 1 };
long[] labelsShape = { 2 };

using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory<float>(inputData, inputShape);
using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory<Int32>(labelsData, labelsShape);
var inputValues = new List<OrtValue> { inputOrtValue, labelsOrtValue };

using (var results = trainingSession.TrainStep(inputValues))
{
Assert.Single(results);
var outputOrtValue = results[0];
Assert.True(outputOrtValue.IsTensor);
var resultSpan = outputOrtValue.GetTensorDataAsSpan<float>().ToArray();
Assert.Equal(expectedOutput, resultSpan, new FloatComparer());
}
}
}

[Fact(DisplayName = "TestTrainingSessionEvalStepWithOrtValues")]
public void TestTrainingSessionEvalStepWithOrtValues()
{
string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt");
using (var cleanUp = new DisposableListTest<IDisposable>())
{
var state = CheckpointState.LoadCheckpoint(checkpointPath);
cleanUp.Add(state);
Assert.NotNull(state);
string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx");
string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx");
string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx");

var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath);
cleanUp.Add(trainingSession);

float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out");
var expectedOutputDimensions = new int[] { 1 };
float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in");
long[] inputShape = { 2, 784 };
Int32[] labelsData = { 1, 1 };
long[] labelsShape = { 2 };

using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory<float>(inputData, inputShape);
using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory<Int32>(labelsData, labelsShape);
var inputValues = new List<OrtValue> { inputOrtValue, labelsOrtValue };

using (var results = trainingSession.EvalStep(inputValues))
{
Assert.Single(results);
var outputOrtValue = results[0];
Assert.True(outputOrtValue.IsTensor);
var resultSpan = outputOrtValue.GetTensorDataAsSpan<float>().ToArray();
Assert.Equal(expectedOutput, resultSpan, new FloatComparer());
}
}
}

internal class FloatComparer : IEqualityComparer<float>
{
private float atol = 1e-3f;
Expand Down

0 comments on commit fc44f96

Please sign in to comment.