Skip to content

Commit

Permalink
Support writing IList<T> impls again (#5558)
Browse files Browse the repository at this point in the history
(cherry picked from commit 786ce89)
  • Loading branch information
NinoFloris committed Feb 9, 2024
1 parent ab7bf09 commit 9e24837
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 36 deletions.
3 changes: 0 additions & 3 deletions src/Npgsql/Internal/Converters/MultirangeConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ sealed class MultirangeConverter<T, TRange> : PgStreamingConverter<T>
readonly PgConverter<TRange> _rangeConverter;
readonly BufferRequirements _rangeRequirements;

static MultirangeConverter()
=> Debug.Assert(typeof(T).IsArray || typeof(T).IsGenericType && typeof(T).GetGenericTypeDefinition() == typeof(List<>));

public MultirangeConverter(PgConverter<TRange> rangeConverter)
{
if (!rangeConverter.CanConvert(DataFormat.Binary, out var bufferRequirements))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,9 @@ sealed class ArrayResolver : Resolver, IPgTypeInfoResolver
var info = Mappings.Find(type, dataTypeName, options);

Type? elementType = null;
if (info is null && dataTypeName is not null &&
(type is null || type == typeof(object) || TypeInfoMappingCollection.IsArrayLikeType(type, out elementType))
&& options.DatabaseInfo.GetPostgresType(dataTypeName) is PostgresArrayType { Element: var pgElementType })
if (info is null && dataTypeName is not null
&& options.DatabaseInfo.GetPostgresType(dataTypeName) is PostgresArrayType { Element: var pgElementType }
&& (type is null || type == typeof(object) || TypeInfoMappingCollection.IsArrayLikeType(type, out elementType)))
{
info = GetEnumArrayTypeInfo(elementType, pgElementType, type, dataTypeName.GetValueOrDefault(), options) ??
GetObjectArrayTypeInfo(elementType, pgElementType, type, dataTypeName.GetValueOrDefault(), options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ sealed class UnsupportedTypeInfoResolver<TBuilder> : IPgTypeInfoResolver
throw new NotSupportedException("Writing is not supported for jagged collections, use a multidimensional array instead.");

if (typeof(IEnumerable).IsAssignableFrom(type) && !typeof(IList).IsAssignableFrom(type) && type != typeof(string) && (dataTypeName is null || dataTypeName.Value.IsArray))
throw new NotSupportedException("Writing is not supported for IEnumerable parameters, use an array or List instead.");
throw new NotSupportedException("Writing is not supported for IEnumerable parameters, use an array or some implementation of IList<T> instead.");
}

return null;
Expand Down
73 changes: 46 additions & 27 deletions src/Npgsql/Internal/TypeInfoMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,12 @@ static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping i
};

Func<Type?, bool> GetArrayTypeMatchPredicate(Func<Type?, bool> elementTypeMatchPredicate)
=> type => type is null ? elementTypeMatchPredicate(null) : type.IsArray && elementTypeMatchPredicate.Invoke(type.GetElementType()!);
Func<Type?, bool> GetListTypeMatchPredicate(Func<Type?, bool> elementTypeMatchPredicate)
=> type => type is null ? elementTypeMatchPredicate(null) : type.IsConstructedGenericType && type.GetGenericTypeDefinition() is { } def
&& (def == typeof(List<>) || def == typeof(IList<>))
&& elementTypeMatchPredicate(type.GetGenericArguments()[0]);
=> type => type is null ? elementTypeMatchPredicate(null) : type.IsArray && elementTypeMatchPredicate(type.GetElementType()!);
Func<Type?, bool> GetListTypeMatchPredicate<TElement>(Func<Type?, bool> elementTypeMatchPredicate)
=> type => type is null ? elementTypeMatchPredicate(null)
// We anti-constrain on IsArray to avoid matching byte/sbyte, short/ushort int/uint
// with the list mapping of the earlier type when an exact match is probably available.
: !type.IsArray && typeof(IList<TElement>).IsAssignableFrom(type) && elementTypeMatchPredicate(typeof(TElement));

public void AddType<T>(string dataTypeName, TypeInfoFactory createInfo, bool isDefault = false) where T : class
=> AddType<T>(dataTypeName, createInfo, GetDefaultConfigure(isDefault));
Expand Down Expand Up @@ -330,7 +331,7 @@ static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping i
{
// Always use a predicate to match all dimensions.
var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));
var listTypeMatchPredicate = GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));
var listTypeMatchPredicate = GetListTypeMatchPredicate<TElement>(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));

