Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast null string parameters to varchar Fixes npgsql/EntityFramework6.Npgsql#121 #144

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions EF6.PG.Tests/EntityFrameworkBasicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ public void InsertAndSelect()
}
var someParameter = "Some";
Assert.IsTrue(context.Posts.Any(p => p.Title.StartsWith(someParameter)));
Assert.IsTrue(context.Posts.All(p => p.Title != null));
Assert.IsTrue(context.Posts.Any(p => someParameter != null));
Assert.IsTrue(context.Posts.Select(p => p.VarbitColumn == varbitVal).First());
Assert.IsTrue(context.Posts.All(p => p.VarbitColumn != null));
Assert.IsTrue(context.Posts.Any(p => varbitVal != null));
Assert.IsTrue(context.Posts.Select(p => p.VarbitColumn == "10011").First());
Assert.AreEqual(1, context.NoColumnsEntities.Count());
}
Expand Down Expand Up @@ -147,6 +151,77 @@ public void SelectWithWhere_Ef_TruncateTime()
}
}

[Test]
public void Select_Ef_Timezone()
{
var createdOnDate = new DateTimeOffset(2020, 12, 03, 22, 23, 0, TimeSpan.Zero);
using (var context = new BloggingContext(ConnectionString))
{
context.Logs.Add(new Log()
{
CreationDate = createdOnDate
});
context.SaveChanges();
}

using (var context = new BloggingContext(ConnectionString))
{
context.Database.ExecuteSqlCommand("SET TIMEZONE='UTC';");
var query = context.Logs.Select(p => NpgsqlDateTimeFunctions.Timezone("Pacific/Honolulu", p.CreationDate));
var createdOnDateInTimeZone = query.FirstOrDefault();
Assert.AreEqual(new DateTime(2020, 12, 03, 12, 23, 0), createdOnDateInTimeZone);
}
}

[Test]
public void Select_Ef_StringAgg()
{
DateTime createdOnDate = new DateTime(2014, 05, 08);
using (var context = new BloggingContext(ConnectionString))
{
var blog = new Blog()
{
Name = "Blog 1"
};
blog.Posts = new List<Post>();

blog.Posts.Add(new Post()
{
Content = "Content 1",
Rating = 1,
Title = "Title 1",
CreationDate = createdOnDate
});
blog.Posts.Add(new Post()
{
Content = "Content 2",
Rating = 2,
Title = "Title 2",
CreationDate = createdOnDate
});
blog.Posts.Add(new Post()
{
Content = "Content 3",
Rating = 3,
Title = "Title 3",
CreationDate = createdOnDate
});

context.Blogs.Add(blog);
context.SaveChanges();
}

using (var context = new BloggingContext(ConnectionString))
{
context.Database.Initialize(true);
var query = context.Posts
.GroupBy(p => p.BlogId)
.Select(g => g.Select(x => x.Title).StringAgg());
var result = query.FirstOrDefault();
Assert.AreEqual("Title 1, Title 2, Title 3", result);
}
}

