Skip to content

Commit

Permalink
Low-level AttachGraph method
Browse files Browse the repository at this point in the history
Traverses a graph and provides a callback for each entry visited. See #1229

No async version yet, and no higher-level overloads.
  • Loading branch information
ajcvickers committed Dec 8, 2014
1 parent 48b776f commit 0feb4cf
Show file tree
Hide file tree
Showing 9 changed files with 457 additions and 148 deletions.
40 changes: 30 additions & 10 deletions src/EntityFramework.Core/ChangeTracking/ChangeTracker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
using System.Collections.Generic;
using System.Linq;
using JetBrains.Annotations;
using Microsoft.Data.Entity.Infrastructure;
using Microsoft.Data.Entity.Utilities;

namespace Microsoft.Data.Entity.ChangeTracking
{
// This is the app-developer facing public API to the change tracker
public class ChangeTracker
{
private readonly StateManager _stateManager;
private readonly ChangeDetector _changeDetector;
private readonly EntityEntryGraphIterator _graphIterator;
private readonly DbContextService<DbContext> _context;

/// <summary>
/// This constructor is intended only for use when creating test doubles that will override members
Expand All @@ -24,35 +26,53 @@ protected ChangeTracker()
{
}

public ChangeTracker([NotNull] StateManager stateManager, [NotNull] ChangeDetector changeDetector)
public ChangeTracker(
[NotNull] StateManager stateManager,
[NotNull] ChangeDetector changeDetector,
[NotNull] EntityEntryGraphIterator graphIterator,
[NotNull] DbContextService<DbContext> context)
{
Check.NotNull(stateManager, "stateManager");
Check.NotNull(changeDetector, "changeDetector");
Check.NotNull(graphIterator, "graphIterator");
Check.NotNull(context, "context");

_stateManager = stateManager;
StateManager = stateManager;
_changeDetector = changeDetector;
_graphIterator = graphIterator;
_context = context;
}

public virtual IEnumerable<EntityEntry> Entries()
{
return _stateManager.StateEntries.Select(e => new EntityEntry(e));
return StateManager.StateEntries.Select(e => new EntityEntry(_context.Service, e));
}

public virtual IEnumerable<EntityEntry<TEntity>> Entries<TEntity>()
{
return _stateManager.StateEntries
return StateManager.StateEntries
.Where(e => e.Entity is TEntity)
.Select(e => new EntityEntry<TEntity>(e));
.Select(e => new EntityEntry<TEntity>(_context.Service, e));
}

public virtual StateManager StateManager
public virtual StateManager StateManager { get; }

public virtual DbContext Context => _context.Service;

public virtual bool DetectChanges()
{
get { return _stateManager; }
return _changeDetector.DetectChanges(StateManager);
}

public virtual bool DetectChanges()
public virtual void AttachGraph([NotNull] object rootEntity, [NotNull] Action<EntityEntry> callback)
{
return _changeDetector.DetectChanges(_stateManager);
Check.NotNull(rootEntity, "rootEntity");
Check.NotNull(callback, "callback");

foreach (var entry in _graphIterator.TraverseGraph(rootEntity))
{
callback(entry);
}
}
}
}
25 changes: 10 additions & 15 deletions src/EntityFramework.Core/ChangeTracking/EntityEntry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,36 @@ namespace Microsoft.Data.Entity.ChangeTracking
[DebuggerDisplay("{_stateEntry,nq}")]
public class EntityEntry
{
private readonly StateEntry _stateEntry;

public EntityEntry([NotNull] StateEntry stateEntry)
public EntityEntry([NotNull] DbContext context, [NotNull] StateEntry stateEntry)
{
Check.NotNull(stateEntry, "stateEntry");
Check.NotNull(context, "context");

_stateEntry = stateEntry;
StateEntry = stateEntry;
Context = context;
}

public virtual object Entity
{
get { return _stateEntry.Entity; }
}
public virtual object Entity => StateEntry.Entity;

public virtual EntityState State
{
get { return _stateEntry.EntityState; }
get { return StateEntry.EntityState; }
set
{
Check.IsDefined(value, "value");

_stateEntry.EntityState = value;
StateEntry.EntityState = value;
}
}

public virtual StateEntry StateEntry
{
get { return _stateEntry; }
}
public virtual StateEntry StateEntry { get; }
public virtual DbContext Context { get; }

public virtual PropertyEntry Property([NotNull] string propertyName)
{
Check.NotEmpty(propertyName, "propertyName");

return new PropertyEntry(_stateEntry, propertyName);
return new PropertyEntry(StateEntry, propertyName);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections;
using System.Collections.Generic;
using JetBrains.Annotations;
using Microsoft.Data.Entity.Infrastructure;
using Microsoft.Data.Entity.Metadata;
using Microsoft.Data.Entity.Utilities;

namespace Microsoft.Data.Entity.ChangeTracking
{
public class EntityEntryGraphIterator
{
private readonly DbContextService<DbContext> _context;

public EntityEntryGraphIterator(
[NotNull] DbContextService<DbContext> context)
{
Check.NotNull(context, "context");

_context = context;
}

public virtual IEnumerable<EntityEntry> TraverseGraph([NotNull] object entity)
{
Check.NotNull(entity, "entity");

var entry = _context.Service.Entry(entity);

if (entry.State != EntityState.Unknown)
{
yield break;
}

yield return entry;

if (entry.State != EntityState.Unknown)
{
var navigations = entry.StateEntry.EntityType.Navigations;

foreach (var navigation in navigations)
{
var navigationValue = entry.StateEntry[navigation];

if (navigationValue != null)
{
if (navigation.IsCollection())
{
foreach (var relatedEntity in (IEnumerable)navigationValue)
{
foreach (var relatedEntry in TraverseGraph(relatedEntity))
{
yield return relatedEntry;
}
}
}
else
{
foreach (var relatedEntry in TraverseGraph(navigationValue))
{
yield return relatedEntry;
}
}
}
}
}
}
}
}
9 changes: 3 additions & 6 deletions src/EntityFramework.Core/ChangeTracking/EntityEntry`.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@ namespace Microsoft.Data.Entity.ChangeTracking
{
public class EntityEntry<TEntity> : EntityEntry
{
public EntityEntry([NotNull] StateEntry stateEntry)
: base(stateEntry)
public EntityEntry([NotNull] DbContext context, [NotNull] StateEntry stateEntry)
: base(context, stateEntry)
{
}

public new virtual TEntity Entity
{
get { return (TEntity)base.Entity; }
}
public new virtual TEntity Entity => (TEntity)base.Entity;

public virtual PropertyEntry<TEntity, TProperty> Property<TProperty>(
[NotNull] Expression<Func<TEntity, TProperty>> propertyExpression)
Expand Down
8 changes: 4 additions & 4 deletions src/EntityFramework.Core/DbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,14 @@ public virtual EntityEntry<TEntity> Entry<TEntity>([NotNull] TEntity entity)
{
Check.NotNull(entity, "entity");

return new EntityEntry<TEntity>(GetStateManager().GetOrCreateEntry(entity));
return new EntityEntry<TEntity>(this, GetStateManager().GetOrCreateEntry(entity));
}

public virtual EntityEntry Entry([NotNull] object entity)
{
Check.NotNull(entity, "entity");

return new EntityEntry(GetStateManager().GetOrCreateEntry(entity));
return new EntityEntry(this, GetStateManager().GetOrCreateEntry(entity));
}

public virtual EntityEntry<TEntity> Add<TEntity>([NotNull] TEntity entity)
Expand Down Expand Up @@ -435,7 +435,7 @@ private List<EntityEntry<TEntity>> GetOrCreateEntries<TEntity>(IEnumerable<TEnti
{
var stateManager = GetStateManager();

return entities.Select(e => new EntityEntry<TEntity>( stateManager.GetOrCreateEntry(e))).ToList();
return entities.Select(e => new EntityEntry<TEntity>(this, stateManager.GetOrCreateEntry(e))).ToList();
}

public virtual IReadOnlyList<EntityEntry> Add([NotNull] params object[] entities)
Expand Down Expand Up @@ -507,7 +507,7 @@ private List<EntityEntry> GetOrCreateEntries(IEnumerable<object> entities)
{
var stateManager = GetStateManager();

return entities.Select(e => new EntityEntry(stateManager.GetOrCreateEntry(e))).ToList();
return entities.Select(e => new EntityEntry(this, stateManager.GetOrCreateEntry(e))).ToList();
}

public virtual Database Database
Expand Down
1 change: 1 addition & 0 deletions src/EntityFramework.Core/EntityFramework.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
<Link>LoggingExtensions.cs</Link>
</Compile>
<Compile Include="ChangeTracking\ChangeDetector.cs" />
<Compile Include="ChangeTracking\EntityEntryGraphIterator.cs" />
<Compile Include="ChangeTracking\IPropertyListener.cs" />
<Compile Include="ChangeTracking\IRelationshipListener.cs" />
<Compile Include="ChangeTracking\PropertyBagEntryExtensions.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public static class EntityServiceCollectionExtensions
.AddScoped<ValueGenerationManager>()
.AddScoped<EntityQueryExecutor>()
.AddScoped<ChangeTracker>()
.AddScoped<EntityEntryGraphIterator>()
.AddScoped(DbContextServices.ModelFactory)
.AddScoped(DbContextServices.ContextFactory)
.AddScoped(DbContextServices.ContextOptionsFactory)
Expand Down

0 comments on commit 0feb4cf

Please sign in to comment.