Skip to content

Commit

Permalink
React to network type changes in Npgsql
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Oct 11, 2023
1 parent bd0537a commit bf80079
Show file tree
Hide file tree
Showing 8 changed files with 885 additions and 692 deletions.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Diagnostics.CodeAnalysis;
using System.Net;
using System.Runtime.CompilerServices;
using System.Text;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
Expand Down Expand Up @@ -115,6 +116,15 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
case ExpressionType.Convert
when unaryExpression.Type == typeof(ITuple) && unaryExpression.Operand.Type.IsAssignableTo(typeof(ITuple)):
return Visit(unaryExpression.Operand);

// We map both IPAddress and NpgsqlInet to PG inet, and translate many methods accepting NpgsqlInet, so ignore casts from
// IPAddress to NpgsqlInet.
// On the PostgreSQL side, cidr is also implicitly convertible to inet, and at the ADO.NET level NpgsqlCidr has a similar
// implicit conversion operator to NpgsqlInet. So remove that cast as well.
case ExpressionType.Convert
when unaryExpression.Type == typeof(NpgsqlInet)
&& (unaryExpression.Operand.Type == typeof(IPAddress) || unaryExpression.Operand.Type == typeof(NpgsqlCidr)):
return Visit(unaryExpression.Operand);
}

return base.VisitUnary(unaryExpression);
Expand Down
112 changes: 52 additions & 60 deletions src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Collections;
using System.Diagnostics.CodeAnalysis;
using System.Net;
using System.Runtime.CompilerServices;
using System.Text.RegularExpressions;
using Npgsql.EntityFrameworkCore.PostgreSQL.Internal;
Expand Down Expand Up @@ -231,39 +232,40 @@ public NpgsqlSqlExpressionFactory(SqlExpressionFactoryDependencies dependencies)
SqlExpression right,
RelationalTypeMapping? typeMapping)
{
Check.NotNull(left, nameof(left));
Check.NotNull(right, nameof(right));

if (operatorType == ExpressionType.Subtract)
switch (operatorType)
{
if (left.Type == typeof(DateTime) && right.Type == typeof(DateTime) ||
left.Type == typeof(DateTimeOffset) && right.Type == typeof(DateTimeOffset) ||
left.Type == typeof(TimeOnly) && right.Type == typeof(TimeOnly))
case ExpressionType.Subtract
when left.Type == typeof(DateTime) && right.Type == typeof(DateTime)
|| left.Type == typeof(DateTimeOffset) && right.Type == typeof(DateTimeOffset)
|| left.Type == typeof(TimeOnly) && right.Type == typeof(TimeOnly):
{
return (SqlBinaryExpression)ApplyTypeMapping(
new SqlBinaryExpression(operatorType, left, right, typeof(TimeSpan), null), typeMapping);
new SqlBinaryExpression(ExpressionType.Subtract, left, right, typeof(TimeSpan), null), typeMapping);
}

if (left.Type.FullName == "NodaTime.Instant" && right.Type.FullName == "NodaTime.Instant" ||
left.Type.FullName == "NodaTime.ZonedDateTime" && right.Type.FullName == "NodaTime.ZonedDateTime")
case ExpressionType.Subtract
when left.Type.FullName == "NodaTime.Instant" && right.Type.FullName == "NodaTime.Instant"
|| left.Type.FullName == "NodaTime.ZonedDateTime" && right.Type.FullName == "NodaTime.ZonedDateTime":
{
_nodaTimeDurationType ??= left.Type.Assembly.GetType("NodaTime.Duration");
return (SqlBinaryExpression)ApplyTypeMapping(
new SqlBinaryExpression(operatorType, left, right, _nodaTimeDurationType!, null), typeMapping);
new SqlBinaryExpression(ExpressionType.Subtract, left, right, _nodaTimeDurationType!, null), typeMapping);
}

