Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎉 Add Azure Content Moderator service for image analysis #143

Merged
merged 15 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 63 additions & 8 deletions webapi/Controllers/DocumentImportController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.Text;
using UglyToad.PdfPig;
using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor;
Expand Down Expand Up @@ -79,6 +80,7 @@ private enum SupportedFileType
private const string GlobalDocumentUploadedClientCall = "GlobalDocumentUploaded";
private const string ReceiveMessageClientCall = "ReceiveMessage";
private readonly IOcrEngine _ocrEngine;
private readonly AzureContentSafety? _contentSafetyService = null;

/// <summary>
/// Initializes a new instance of the <see cref="DocumentImportController"/> class.
Expand All @@ -91,7 +93,8 @@ private enum SupportedFileType
ChatMemorySourceRepository sourceRepository,
ChatMessageRepository messageRepository,
ChatParticipantRepository participantRepository,
IOcrEngine ocrEngine)
IOcrEngine ocrEngine,
AzureContentSafety? contentSafety = null)
{
this._logger = logger;
this._options = documentMemoryOptions.Value;
Expand All @@ -101,6 +104,19 @@ private enum SupportedFileType
this._messageRepository = messageRepository;
this._participantRepository = participantRepository;
this._ocrEngine = ocrEngine;
this._contentSafetyService = contentSafety;
}

/// <summary>
/// Gets the status of content safety.
/// </summary>
/// <returns></returns>
[HttpGet]
[Route("contentSafety/status")]
[ProducesResponseType(StatusCodes.Status200OK)]
public bool ContentSafetyStatus()
{
return this._contentSafetyService!.ContentSafetyStatus(this._logger);
}

/// <summary>
Expand Down Expand Up @@ -265,7 +281,7 @@ private async Task ValidateDocumentImportFormAsync(DocumentImportForm documentIm
throw new ArgumentException($"File {formFile.FileName} size exceeds the limit.");
}

// Make sure the file type is supported.
// Make sure the file type is supported and validate any images if ContentSafety is enabled.
var fileType = this.GetFileType(Path.GetFileName(formFile.FileName));
switch (fileType)
{
Expand All @@ -276,16 +292,41 @@ private async Task ValidateDocumentImportFormAsync(DocumentImportForm documentIm
case SupportedFileType.Jpg:
case SupportedFileType.Png:
case SupportedFileType.Tiff:
{
if (this._ocrSupportOptions.Type != OcrSupportOptions.OcrSupportType.None)
{
teresaqhoang marked this conversation as resolved.
Show resolved Hide resolved
if (documentImportForm.UseContentSafety)
{
if (!this._contentSafetyService!.ContentSafetyStatus(this._logger))
{
throw new ArgumentException("Unable to analyze image. Content Safety is currently disabled in the backend.");
}

var violations = new List<string>();
try
{
// Convert the form file to a base64 string
var base64Image = await this.ConvertFormFileToBase64Async(formFile);

// Call the content safety controller to analyze the image
var imageAnalysisResponse = await this._contentSafetyService!.ImageAnalysisAsync(base64Image, default);
violations = AzureContentSafety.ParseViolatedCategories(imageAnalysisResponse, this._contentSafetyService!.Options!.ViolationThreshold);
}
catch (Exception ex) when (!ex.IsCriticalException())
{
this._logger.LogError(ex, "Failed to analyze image {0} with Content Safety. ErrorCode: {{1}}", formFile.FileName, (ex as AIException)?.ErrorCode);
throw new AggregateException($"Failed to analyze image {formFile.FileName} with Content Safety.", ex);
}

if (violations.Count > 0)
{
throw new ArgumentException($"Unable to upload image {formFile.FileName}. Detected undesirable content with potential risk: {string.Join(", ", violations)}");
}
}
break;
}

throw new ArgumentException($"Unsupported image file type: {fileType} when " +
$"{OcrSupportOptions.PropertyName}:{nameof(OcrSupportOptions.Type)} is set to " +
nameof(OcrSupportOptions.OcrSupportType.None));
}
default:
throw new ArgumentException($"Unsupported file type: {fileType}");
}
Expand Down Expand Up @@ -315,11 +356,12 @@ private async Task<ImportResult> ImportDocumentHelperAsync(IKernel kernel, IForm
case SupportedFileType.Jpg:
case SupportedFileType.Png:
case SupportedFileType.Tiff:
{
documentContent = await this.ReadTextFromImageFileAsync(formFile);
if (documentContent.Trim().Length == 0)
{
teresaqhoang marked this conversation as resolved.
Show resolved Hide resolved
throw new ArgumentException($"Image {{{formFile.FileName}}} does not contain text.");
}
break;
}

default:
// This should never happen. Validation should have already caught this.
return ImportResult.Fail();
Expand Down Expand Up @@ -487,6 +529,19 @@ private async Task<string> ReadTextFromImageFileAsync(IFormFile file)
return textFromFile;
}

/// <summary>
/// Helper method to convert a form file to a base64 string.
/// </summary>
/// <param name="file">An IFormFile object.</param>
/// <returns>A Base64 string of the content of the image.</returns>
private async Task<string> ConvertFormFileToBase64Async(IFormFile formFile)
{
using var memoryStream = new MemoryStream();
await formFile.CopyToAsync(memoryStream);
var bytes = memoryStream.ToArray();
return Convert.ToBase64String(bytes);
}

/// <summary>
/// Read the content of a text file.
/// </summary>
Expand Down
20 changes: 20 additions & 0 deletions webapi/Extensions/SemanticKernelExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
using System.Threading.Tasks;
using CopilotChat.WebApi.Hubs;
using CopilotChat.WebApi.Options;
using CopilotChat.WebApi.Services;
using CopilotChat.WebApi.Skills.ChatSkills;
using CopilotChat.WebApi.Storage;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
Expand Down Expand Up @@ -61,6 +63,9 @@ internal static IServiceCollection AddSemanticKernelServices(this IServiceCollec
// Semantic memory
services.AddSemanticTextMemory();

// Azure Content Safety
services.AddContentSafety();

// Register skills
services.AddScoped<RegisterSkillsWithKernel>(sp => RegisterSkillsAsync);

Expand Down Expand Up @@ -103,6 +108,7 @@ public static IKernel RegisterChatSkill(this IKernel kernel, IServiceProvider sp
messageRelayHubContext: sp.GetRequiredService<IHubContext<MessageRelayHub>>(),
promptOptions: sp.GetRequiredService<IOptions<PromptsOptions>>(),
documentImportOptions: sp.GetRequiredService<IOptions<DocumentMemoryOptions>>(),
contentSafety: sp.GetService<AzureContentSafety>(),
planner: sp.GetRequiredService<CopilotChatPlanner>(),
logger: sp.GetRequiredService<ILogger<ChatSkill>>()),
nameof(ChatSkill));
Expand Down Expand Up @@ -253,6 +259,20 @@ private static void AddSemanticTextMemory(this IServiceCollection services)
.ToTextEmbeddingsService(logger: sp.GetRequiredService<ILogger<AIServiceOptions>>())));
}

