From ad37f285003ec8519d0a5472d91f20abf89f9303 Mon Sep 17 00:00:00 2001 From: kortschak Date: Wed, 26 Aug 2015 16:14:41 +0930 Subject: [PATCH] path: add tests for minimum spanning tree functions * Replace incorrect Prim's implementation. * Make functions return weight. * Drop default to uniform cost function since this is essentially meaningless here. Fixes #86. --- path/a_star_test.go | 7 +- path/spanning_tree.go | 171 ++++++++++++++------- path/spanning_tree_test.go | 298 +++++++++++++++++++++++++++++++++++++ 3 files changed, 417 insertions(+), 59 deletions(-) create mode 100644 path/spanning_tree_test.go diff --git a/path/a_star_test.go b/path/a_star_test.go index ddb236d2..63e8c1d6 100644 --- a/path/a_star_test.go +++ b/path/a_star_test.go @@ -226,12 +226,7 @@ func (e weightedEdge) From() graph.Node { return e.from } func (e weightedEdge) To() graph.Node { return e.to } func (e weightedEdge) Weight() float64 { return e.cost } -type costEdgeListGraph interface { - graph.Weighter - EdgeListerGraph -} - -func isMonotonic(g costEdgeListGraph, h Heuristic) (ok bool, at graph.Edge, goal graph.Node) { +func isMonotonic(g UndirectedWeightLister, h Heuristic) (ok bool, at graph.Edge, goal graph.Node) { for _, goal := range g.Nodes() { for _, edge := range g.Edges() { from := edge.From() diff --git a/path/spanning_tree.go b/path/spanning_tree.go index f0d6f060..59ffb396 100644 --- a/path/spanning_tree.go +++ b/path/spanning_tree.go @@ -5,101 +5,166 @@ package path import ( + "container/heap" + "math" "sort" "github.com/gonum/graph" "github.com/gonum/graph/concrete" - "github.com/gonum/graph/internal" ) -// EdgeListerGraph is an undirected graph than returns its complete set of edges. -type EdgeListerGraph interface { +// UndirectedWeighter is an undirected graph that returns distinct edge weights. +type UndirectedWeighter interface { graph.Undirected - Edges() []graph.Edge + graph.Weighter } // Prim generates a minimum spanning tree of g by greedy tree extension, placing -// the result in the destination. The destination is not cleared first. -func Prim(dst graph.MutableUndirected, g EdgeListerGraph) { - var weight Weighting - if wg, ok := g.(graph.Weighter); ok { - weight = wg.Weight - } else { - weight = UniformCost(g) +// the result in the destination. The destination is not cleared first. The weight +// of the minimum spanning tree is returned. If g is not connected, the minimum +// spanning forest will be constructed in dst and the sum of minimum spanning tree +// weights will be returned. +func Prim(dst graph.MutableUndirected, g UndirectedWeighter) float64 { + nodes := g.Nodes() + if len(nodes) == 0 { + return 0 } - - nlist := g.Nodes() - - if nlist == nil || len(nlist) == 0 { - return + u := nodes[0] + q := &primQueue{ + indexOf: make(map[int]int, len(nodes)-1), + nodes: make([]concrete.Edge, 0, len(nodes)-1), } - - dst.AddNode(nlist[0]) - remain := make(internal.IntSet) - for _, node := range nlist[1:] { - remain.Add(node.ID()) + for _, u := range nodes[1:] { + heap.Push(q, concrete.Edge{F: u, W: math.Inf(1)}) } + for _, v := range g.From(u) { + w, ok := g.Weight(u, v) + if !ok { + panic("prim: unexpected invalid weight") + } + q.update(v, u, w) + } + + var w float64 + for q.Len() > 0 { + e := heap.Pop(q).(concrete.Edge) + if e.To() != nil && g.HasEdge(e.From(), e.To()) { + dst.SetEdge(e) + w += e.Weight() + } - edgeList := g.Edges() - for remain.Count() != 0 { - var edges []concrete.Edge - for _, e := range edgeList { - u := e.From() - v := e.To() - if (dst.Has(u) && remain.Has(v.ID())) || (dst.Has(v) && remain.Has(u.ID())) { - w, ok := weight(u, v) + u = e.From() + for _, n := range g.From(u) { + if key, ok := q.key(n); ok { + w, ok := g.Weight(u, n) if !ok { panic("prim: unexpected invalid weight") } - edges = append(edges, concrete.Edge{F: u, T: v, W: w}) + if w < key { + q.update(n, u, w) + } } } - - sort.Sort(byWeight(edges)) - min := edges[0] - dst.SetEdge(min) - remain.Remove(min.From().ID()) } + return w +} + +// primQueue is an Prim's priority queue. +type primQueue struct { + indexOf map[int]int + nodes []concrete.Edge +} + +func (q *primQueue) Less(i, j int) bool { + return q.nodes[i].Weight() < q.nodes[j].Weight() +} +func (q *primQueue) Swap(i, j int) { + q.indexOf[q.nodes[i].From().ID()] = j + q.indexOf[q.nodes[j].From().ID()] = i + q.nodes[i], q.nodes[j] = q.nodes[j], q.nodes[i] } -// Kruskal generates a minimum spanning tree of g by greedy tree coalesence, placing -// the result in the destination. The destination is not cleared first. -func Kruskal(dst graph.MutableUndirected, g EdgeListerGraph) { - var weight Weighting - if wg, ok := g.(graph.Weighter); ok { - weight = wg.Weight - } else { - weight = UniformCost(g) +func (q *primQueue) Len() int { + return len(q.nodes) +} + +func (q *primQueue) Push(x interface{}) { + n := x.(concrete.Edge) + q.indexOf[n.From().ID()] = len(q.nodes) + q.nodes = append(q.nodes, n) +} + +func (q *primQueue) Pop() interface{} { + n := q.nodes[len(q.nodes)-1] + q.nodes = q.nodes[:len(q.nodes)-1] + delete(q.indexOf, n.From().ID()) + return n +} + +// key returns the key for the node n and whether the node is +// in the queue. If the node is not in the queue, key is returned +// as +Inf. +func (q *primQueue) key(u graph.Node) (key float64, ok bool) { + i, ok := q.indexOf[u.ID()] + if !ok { + return math.Inf(1), false } + return q.nodes[i].Weight(), ok +} - edgeList := g.Edges() - edges := make([]concrete.Edge, 0, len(edgeList)) - for _, e := range edgeList { +// update updates the node's position in the queue with the new key. +func (q *primQueue) update(u, v graph.Node, key float64) { + id := u.ID() + i, ok := q.indexOf[id] + if !ok { + return + } + q.nodes[i].T = v + q.nodes[i].W = key + heap.Fix(q, i) +} + +// UndirectedWeightLister is an undirected graph that returns distinct edge weights and +// the set of edges in the graph. +type UndirectedWeightLister interface { + UndirectedWeighter + Edges() []graph.Edge +} + +// Kruskal generates a minimum spanning tree of g by greedy tree coalescence, placing +// the result in the destination. The destination is not cleared first. The weight +// of the minimum spanning tree is returned. If g is not connected, the minimum +// spanning forest will be constructed in dst and the sum of minimum spanning tree +// weights will be returned. +func Kruskal(dst graph.MutableUndirected, g UndirectedWeightLister) float64 { + edges := g.Edges() + ascend := make([]concrete.Edge, 0, len(edges)) + for _, e := range edges { u := e.From() v := e.To() - w, ok := weight(u, v) + w, ok := g.Weight(u, v) if !ok { panic("kruskal: unexpected invalid weight") } - edges = append(edges, concrete.Edge{F: u, T: v, W: w}) + ascend = append(ascend, concrete.Edge{F: u, T: v, W: w}) } - - sort.Sort(byWeight(edges)) + sort.Sort(byWeight(ascend)) ds := newDisjointSet() for _, node := range g.Nodes() { ds.makeSet(node.ID()) } - for _, e := range edges { - // The disjoint set doesn't really care for which is head and which is tail so this - // should work fine without checking both ways + var w float64 + for _, e := range ascend { if s1, s2 := ds.find(e.From().ID()), ds.find(e.To().ID()); s1 != s2 { ds.union(s1, s2) dst.SetEdge(e) + w += e.Weight() } } + return w } type byWeight []concrete.Edge diff --git a/path/spanning_tree_test.go b/path/spanning_tree_test.go new file mode 100644 index 00000000..2f071c83 --- /dev/null +++ b/path/spanning_tree_test.go @@ -0,0 +1,298 @@ +// Copyright ©2015 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package path + +import ( + "fmt" + "math" + "testing" + + "github.com/gonum/graph/encoding/dot" + + "github.com/gonum/graph" + "github.com/gonum/graph/concrete" +) + +func init() { + for _, test := range spanningTreeTests { + var w float64 + for _, e := range test.treeEdges { + w += e.W + } + if w != test.want { + panic(fmt.Sprintf("bad test: %s weight mismatch: %v != %v", test.name, w, test.want)) + } + } +} + +type spanningGraph interface { + graph.MutableUndirected + graph.Weighter + Edges() []graph.Edge +} + +var spanningTreeTests = []struct { + name string + graph func() spanningGraph + edges []concrete.Edge + want float64 + treeEdges []concrete.Edge +}{ + { + name: "Empty", + graph: func() spanningGraph { return concrete.NewGraph(0, math.Inf(1)) }, + want: 0, + }, + { + // https://upload.wikimedia.org/wikipedia/commons/f/f7/Prim%27s_algorithm.svg + // Modified to make edge weights unique; A--B is increased to 2.5 otherwise + // to prevent the alternative solution being found. + name: "Prim WP figure 1", + graph: func() spanningGraph { return concrete.NewGraph(0, math.Inf(1)) }, + edges: []concrete.Edge{ + {F: concrete.Node('A'), T: concrete.Node('B'), W: 2.5}, + {F: concrete.Node('A'), T: concrete.Node('D'), W: 1}, + {F: concrete.Node('B'), T: concrete.Node('D'), W: 2}, + {F: concrete.Node('C'), T: concrete.Node('D'), W: 3}, + }, + + want: 6, + treeEdges: []concrete.Edge{ + {F: concrete.Node('A'), T: concrete.Node('D'), W: 1}, + {F: concrete.Node('B'), T: concrete.Node('D'), W: 2}, + {F: concrete.Node('C'), T: concrete.Node('D'), W: 3}, + }, + }, + { + // https://upload.wikimedia.org/wikipedia/commons/5/5c/MST_kruskal_en.gif + name: "Kruskal WP figure 1", + graph: func() spanningGraph { return concrete.NewGraph(0, math.Inf(1)) }, + edges: []concrete.Edge{ + {F: concrete.Node('a'), T: concrete.Node('b'), W: 3}, + {F: concrete.Node('a'), T: concrete.Node('e'), W: 1}, + {F: concrete.Node('b'), T: concrete.Node('c'), W: 5}, + {F: concrete.Node('b'), T: concrete.Node('e'), W: 4}, + {F: concrete.Node('c'), T: concrete.Node('d'), W: 2}, + {F: concrete.Node('c'), T: concrete.Node('e'), W: 6}, + {F: concrete.Node('d'), T: concrete.Node('e'), W: 7}, + }, + + want: 11, + treeEdges: []concrete.Edge{ + {F: concrete.Node('a'), T: concrete.Node('b'), W: 3}, + {F: concrete.Node('a'), T: concrete.Node('e'), W: 1}, + {F: concrete.Node('b'), T: concrete.Node('c'), W: 5}, + {F: concrete.Node('c'), T: concrete.Node('d'), W: 2}, + }, + }, + { + // https://upload.wikimedia.org/wikipedia/commons/8/87/Kruskal_Algorithm_6.svg + name: "Kruskal WP example", + graph: func() spanningGraph { return concrete.NewGraph(0, math.Inf(1)) }, + edges: []concrete.Edge{ + {F: concrete.Node('A'), T: concrete.Node('B'), W: 7}, + {F: concrete.Node('A'), T: concrete.Node('D'), W: 5}, + {F: concrete.Node('B'), T: concrete.Node('C'), W: 8}, + {F: concrete.Node('B'), T: concrete.Node('D'), W: 9}, + {F: concrete.Node('B'), T: concrete.Node('E'), W: 7}, + {F: concrete.Node('C'), T: concrete.Node('E'), W: 5}, + {F: concrete.Node('D'), T: concrete.Node('E'), W: 15}, + {F: concrete.Node('D'), T: concrete.Node('F'), W: 6}, + {F: concrete.Node('E'), T: concrete.Node('F'), W: 8}, + {F: concrete.Node('E'), T: concrete.Node('G'), W: 9}, + {F: concrete.Node('F'), T: concrete.Node('G'), W: 11}, + }, + + want: 39, + treeEdges: []concrete.Edge{ + {F: concrete.Node('A'), T: concrete.Node('B'), W: 7}, + {F: concrete.Node('A'), T: concrete.Node('D'), W: 5}, + {F: concrete.Node('B'), T: concrete.Node('E'), W: 7}, + {F: concrete.Node('C'), T: concrete.Node('E'), W: 5}, + {F: concrete.Node('D'), T: concrete.Node('F'), W: 6}, + {F: concrete.Node('E'), T: concrete.Node('G'), W: 9}, + }, + }, + { + // https://upload.wikimedia.org/wikipedia/commons/2/2e/Boruvka%27s_algorithm_%28Sollin%27s_algorithm%29_Anim.gif + name: "Borůvka WP example", + graph: func() spanningGraph { return concrete.NewGraph(0, math.Inf(1)) }, + edges: []concrete.Edge{ + {F: concrete.Node('A'), T: concrete.Node('B'), W: 13}, + {F: concrete.Node('A'), T: concrete.Node('C'), W: 6}, + {F: concrete.Node('B'), T: concrete.Node('C'), W: 7}, + {F: concrete.Node('B'), T: concrete.Node('D'), W: 1}, + {F: concrete.Node('C'), T: concrete.Node('D'), W: 14}, + {F: concrete.Node('C'), T: concrete.Node('E'), W: 8}, + {F: concrete.Node('C'), T: concrete.Node('H'), W: 20}, + {F: concrete.Node('D'), T: concrete.Node('E'), W: 9}, + {F: concrete.Node('D'), T: concrete.Node('F'), W: 3}, + {F: concrete.Node('E'), T: concrete.Node('F'), W: 2}, + {F: concrete.Node('E'), T: concrete.Node('J'), W: 18}, + {F: concrete.Node('G'), T: concrete.Node('H'), W: 15}, + {F: concrete.Node('G'), T: concrete.Node('I'), W: 5}, + {F: concrete.Node('G'), T: concrete.Node('J'), W: 19}, + {F: concrete.Node('G'), T: concrete.Node('K'), W: 10}, + {F: concrete.Node('H'), T: concrete.Node('J'), W: 17}, + {F: concrete.Node('I'), T: concrete.Node('K'), W: 11}, + {F: concrete.Node('J'), T: concrete.Node('K'), W: 16}, + {F: concrete.Node('J'), T: concrete.Node('L'), W: 4}, + {F: concrete.Node('K'), T: concrete.Node('L'), W: 12}, + }, + + want: 83, + treeEdges: []concrete.Edge{ + {F: concrete.Node('A'), T: concrete.Node('C'), W: 6}, + {F: concrete.Node('B'), T: concrete.Node('C'), W: 7}, + {F: concrete.Node('B'), T: concrete.Node('D'), W: 1}, + {F: concrete.Node('D'), T: concrete.Node('F'), W: 3}, + {F: concrete.Node('E'), T: concrete.Node('F'), W: 2}, + {F: concrete.Node('E'), T: concrete.Node('J'), W: 18}, + {F: concrete.Node('G'), T: concrete.Node('H'), W: 15}, + {F: concrete.Node('G'), T: concrete.Node('I'), W: 5}, + {F: concrete.Node('G'), T: concrete.Node('K'), W: 10}, + {F: concrete.Node('J'), T: concrete.Node('L'), W: 4}, + {F: concrete.Node('K'), T: concrete.Node('L'), W: 12}, + }, + }, + { + // https://upload.wikimedia.org/wikipedia/commons/d/d2/Minimum_spanning_tree.svg + // Nodes labelled row major. + name: "Minimum Spanning Tree WP figure 1", + graph: func() spanningGraph { return concrete.NewGraph(0, math.Inf(1)) }, + edges: []concrete.Edge{ + {F: concrete.Node(1), T: concrete.Node(2), W: 4}, + {F: concrete.Node(1), T: concrete.Node(3), W: 1}, + {F: concrete.Node(1), T: concrete.Node(4), W: 4}, + {F: concrete.Node(2), T: concrete.Node(3), W: 5}, + {F: concrete.Node(2), T: concrete.Node(5), W: 9}, + {F: concrete.Node(2), T: concrete.Node(6), W: 9}, + {F: concrete.Node(2), T: concrete.Node(8), W: 7}, + {F: concrete.Node(3), T: concrete.Node(4), W: 3}, + {F: concrete.Node(3), T: concrete.Node(8), W: 9}, + {F: concrete.Node(4), T: concrete.Node(8), W: 10}, + {F: concrete.Node(4), T: concrete.Node(10), W: 18}, + {F: concrete.Node(5), T: concrete.Node(6), W: 2}, + {F: concrete.Node(5), T: concrete.Node(7), W: 4}, + {F: concrete.Node(5), T: concrete.Node(9), W: 6}, + {F: concrete.Node(6), T: concrete.Node(7), W: 2}, + {F: concrete.Node(6), T: concrete.Node(8), W: 8}, + {F: concrete.Node(7), T: concrete.Node(8), W: 9}, + {F: concrete.Node(7), T: concrete.Node(9), W: 3}, + {F: concrete.Node(7), T: concrete.Node(10), W: 9}, + {F: concrete.Node(8), T: concrete.Node(10), W: 8}, + {F: concrete.Node(9), T: concrete.Node(10), W: 9}, + }, + + want: 38, + treeEdges: []concrete.Edge{ + {F: concrete.Node(1), T: concrete.Node(2), W: 4}, + {F: concrete.Node(1), T: concrete.Node(3), W: 1}, + {F: concrete.Node(2), T: concrete.Node(8), W: 7}, + {F: concrete.Node(3), T: concrete.Node(4), W: 3}, + {F: concrete.Node(5), T: concrete.Node(6), W: 2}, + {F: concrete.Node(6), T: concrete.Node(7), W: 2}, + {F: concrete.Node(6), T: concrete.Node(8), W: 8}, + {F: concrete.Node(7), T: concrete.Node(9), W: 3}, + {F: concrete.Node(8), T: concrete.Node(10), W: 8}, + }, + }, + + { + // https://upload.wikimedia.org/wikipedia/commons/2/2e/Boruvka%27s_algorithm_%28Sollin%27s_algorithm%29_Anim.gif + // but with C--H and E--J cut. + name: "Borůvka WP example cut", + graph: func() spanningGraph { return concrete.NewGraph(0, math.Inf(1)) }, + edges: []concrete.Edge{ + {F: concrete.Node('A'), T: concrete.Node('B'), W: 13}, + {F: concrete.Node('A'), T: concrete.Node('C'), W: 6}, + {F: concrete.Node('B'), T: concrete.Node('C'), W: 7}, + {F: concrete.Node('B'), T: concrete.Node('D'), W: 1}, + {F: concrete.Node('C'), T: concrete.Node('D'), W: 14}, + {F: concrete.Node('C'), T: concrete.Node('E'), W: 8}, + {F: concrete.Node('D'), T: concrete.Node('E'), W: 9}, + {F: concrete.Node('D'), T: concrete.Node('F'), W: 3}, + {F: concrete.Node('E'), T: concrete.Node('F'), W: 2}, + {F: concrete.Node('G'), T: concrete.Node('H'), W: 15}, + {F: concrete.Node('G'), T: concrete.Node('I'), W: 5}, + {F: concrete.Node('G'), T: concrete.Node('J'), W: 19}, + {F: concrete.Node('G'), T: concrete.Node('K'), W: 10}, + {F: concrete.Node('H'), T: concrete.Node('J'), W: 17}, + {F: concrete.Node('I'), T: concrete.Node('K'), W: 11}, + {F: concrete.Node('J'), T: concrete.Node('K'), W: 16}, + {F: concrete.Node('J'), T: concrete.Node('L'), W: 4}, + {F: concrete.Node('K'), T: concrete.Node('L'), W: 12}, + }, + + want: 65, + treeEdges: []concrete.Edge{ + {F: concrete.Node('A'), T: concrete.Node('C'), W: 6}, + {F: concrete.Node('B'), T: concrete.Node('C'), W: 7}, + {F: concrete.Node('B'), T: concrete.Node('D'), W: 1}, + {F: concrete.Node('D'), T: concrete.Node('F'), W: 3}, + {F: concrete.Node('E'), T: concrete.Node('F'), W: 2}, + {F: concrete.Node('G'), T: concrete.Node('H'), W: 15}, + {F: concrete.Node('G'), T: concrete.Node('I'), W: 5}, + {F: concrete.Node('G'), T: concrete.Node('K'), W: 10}, + {F: concrete.Node('J'), T: concrete.Node('L'), W: 4}, + {F: concrete.Node('K'), T: concrete.Node('L'), W: 12}, + }, + }, +} + +func testMinumumSpanning(mst func(dst graph.MutableUndirected, g spanningGraph) float64, t *testing.T) { + for _, test := range spanningTreeTests { + g := test.graph() + for _, e := range test.edges { + g.SetEdge(e) + } + + dst := concrete.NewGraph(0, math.Inf(1)) + w := mst(dst, g) + if w != test.want { + t.Errorf("unexpected minimum spanning tree weight for %q: got: %f want: %f", + test.name, w, test.want) + } + var got float64 + for _, e := range dst.Edges() { + got += e.Weight() + } + if got != test.want { + t.Errorf("unexpected minimum spanning tree edge weight sum for %q: got: %f want: %f", + test.name, got, test.want) + } + + gotEdges := dst.Edges() + if len(gotEdges) != len(test.treeEdges) { + t.Errorf("unexpected number of spanning tree edges for %q: got: %d want: %d", + test.name, len(gotEdges), len(test.treeEdges)) + } + for _, e := range test.treeEdges { + w, ok := dst.Weight(e.From(), e.To()) + if !ok { + t.Errorf("spanning tree edge not found in graph for %q: %+v", + test.name, e) + b, _ := dot.Marshal(dst, "", "", " ", false) + fmt.Printf("%s\n", b) + } + if w != e.Weight() { + t.Errorf("unexpected spanning tree edge weight for %q: got: %f want: %f", + test.name, w, e.Weight()) + } + } + } +} + +func TestKruskal(t *testing.T) { + testMinumumSpanning(func(dst graph.MutableUndirected, g spanningGraph) float64 { + return Kruskal(dst, g) + }, t) +} + +func TestPrim(t *testing.T) { + testMinumumSpanning(func(dst graph.MutableUndirected, g spanningGraph) float64 { + return Prim(dst, g) + }, t) +}