if (left.Type.FullName == "NodaTime.LocalDateTime" && right.Type.FullName == "NodaTime.LocalDateTime" ||
left.Type.FullName == "NodaTime.LocalTime" && right.Type.FullName == "NodaTime.LocalTime")
case ExpressionType.Subtract
when left.Type.FullName == "NodaTime.LocalDateTime" && right.Type.FullName == "NodaTime.LocalDateTime"
|| left.Type.FullName == "NodaTime.LocalTime" && right.Type.FullName == "NodaTime.LocalTime":
{
_nodaTimePeriodType ??= left.Type.Assembly.GetType("NodaTime.Period");
return (SqlBinaryExpression)ApplyTypeMapping(
new SqlBinaryExpression(operatorType, left, right, _nodaTimePeriodType!, null), typeMapping);
new SqlBinaryExpression(ExpressionType.Subtract, left, right, _nodaTimePeriodType!, null), typeMapping);
}

if (left.Type.FullName == "NodaTime.LocalDate" && right.Type.FullName == "NodaTime.LocalDate")
case ExpressionType.Subtract
when left.Type.FullName == "NodaTime.LocalDate" && right.Type.FullName == "NodaTime.LocalDate":
{
return (SqlBinaryExpression)ApplyTypeMapping(
new SqlBinaryExpression(operatorType, left, right, typeof(int), null), typeMapping);
new SqlBinaryExpression(ExpressionType.Subtract, left, right, typeof(int), null), typeMapping);
}
}