/// <summary>
/// Adds Azure Content Safety
/// </summary>
internal static void AddContentSafety(this IServiceCollection services)
{
IConfiguration configuration = services.BuildServiceProvider().GetRequiredService<IConfiguration>();
ContentSafetyOptions options = configuration.GetSection(ContentSafetyOptions.PropertyName).Get<ContentSafetyOptions>();

if (options.Enabled)
{
services.AddSingleton<AzureContentSafety>(sp => new AzureContentSafety(new Uri(options.Endpoint), options.Key, options));
}
}

/// <summary>
/// Add the completion backend to the kernel config
/// </summary>
Expand Down
6 changes: 6 additions & 0 deletions webapi/Extensions/ServiceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ public static IServiceCollection AddOptions(this IServiceCollection services, Co
.ValidateOnStart()
.PostConfigure(TrimStringProperties);

// Content safety options
services.AddOptions<ContentSafetyOptions>()
.Bind(configuration.GetSection(ContentSafetyOptions.PropertyName))
.ValidateOnStart()
.PostConfigure(TrimStringProperties);

return services;
}

Expand Down
5 changes: 5 additions & 0 deletions webapi/Models/Request/DocumentImportForm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ public enum DocumentScopes
/// Will be used to create the chat message representing the document upload.
/// </summary>
public string UserName { get; set; } = string.Empty;

/// <summary>
/// Flag indicating whether user has content safety enabled from the client.
/// </summary>
public bool UseContentSafety { get; set; } = false;
alliscode marked this conversation as resolved.
Show resolved Hide resolved
}
37 changes: 37 additions & 0 deletions webapi/Models/Response/ImageAnalysisResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;
using CopilotChat.WebApi.Services;

namespace CopilotChat.WebApi.Models.Response;

/// <summary>
/// Response definition to the /contentsafety/image:analyze
/// endpoint made by the AzureContentSafety.
/// </summary>
public class ImageAnalysisResponse
{
/// <summary>
/// Gets or sets the AnalysisResult related to hate.
/// </summary>
[JsonPropertyName("hateResult")]
public AnalysisResult? HateResult { get; set; }

/// <summary>
/// Gets or sets the AnalysisResult related to self-harm.
/// </summary>
[JsonPropertyName("selfHarmResult")]
public AnalysisResult? SelfHarmResult { get; set; }

/// <summary>
/// Gets or sets the AnalysisResult related to sexual content.
/// </summary>
[JsonPropertyName("sexualResult")]
public AnalysisResult? SexualResult { get; set; }

/// <summary>
/// Gets or sets the AnalysisResult related to violence.
/// </summary>
[JsonPropertyName("violenceResult")]
public AnalysisResult? ViolenceResult { get; set; }
}
38 changes: 38 additions & 0 deletions webapi/Options/ContentSafetyOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.ComponentModel.DataAnnotations;

namespace CopilotChat.WebApi.Options;

/// <summary>
/// Configuration options for content safety.
/// </summary>
public class ContentSafetyOptions
{
public const string PropertyName = "ContentSafety";

/// <summary>
/// Whether to enable content safety.
/// </summary>
[Required, NotEmptyOrWhitespace]
public bool Enabled { get; set; } = false;

/// <summary>
/// Azure Content Safety endpoints
/// </summary>
[RequiredOnPropertyValue(nameof(Enabled), true)]
public string Endpoint { get; set; } = string.Empty;

/// <summary>
/// Key to access the content safety service.
/// </summary>
[RequiredOnPropertyValue(nameof(Enabled), true)]
public string Key { get; set; } = string.Empty;

/// <summary>
/// Set the violation threshold. See https://learn.microsoft.com/en-us/azure/ai-services/content-safety/quickstart-image for details.
/// </summary>
[Range(0, 6)]
public short ViolationThreshold { get; set; } = 4;
}
Loading