var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName);

Expand Down Expand Up @@ -370,7 +371,7 @@ void AddArrayType(TypeInfoMapping elementMapping, Type type, Func<TypeInfoMappin
{
// Always use a predicate to match all dimensions.
var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));
var listTypeMatchPredicate = GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));
var listTypeMatchPredicate = GetListTypeMatchPredicate<TElement>(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));

var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName);

Expand Down Expand Up @@ -427,9 +428,9 @@ void AddResolverArrayType(TypeInfoMapping elementMapping, Type type, Func<TypeIn
{
MatchRequirement = mapping.MatchRequirement,
TypeMatchPredicate = mapping.TypeMatchPredicate is not null
? type => type is null
? matchType => matchType is null
? mapping.TypeMatchPredicate(null)
: Nullable.GetUnderlyingType(type) is { } underlying && mapping.TypeMatchPredicate(underlying)
: matchType == nullableType && mapping.TypeMatchPredicate(type)
: null
});
}
Expand All @@ -448,10 +449,10 @@ void AddResolverArrayType(TypeInfoMapping elementMapping, Type type, Func<TypeIn
// Always use a predicate to match all dimensions.
var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));
var nullableArrayTypeMatchPredicate = GetArrayTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate ?? (static type =>
type is null || (Nullable.GetUnderlyingType(type) is { } underlying && underlying == typeof(TElement))));
var listTypeMatchPredicate = GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));
var nullableListTypeMatchPredicate = GetListTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate ?? (static type =>
type is null || (Nullable.GetUnderlyingType(type) is { } underlying && underlying == typeof(TElement))));
type is null || type == typeof(TElement?)));
var listTypeMatchPredicate = GetListTypeMatchPredicate<TElement>(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));
var nullableListTypeMatchPredicate = GetListTypeMatchPredicate<TElement?>(nullableElementMapping.TypeMatchPredicate ?? (static type =>
type is null || type == typeof(TElement?)));

var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName);

Expand Down Expand Up @@ -544,9 +545,9 @@ PgTypeInfo CreateComposedPerInstance(PgTypeInfo innerTypeInfo, PgTypeInfo nullab
{
MatchRequirement = mapping.MatchRequirement,
TypeMatchPredicate = mapping.TypeMatchPredicate is not null
? type => type is null
? matchType => matchType is null
? mapping.TypeMatchPredicate(null)
: Nullable.GetUnderlyingType(type) is { } underlying && mapping.TypeMatchPredicate(underlying)
: matchType == nullableType && mapping.TypeMatchPredicate(type)
: null
});
}
Expand All @@ -565,10 +566,10 @@ PgTypeInfo CreateComposedPerInstance(PgTypeInfo innerTypeInfo, PgTypeInfo nullab
// Always use a predicate to match all dimensions.
var arrayTypeMatchPredicate = GetArrayTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));
var nullableArrayTypeMatchPredicate = GetArrayTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate ?? (static type =>
type is null || (Nullable.GetUnderlyingType(type) is { } underlying && underlying == typeof(TElement))));
var listTypeMatchPredicate = GetListTypeMatchPredicate(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));
var nullableListTypeMatchPredicate = GetListTypeMatchPredicate(nullableElementMapping.TypeMatchPredicate ?? (static type =>
type is null || (Nullable.GetUnderlyingType(type) is { } underlying && underlying == typeof(TElement))));
type is null || type == typeof(TElement?)));
var listTypeMatchPredicate = GetListTypeMatchPredicate<TElement>(elementMapping.TypeMatchPredicate ?? (static type => type is null || type == typeof(TElement)));
var nullableListTypeMatchPredicate = GetListTypeMatchPredicate<TElement?>(nullableElementMapping.TypeMatchPredicate ?? (static type =>
type is null || type == typeof(TElement?)));