[Test]
public void SelectWithLike_SpecialCharacters()
{
Expand Down Expand Up @@ -791,7 +866,7 @@ public void Test_string_null_propagation()
Console.WriteLine(query.ToString());
StringAssert.AreEqualIgnoringCase(
"SELECT CASE WHEN (COALESCE(@p__linq__0,E'default_value') IS NULL) THEN (E'')"
+ " WHEN (@p__linq__0 IS NULL) THEN (E'default_value') ELSE (@p__linq__0) END ||"
+ " WHEN (CAST (@p__linq__0 AS varchar) IS NULL) THEN (E'default_value') ELSE (@p__linq__0) END ||"
+ " E'_postfix' AS \"C1\" FROM \"dbo\".\"Blogs\" AS \"Extent1\"",
query.ToString());
}
Expand Down Expand Up @@ -819,7 +894,7 @@ public void Test_string_multiple_null_propagation()
Console.WriteLine(query.ToString());
StringAssert.AreEqualIgnoringCase(
"SELECT CASE WHEN (COALESCE(@p__linq__0,COALESCE(@p__linq__1,@p__linq__2)) IS NULL)"
+ " THEN (E'') WHEN (@p__linq__0 IS NULL) THEN (COALESCE(@p__linq__1,@p__linq__2)) ELSE"
+ " THEN (E'') WHEN (CAST (@p__linq__0 AS varchar) IS NULL) THEN (COALESCE(@p__linq__1,@p__linq__2)) ELSE"
+ " (@p__linq__0) END || E'_postfix' AS \"C1\" FROM \"dbo\".\"Blogs\" AS \"Extent1\"",
query.ToString());
}
Expand Down
18 changes: 16 additions & 2 deletions EF6.PG.Tests/Support/EntityFrameworkTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
using System.Data.Entity.Core.Metadata.Edm;
using System.Data.Entity.Core.Objects;
using System.Data.Entity.Infrastructure;
using System.Data.Entity.ModelConfiguration.Conventions;
using System.Reflection;
using NpgsqlTypes;

// ReSharper disable once CheckNamespace
Expand Down Expand Up @@ -51,6 +53,7 @@ protected void SetUp()
{
context.Blogs.RemoveRange(context.Blogs);
context.Posts.RemoveRange(context.Posts);
context.Logs.RemoveRange(context.Logs);
context.NoColumnsEntities.RemoveRange(context.NoColumnsEntities);
context.SaveChanges();
}
Expand Down Expand Up @@ -80,6 +83,12 @@ public class Post
public virtual Blog Blog { get; set; }
}

public class Log
{
public int Id { get; set; }
public DateTimeOffset CreationDate { get; set; }
}

