diff --git a/DevProxy.Plugins/Inspection/OpenAITelemetryPlugin.cs b/DevProxy.Plugins/Inspection/OpenAITelemetryPlugin.cs index 126db908..7ff4a0b1 100644 --- a/DevProxy.Plugins/Inspection/OpenAITelemetryPlugin.cs +++ b/DevProxy.Plugins/Inspection/OpenAITelemetryPlugin.cs @@ -11,12 +11,12 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using OpenTelemetry; using OpenTelemetry.Exporter; using OpenTelemetry.Metrics; using OpenTelemetry.Resources; using OpenTelemetry.Trace; -using System.Collections.Concurrent; using System.Diagnostics; using System.Diagnostics.Metrics; using System.Text.Json; @@ -67,7 +67,6 @@ public sealed class OpenAITelemetryPlugin( private LanguageModelPricesLoader? _loader; private MeterProvider? _meterProvider; private TracerProvider? _tracerProvider; - private readonly ConcurrentDictionary> _modelUsage = []; public override string Name => nameof(OpenAITelemetryPlugin); @@ -196,17 +195,18 @@ public override Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); + ArgumentNullException.ThrowIfNull(e); + var report = new OpenAITelemetryPluginReport { Application = Configuration.Application, Environment = Configuration.Environment, Currency = Configuration.Currency, IncludeCosts = Configuration.IncludeCosts, - ModelUsage = _modelUsage.ToDictionary() + ModelUsage = GetOpenAIModelUsage(e.RequestLogs) }; StoreReport(report, e); - _modelUsage.Clear(); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); return Task.CompletedTask; @@ -849,16 +849,6 @@ private void RecordUsageMetrics(Activity activity, OpenAIRequest request, OpenAI .SetTag(SemanticConvention.GEN_AI_USAGE_OUTPUT_TOKENS, usage.CompletionTokens) .SetTag(SemanticConvention.GEN_AI_USAGE_TOTAL_TOKENS, usage.TotalTokens); - var reportModelUsageInformation = new OpenAITelemetryPluginReportModelUsageInformation - { - Model = response.Model, - PromptTokens = usage.PromptTokens, - CompletionTokens = usage.CompletionTokens, - CachedTokens = usage.PromptTokensDetails?.CachedTokens ?? 0L - }; - var usagePerModel = _modelUsage.GetOrAdd(response.Model, model => []); - usagePerModel.Add(reportModelUsageInformation); - if (!Configuration.IncludeCosts || Configuration.Prices is null) { Logger.LogDebug("Cost tracking is disabled or prices data is not available"); @@ -895,7 +885,6 @@ private void RecordUsageMetrics(Activity activity, OpenAIRequest request, OpenAI new(SemanticConvention.GEN_AI_REQUEST_MODEL, request.Model), new(SemanticConvention.GEN_AI_RESPONSE_MODEL, response.Model) ]); - reportModelUsageInformation.Cost = totalCost; } else { @@ -905,6 +894,100 @@ private void RecordUsageMetrics(Activity activity, OpenAIRequest request, OpenAI Logger.LogTrace("RecordUsageMetrics() finished"); } + private Dictionary> GetOpenAIModelUsage(IEnumerable requestLogs) + { + var modelUsage = new Dictionary>(); + var openAIRequestLogs = requestLogs.Where(r => + r is not null && + r.Context is not null && + r.Context.Session is not null && + r.MessageType == MessageType.InterceptedResponse && + string.Equals("POST", r.Context.Session.HttpClient.Request.Method, StringComparison.OrdinalIgnoreCase) && + r.Context.Session.HttpClient.Response.StatusCode >= 200 && + r.Context.Session.HttpClient.Response.StatusCode < 300 && + r.Context.Session.HttpClient.Response.HasBody && + !string.IsNullOrEmpty(r.Context.Session.HttpClient.Response.BodyString) && + ProxyUtils.MatchesUrlToWatch(UrlsToWatch, r.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri) && + OpenAIRequest.TryGetOpenAIRequest(r.Context.Session.HttpClient.Request.BodyString, NullLogger.Instance, out var openAiRequest) && + openAiRequest is not null + ); + + foreach (var requestLog in openAIRequestLogs) + { + try + { + var response = JsonSerializer.Deserialize(requestLog.Context!.Session.HttpClient.Response.BodyString, ProxyUtils.JsonSerializerOptions); + if (response is null) + { + continue; + } + + var reportModelUsageInfo = GetReportModelUsageInfo(response); + if (modelUsage.TryGetValue(response.Model, out var usagePerModel)) + { + usagePerModel.AddRange(reportModelUsageInfo); + } + else + { + modelUsage.Add(response.Model, reportModelUsageInfo); + } + } + catch (JsonException ex) + { + Logger.LogError(ex, "Failed to deserialize OpenAI response"); + } + } + + return modelUsage; + } + + private List GetReportModelUsageInfo(OpenAIResponse response) + { + Logger.LogTrace("GetReportModelUsageInfo() called"); + var usagePerModel = new List(); + var usage = response.Usage; + if (usage is null) + { + return usagePerModel; + } + + var reportModelUsageInformation = new OpenAITelemetryPluginReportModelUsageInformation + { + Model = response.Model, + PromptTokens = usage.PromptTokens, + CompletionTokens = usage.CompletionTokens, + CachedTokens = usage.PromptTokensDetails?.CachedTokens ?? 0L + }; + usagePerModel.Add(reportModelUsageInformation); + + if (!Configuration.IncludeCosts || Configuration.Prices is null) + { + Logger.LogDebug("Cost tracking is disabled or prices data is not available"); + return usagePerModel; + } + + if (string.IsNullOrEmpty(response.Model)) + { + Logger.LogDebug("Response model is empty or null"); + return usagePerModel; + } + + var (inputCost, outputCost) = Configuration.Prices.CalculateCost(response.Model, usage.PromptTokens, usage.CompletionTokens); + + if (inputCost > 0) + { + var totalCost = inputCost + outputCost; + reportModelUsageInformation.Cost = totalCost; + } + else + { + Logger.LogDebug("Input cost is zero, skipping cost metrics recording"); + } + + Logger.LogTrace("GetReportModelUsageInfo() finished"); + return usagePerModel; + } + private static string GetOperationName(OpenAIRequest request) { if (request == null) diff --git a/DevProxy/Proxy/ProxyStateController.cs b/DevProxy/Proxy/ProxyStateController.cs index 1f418abe..93e6d6fb 100644 --- a/DevProxy/Proxy/ProxyStateController.cs +++ b/DevProxy/Proxy/ProxyStateController.cs @@ -69,7 +69,7 @@ public async Task StopRecordingAsync(CancellationToken cancellationToken) public async Task MockRequestAsync(CancellationToken cancellationToken) { - var eventArgs = new EventArgs(); + var eventArgs = EventArgs.Empty; foreach (var plugin in _plugins.Where(p => p.Enabled)) {