Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions DevProxy.Abstractions/LanguageModel/OpenAIModels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using DevProxy.Abstractions.Utils;
using Microsoft.Extensions.Logging;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace DevProxy.Abstractions.LanguageModel;
Expand All @@ -20,6 +23,95 @@ public class OpenAIRequest
public double? Temperature { get; set; }
[JsonPropertyName("top_p")]
public double? TopP { get; set; }

public static bool TryGetOpenAIRequest(string content, ILogger logger, out OpenAIRequest? request)
{
logger.LogTrace("{Method} called", nameof(TryGetOpenAIRequest));

request = null;

if (string.IsNullOrEmpty(content))
{
logger.LogDebug("Request content is empty or null");
return false;
}

try
{
logger.LogDebug("Checking if the request is an OpenAI request...");

var rawRequest = JsonSerializer.Deserialize<JsonElement>(content, ProxyUtils.JsonSerializerOptions);

// Check for completion request (has "prompt", but not specific to image)
if (rawRequest.TryGetProperty("prompt", out _) &&
!rawRequest.TryGetProperty("size", out _) &&
!rawRequest.TryGetProperty("n", out _))
{
logger.LogDebug("Request is a completion request");
request = JsonSerializer.Deserialize<OpenAICompletionRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Chat completion request
if (rawRequest.TryGetProperty("messages", out _))
{
logger.LogDebug("Request is a chat completion request");
request = JsonSerializer.Deserialize<OpenAIChatCompletionRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Embedding request
if (rawRequest.TryGetProperty("input", out _) &&
rawRequest.TryGetProperty("model", out _) &&
!rawRequest.TryGetProperty("voice", out _))
{
logger.LogDebug("Request is an embedding request");
request = JsonSerializer.Deserialize<OpenAIEmbeddingRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Image generation request
if (rawRequest.TryGetProperty("prompt", out _) &&
(rawRequest.TryGetProperty("size", out _) || rawRequest.TryGetProperty("n", out _)))
{
logger.LogDebug("Request is an image generation request");
request = JsonSerializer.Deserialize<OpenAIImageRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Audio transcription request
if (rawRequest.TryGetProperty("file", out _))
{
logger.LogDebug("Request is an audio transcription request");
request = JsonSerializer.Deserialize<OpenAIAudioRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Audio speech synthesis request
if (rawRequest.TryGetProperty("input", out _) && rawRequest.TryGetProperty("voice", out _))
{
logger.LogDebug("Request is an audio speech synthesis request");
request = JsonSerializer.Deserialize<OpenAIAudioSpeechRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Fine-tuning request
if (rawRequest.TryGetProperty("training_file", out _))
{
logger.LogDebug("Request is a fine-tuning request");
request = JsonSerializer.Deserialize<OpenAIFineTuneRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

logger.LogDebug("Request is not an OpenAI request.");
return false;
}
catch (JsonException ex)
{
logger.LogDebug(ex, "Failed to deserialize OpenAI request.");
return false;
}
}
}

public class OpenAIResponse : ILanguageModelCompletionResponse
Expand Down Expand Up @@ -82,10 +174,18 @@ public class OpenAIResponseUsage
public long CompletionTokens { get; set; }
[JsonPropertyName("prompt_tokens")]
public long PromptTokens { get; set; }
[JsonPropertyName("prompt_tokens_details")]
public PromptTokenDetails? PromptTokensDetails { get; set; }
[JsonPropertyName("total_tokens")]
public long TotalTokens { get; set; }
}

public class PromptTokenDetails
{
[JsonPropertyName("cached_tokens")]
public long CachedTokens { get; set; }
}

public class OpenAIResponsePromptFilterResult
{
[JsonPropertyName("content_filter_results")]
Expand Down
154 changes: 4 additions & 150 deletions DevProxy.Plugins/Inspection/OpenAITelemetryPlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using DevProxy.Abstractions.Plugins;
using DevProxy.Abstractions.Proxy;
using DevProxy.Abstractions.Utils;
using DevProxy.Plugins.Utils;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
Expand All @@ -19,7 +20,6 @@
using System.Diagnostics;
using System.Diagnostics.Metrics;
using System.Text.Json;
using Titanium.Web.Proxy.Http;

namespace DevProxy.Plugins.Inspection;

Expand Down Expand Up @@ -108,7 +108,7 @@ public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken ca
return Task.CompletedTask;
}

if (!TryGetOpenAIRequest(request.BodyString, out var openAiRequest) || openAiRequest is null)
if (!OpenAIRequest.TryGetOpenAIRequest(request.BodyString, Logger, out var openAiRequest) || openAiRequest is null)
{
Logger.LogRequest("Skipping non-OpenAI request", MessageType.Skipped, new(e.Session));
return Task.CompletedTask;
Expand Down Expand Up @@ -323,9 +323,9 @@ private void ProcessSuccessResponse(Activity activity, ProxyResponseArgs e)
}

var bodyString = response.BodyString;
if (IsStreamingResponse(response))
if (HttpUtils.IsStreamingResponse(response, Logger))
{
bodyString = GetBodyFromStreamingResponse(response);
bodyString = HttpUtils.GetBodyFromStreamingResponse(response, Logger);
}

AddResponseTypeSpecificTags(activity, openAiRequest, bodyString);
Expand Down Expand Up @@ -895,95 +895,6 @@ private void RecordUsageMetrics(Activity activity, OpenAIRequest request, OpenAI
Logger.LogTrace("RecordUsageMetrics() finished");
}

private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request)
{
Logger.LogTrace("TryGetOpenAIRequest() called");

request = null;

if (string.IsNullOrEmpty(content))
{
Logger.LogDebug("Request content is empty or null");
return false;
}

try
{
Logger.LogDebug("Checking if the request is an OpenAI request...");

var rawRequest = JsonSerializer.Deserialize<JsonElement>(content, ProxyUtils.JsonSerializerOptions);

// Check for completion request (has "prompt", but not specific to image)
if (rawRequest.TryGetProperty("prompt", out _) &&
!rawRequest.TryGetProperty("size", out _) &&
!rawRequest.TryGetProperty("n", out _))
{
Logger.LogDebug("Request is a completion request");
request = JsonSerializer.Deserialize<OpenAICompletionRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Chat completion request
if (rawRequest.TryGetProperty("messages", out _))
{
Logger.LogDebug("Request is a chat completion request");
request = JsonSerializer.Deserialize<OpenAIChatCompletionRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Embedding request
if (rawRequest.TryGetProperty("input", out _) &&
rawRequest.TryGetProperty("model", out _) &&
!rawRequest.TryGetProperty("voice", out _))
{
Logger.LogDebug("Request is an embedding request");
request = JsonSerializer.Deserialize<OpenAIEmbeddingRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Image generation request
if (rawRequest.TryGetProperty("prompt", out _) &&
(rawRequest.TryGetProperty("size", out _) || rawRequest.TryGetProperty("n", out _)))
{
Logger.LogDebug("Request is an image generation request");
request = JsonSerializer.Deserialize<OpenAIImageRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Audio transcription request
if (rawRequest.TryGetProperty("file", out _))
{
Logger.LogDebug("Request is an audio transcription request");
request = JsonSerializer.Deserialize<OpenAIAudioRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Audio speech synthesis request
if (rawRequest.TryGetProperty("input", out _) && rawRequest.TryGetProperty("voice", out _))
{
Logger.LogDebug("Request is an audio speech synthesis request");
request = JsonSerializer.Deserialize<OpenAIAudioSpeechRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

// Fine-tuning request
if (rawRequest.TryGetProperty("training_file", out _))
{
Logger.LogDebug("Request is a fine-tuning request");
request = JsonSerializer.Deserialize<OpenAIFineTuneRequest>(content, ProxyUtils.JsonSerializerOptions);
return true;
}

Logger.LogDebug("Request is not an OpenAI request.");
return false;
}
catch (JsonException ex)
{
Logger.LogDebug(ex, "Failed to deserialize OpenAI request.");
return false;
}
}

private static string GetOperationName(OpenAIRequest request)
{
if (request == null)
Expand All @@ -1004,63 +915,6 @@ private static string GetOperationName(OpenAIRequest request)
};
}

private bool IsStreamingResponse(Response response)
{
Logger.LogTrace("{Method} called", nameof(IsStreamingResponse));
var contentType = response.Headers.FirstOrDefault(h => h.Name.Equals("content-type", StringComparison.OrdinalIgnoreCase))?.Value;
if (string.IsNullOrEmpty(contentType))
{
Logger.LogDebug("No content-type header found");
return false;
}

var isStreamingResponse = contentType.Contains("text/event-stream", StringComparison.OrdinalIgnoreCase);
Logger.LogDebug("IsStreamingResponse: {IsStreamingResponse}", isStreamingResponse);

Logger.LogTrace("{Method} finished", nameof(IsStreamingResponse));
return isStreamingResponse;
}

private string GetBodyFromStreamingResponse(Response response)
{
Logger.LogTrace("{Method} called", nameof(GetBodyFromStreamingResponse));

// default to the whole body
var bodyString = response.BodyString;

var chunks = bodyString.Split("\n\n", StringSplitOptions.RemoveEmptyEntries);
if (chunks.Length == 0)
{
Logger.LogDebug("No chunks found in the response body");
return bodyString;
}

// check if the last chunk is `data: [DONE]`
var lastChunk = chunks.Last().Trim();
if (lastChunk.Equals("data: [DONE]", StringComparison.OrdinalIgnoreCase))
{
// get next to last chunk
var chunk = chunks.Length > 1 ? chunks[^2].Trim() : string.Empty;
if (chunk.StartsWith("data: ", StringComparison.OrdinalIgnoreCase))
{
// remove the "data: " prefix
bodyString = chunk["data: ".Length..].Trim();
Logger.LogDebug("Last chunk starts with 'data: ', using the last chunk as the body: {BodyString}", bodyString);
}
else
{
Logger.LogDebug("Last chunk does not start with 'data: ', using the whole body");
}
}
else
{
Logger.LogDebug("Last chunk is not `data: [DONE]`, using the whole body");
}

Logger.LogTrace("{Method} finished", nameof(GetBodyFromStreamingResponse));
return bodyString;
}

public void Dispose()
{
_loader?.Dispose();
Expand Down
Loading