From 612f39d7a4c61f02f65af24f0acf24bb13a23da0 Mon Sep 17 00:00:00 2001 From: Ricardo Peres Date: Thu, 16 Apr 2015 09:10:10 +0100 Subject: [PATCH 1/2] NH-3470 --- .../Linq/QueryReadOnlyTests.cs | 47 ++++++++++++++ src/NHibernate.Test/NHibernate.Test.csproj | 1 + .../GroupBy/AggregatingGroupByRewriter.cs | 1 + src/NHibernate/Linq/LinqExtensionMethods.cs | 28 +++++--- src/NHibernate/Linq/NhRelinqQueryParser.cs | 64 ++++++++++++++++++- .../QueryReferenceExpressionFlattener.cs | 1 + .../Linq/ReWriters/ResultOperatorRewriter.cs | 1 + .../Linq/Visitors/QueryModelVisitor.cs | 1 + .../ProcessAsReadOnly.cs | 15 +++++ src/NHibernate/NHibernate.csproj | 1 + 10 files changed, 147 insertions(+), 13 deletions(-) create mode 100644 src/NHibernate.Test/Linq/QueryReadOnlyTests.cs create mode 100644 src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAsReadOnly.cs diff --git a/src/NHibernate.Test/Linq/QueryReadOnlyTests.cs b/src/NHibernate.Test/Linq/QueryReadOnlyTests.cs new file mode 100644 index 00000000000..1cb51c2aeb1 --- /dev/null +++ b/src/NHibernate.Test/Linq/QueryReadOnlyTests.cs @@ -0,0 +1,47 @@ +using System.Linq; +using NHibernate.Cfg; +using NHibernate.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.Linq +{ + public class QueryReadOnlyTests : LinqTestCase + { + protected override void Configure(Configuration configuration) + { + base.Configure(configuration); + } + + [Test] + public void CanSetReadOnlyOnLinqQueries() + { + var result = (from e in db.Customers + where e.CompanyName == "Bon app'" + select e).AsReadOnly().ToList(); + + Assert.That(result.All(x => this.session.IsReadOnly(x)), Is.True); + } + + + [Test] + public void CanSetReadOnlyOnLinqPagingQuery() + { + var result = (from e in db.Customers + select e).Skip(1).Take(1).AsReadOnly().ToList(); + + Assert.That(result.All(x => this.session.IsReadOnly(x)), Is.True); + } + + + [Test] + public void CanSetReadOnlyBeforeSkipOnLinqOrderedPageQuery() + { + var result = (from e in db.Customers + orderby e.CompanyName + select e) + .AsReadOnly().Skip(5).Take(5).ToList(); + + Assert.That(result.All(x => this.session.IsReadOnly(x)), Is.True); + } + } +} \ No newline at end of file diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index e5aa8a0e3c6..de8fc95e316 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -539,6 +539,7 @@ + diff --git a/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs b/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs index 11bf22bfb0d..53b1a3306c7 100644 --- a/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/AggregatingGroupByRewriter.cs @@ -39,6 +39,7 @@ public static class AggregatingGroupByRewriter typeof (AnyResultOperator), typeof (AllResultOperator), typeof (TimeoutResultOperator), + typeof (AsReadOnlyResultOperator), typeof (CacheableResultOperator) }; diff --git a/src/NHibernate/Linq/LinqExtensionMethods.cs b/src/NHibernate/Linq/LinqExtensionMethods.cs index e9ee33324f0..09cc9bdc4e1 100755 --- a/src/NHibernate/Linq/LinqExtensionMethods.cs +++ b/src/NHibernate/Linq/LinqExtensionMethods.cs @@ -22,7 +22,7 @@ public static IQueryable Query(this IStatelessSession session) public static IQueryable Cacheable(this IQueryable query) { - var method = ReflectionHelper.GetMethodDefinition(() => Cacheable(null)).MakeGenericMethod(typeof (T)); + var method = ReflectionHelper.GetMethodDefinition(() => Cacheable(null)).MakeGenericMethod(typeof(T)); var callExpression = Expression.Call(method, query.Expression); @@ -31,7 +31,7 @@ public static IQueryable Cacheable(this IQueryable query) public static IQueryable CacheMode(this IQueryable query, CacheMode cacheMode) { - var method = ReflectionHelper.GetMethodDefinition(() => CacheMode(null, NHibernate.CacheMode.Normal)).MakeGenericMethod(typeof (T)); + var method = ReflectionHelper.GetMethodDefinition(() => CacheMode(null, NHibernate.CacheMode.Normal)).MakeGenericMethod(typeof(T)); var callExpression = Expression.Call(method, query.Expression, Expression.Constant(cacheMode)); @@ -40,13 +40,21 @@ public static IQueryable CacheMode(this IQueryable query, CacheMode cac public static IQueryable CacheRegion(this IQueryable query, string region) { - var method = ReflectionHelper.GetMethodDefinition(() => CacheRegion(null, null)).MakeGenericMethod(typeof (T)); + var method = ReflectionHelper.GetMethodDefinition(() => CacheRegion(null, null)).MakeGenericMethod(typeof(T)); var callExpression = Expression.Call(method, query.Expression, Expression.Constant(region)); return new NhQueryable(query.Provider, callExpression); } + public static IQueryable AsReadOnly(this IQueryable query) + { + var method = ReflectionHelper.GetMethodDefinition(() => AsReadOnly(null)).MakeGenericMethod(typeof(T)); + + var callExpression = Expression.Call(method, query.Expression); + + return new NhQueryable(query.Provider, callExpression); + } public static IQueryable Timeout(this IQueryable query, int timeout) { @@ -63,9 +71,9 @@ public static IEnumerable ToFuture(this IQueryable query) if (nhQueryable == null) throw new NotSupportedException("Query needs to be of type QueryableBase"); - var provider = (INhQueryProvider) nhQueryable.Provider; + var provider = (INhQueryProvider)nhQueryable.Provider; var future = provider.ExecuteFuture(nhQueryable.Expression); - return (IEnumerable) future; + return (IEnumerable)future; } public static IFutureValue ToFutureValue(this IQueryable query) @@ -74,14 +82,14 @@ public static IFutureValue ToFutureValue(this IQueryable query) if (nhQueryable == null) throw new NotSupportedException("Query needs to be of type QueryableBase"); - var provider = (INhQueryProvider) nhQueryable.Provider; + var provider = (INhQueryProvider)nhQueryable.Provider; var future = provider.ExecuteFuture(nhQueryable.Expression); if (future is IEnumerable) { - return new FutureValue(() => ((IEnumerable) future)); + return new FutureValue(() => ((IEnumerable)future)); } - return (IFutureValue) future; + return (IFutureValue)future; } public static IFutureValue ToFutureValue(this IQueryable query, Expression, TResult>> selector) @@ -90,13 +98,13 @@ public static IFutureValue ToFutureValue(this IQueryable if (nhQueryable == null) throw new NotSupportedException("Query needs to be of type QueryableBase"); - var provider = (INhQueryProvider) query.Provider; + var provider = (INhQueryProvider)query.Provider; var expression = ReplacingExpressionTreeVisitor.Replace(selector.Parameters.Single(), query.Expression, selector.Body); - return (IFutureValue) provider.ExecuteFuture(expression); + return (IFutureValue)provider.ExecuteFuture(expression); } } } diff --git a/src/NHibernate/Linq/NhRelinqQueryParser.cs b/src/NHibernate/Linq/NhRelinqQueryParser.cs index 86c1410099e..a7f10082b43 100644 --- a/src/NHibernate/Linq/NhRelinqQueryParser.cs +++ b/src/NHibernate/Linq/NhRelinqQueryParser.cs @@ -75,7 +75,14 @@ public NHibernateNodeTypeProvider() new[] { ReflectionHelper.GetMethodDefinition(() => LinqExtensionMethods.Timeout(null, 0)), - }, typeof (TimeoutExpressionNode) + }, typeof(TimeoutExpressionNode) + ); + + methodInfoRegistry.Register( + new[] + { + ReflectionHelper.GetMethodDefinition(() => LinqExtensionMethods.AsReadOnly(null)), + }, typeof(AsReadOnlyExpressionNode) ); var nodeTypeProvider = ExpressionTreeParser.CreateDefaultNodeTypeProvider(); @@ -100,7 +107,8 @@ public System.Type GetNodeType(MethodInfo method) public class AsQueryableExpressionNode : MethodCallExpressionNodeBase { - public AsQueryableExpressionNode(MethodCallExpressionParseInfo parseInfo) : base(parseInfo) + public AsQueryableExpressionNode(MethodCallExpressionParseInfo parseInfo) + : base(parseInfo) { } @@ -120,7 +128,8 @@ public class CacheableExpressionNode : ResultOperatorExpressionNodeBase private readonly MethodCallExpressionParseInfo _parseInfo; private readonly ConstantExpression _data; - public CacheableExpressionNode(MethodCallExpressionParseInfo parseInfo, ConstantExpression data) : base(parseInfo, null, null) + public CacheableExpressionNode(MethodCallExpressionParseInfo parseInfo, ConstantExpression data) + : base(parseInfo, null, null) { _parseInfo = parseInfo; _data = data; @@ -168,6 +177,55 @@ public override void TransformExpressions(Func transform } } + internal class AsReadOnlyExpressionNode : ResultOperatorExpressionNodeBase + { + private readonly MethodCallExpressionParseInfo _parseInfo; + + public AsReadOnlyExpressionNode(MethodCallExpressionParseInfo parseInfo) + : base(parseInfo, null, null) + { + _parseInfo = parseInfo; + } + + public override Expression Resolve(ParameterExpression inputParameter, Expression expressionToBeResolved, ClauseGenerationContext clauseGenerationContext) + { + return Source.Resolve(inputParameter, expressionToBeResolved, clauseGenerationContext); + } + + protected override ResultOperatorBase CreateResultOperator(ClauseGenerationContext clauseGenerationContext) + { + return new AsReadOnlyResultOperator(_parseInfo); + } + } + + internal class AsReadOnlyResultOperator : ResultOperatorBase + { + public MethodCallExpressionParseInfo ParseInfo { get; private set; } + + public AsReadOnlyResultOperator(MethodCallExpressionParseInfo parseInfo) + { + ParseInfo = parseInfo; + } + + public override IStreamedData ExecuteInMemory(IStreamedData input) + { + throw new NotImplementedException(); + } + + public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo) + { + return inputInfo; + } + + public override ResultOperatorBase Clone(CloneContext cloneContext) + { + throw new NotImplementedException(); + } + + public override void TransformExpressions(Func transformation) + { + } + } internal class TimeoutExpressionNode : ResultOperatorExpressionNodeBase { diff --git a/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs b/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs index 9ec70a5fabf..3968c6de5f9 100644 --- a/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs +++ b/src/NHibernate/Linq/ReWriters/QueryReferenceExpressionFlattener.cs @@ -16,6 +16,7 @@ public class QueryReferenceExpressionFlattener : ExpressionTreeVisitor { typeof (CacheableResultOperator), typeof (TimeoutResultOperator), + typeof (AsReadOnlyResultOperator), }; private QueryReferenceExpressionFlattener(QueryModel model) diff --git a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs index fe1e50f7c40..493c5916fdc 100644 --- a/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs +++ b/src/NHibernate/Linq/ReWriters/ResultOperatorRewriter.cs @@ -67,6 +67,7 @@ private class ResultOperatorExpressionRewriter : ExpressionTreeVisitor typeof(OfTypeResultOperator), typeof(CacheableResultOperator), typeof(TimeoutResultOperator), + typeof(AsReadOnlyResultOperator), typeof(CastResultOperator), // see ProcessCast class }; diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 9ac7c237acc..ee1b7da76c5 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -116,6 +116,7 @@ static QueryModelVisitor() ResultOperatorMap.Add(); ResultOperatorMap.Add(); ResultOperatorMap.Add(); + ResultOperatorMap.Add(); ResultOperatorMap.Add(); ResultOperatorMap.Add(); } diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAsReadOnly.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAsReadOnly.cs new file mode 100644 index 00000000000..8441cb37a59 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessAsReadOnly.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace NHibernate.Linq.Visitors.ResultOperatorProcessors +{ + internal class ProcessAsReadOnly : IResultOperatorProcessor + { + public void Process(AsReadOnlyResultOperator resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree) + { + tree.AddAdditionalCriteria((q, p) => q.SetReadOnly(true)); + } + } +} \ No newline at end of file diff --git a/src/NHibernate/NHibernate.csproj b/src/NHibernate/NHibernate.csproj index cbb25cad16d..365321a7c18 100644 --- a/src/NHibernate/NHibernate.csproj +++ b/src/NHibernate/NHibernate.csproj @@ -321,6 +321,7 @@ + From 1b01cc6c9a95b4c14d32183b09b0a5c4beae17b3 Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Mon, 13 Feb 2017 01:13:21 +1300 Subject: [PATCH 2/2] Copy test from #260 --- .../Linq/QueryReadOnlyTests.cs | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/NHibernate.Test/Linq/QueryReadOnlyTests.cs b/src/NHibernate.Test/Linq/QueryReadOnlyTests.cs index 1cb51c2aeb1..7d315ad8653 100644 --- a/src/NHibernate.Test/Linq/QueryReadOnlyTests.cs +++ b/src/NHibernate.Test/Linq/QueryReadOnlyTests.cs @@ -7,11 +7,6 @@ namespace NHibernate.Test.Linq { public class QueryReadOnlyTests : LinqTestCase { - protected override void Configure(Configuration configuration) - { - base.Configure(configuration); - } - [Test] public void CanSetReadOnlyOnLinqQueries() { @@ -43,5 +38,22 @@ orderby e.CompanyName Assert.That(result.All(x => this.session.IsReadOnly(x)), Is.True); } + + [Test] + public void CanSetReadOnlyOnLinqGroupPageQuery() + { + var subQuery = db.Customers.Where(e2 => e2.CompanyName.Contains("a")).Select(e2 => e2.CustomerId) + .AsReadOnly(); // This AsReadOnly() should not cause trouble, and be ignored. + + var result = (from e in db.Customers + where subQuery.Contains(e.CustomerId) + group e by e.CompanyName + into g + select new { g.Key, Count = g.Count() }) + .Skip(5).Take(5) + .AsReadOnly().ToList(); + + Assert.That(result.All(x => this.session.IsReadOnly(x)), Is.True); + } } -} \ No newline at end of file +}