diff --git a/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs b/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs index 5e806d7a8..7e41ae874 100644 --- a/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs +++ b/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs @@ -68,10 +68,33 @@ public int Remove(TKey key) { return _version; } - Interlocked.Increment(ref _version); + var version = Interlocked.Increment(ref _version); + var vertex = _vertices[index]; _vertices[index] = default; - return _version; + if (vertex == null) { + return version; + } + + foreach (var incomingIndex in vertex.Incoming) { + var incoming = _vertices[incomingIndex]; + if (incoming != null && incoming.IsSealed) { + _vertices[incomingIndex] = new DependencyVertex(incoming, version, false); + } + } + + if (!vertex.IsSealed) { + return version; + } + + foreach (var outgoingIndex in vertex.Outgoing) { + var outgoing = _vertices[outgoingIndex]; + if (outgoing != null && !outgoing.IsNew) { + _vertices[outgoingIndex] = new DependencyVertex(outgoing, version, true); + } + } + + return version; } } @@ -135,7 +158,7 @@ private ImmutableArray EnsureKeys(int index, ImmutableArray keys, int } else { var vertex = _vertices[keyIndex]; if (vertex != default && vertex.IsSealed && !vertex.ContainsOutgoing(index)) { - _vertices[keyIndex] = new DependencyVertex(vertex, version); + _vertices[keyIndex] = new DependencyVertex(vertex, version, false); } } @@ -263,13 +286,14 @@ private bool TryBuildReverseGraph(ImmutableArray> private bool TryCreateWalkingGraph(in ImmutableArray> vertices, int version, out ImmutableArray> analysisGraph) { var nodesByVertexIndex = new Dictionary>(); - foreach (var vertex in vertices) { + for (var index = 0; index < vertices.Count; index++) { + var vertex = vertices[index]; if (vertex == null || vertex.IsWalked) { continue; } - var node = new WalkingVertex(vertices[vertex.Index]); - nodesByVertexIndex[vertex.Index] = node; + var node = new WalkingVertex(vertices[index]); + nodesByVertexIndex[index] = node; } if (nodesByVertexIndex.Count == 0) { @@ -288,9 +312,12 @@ private bool TryCreateWalkingGraph(in ImmutableArray(vertex); nodesByVertexIndex[outgoingIndex] = outgoingNode; - queue.Enqueue(outgoingNode); } diff --git a/src/Analysis/Ast/Impl/Dependencies/DependencyVertex.cs b/src/Analysis/Ast/Impl/Dependencies/DependencyVertex.cs index 52b14cb5d..03bffead2 100644 --- a/src/Analysis/Ast/Impl/Dependencies/DependencyVertex.cs +++ b/src/Analysis/Ast/Impl/Dependencies/DependencyVertex.cs @@ -29,6 +29,7 @@ internal sealed class DependencyVertex { public int Index { get; } public string DebuggerDisplay => $"{Key}:{Value}"; + public bool IsNew => _state == (int)State.New; public bool IsSealed => _state >= (int)State.Sealed; public bool IsWalked => _state == (int)State.Walked; @@ -38,7 +39,7 @@ internal sealed class DependencyVertex { private HashSet _outgoing; private static HashSet _empty = new HashSet(); - public DependencyVertex(DependencyVertex oldVertex, int version) { + public DependencyVertex(DependencyVertex oldVertex, int version, bool isNew) { Key = oldVertex.Key; Value = oldVertex.Value; IsRoot = oldVertex.IsRoot; @@ -48,7 +49,7 @@ public DependencyVertex(DependencyVertex oldVertex, int version) { Version = version; _outgoing = oldVertex.Outgoing; - _state = oldVertex.IsWalked ? (int)State.ChangedOutgoing : (int)State.New; + _state = !isNew && oldVertex.IsWalked ? (int)State.ChangedOutgoing : (int)State.New; } public DependencyVertex(TKey key, TValue value, bool isRoot, ImmutableArray incoming, int version, int index) { diff --git a/src/Analysis/Ast/Test/DependencyResolverTests.cs b/src/Analysis/Ast/Test/DependencyResolverTests.cs index e9ed6b646..5cbcee7d4 100644 --- a/src/Analysis/Ast/Test/DependencyResolverTests.cs +++ b/src/Analysis/Ast/Test/DependencyResolverTests.cs @@ -177,6 +177,84 @@ public async Task ChangeValue_MissingKeys() { result.ToString().Should().Be("AD"); } + [TestMethod] + public async Task ChangeValue_Add() { + var resolver = new DependencyResolver(); + resolver.ChangeValue("A", "A:BD", true, "B", "D"); + resolver.ChangeValue("C", "C", false); + + var walker = resolver.CreateWalker(); + walker.MissingKeys.Should().Equal("B", "D"); + var node1 = await walker.GetNextAsync(default); + var node2 = await walker.GetNextAsync(default); + node1.Value.Should().Be("A:BD"); + node2.Value.Should().Be("C"); + node1.Commit(); + node2.Commit(); + + walker.Remaining.Should().Be(0); + + resolver.ChangeValue("B", "B", false); + walker = resolver.CreateWalker(); + walker.MissingKeys.Should().Equal("D"); + + var node = await walker.GetNextAsync(default); + node.Value.Should().Be("B"); + node.Commit(); + + node = await walker.GetNextAsync(default); + node.Value.Should().Be("A:BD"); + node.Commit(); + + walker.Remaining.Should().Be(0); + + resolver.ChangeValue("D", "D:C", false); + walker = resolver.CreateWalker(); + walker.MissingKeys.Should().BeEmpty(); + + node = await walker.GetNextAsync(default); + node.Value.Should().Be("D:C"); + node.Commit(); + + node = await walker.GetNextAsync(default); + node.Value.Should().Be("A:BD"); + node.Commit(); + + walker.Remaining.Should().Be(0); + } + + [TestMethod] + public async Task ChangeValue_Remove() { + var resolver = new DependencyResolver(); + resolver.ChangeValue("A", "A:BC", true, "B", "C"); + resolver.ChangeValue("B", "B:C", false, "C"); + resolver.ChangeValue("C", "C", false); + + var walker = resolver.CreateWalker(); + walker.MissingKeys.Should().BeEmpty(); + var node = await walker.GetNextAsync(default); + node.Value.Should().Be("C"); + node.Commit(); + + node = await walker.GetNextAsync(default); + node.Value.Should().Be("B:C"); + node.Commit(); + + node = await walker.GetNextAsync(default); + node.Value.Should().Be("A:BC"); + node.Commit(); + + resolver.Remove("B"); + walker = resolver.CreateWalker(); + walker.MissingKeys.Should().Equal("B"); + + node = await walker.GetNextAsync(default); + node.Value.Should().Be("A:BC"); + node.Commit(); + + walker.Remaining.Should().Be(0); + } + [TestMethod] public async Task ChangeValue_RemoveKeys() { var resolver = new DependencyResolver();