diff --git a/pkgs/sdk/server-ai/src/Interfaces/ILdAiConfigTracker.cs b/pkgs/sdk/server-ai/src/Interfaces/ILdAiConfigTracker.cs index aa652d14..937afe8e 100644 --- a/pkgs/sdk/server-ai/src/Interfaces/ILdAiConfigTracker.cs +++ b/pkgs/sdk/server-ai/src/Interfaces/ILdAiConfigTracker.cs @@ -45,6 +45,11 @@ public interface ILdAiConfigTracker /// public void TrackSuccess(); + /// + /// Tracks an unsuccessful generation event related to this config. + /// + public void TrackError(); + /// /// Tracks a request to a provider. The request is a task that returns a , which /// contains information about the request such as token usage and metrics. diff --git a/pkgs/sdk/server-ai/src/LdAiConfigTracker.cs b/pkgs/sdk/server-ai/src/LdAiConfigTracker.cs index c820fecf..bfe7eb0b 100644 --- a/pkgs/sdk/server-ai/src/LdAiConfigTracker.cs +++ b/pkgs/sdk/server-ai/src/LdAiConfigTracker.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Threading.Tasks; using LaunchDarkly.Sdk.Server.Ai.Config; using LaunchDarkly.Sdk.Server.Ai.Interfaces; @@ -21,6 +22,8 @@ public class LdAiConfigTracker : ILdAiConfigTracker private const string FeedbackPositive = "$ld:ai:feedback:user:positive"; private const string FeedbackNegative = "$ld:ai:feedback:user:negative"; private const string Generation = "$ld:ai:generation"; + private const string GenerationSuccess = "$ld:ai:generation:success"; + private const string GenerationError = "$ld:ai:generation:error"; private const string TokenTotal = "$ld:ai:tokens:total"; private const string TokenInput = "$ld:ai:tokens:input"; private const string TokenOutput = "$ld:ai:tokens:output"; @@ -57,18 +60,14 @@ public void TrackDuration(float durationMs) => /// public async Task TrackDurationOfTask(Task task) - { - var result = await MeasureDurationOfTaskMs(task); - TrackDuration(result.Item2); - return result.Item1; - } - - private static async Task> MeasureDurationOfTaskMs(Task task) { var sw = Stopwatch.StartNew(); - var result = await task; - sw.Stop(); - return Tuple.Create(result, sw.ElapsedMilliseconds); + try { + return await task; + } finally { + sw.Stop(); + TrackDuration(sw.ElapsedMilliseconds); + } } /// @@ -90,23 +89,44 @@ public void TrackFeedback(Feedback feedback) /// public void TrackSuccess() { + _client.Track(GenerationSuccess, _context, _trackData, 1); + _client.Track(Generation, _context, _trackData, 1); + } + + /// + public void TrackError() + { + _client.Track(GenerationError, _context, _trackData, 1); _client.Track(Generation, _context, _trackData, 1); } /// public async Task TrackRequest(Task request) { - var (result, durationMs) = await MeasureDurationOfTaskMs(request); - TrackSuccess(); + var sw = Stopwatch.StartNew(); + try + { + var result = await request; + TrackSuccess(); + + sw.Stop(); + TrackDuration(result.Metrics?.LatencyMs ?? sw.ElapsedMilliseconds); - TrackDuration(result.Metrics?.LatencyMs ?? durationMs); + if (result.Usage != null) + { + TrackTokens(result.Usage.Value); + } - if (result.Usage != null) + return result; + } + catch (Exception) { - TrackTokens(result.Usage.Value); + sw.Stop(); + TrackDuration(sw.ElapsedMilliseconds); + TrackError(); + throw; } - return result; } /// diff --git a/pkgs/sdk/server-ai/test/LdAiConfigTrackerTest.cs b/pkgs/sdk/server-ai/test/LdAiConfigTrackerTest.cs index b5120431..4145a332 100644 --- a/pkgs/sdk/server-ai/test/LdAiConfigTrackerTest.cs +++ b/pkgs/sdk/server-ai/test/LdAiConfigTrackerTest.cs @@ -68,6 +68,27 @@ public void CanTrackSuccess() var tracker = new LdAiConfigTracker(mockClient.Object, flagKey, config, context); tracker.TrackSuccess(); mockClient.Verify(x => x.Track("$ld:ai:generation", context, data, 1.0f), Times.Once); + mockClient.Verify(x => x.Track("$ld:ai:generation:success", context, data, 1.0f), Times.Once); + } + + + [Fact] + public void CanTrackError() + { + var mockClient = new Mock(); + var context = Context.New("key"); + const string flagKey = "key"; + var config = LdAiConfig.Disabled; + var data = LdValue.ObjectFrom(new Dictionary + { + { "variationKey", LdValue.Of(config.VariationKey) }, + { "configKey", LdValue.Of(flagKey) } + }); + + var tracker = new LdAiConfigTracker(mockClient.Object, flagKey, config, context); + tracker.TrackError(); + mockClient.Verify(x => x.Track("$ld:ai:generation", context, data, 1.0f), Times.Once); + mockClient.Verify(x => x.Track("$ld:ai:generation:error", context, data, 1.0f), Times.Once); } @@ -189,6 +210,8 @@ public void CanTrackResponseWithSpecificLatency() var result = tracker.TrackRequest(Task.Run(() => givenResponse)); Assert.Equal(givenResponse, result.Result); + mockClient.Verify(x => x.Track("$ld:ai:generation:success", context, data, 1.0f), Times.Once); + mockClient.Verify(x => x.Track("$ld:ai:generation", context, data, 1.0f), Times.Once); mockClient.Verify(x => x.Track("$ld:ai:tokens:total", context, data, 1.0f), Times.Once); mockClient.Verify(x => x.Track("$ld:ai:tokens:input", context, data, 2.0f), Times.Once); mockClient.Verify(x => x.Track("$ld:ai:tokens:output", context, data, 3.0f), Times.Once); @@ -228,5 +251,29 @@ public void CanTrackResponseWithPartialData() // if latency isn't provided via Statistics, then it is automatically measured. mockClient.Verify(x => x.Track("$ld:ai:duration:total", context, data, It.IsAny()), Times.Once); } + + [Fact] + public async Task CanTrackExceptionFromResponse() + { + var mockClient = new Mock(); + var context = Context.New("key"); + const string flagKey = "key"; + var config = LdAiConfig.Disabled; + var data = LdValue.ObjectFrom(new Dictionary + { + { "variationKey", LdValue.Of(config.VariationKey) }, + { "configKey", LdValue.Of(flagKey) } + }); + + var tracker = new LdAiConfigTracker(mockClient.Object, flagKey, config, context); + + await Assert.ThrowsAsync(() => tracker.TrackRequest(Task.FromException(new System.Exception("I am an exception")))); + + mockClient.Verify(x => x.Track("$ld:ai:generation", context, data, 1.0f), Times.Once); + mockClient.Verify(x => x.Track("$ld:ai:generation:error", context, data, 1.0f), Times.Once); + + // if latency isn't provided via Statistics, then it is automatically measured. + mockClient.Verify(x => x.Track("$ld:ai:duration:total", context, data, It.IsAny()), Times.Once); + } } }