var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName);

Expand Down Expand Up @@ -652,18 +653,36 @@ void AddPolymorphicResolverArrayType(TypeInfoMapping elementMapping, Type type,
}

/// Returns whether type matches any of the types we register pg arrays as.
[UnconditionalSuppressMessage("Trimming", "IL2070",
Justification = "Checking for IList<T> implementing types requires interface list enumeration which isn't compatible with trimming. " +
"However as long as a concrete IList<T> is rooted somewhere in the app, for instance through an `AddArrayType<T>(...)` mapping, every implementation must keep it.")]
// We care about IList<T> implementations if the instantiation is actually rooted by us through an Array mapping.
// Dynamic resolvers are a notable counterexample, but they are all correctly marked with RequiresUnreferencedCode.
public static bool IsArrayLikeType(Type type, [NotNullWhen(true)] out Type? elementType)
{
elementType = type switch
if (type.GetElementType() is { } t)
{
{ IsArray: true } => type.GetElementType(),
{ IsConstructedGenericType: true } when type.GetGenericTypeDefinition() is { } def &&
(def == typeof(List<>) || def == typeof(IList<>))
=> type.GetGenericArguments()[0],
_ => null
};
elementType = t;
return true;
}

if (type.IsConstructedGenericType && type.GetGenericTypeDefinition() is var def && (def == typeof(List<>) || def == typeof(IList<>)))
{
elementType = type.GetGenericArguments()[0];
return true;
}

return elementType is not null;
foreach (var inf in type.GetInterfaces())
{
if (inf.IsConstructedGenericType && inf.GetGenericTypeDefinition() == typeof(IList<>))
{
elementType = inf.GetGenericArguments()[0];
return true;
}
}

elementType = null;
return false;
}

static string GetArrayDataTypeName(string dataTypeName)
Expand Down
23 changes: 21 additions & 2 deletions test/Npgsql.Tests/Types/ArrayTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Data;
Expand Down Expand Up @@ -155,6 +156,18 @@ public async Task Generic_List()
=> await AssertType(
new List<int> { 1, 2, 3 }, "{1,2,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array, isDefaultForReading: false);

[Test]
public async Task Write_IList_implementation()
=> await AssertTypeWrite(
ImmutableArray.Create(1, 2, 3), "{1,2,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array);

[Test]
public void Read_IList_implementation_throws()
{
Assert.ThrowsAsync<InvalidCastException>(() =>
AssertTypeRead("{1,2,3}", "integer[]", ImmutableArray.Create(1, 2, 3), isDefault: false));
}

[Test]
public async Task Generic_IList()
{
Expand Down Expand Up @@ -283,8 +296,14 @@ public async Task Writing_IEnumerable_is_not_supported()
{
await using var conn = await OpenConnectionAsync();
await using var cmd = new NpgsqlCommand("SELECT @p1", conn);
cmd.Parameters.AddWithValue("p1", Enumerable.Range(1, 3));
Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf<InvalidCastException>().With.Property("InnerException").Message.Contains("array or List"));
cmd.Parameters.AddWithValue("p1", new EnumerableOnly<int>());
Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf<InvalidCastException>().With.Property("InnerException").Message.Contains("array or some implementation of IList<T>"));
}

class EnumerableOnly<T> : IEnumerable<T>
{
public IEnumerator<T> GetEnumerator() => throw new NotImplementedException();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}

[Test, IssueLink("https://github.com/npgsql/npgsql/issues/960")]
Expand Down

0 comments on commit 9e24837

Please sign in to comment.