Expand Down Expand Up @@ -422,66 +424,56 @@ public virtual PostgresBinaryExpression Overlaps(SqlExpression left, SqlExpressi

private SqlBinaryExpression ApplyTypeMappingOnSqlBinary(SqlBinaryExpression binary, RelationalTypeMapping? typeMapping)
{
var (left, right) = (binary.Left, binary.Right);

// The default SqlExpressionFactory behavior is to assume that the two added operands have the same type,
// and so to infer one side's mapping from the other if needed. Here we take care of some heterogeneous
// operand cases where this doesn't work:
// * Period + Period (???)

if (binary.OperatorType is ExpressionType.Add or ExpressionType.Subtract)
switch (binary.OperatorType)
{
var (left, right) = (binary.Left, binary.Right);
var leftType = left.Type.UnwrapNullableType();
var rightType = right.Type.UnwrapNullableType();

// Note that we apply the given type mapping from above to the left operand (which has the same CLR type as
// the binary expression's)

// DateTime + TimeSpan => DateTime
// DateTimeOffset + TimeSpan => DateTimeOffset
// TimeOnly + TimeSpan => TimeOnly
if (rightType == typeof(TimeSpan)
&& (
leftType == typeof(DateTime)
|| leftType == typeof(DateTimeOffset)
|| leftType == typeof(TimeOnly)
)
|| rightType.FullName == "NodaTime.Period"
&& leftType.FullName is "NodaTime.LocalDateTime" or "NodaTime.LocalDate" or "NodaTime.LocalTime"
|| rightType.FullName == "NodaTime.Duration"
&& leftType.FullName is "NodaTime.Instant" or "NodaTime.ZonedDateTime")
case ExpressionType.Add or ExpressionType.Subtract
when right.Type == typeof(TimeSpan)
&& (left.Type == typeof(DateTime) || left.Type == typeof(DateTimeOffset) || left.Type == typeof(TimeOnly))
|| right.Type.FullName == "NodaTime.Period"
&& left.Type.FullName is "NodaTime.LocalDateTime" or "NodaTime.LocalDate" or "NodaTime.LocalTime"
|| right.Type.FullName == "NodaTime.Duration"
&& left.Type.FullName is "NodaTime.Instant" or "NodaTime.ZonedDateTime":
{
var newLeft = ApplyTypeMapping(left, typeMapping);
var newRight = ApplyDefaultTypeMapping(right);
return new SqlBinaryExpression(binary.OperatorType, newLeft, newRight, binary.Type, newLeft.TypeMapping);
}

if (binary.OperatorType == ExpressionType.Subtract)
// DateTime - DateTime => TimeSpan
// DateTimeOffset - DateTimeOffset => TimeSpan
// DateOnly - DateOnly => TimeSpan
// TimeOnly - TimeOnly => TimeSpan
// Instant - Instant => Duration
// LocalDateTime - LocalDateTime => int (days)
case ExpressionType.Subtract
when left.Type == typeof(DateTime) && right.Type == typeof(DateTime)
|| left.Type == typeof(DateTimeOffset) && right.Type == typeof(DateTimeOffset)
|| left.Type == typeof(DateOnly) && right.Type == typeof(DateOnly)
|| left.Type == typeof(TimeOnly) && right.Type == typeof(TimeOnly)
|| left.Type.FullName == "NodaTime.Instant" && right.Type.FullName == "NodaTime.Instant"
|| left.Type.FullName == "NodaTime.LocalDateTime" && right.Type.FullName == "NodaTime.LocalDateTime"
|| left.Type.FullName == "NodaTime.ZonedDateTime" && right.Type.FullName == "NodaTime.ZonedDateTime"
|| left.Type.FullName == "NodaTime.LocalDate" && right.Type.FullName == "NodaTime.LocalDate"
|| left.Type.FullName == "NodaTime.LocalTime" && right.Type.FullName == "NodaTime.LocalTime":
{
// DateTime - DateTime => TimeSpan
// DateTimeOffset - DateTimeOffset => TimeSpan
// DateOnly - DateOnly => TimeSpan
// TimeOnly - TimeOnly => TimeSpan
// Instant - Instant => Duration
// LocalDateTime - LocalDateTime => int (days)
if (leftType == typeof(DateTime) && rightType == typeof(DateTime)
|| leftType == typeof(DateTimeOffset) && rightType == typeof(DateTimeOffset)
|| leftType == typeof(DateOnly) && rightType == typeof(DateOnly)
|| leftType == typeof(TimeOnly) && rightType == typeof(TimeOnly)
|| leftType.FullName == "NodaTime.Instant" && rightType.FullName == "NodaTime.Instant"
|| leftType.FullName == "NodaTime.LocalDateTime" && rightType.FullName == "NodaTime.LocalDateTime"
|| leftType.FullName == "NodaTime.ZonedDateTime" && rightType.FullName == "NodaTime.ZonedDateTime"
|| leftType.FullName == "NodaTime.LocalDate" && rightType.FullName == "NodaTime.LocalDate"
|| leftType.FullName == "NodaTime.LocalTime" && rightType.FullName == "NodaTime.LocalTime")
{
var inferredTypeMapping = ExpressionExtensions.InferTypeMapping(left, right);

return new SqlBinaryExpression(
ExpressionType.Subtract,
ApplyTypeMapping(left, inferredTypeMapping),
ApplyTypeMapping(right, inferredTypeMapping),
binary.Type,
typeMapping ?? _typeMappingSource.FindMapping(binary.Type, "interval"));
}
var inferredTypeMapping = ExpressionExtensions.InferTypeMapping(left, right);

return new SqlBinaryExpression(
ExpressionType.Subtract,
ApplyTypeMapping(left, inferredTypeMapping),
ApplyTypeMapping(right, inferredTypeMapping),
binary.Type,
typeMapping ?? _typeMappingSource.FindMapping(binary.Type, "interval"));
}
}

Expand Down
39 changes: 26 additions & 13 deletions src/EFCore.PG/Storage/Internal/Mapping/NpgsqlNetworkTypeMappings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,14 @@ public class NpgsqlInetTypeMapping : NpgsqlTypeMapping
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public NpgsqlInetTypeMapping() : base("inet", typeof(IPAddress), NpgsqlDbType.Inet) {}
public NpgsqlInetTypeMapping(Type clrType)
: base("inet", clrType, NpgsqlDbType.Inet)
{
if (clrType != typeof(IPAddress) && clrType != typeof(NpgsqlInet))
{
throw new ArgumentException($"Only {nameof(IPAddress)} and {nameof(NpgsqlInet)} are supported", nameof(clrType));
}
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -154,7 +161,7 @@ protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters p
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override string GenerateNonNullSqlLiteral(object value)
=> $"INET '{(IPAddress)value}'";
=> $"INET '{value}'";

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -163,9 +170,15 @@ protected override string GenerateNonNullSqlLiteral(object value)
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public override Expression GenerateCodeLiteral(object value)
=> Expression.Call(ParseMethod, Expression.Constant(((IPAddress)value).ToString()));

private static readonly MethodInfo ParseMethod = typeof(IPAddress).GetMethod("Parse", new[] { typeof(string) })!;
=> value switch
{
IPAddress ip => Expression.Call(IPAddressParseMethod, Expression.Constant(ip.ToString())),
NpgsqlInet ip => Expression.New(NpgsqlInetConstructor, Expression.Constant(ip.ToString())),
_ => throw new UnreachableException()
};

private static readonly MethodInfo IPAddressParseMethod = typeof(IPAddress).GetMethod("Parse", new[] { typeof(string) })!;
private static readonly ConstructorInfo NpgsqlInetConstructor = typeof(NpgsqlInet).GetConstructor(new[] { typeof(string) })!;
}

/// <summary>
Expand All @@ -182,7 +195,7 @@ public class NpgsqlCidrTypeMapping : NpgsqlTypeMapping
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public NpgsqlCidrTypeMapping() : base("cidr", typeof((IPAddress, int)), NpgsqlDbType.Cidr) {}
public NpgsqlCidrTypeMapping() : base("cidr", typeof(NpgsqlCidr), NpgsqlDbType.Cidr) {}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down Expand Up @@ -210,8 +223,8 @@ protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters p
/// </summary>
protected override string GenerateNonNullSqlLiteral(object value)
{
var cidr = ((IPAddress Address, int Subnet))value;
return $"CIDR '{cidr.Address}/{cidr.Subnet}'";
var cidr = (NpgsqlCidr)value;
return $"CIDR '{cidr.Address}/{cidr.Netmask}'";
}

/// <summary>
Expand All @@ -222,15 +235,15 @@ protected override string GenerateNonNullSqlLiteral(object value)
/// </summary>
public override Expression GenerateCodeLiteral(object value)
{
var cidr = ((IPAddress Address, int Subnet))value;
var cidr = (NpgsqlCidr)value;
return Expression.New(
Constructor,
NpgsqlCidrConstructor,
Expression.Call(ParseMethod, Expression.Constant(cidr.Address.ToString())),
Expression.Constant(cidr.Subnet));
Expression.Constant(cidr.Netmask));
}

private static readonly MethodInfo ParseMethod = typeof(IPAddress).GetMethod("Parse", new[] { typeof(string) })!;

private static readonly ConstructorInfo Constructor =
typeof((IPAddress, int)).GetConstructor(new[] { typeof(IPAddress), typeof(int) })!;
private static readonly ConstructorInfo NpgsqlCidrConstructor =
typeof(NpgsqlCidr).GetConstructor(new[] { typeof(IPAddress), typeof(byte) })!;
}
10 changes: 6 additions & 4 deletions src/EFCore.PG/Storage/Internal/NpgsqlTypeMappingSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ static NpgsqlTypeMappingSource()
// Network address types
private readonly NpgsqlMacaddrTypeMapping _macaddr = new();
private readonly NpgsqlMacaddr8TypeMapping _macaddr8 = new();
private readonly NpgsqlInetTypeMapping _inet = new();
private readonly NpgsqlInetTypeMapping _inetAsIPAddress = new(typeof(IPAddress));
private readonly NpgsqlInetTypeMapping _inetAsNpgsqlInet = new(typeof(NpgsqlInet));
private readonly NpgsqlCidrTypeMapping _cidr = new();

// Built-in geometric types
Expand Down Expand Up @@ -255,7 +256,7 @@ static NpgsqlTypeMappingSource()

{ "macaddr", new[] { _macaddr } },
{ "macaddr8", new[] { _macaddr8 } },
{ "inet", new[] { _inet } },
{ "inet", new RelationalTypeMapping[] { _inetAsIPAddress, _inetAsNpgsqlInet } },
{ "cidr", new[] { _cidr } },

{ "point", new[] { _point } },
Expand Down Expand Up @@ -327,8 +328,9 @@ static NpgsqlTypeMappingSource()
{ typeof(DateTimeOffset), _timestamptzDto },

{ typeof(PhysicalAddress), _macaddr },
{ typeof(IPAddress), _inet },
{ typeof((IPAddress, int)), _cidr },
{ typeof(IPAddress), _inetAsIPAddress },
{ typeof(NpgsqlInet), _inetAsNpgsqlInet },
{ typeof(NpgsqlCidr), _cidr },

{ typeof(BitArray), _varbit },
{ typeof(ImmutableDictionary<string, string>), _immutableHstore },
Expand Down
Loading

0 comments on commit bf80079

Please sign in to comment.