Skip to content

Commit

Permalink
Throw ArgumentNullExceptions for null arguments (#570)
Browse files Browse the repository at this point in the history
### Motivation and Context

Use ArgumentNullException as devs expect, and avoid exception-related
work when not throwing.

### Description

.NET developers expect that passing null erroneously to a method will
result in an ArgumentNullException, but currently it's resulting in a
custom ValidationException.

There are also places where string interpolation is being used to create
the error message; the work and allocation associated with that
interpolation is going to happen regardless of whether the exception is
thrown or not.
  • Loading branch information
stephentoub authored and dluc committed Apr 29, 2023
1 parent 923e6d9 commit b38816e
Show file tree
Hide file tree
Showing 26 changed files with 102 additions and 82 deletions.
Expand Up @@ -40,7 +40,7 @@ public abstract class ClientBase
CompleteRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
Verify.NotNull(requestSettings, "Completion settings cannot be empty");
Verify.NotNull(requestSettings);

if (requestSettings.MaxTokens < 1)
{
Expand Down Expand Up @@ -126,8 +126,8 @@ public abstract class ClientBase
ChatRequestSettings requestSettings,
CancellationToken cancellationToken = default)
{
Verify.NotNull(chat, "The chat history cannot be null");
Verify.NotNull(requestSettings, "Completion settings cannot be empty");
Verify.NotNull(chat);
Verify.NotNull(requestSettings);

if (requestSettings.MaxTokens < 1)
{
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/Extensions/Planning.ActionPlanner/ActionPlanner.cs
Expand Up @@ -47,7 +47,7 @@ public sealed class ActionPlanner
IKernel kernel,
string? prompt = null)
{
Verify.NotNull(kernel, "The planner requires a non-null kernel instance");
Verify.NotNull(kernel);

string promptTemplate = prompt ?? EmbeddedResource.Read("skprompt.txt");

Expand Down Expand Up @@ -140,7 +140,7 @@ public async Task<Plan> CreatePlanAsync(string goal)
[SKFunctionInput(Description = "The current goal processed by the planner", DefaultValue = "")]
public string ListOfFunctions(string goal, SKContext context)
{
Verify.NotNull(context.Skills, "The planner requires a non-null skill collection");
Verify.NotNull(context.Skills);
var functionsAvailable = context.Skills.GetFunctionsView();

// Prepare list using the format used by skprompt.txt
Expand Down
Expand Up @@ -89,9 +89,6 @@ internal static Plan ToPlanFromXml(this string xmlString, string goal, SKContext

if (!string.IsNullOrEmpty(functionName) && context.IsFunctionRegistered(skillName, functionName, out var skillFunction))
{
Verify.NotNull(functionName, nameof(functionName));
Verify.NotNull(skillFunction, nameof(skillFunction));

var planStep = new Plan(skillFunction);

var functionVariables = new ContextVariables();
Expand Down
Expand Up @@ -29,7 +29,7 @@ public sealed class SequentialPlanner
SequentialPlannerConfig? config = null,
string? prompt = null)
{
Verify.NotNull(kernel, $"{this.GetType().FullName} requires a kernel instance.");
Verify.NotNull(kernel);
this.Config = config ?? new();

this.Config.ExcludedSkills.Add(RestrictedSkillName);
Expand Down
Expand Up @@ -42,7 +42,7 @@ public static Embedding<TEmbedding> Empty
[JsonConstructor]
public Embedding(IEnumerable<TEmbedding> vector)
{
Verify.NotNull(vector, nameof(vector));
Verify.NotNull(vector);

if (!IsSupported)
{
Expand Down
Expand Up @@ -43,7 +43,7 @@ public static class EmbeddingGenerationExtensions
(this IEmbeddingGeneration<TValue, TEmbedding> generator, TValue value, CancellationToken cancellationToken = default)
where TEmbedding : unmanaged
{
Verify.NotNull(generator, "Embeddings generator cannot be NULL");
Verify.NotNull(generator);
return (await generator.GenerateEmbeddingsAsync(new[] { value }, cancellationToken).ConfigureAwait(false)).FirstOrDefault();
}
}
@@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#pragma warning disable IDE0130 // Namespace does not match folder structure
// ReSharper disable once CheckNamespace
namespace System.Runtime.CompilerServices;
#pragma warning restore IDE0130

#if !NETCOREAPP

[AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = false)]
internal sealed class CallerArgumentExpressionAttribute : Attribute
{
public CallerArgumentExpressionAttribute(string parameterName)
{
this.ParameterName = parameterName;
}

public string ParameterName { get; }
}

#endif
42 changes: 15 additions & 27 deletions dotnet/src/SemanticKernel.Abstractions/Diagnostics/Verify.cs
Expand Up @@ -15,18 +15,18 @@ internal static class Verify
private static readonly Regex s_asciiLettersDigitsUnderscoresRegex = new("^[0-9A-Za-z_]*$");

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static void NotNull([NotNull] object? obj, string message)
internal static void NotNull([NotNull] object? obj, [CallerArgumentExpression(nameof(obj))] string? paramName = null)
{
if (obj is null)
{
ThrowValidationException(ValidationException.ErrorCodes.NullValue, message);
ThrowArgumentNullException(paramName);
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static void NotEmpty([NotNull] string? str, string message)
internal static void NotEmpty([NotNull] string? str, string? message = null, [CallerArgumentExpression(nameof(str))] string? paramName = null)
{
NotNull(str, message);
NotNull(str, paramName);
if (string.IsNullOrWhiteSpace(str))
{
ThrowValidationException(ValidationException.ErrorCodes.EmptyValue, message);
Expand Down Expand Up @@ -62,8 +62,8 @@ internal static void ValidFunctionParamName([NotNull] string? functionParamName)

internal static void StartsWith(string text, string prefix, string message)
{
NotEmpty(text, "The text to verify cannot be empty");
NotNull(prefix, "The prefix to verify is empty");
NotEmpty(text);
NotNull(prefix);
if (!text.StartsWith(prefix, StringComparison.OrdinalIgnoreCase))
{
ThrowValidationException(ValidationException.ErrorCodes.MissingPrefix, message);
Expand Down Expand Up @@ -104,33 +104,21 @@ internal static void ParametersUniqueness(IList<ParameterView> parameters)
}
}

internal static void GreaterThan<T>(T value, T min, string message) where T : IComparable<T>
{
int cmp = value.CompareTo(min);

if (cmp <= 0)
{
throw new ValidationException(ValidationException.ErrorCodes.OutOfRange, message);
}
}

public static void LessThan<T>(T value, T max, string message) where T : IComparable<T>
{
int cmp = value.CompareTo(max);

if (cmp >= 0)
{
throw new ValidationException(ValidationException.ErrorCodes.OutOfRange, message);
}
}

[DoesNotReturn]
private static void ThrowInvalidName(string kind, string name) =>
throw new KernelException(
KernelException.ErrorCodes.InvalidFunctionDescription,
$"A {kind} can contain only ASCII letters, digits, and underscores: '{name}' is not a valid name.");

[DoesNotReturn]
private static void ThrowValidationException(ValidationException.ErrorCodes errorCodes, string message) =>
internal static void ThrowArgumentNullException(string? paramName) =>
throw new ArgumentNullException(paramName);

[DoesNotReturn]
internal static void ThrowArgumentOutOfRangeException<T>(string? paramName, T actualValue, string message) =>
throw new ArgumentOutOfRangeException(paramName, actualValue, message);

[DoesNotReturn]
internal static void ThrowValidationException(ValidationException.ErrorCodes errorCodes, string? message) =>
throw new ValidationException(errorCodes, message);
}
Expand Up @@ -94,7 +94,10 @@ public SKContext Fail(string errorDescription, Exception? exception = null)
/// <returns>Delegate to execute the function</returns>
public ISKFunction Func(string skillName, string functionName)
{
Verify.NotNull(this.Skills, "The skill collection hasn't been set");
if (this.Skills is null)
{
Verify.ThrowValidationException(ValidationException.ErrorCodes.NullValue, nameof(this.Skills));
}

if (this.Skills.HasNativeFunction(skillName, functionName))
{
Expand Down
Expand Up @@ -179,7 +179,10 @@ public PromptTemplateConfig Compact()
public static PromptTemplateConfig FromJson(string json)
{
var result = Json.Deserialize<PromptTemplateConfig>(json);
Verify.NotNull(result, "Unable to deserialize prompt template config. The deserialized returned NULL.");
if (result is null)
{
Verify.ThrowValidationException(ValidationException.ErrorCodes.NullValue, "Unable to deserialize prompt template config. The deserialized returned NULL.");
}
return result;
}
}
Expand Up @@ -69,7 +69,7 @@ public void ItAllowsUnsupportedTypesOnEachOperation()
public void ItThrowsWithNullVector()
{
// Assert
Assert.Throws<ValidationException>(() => new Embedding<float>(null!));
Assert.Throws<ArgumentNullException>("vector", () => new Embedding<float>(null!));
}

[Fact]
Expand Down
@@ -1,8 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Memory.Collections;
using Xunit;

Expand All @@ -27,11 +27,10 @@ public void ItThrowsExceptionWhenCapacityIsInvalid()
};

// Act
var exception = Assert.Throws<ValidationException>(() => action());
var exception = Assert.Throws<ArgumentOutOfRangeException>("capacity", () => action());

// Assert
Assert.Equal(ValidationException.ErrorCodes.OutOfRange, exception.ErrorCode);
Assert.Equal("Out of range: MinHeap capacity must be greater than 0.", exception.Message);
Assert.Equal(-1, exception.ActualValue);
}

[Fact]
Expand Down Expand Up @@ -91,11 +90,10 @@ public void ItThrowsExceptionOnAddingItemsAtInvalidIndex()
var action = () => { minHeap.Add(items, startIndex); };

// Act
var exception = Assert.Throws<ValidationException>(() => action());
var exception = Assert.Throws<ArgumentOutOfRangeException>("startAt", () => action());

// Assert
Assert.Equal(ValidationException.ErrorCodes.OutOfRange, exception.ErrorCode);
Assert.Equal("Out of range: startAt value must be less than items count.", exception.Message);
Assert.Equal(startIndex, exception.ActualValue);
}

[Fact]
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/SemanticKernel/AI/Embeddings/IEmbeddingIndex.cs
Expand Up @@ -42,7 +42,7 @@ public static class EmbeddingIndexExtensions
double minScore = 0.0)
where TEmbedding : unmanaged
{
Verify.NotNull(index, "Embedding index cannot be NULL");
Verify.NotNull(index);
await foreach (var match in index.GetNearestMatchesAsync(collection, embedding, 1, minScore))
{
return match;
Expand Down
Expand Up @@ -34,7 +34,7 @@ public sealed class HuggingFaceTextCompletion : ITextCompletion, IDisposable
/// <param name="httpClientHandler">Instance of <see cref="HttpClientHandler"/> to setup specific scenarios.</param>
public HuggingFaceTextCompletion(Uri endpoint, string model, HttpClientHandler httpClientHandler)
{
Verify.NotNull(endpoint, "Endpoint cannot be null.");
Verify.NotNull(endpoint);
Verify.NotEmpty(model, "Model cannot be empty.");

this._endpoint = endpoint;
Expand All @@ -53,7 +53,7 @@ public HuggingFaceTextCompletion(Uri endpoint, string model, HttpClientHandler h
/// <param name="model">Model to use for service API call.</param>
public HuggingFaceTextCompletion(Uri endpoint, string model)
{
Verify.NotNull(endpoint, "Endpoint cannot be null.");
Verify.NotNull(endpoint);
Verify.NotEmpty(model, "Model cannot be empty.");

this._endpoint = endpoint;
Expand Down
Expand Up @@ -33,7 +33,7 @@ public sealed class HuggingFaceTextEmbeddingGeneration : IEmbeddingGeneration<st
/// <param name="httpClientHandler">Instance of <see cref="HttpClientHandler"/> to setup specific scenarios.</param>
public HuggingFaceTextEmbeddingGeneration(Uri endpoint, string model, HttpClientHandler httpClientHandler)
{
Verify.NotNull(endpoint, "Endpoint cannot be null.");
Verify.NotNull(endpoint);
Verify.NotEmpty(model, "Model cannot be empty.");

this._endpoint = endpoint;
Expand All @@ -52,7 +52,7 @@ public HuggingFaceTextEmbeddingGeneration(Uri endpoint, string model, HttpClient
/// <param name="model">Model to use for service API call.</param>
public HuggingFaceTextEmbeddingGeneration(Uri endpoint, string model)
{
Verify.NotNull(endpoint, "Endpoint cannot be null.");
Verify.NotNull(endpoint);
Verify.NotEmpty(model, "Model cannot be empty.");

this._endpoint = endpoint;
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/SemanticKernel/Kernel.cs
Expand Up @@ -123,7 +123,7 @@ public ISKFunction RegisterCustomFunction(string skillName, ISKFunction customFu
{
// Future-proofing the name not to contain special chars
Verify.ValidSkillName(skillName);
Verify.NotNull(customFunction, $"The {nameof(customFunction)} parameter is not set to an instance of an object.");
Verify.NotNull(customFunction);

customFunction.SetDefaultSkillCollection(this.Skills);
this._skillCollection.AddSemanticFunction(customFunction);
Expand Down
16 changes: 8 additions & 8 deletions dotnet/src/SemanticKernel/KernelBuilder.cs
Expand Up @@ -69,7 +69,7 @@ public IKernel Build()
/// <returns>Updated kernel builder including the logger.</returns>
public KernelBuilder WithLogger(ILogger log)
{
Verify.NotNull(log, "The logger instance provided is NULL");
Verify.NotNull(log);
this._log = log;
return this;
}
Expand All @@ -81,7 +81,7 @@ public KernelBuilder WithLogger(ILogger log)
/// <returns>Updated kernel builder including the semantic text memory entity.</returns>
public KernelBuilder WithMemory(ISemanticTextMemory memory)
{
Verify.NotNull(memory, "The memory instance provided is NULL");
Verify.NotNull(memory);
this._memory = memory;
return this;
}
Expand All @@ -93,7 +93,7 @@ public KernelBuilder WithMemory(ISemanticTextMemory memory)
/// <returns>Updated kernel builder including the memory storage.</returns>
public KernelBuilder WithMemoryStorage(IMemoryStore storage)
{
Verify.NotNull(storage, "The memory instance provided is NULL");
Verify.NotNull(storage);
this._memoryStorage = storage;
return this;
}
Expand All @@ -107,8 +107,8 @@ public KernelBuilder WithMemoryStorage(IMemoryStore storage)
public KernelBuilder WithMemoryStorageAndTextEmbeddingGeneration(
IMemoryStore storage, IEmbeddingGeneration<string, float> embeddingGenerator)
{
Verify.NotNull(storage, "The memory instance provided is NULL");
Verify.NotNull(embeddingGenerator, "The embedding generator instance provided is NULL");
Verify.NotNull(storage);
Verify.NotNull(embeddingGenerator);
this._memory = new SemanticTextMemory(storage, embeddingGenerator);
return this;
}
Expand All @@ -120,7 +120,7 @@ public KernelBuilder WithMemoryStorage(IMemoryStore storage)
/// <returns>Updated kernel builder including the retry handler factory.</returns>
public KernelBuilder WithRetryHandlerFactory(IDelegatingHandlerFactory httpHandlerFactory)
{
Verify.NotNull(httpHandlerFactory, "The retry handler factory instance provided is NULL");
Verify.NotNull(httpHandlerFactory);
this._httpHandlerFactory = httpHandlerFactory;
return this;
}
Expand All @@ -132,7 +132,7 @@ public KernelBuilder WithRetryHandlerFactory(IDelegatingHandlerFactory httpHandl
/// <returns>Updated kernel builder including the given configuration.</returns>
public KernelBuilder WithConfiguration(KernelConfig config)
{
Verify.NotNull(config, "The configuration instance provided is NULL");
Verify.NotNull(config);
this._config = config;
return this;
}
Expand All @@ -144,7 +144,7 @@ public KernelBuilder WithConfiguration(KernelConfig config)
/// <returns>Updated kernel builder including the updated configuration.</returns>
public KernelBuilder Configure(Action<KernelConfig> configure)
{
Verify.NotNull(configure, "The configuration action provided is NULL");
Verify.NotNull(configure);
configure.Invoke(this._config);
return this;
}
Expand Down

0 comments on commit b38816e

Please sign in to comment.