Skip to content

Commit

Permalink
Partial implementation of CSHARP-433. Implemented OfType query operat…
Browse files Browse the repository at this point in the history
…or. Still need to implement "is" operator and comparison of types using "==".
  • Loading branch information
rstam committed Apr 12, 2012
1 parent 11427b5 commit 3b46a19
Show file tree
Hide file tree
Showing 8 changed files with 579 additions and 12 deletions.
1 change: 1 addition & 0 deletions Driver/Driver.csproj
Expand Up @@ -183,6 +183,7 @@
<Compile Include="Internal\MongoUpdateMessage.cs" />
<Compile Include="Internal\ReplicaSetConnector.cs" />
<Compile Include="Linq\Expressions\ExpressionFormatter.cs" />
<Compile Include="Linq\Expressions\ExpressionParameterFinder.cs" />
<Compile Include="Linq\Expressions\ExpressionParameterReplacer.cs" />
<Compile Include="Linq\LinqToMongo.cs" />
<Compile Include="Linq\Translators\DeserializationProjector.cs" />
Expand Down
82 changes: 82 additions & 0 deletions Driver/Linq/Expressions/ExpressionParameterFinder.cs
@@ -0,0 +1,82 @@
/* Copyright 2010-2012 10gen Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Text;

namespace MongoDB.Driver.Linq
{
/// <summary>
/// A class that finds the first parameter in an expression.
/// </summary>
public class ExpressionParameterFinder : ExpressionVisitor
{
// private fields
private ParameterExpression _parameter;

// constructors
/// <summary>
/// Initializes a new instance of the ExpressionParameterFinder class.
/// </summary>
public ExpressionParameterFinder()
{
}

// public static methods
/// <summary>
/// Finds the first parameter in an expression.
/// </summary>
/// <param name="node">The expression containing the parameter that should be found.</param>
/// <returns>The first parameter found in the expression (or null if none was found).</returns>
public static ParameterExpression FindParameter(Expression node)
{
var finder = new ExpressionParameterFinder();
finder.Visit(node);
return finder._parameter;
}

// protected methods
/// <summary>
/// Visits an Expression.
/// </summary>
/// <param name="node">The Expression.</param>
/// <returns>The Expression (posibly modified).</returns>
protected override Expression Visit(Expression node)
{
if (_parameter != null)
{
return node; // exit faster if we've already found the parameter
}
return base.Visit(node);
}

/// <summary>
/// Remembers this parameter if it is the first parameter found.
/// </summary>
/// <param name="node">The ParameterExpression.</param>
/// <returns>The ParameterExpression.</returns>
protected override Expression VisitParameter(ParameterExpression node)
{
if (_parameter == null)
{
_parameter = node;
}
return node;
}
}
}
122 changes: 119 additions & 3 deletions Driver/Linq/Translators/SelectQuery.cs
Expand Up @@ -37,6 +37,7 @@ public class SelectQuery : TranslatedQuery
{
// private fields
private LambdaExpression _where;
private Type _ofType;
private List<OrderByClause> _orderBy;
private LambdaExpression _projection;
private Expression _skip;
Expand All @@ -56,6 +57,14 @@ public SelectQuery(MongoCollection collection, Type documentType)
}

// public properties
/// <summary>
/// Gets the final result type if an OfType query operator was used (otherwise null).
/// </summary>
public Type OfType
{
get { return _ofType; }
}

/// <summary>
/// Gets a list of Expressions that defines the sort order (or null if not specified).
/// </summary>
Expand Down Expand Up @@ -157,6 +166,21 @@ public override object Execute()
cursor.SetLimit(ToInt32(_take));
}

if (_ofType != null)
{
if (_projection == null)
{
var paramExpression = Expression.Parameter(DocumentType, "x");
var convertExpression = Expression.Convert(paramExpression, _ofType);
_projection = Expression.Lambda(convertExpression, paramExpression);
}
else
{
// TODO: handle projection after OfType
throw new NotSupportedException();
}
}

IEnumerable enumerable;
if (_projection == null)
{
Expand Down Expand Up @@ -208,6 +232,7 @@ public void Translate(Expression expression)
}

var message = string.Format("Don't know how to translate expression: {0}.", ExpressionFormatter.ToString(expression));
throw new NotSupportedException(message);
}

// private methods
Expand Down Expand Up @@ -1239,10 +1264,39 @@ private void CombinePredicateWithWhereClause(MethodCallExpression methodCallExpr
return;
}

if (_where.Parameters.Count != 1)
{
throw new MongoInternalException("Where lambda expression should have one parameter.");
}
var whereBody = _where.Body;
var predicateBody = ExpressionParameterReplacer.ReplaceParameter(predicate.Body, predicate.Parameters[0], _where.Parameters[0]);
var whereParameter = _where.Parameters[0];

if (predicate.Parameters.Count != 1)
{
throw new MongoInternalException("Predicate lambda expression should have one parameter.");
}
var predicateBody = predicate.Body;
var predicateParameter = predicate.Parameters[0];

// when using OfType the parameter types might not match (but they do have to be compatible)
ParameterExpression parameter;
if (predicateParameter.Type.IsAssignableFrom(whereParameter.Type))
{
predicateBody = ExpressionParameterReplacer.ReplaceParameter(predicateBody, predicateParameter, whereParameter);
parameter = whereParameter;
}
else if (whereParameter.Type.IsAssignableFrom(predicateParameter.Type))
{
whereBody = ExpressionParameterReplacer.ReplaceParameter(whereBody, whereParameter, predicateParameter);
parameter = predicateParameter;
}
else
{
throw new NotSupportedException("Can't combine existing where clause with new predicate because parameter types are incompatible.");
}

var combinedBody = Expression.AndAlso(whereBody, predicateBody);
_where = Expression.Lambda(combinedBody, _where.Parameters.ToArray());
_where = Expression.Lambda(combinedBody, parameter);
}
}

Expand Down Expand Up @@ -1278,7 +1332,9 @@ private object ExecuteDistinct(IMongoQuery query)

private BsonSerializationInfo GetSerializationInfo(Expression expression)
{
var documentSerializer = BsonSerializer.LookupSerializer(DocumentType);
// when using OfType the documentType used by the parameter might be a subclass of the DocumentType from the collection
var parameterExpression = ExpressionParameterFinder.FindParameter(expression);
var documentSerializer = BsonSerializer.LookupSerializer(parameterExpression.Type);
return GetSerializationInfo(documentSerializer, expression);
}

Expand Down Expand Up @@ -1755,6 +1811,9 @@ private void TranslateMethodCall(MethodCallExpression methodCallExpression)
case "Min":
TranslateMaxMin(methodCallExpression);
break;
case "OfType":
TranslateOfType(methodCallExpression);
break;
case "OrderBy":
case "OrderByDescending":
TranslateOrderBy(methodCallExpression);
Expand All @@ -1781,6 +1840,63 @@ private void TranslateMethodCall(MethodCallExpression methodCallExpression)
}
}

private void TranslateOfType(MethodCallExpression methodCallExpression)
{
var method = methodCallExpression.Method;
if (method.DeclaringType != typeof(Queryable))
{
var message = string.Format("OfType method of class {0} is not supported.", BsonUtils.GetFriendlyTypeName(method.DeclaringType));
throw new NotSupportedException(message);
}
if (!method.IsStatic)
{
throw new NotSupportedException("Expected OfType to be a static method.");
}
if (!method.IsGenericMethod)
{
throw new NotSupportedException("Expected OfType to be a generic method.");
}
var actualType = method.GetGenericArguments()[0];

var args = methodCallExpression.Arguments.ToArray();
if (args.Length != 1)
{
throw new NotSupportedException("Expected OfType method to have a single argument.");
}
var sourceExpression = args[0];
if (!sourceExpression.Type.IsGenericType)
{
throw new NotSupportedException("Expected source argument to OfType to be a generic type.");
}
var nominalType = sourceExpression.Type.GetGenericArguments()[0];

if (nominalType == actualType)
{
return; // nothing to do
}

if (_projection != null)
{
throw new NotSupportedException("OfType after a projection is not supported.");
}

var discriminatorConvention = BsonDefaultSerializer.LookupDiscriminatorConvention(nominalType);
var discriminator = discriminatorConvention.GetDiscriminator(nominalType, actualType);
if (discriminator.IsBsonArray)
{
discriminator = discriminator.AsBsonArray[discriminator.AsBsonArray.Count - 1];
}

var injectMethodInfo = typeof(LinqToMongo).GetMethod("Inject");
var query = Query.EQ("_t", discriminator);
var body = Expression.Call(injectMethodInfo, Expression.Constant(query));
var parameter = Expression.Parameter(nominalType, "x");
var predicate = Expression.Lambda(body, parameter);
CombinePredicateWithWhereClause(methodCallExpression, predicate);

_ofType = actualType;
}

private void TranslateOrderBy(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Arguments.Count != 2)
Expand Down
2 changes: 2 additions & 0 deletions DriverUnitTests/DriverUnitTests.csproj
Expand Up @@ -162,6 +162,8 @@
<Compile Include="Jira\CSharp93Tests.cs" />
<Compile Include="Jira\CSharp98Tests.cs" />
<Compile Include="Jira\CSharp100Tests.cs" />
<Compile Include="Linq\SelectOfTypeHierarchicalTests.cs" />
<Compile Include="Linq\SelectOfTypeTests.cs" />
<Compile Include="Linq\SelectQueryTests.cs" />
<Compile Include="Linq\MongoQueryableTests.cs" />
<Compile Include="Linq\MongoQueryProviderTests.cs" />
Expand Down
2 changes: 2 additions & 0 deletions DriverUnitTests/GridFS/MongoGridFSStreamTests.cs
Expand Up @@ -282,7 +282,9 @@ public void TestUpdateMD5()
{
var bytes = new byte[] { 1, 2, 3, 4 };
stream.Write(bytes, 0, 4);
#pragma warning disable 618 // about obsolete BsonBinarySubType.OldBinary
stream.UpdateMD5 = false;
#pragma warning restore
}

fileInfo = _gridFS.FindOne("test");
Expand Down

0 comments on commit 3b46a19

Please sign in to comment.