public class ClrEnumEntity
{
public int Id { get; set; }
Expand Down Expand Up @@ -156,13 +165,14 @@ public BloggingContext(string connection)

public DbSet<Blog> Blogs { get; set; }
public DbSet<Post> Posts { get; set; }
public DbSet<Log> Logs { get; set; }
public DbSet<NoColumnsEntity> NoColumnsEntities { get; set; }
public DbSet<ClrEnumEntity> ClrEnumEntities { get; set; }
public DbSet<ClrEnumCompositeKeyEntity> ClrEnumCompositeKeyEntities { get; set; }
public DbSet<User> Users { get; set; }
public DbSet<Editor> Editors { get; set; }
public DbSet<Administrator> Administrators { get; set; }

[DbFunction("BloggingContext", "ClrStoredAddFunction")]
public static int StoredAddFunction(int val1, int val2)
{
Expand All @@ -183,14 +193,15 @@ public IQueryable<Blog> GetBlogsByName(string name)
return ((IObjectContextAdapter)this).ObjectContext.CreateQuery<Blog>(
$"[GetBlogsByName](@Name)", nameParameter);
}

private static DbCompiledModel CreateModel(NpgsqlConnection connection)
{
var dbModelBuilder = new DbModelBuilder(DbModelBuilderVersion.Latest);

// Import Sets
dbModelBuilder.Entity<Blog>();
dbModelBuilder.Entity<Post>();
dbModelBuilder.Entity<Log>();
dbModelBuilder.Entity<NoColumnsEntity>();
dbModelBuilder.Entity<ClrEnumEntity>();
dbModelBuilder.Entity<ClrEnumCompositeKeyEntity>();
Expand All @@ -202,6 +213,9 @@ private static DbCompiledModel CreateModel(NpgsqlConnection connection)
var dbModel = dbModelBuilder.Build(connection);
var edmType = PrimitiveType.GetEdmPrimitiveType(PrimitiveTypeKind.Int32);

//these parameter types need to match both the database method and the C# method for EF to link
var edmStringType = PrimitiveType.GetEdmPrimitiveType(PrimitiveTypeKind.String);

var addFunc = EdmFunction.Create(
"ClrStoredAddFunction",
"BloggingContext",
Expand Down
4 changes: 3 additions & 1 deletion EF6.PG/EF6.PG.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<Copyright>Copyright 2019 © The Npgsql Development Team</Copyright>
<Company>Npgsql</Company>
<PackageTags>npgsql postgresql postgres data database entity framework ef orm</PackageTags>
<VersionPrefix>6.4.0</VersionPrefix>
<VersionPrefix>6.4.3</VersionPrefix>
<LangVersion>latest</LangVersion>
<TargetFrameworks>net45;net461;netstandard21</TargetFrameworks>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
Expand All @@ -23,6 +23,8 @@
<RepositoryType>git</RepositoryType>
<RepositoryUrl>git://github.com/npgsql/EntityFramework6.Npgsql</RepositoryUrl>
<Deterministic>true</Deterministic>
<PackageId>SureWx.EntityFramework6.Npgsql</PackageId>
<Version>6.4.3</Version>
</PropertyGroup>
<ItemGroup>
<EmbeddedResource Include="Resources/*" />
Expand Down
21 changes: 21 additions & 0 deletions EF6.PG/NpgsqlAggregateFunctions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Data.Entity;
using System.Diagnostics.CodeAnalysis;

namespace Npgsql
{
/// <summary>
/// Use this class in LINQ queries to emit aggregate functions.
/// </summary>
[SuppressMessage("ReSharper", "UnusedParameter.Global")]
public static class NpgsqlAggregateFunctions
{
/// <summary>
/// Concatenate strings
/// </summary>
[DbFunction("NpgsqlAggregateFunctions", "StringAgg")]
public static string StringAgg<TSource>(this IEnumerable<TSource> source)
=> throw new NotSupportedException();
}
}
20 changes: 20 additions & 0 deletions EF6.PG/NpgsqlDateTimeFunctions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using System;
using System.Data.Entity;
using System.Diagnostics.CodeAnalysis;

namespace Npgsql
{
/// <summary>
/// Use this class in LINQ queries to emit timestamp manipulation SQL fragments.
/// </summary>
[SuppressMessage("ReSharper", "UnusedParameter.Global")]
public static class NpgsqlDateTimeFunctions
{
/// <summary>
/// Convert a timestamptz to a timezone
/// </summary>
[DbFunction("Npgsql", "timezone")]
public static DateTime Timezone(string zone, DateTimeOffset timestamp)
=> throw new NotSupportedException();
}
}
26 changes: 24 additions & 2 deletions EF6.PG/NpgsqlProviderManifest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,36 @@ public override string EscapeLikeArgument([NotNull] string argument)
public override bool SupportsInExpression() => true;

public override ReadOnlyCollection<EdmFunction> GetStoreFunctions()
=> new[] { typeof(NpgsqlTextFunctions).GetTypeInfo(), typeof(NpgsqlTypeFunctions) }
=> new[] { typeof(NpgsqlTextFunctions).GetTypeInfo(), typeof(NpgsqlTypeFunctions), typeof(NpgsqlDateTimeFunctions) }
.SelectMany(x => x.GetMethods(BindingFlags.Public | BindingFlags.Static))
.Select(x => new { Method = x, DbFunction = x.GetCustomAttribute<DbFunctionAttribute>() })
.Where(x => x.DbFunction != null)
.Select(x => CreateComposableEdmFunction(x.Method, x.DbFunction))
.Union(new []
{
EdmFunction.Create("StringAgg", "NpgsqlAggregateFunctions", DataSpace.SSpace, new EdmFunctionPayload
{
ParameterTypeSemantics = ParameterTypeSemantics.AllowImplicitConversion,
IsComposable = true,
IsAggregate = true,
Schema = "",
IsFromProviderManifest = true,
IsBuiltIn = true,
StoreFunctionName = "string_agg",
ReturnParameters = new[]
{
FunctionParameter.Create("ReturnType", PrimitiveType.GetEdmPrimitiveType(PrimitiveTypeKind.String), ParameterMode.ReturnValue)
},
Parameters = new[]
{
FunctionParameter.Create("input", PrimitiveType.GetEdmPrimitiveType(PrimitiveTypeKind.String).GetEdmPrimitiveType().GetCollectionType(), ParameterMode.In),
//FunctionParameter.Create("separator", PrimitiveType.GetEdmPrimitiveType(PrimitiveTypeKind.String), ParameterMode.In),
},
}, null),
})
.ToList()
.AsReadOnly();

static EdmFunction CreateComposableEdmFunction([NotNull] MethodInfo method, [NotNull] DbFunctionAttribute dbFunctionInfo)
{
if (method == null)
Expand Down
13 changes: 13 additions & 0 deletions EF6.PG/Properties/PublishProfiles/FolderProfile.pubxml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<?xml version="1.0" encoding="utf-8"?>
<!--
https://go.microsoft.com/fwlink/?LinkID=208121.
-->
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<PropertyGroup>
<Configuration>Debug</Configuration>
<Platform>Any CPU</Platform>
<PublishDir>bin\Release\net45\publish\</PublishDir>
<PublishProtocol>FileSystem</PublishProtocol>
<TargetFramework>net45</TargetFramework>
</PropertyGroup>
</Project>
44 changes: 43 additions & 1 deletion EF6.PG/SqlGenerators/SqlBaseGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,15 @@ public override VisitedExpression Visit([NotNull] DbIsOfExpression expression)
}

public override VisitedExpression Visit([NotNull] DbIsNullExpression expression)
=> OperatorExpression.Build(Operator.IsNull, _useNewPrecedences, expression.Argument.Accept(this));
{
if (expression.Argument.ExpressionKind == DbExpressionKind.ParameterReference &&
(expression.Argument.ResultType.EdmType as PrimitiveType)?.PrimitiveTypeKind == PrimitiveTypeKind.String)
{
var castedParameterExpression = new CastExpression(expression.Argument.Accept(this), "varchar");
return OperatorExpression.Build(Operator.IsNull, _useNewPrecedences, castedParameterExpression);
}
return OperatorExpression.Build(Operator.IsNull, _useNewPrecedences, expression.Argument.Accept(this));
}

// NOT EXISTS
public override VisitedExpression Visit([NotNull] DbIsEmptyExpression expression)
Expand Down Expand Up @@ -928,6 +936,40 @@ VisitedExpression VisitFunction(DbFunctionAggregate functionAggregate)
aggregate.AddArgument(aggregateArg);
return new CastExpression(aggregate, GetDbType(functionAggregate.ResultType.EdmType));
}

if (functionAggregate.Function.NamespaceName == "NpgsqlAggregateFunctions")
{
FunctionExpression aggregate;
try
{
aggregate = new FunctionExpression(functionAggregate.Function.StoreFunctionNameAttribute);
}
catch (KeyNotFoundException)
{
throw new NotSupportedException();
}
Debug.Assert(functionAggregate.Arguments.Count == 1);
VisitedExpression aggregateArg;
if (functionAggregate.Distinct)
{
aggregateArg = new LiteralExpression("DISTINCT ");
((LiteralExpression)aggregateArg).Append(functionAggregate.Arguments[0].Accept(this));
}
else
{
aggregateArg = functionAggregate.Arguments[0].Accept(this);
}
aggregate.AddArgument(aggregateArg);

if (functionAggregate.Function.Name == "StringAgg")
{
// HACK add second argument, since EF doesn't support more than one argument for aggregate functions
aggregate.AddArgument(new ConstantExpression(", ", functionAggregate.ResultType));
}

return new CastExpression(aggregate, GetDbType(functionAggregate.ResultType.EdmType));
}

throw new NotSupportedException();
}

Expand Down
7 changes: 4 additions & 3 deletions EntityFramework6.Npgsql.sln
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
VisualStudioVersion = 15.0.27703.2035

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 16
VisualStudioVersion = 16.0.31105.61
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{4A5A60DD-41B6-40BF-B677-227A921ECCC8}"
ProjectSection(SolutionItems) = preProject
Expand Down