Skip to content

Commit

Permalink
existing tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
craiggwilson committed May 1, 2010
1 parent c1b238e commit 0d4e26d
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 300 deletions.
11 changes: 11 additions & 0 deletions source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryTests.cs
Expand Up @@ -350,5 +350,16 @@ public void Complex_Addition()

Assert.AreEqual(1, people.Count);
}

[Test]
public void Join()
{
var people = Enumerable.ToList(
from p in collection.Linq()
join op in collection.Linq() on p.PrimaryAddress equals op.PrimaryAddress
select p);

Assert.AreEqual(0, people.Count);
}
}
}
99 changes: 9 additions & 90 deletions source/MongoDB/Linq/ExecutionBuilder.cs
Expand Up @@ -18,7 +18,6 @@ internal class ExecutionBuilder : MongoExpressionVisitor
private int _numCursors;
private Expression _provider;
private MemberInfo _receivingMember;
private Scope _scope;
private List<ParameterExpression> _variables;


Expand Down Expand Up @@ -80,6 +79,11 @@ protected override Expression VisitClientJoin(ClientJoinExpression clientJoin)
return access;
}

protected override Expression VisitField(FieldExpression field)
{
return Visit(field.Expression);
}

protected override Expression VisitProjection(ProjectionExpression projection)
{
if (_isTop)
Expand Down Expand Up @@ -119,7 +123,6 @@ private Expression Build(Expression expression)
private Expression BuildInner(Expression expression)
{
var builder = new ExecutionBuilder();
builder._scope = _scope;
builder._receivingMember = _receivingMember;
builder._numCursors = _numCursors;
builder._lookup = _lookup;
Expand All @@ -128,33 +131,17 @@ private Expression BuildInner(Expression expression)

private Expression ExecuteProjection(ProjectionExpression projection)
{
projection = (ProjectionExpression)new Parameterizer().Parameterize(projection);

if(_scope != null)
projection = (ProjectionExpression)new OuterParameterizer().Parameterize(projection, _scope.Alias);

var saveScope = _scope;
var document = Expression.Parameter(projection.Projector.Type, "d" + (_numCursors++));
_scope = new Scope(_scope, document, projection.Source.Alias, projection.Source.Fields);
var projector = Expression.Lambda(Visit(projection.Projector), document);
_scope = saveScope;

var projection = base.VisitProjection(projection);
var queryObject = new MongoQueryObjectBuilder().Build(projection);
queryObject.Projector = new ProjectionBuilder().Build(queryObject, projector);

var namedValues = new NamedValueGatherer().Gather(projection.Source);
var names = namedValues.Select(v => v.Name).ToArray();
var values = namedValues.Select(v => Expression.Convert(Visit(v.Value), typeof(object))).ToArray();
queryObject.Projector = new ProjectionBuilder().Build(projection.Projector, queryObject.DocumentType, "d" + (_numCursors++), queryObject.IsMapReduce);
queryObject.Aggregator = projection.Aggregator;

Expression result = Expression.Call(
_provider,
"ExecuteQueryObject",
new[] { queryObject.DocumentType, queryObject.Projector.Type },
Type.EmptyTypes,
Expression.Constant(queryObject, typeof(MongoQueryObject)));

if(projection.Aggregator != null)
result = new ExpressionReplacer().Replace(projection.Aggregator.Body, projection.Aggregator.Parameters[0], result);

return result;
}

Expand Down Expand Up @@ -192,74 +179,6 @@ private static Expression MakeSequence(IList<Expression> expressions)
return Expression.Convert(Expression.Call(typeof(ExecutionBuilder), "Sequence", null, Expression.NewArrayInit(typeof(object), expressions)), last.Type);
}

private class Scope
{
private ParameterExpression _document;
private Scope _outer;
private Dictionary<string, int> _nameMap;

internal Alias Alias { get; private set; }

public Scope(Scope outer, ParameterExpression document, Alias alias, IEnumerable<FieldDeclaration> fields)
{
_outer = outer;
_document = document;
Alias = alias;
_nameMap = fields.Select((f, i) => new { f, i }).ToDictionary(x => x.f.Name, x => x.i);
}

public bool TryGetValue(FieldExpression field, out ParameterExpression document, out int ordinal)
{
for (Scope s = this; s != null; s = s._outer)
{
if (field.Alias == s.Alias && _nameMap.TryGetValue(field.Name, out ordinal))
{
document = _document;
return true;
}
}
document = null;
ordinal = 0;
return false;
}
}

private class OuterParameterizer : MongoExpressionVisitor
{
private int _paramIndex;
private Alias _outerAlias;
private Dictionary<FieldExpression, NamedValueExpression> _map;

public Expression Parameterize(Expression expression, Alias outerAlias)
{
_outerAlias = outerAlias;
return Visit(expression);
}

protected override Expression VisitProjection(ProjectionExpression projection)
{
SelectExpression select = (SelectExpression)Visit(projection.Source);
if (select != projection.Source)
return new ProjectionExpression(select, projection.Projector, projection.Aggregator);
return projection;
}

protected override Expression VisitField(FieldExpression field)
{
if (field.Alias == _outerAlias)
{
NamedValueExpression nv;
if (!_map.TryGetValue(field, out nv))
{
nv = new NamedValueExpression("n" + (_paramIndex++), field);
_map.Add(field, nv);
}
return nv;
}
return field;
}
}

private class CompoundKey : IEquatable<CompoundKey>
{
private object[] _values;
Expand Down
6 changes: 6 additions & 0 deletions source/MongoDB/Linq/MongoQueryObject.cs
Expand Up @@ -12,6 +12,12 @@ internal class MongoQueryObject
private Document _query;
private Document _sort;

/// <summary>
/// Gets or sets the aggregator.
/// </summary>
/// <value>The aggregator.</value>
public LambdaExpression Aggregator { get; set; }

/// <summary>
/// Gets or sets the name of the collection.
/// </summary>
Expand Down
56 changes: 30 additions & 26 deletions source/MongoDB/Linq/MongoQueryProvider.cs
Expand Up @@ -138,12 +138,12 @@ internal MongoQueryObject GetQueryObject(Expression expression)
/// </summary>
/// <param name="queryObject">The query object.</param>
/// <returns></returns>
internal IEnumerable<TResult> ExecuteQueryObject<TDocument, TResult>(MongoQueryObject queryObject){
internal object ExecuteQueryObject(MongoQueryObject queryObject){
if (queryObject.IsCount)
return ExecuteCount<TDocument, TResult>(queryObject);
return ExecuteCount(queryObject);
else if (queryObject.IsMapReduce)
return ExecuteMapReduce<TDocument, TResult>(queryObject);
return ExecuteFind<TDocument, TResult>(queryObject);
return ExecuteMapReduce(queryObject);
return ExecuteFind(queryObject);
}

private Expression BuildExecutionPlan(Expression expression)
Expand Down Expand Up @@ -220,20 +220,18 @@ private bool CanBeEvaluatedLocally(Expression expression)
/// </summary>
/// <param name="queryObject">The query object.</param>
/// <returns></returns>
private IEnumerable<TResult> ExecuteCount<TDocument, TResult>(MongoQueryObject queryObject)
private object ExecuteCount(MongoQueryObject queryObject)
{
var miGetCollection = typeof(IMongoDatabase).GetMethods().Where(m => m.Name == "GetCollection" && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1).Single().MakeGenericMethod(queryObject.DocumentType);
var collection = miGetCollection.Invoke(queryObject.Database, new[] { queryObject.CollectionName });

IEnumerable<TDocument> documents;
if (queryObject.Query == null)
documents = new[] { (TDocument)collection.GetType().GetMethod("Count", Type.EmptyTypes).Invoke(collection, null) };
documents = new[] { (TDocument)collection.GetType().GetMethod("Count", new[] { typeof(object) }).Invoke(collection, new[] { queryObject.Query }) };
return Convert.ToInt32(collection.GetType().GetMethod("Count", Type.EmptyTypes).Invoke(collection, null));

return Project(documents, (Func<TDocument, TResult>)queryObject.Projector.Compile());
return Convert.ToInt32(collection.GetType().GetMethod("Count", new[] { typeof(object) }).Invoke(collection, new[] { queryObject.Query }));
}

private IEnumerable<TResult> ExecuteFind<TDocument, TResult>(MongoQueryObject queryObject)
private object ExecuteFind(MongoQueryObject queryObject)
{
var miGetCollection = typeof(IMongoDatabase).GetMethods().Where(m => m.Name == "GetCollection" && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1).Single().MakeGenericMethod(queryObject.DocumentType);
var collection = miGetCollection.Invoke(queryObject.Database, new[] { queryObject.CollectionName });
Expand All @@ -256,11 +254,11 @@ private bool CanBeEvaluatedLocally(Expression expression)
cursorType.GetMethod("Limit").Invoke(cursor, new object[] { queryObject.NumberToLimit });
cursorType.GetMethod("Skip").Invoke(cursor, new object[] { queryObject.NumberToSkip });

var documents = (IEnumerable<TDocument>)cursor.GetType().GetProperty("Documents").GetValue(cursor, null);
return Project(documents, (Func<TDocument, TResult>)queryObject.Projector.Compile());
var executor = GetExecutor(queryObject.DocumentType, queryObject.Projector, queryObject.Aggregator, true);
return executor.Compile().DynamicInvoke(cursor.GetType().GetProperty("Documents").GetValue(cursor, null));
}

private IEnumerable<TResult> ExecuteMapReduce<TDocument, TResult>(MongoQueryObject queryObject)
private object ExecuteMapReduce(MongoQueryObject queryObject)
{
var miGetCollection = typeof(IMongoDatabase).GetMethods().Where(m => m.Name == "GetCollection" && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1).Single().MakeGenericMethod(queryObject.DocumentType);
var collection = miGetCollection.Invoke(queryObject.Database, new[] { queryObject.CollectionName });
Expand All @@ -271,35 +269,41 @@ private bool CanBeEvaluatedLocally(Expression expression)
mapReduce.Finalize = new Code(queryObject.FinalizerFunction);
mapReduce.Query = queryObject.Query;

if (queryObject.Sort != null)
if(queryObject.Sort != null)
mapReduce.Sort = queryObject.Sort;

mapReduce.Limit = queryObject.NumberToLimit;
if (queryObject.NumberToSkip != 0)
throw new InvalidQueryException("MapReduce queries do no support Skips.");

var documents = (IEnumerable<TDocument>)mapReduce.Documents;
return Project(documents, (Func<TDocument, TResult>)queryObject.Projector.Compile());
}

private IEnumerable<TResult> Project<TDocument, TResult>(IEnumerable<TDocument> documents, Func<TDocument, TResult> projector)
{
foreach (var doc in documents)
{
yield return projector(doc);
}
var executor = GetExecutor(typeof(Document), queryObject.Projector, queryObject.Aggregator, true);
return executor.Compile().DynamicInvoke(mapReduce.Documents);
}

private static LambdaExpression GetExecutor(Type documentType, LambdaExpression projector, bool boxReturn)
private static LambdaExpression GetExecutor(Type documentType, LambdaExpression projector, LambdaExpression aggregator, bool boxReturn)
{
var documents = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(documentType), "documents");
Expression body = Expression.New(typeof(ProjectionReader<,>).MakeGenericType(documentType, projector.Body.Type).GetConstructors()[0], documents, projector);
Expression body = Expression.Call(
typeof(MongoQueryProvider),
"Project",
new[] { documentType, projector.Body.Type },
documents,
projector);
if (aggregator != null)
body = Expression.Invoke(aggregator, body);

if (boxReturn && body.Type != typeof(object))
body = Expression.Convert(body, typeof(object));

return Expression.Lambda(body, documents);
}

private static IEnumerable<TResult> Project<TDocument, TResult>(IEnumerable<TDocument> documents, Func<TDocument, TResult> projector)
{
foreach (var doc in documents)
yield return projector(doc);
}

private class RootQueryableFinder : MongoExpressionVisitor
{
private Expression _root;
Expand Down
70 changes: 0 additions & 70 deletions source/MongoDB/Linq/ProjectionReader.cs

This file was deleted.

Expand Up @@ -81,7 +81,7 @@ protected override Expression VisitSubquery(SubqueryExpression subquery)
private bool CanJoinOnClient(SelectExpression select)
{
return !select.IsDistinct
&& select.GroupBy != null
&& select.GroupBy == null
&& !new AggregateChecker().HasAggregates(select);
}

Expand Down

0 comments on commit 0d4e26d

Please sign in to comment.