Skip to content
This repository has been archived by the owner on Apr 26, 2019. It is now read-only.

Commit

Permalink
path: add tests for minimum spanning tree functions
Browse files Browse the repository at this point in the history
* Replace incorrect Prim's implementation.
* Make functions return weight.
* Drop default to uniform cost function since this is essentially
  meaningless here.

Fixes #86.
  • Loading branch information
kortschak committed Aug 27, 2015
1 parent 202374f commit ad37f28
Show file tree
Hide file tree
Showing 3 changed files with 417 additions and 59 deletions.
7 changes: 1 addition & 6 deletions path/a_star_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,7 @@ func (e weightedEdge) From() graph.Node { return e.from }
func (e weightedEdge) To() graph.Node { return e.to }
func (e weightedEdge) Weight() float64 { return e.cost }

type costEdgeListGraph interface {
graph.Weighter
EdgeListerGraph
}

func isMonotonic(g costEdgeListGraph, h Heuristic) (ok bool, at graph.Edge, goal graph.Node) {
func isMonotonic(g UndirectedWeightLister, h Heuristic) (ok bool, at graph.Edge, goal graph.Node) {
for _, goal := range g.Nodes() {
for _, edge := range g.Edges() {
from := edge.From()
Expand Down
171 changes: 118 additions & 53 deletions path/spanning_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,101 +5,166 @@
package path

import (
"container/heap"
"math"
"sort"

"github.com/gonum/graph"
"github.com/gonum/graph/concrete"
"github.com/gonum/graph/internal"
)

// EdgeListerGraph is an undirected graph than returns its complete set of edges.
type EdgeListerGraph interface {
// UndirectedWeighter is an undirected graph that returns distinct edge weights.
type UndirectedWeighter interface {
graph.Undirected
Edges() []graph.Edge
graph.Weighter
}

// Prim generates a minimum spanning tree of g by greedy tree extension, placing
// the result in the destination. The destination is not cleared first.
func Prim(dst graph.MutableUndirected, g EdgeListerGraph) {
var weight Weighting
if wg, ok := g.(graph.Weighter); ok {
weight = wg.Weight
} else {
weight = UniformCost(g)
// the result in the destination. The destination is not cleared first. The weight
// of the minimum spanning tree is returned. If g is not connected, the minimum
// spanning forest will be constructed in dst and the sum of minimum spanning tree
// weights will be returned.
func Prim(dst graph.MutableUndirected, g UndirectedWeighter) float64 {
nodes := g.Nodes()
if len(nodes) == 0 {
return 0
}

nlist := g.Nodes()

if nlist == nil || len(nlist) == 0 {
return
u := nodes[0]
q := &primQueue{
indexOf: make(map[int]int, len(nodes)-1),
nodes: make([]concrete.Edge, 0, len(nodes)-1),
}

dst.AddNode(nlist[0])
remain := make(internal.IntSet)
for _, node := range nlist[1:] {
remain.Add(node.ID())
for _, u := range nodes[1:] {
heap.Push(q, concrete.Edge{F: u, W: math.Inf(1)})
}
for _, v := range g.From(u) {
w, ok := g.Weight(u, v)
if !ok {
panic("prim: unexpected invalid weight")
}
q.update(v, u, w)
}

var w float64
for q.Len() > 0 {
e := heap.Pop(q).(concrete.Edge)
if e.To() != nil && g.HasEdge(e.From(), e.To()) {
dst.SetEdge(e)
w += e.Weight()
}

edgeList := g.Edges()
for remain.Count() != 0 {
var edges []concrete.Edge
for _, e := range edgeList {
u := e.From()
v := e.To()
if (dst.Has(u) && remain.Has(v.ID())) || (dst.Has(v) && remain.Has(u.ID())) {
w, ok := weight(u, v)
u = e.From()
for _, n := range g.From(u) {
if key, ok := q.key(n); ok {
w, ok := g.Weight(u, n)
if !ok {
panic("prim: unexpected invalid weight")
}
edges = append(edges, concrete.Edge{F: u, T: v, W: w})
if w < key {
q.update(n, u, w)
}
}
}

sort.Sort(byWeight(edges))
min := edges[0]
dst.SetEdge(min)
remain.Remove(min.From().ID())
}
return w
}

// primQueue is an Prim's priority queue.
type primQueue struct {
indexOf map[int]int
nodes []concrete.Edge
}

func (q *primQueue) Less(i, j int) bool {
return q.nodes[i].Weight() < q.nodes[j].Weight()
}

func (q *primQueue) Swap(i, j int) {
q.indexOf[q.nodes[i].From().ID()] = j
q.indexOf[q.nodes[j].From().ID()] = i
q.nodes[i], q.nodes[j] = q.nodes[j], q.nodes[i]
}

// Kruskal generates a minimum spanning tree of g by greedy tree coalesence, placing
// the result in the destination. The destination is not cleared first.
func Kruskal(dst graph.MutableUndirected, g EdgeListerGraph) {
var weight Weighting
if wg, ok := g.(graph.Weighter); ok {
weight = wg.Weight
} else {
weight = UniformCost(g)
func (q *primQueue) Len() int {
return len(q.nodes)
}

func (q *primQueue) Push(x interface{}) {
n := x.(concrete.Edge)
q.indexOf[n.From().ID()] = len(q.nodes)
q.nodes = append(q.nodes, n)
}

func (q *primQueue) Pop() interface{} {
n := q.nodes[len(q.nodes)-1]
q.nodes = q.nodes[:len(q.nodes)-1]
delete(q.indexOf, n.From().ID())
return n
}

// key returns the key for the node n and whether the node is
// in the queue. If the node is not in the queue, key is returned
// as +Inf.
func (q *primQueue) key(u graph.Node) (key float64, ok bool) {
i, ok := q.indexOf[u.ID()]
if !ok {
return math.Inf(1), false
}
return q.nodes[i].Weight(), ok
}

edgeList := g.Edges()
edges := make([]concrete.Edge, 0, len(edgeList))
for _, e := range edgeList {
// update updates the node's position in the queue with the new key.
func (q *primQueue) update(u, v graph.Node, key float64) {
id := u.ID()
i, ok := q.indexOf[id]
if !ok {
return
}
q.nodes[i].T = v
q.nodes[i].W = key
heap.Fix(q, i)
}

// UndirectedWeightLister is an undirected graph that returns distinct edge weights and
// the set of edges in the graph.
type UndirectedWeightLister interface {
UndirectedWeighter
Edges() []graph.Edge
}

// Kruskal generates a minimum spanning tree of g by greedy tree coalescence, placing
// the result in the destination. The destination is not cleared first. The weight
// of the minimum spanning tree is returned. If g is not connected, the minimum
// spanning forest will be constructed in dst and the sum of minimum spanning tree
// weights will be returned.
func Kruskal(dst graph.MutableUndirected, g UndirectedWeightLister) float64 {
edges := g.Edges()
ascend := make([]concrete.Edge, 0, len(edges))
for _, e := range edges {
u := e.From()
v := e.To()
w, ok := weight(u, v)
w, ok := g.Weight(u, v)
if !ok {
panic("kruskal: unexpected invalid weight")
}
edges = append(edges, concrete.Edge{F: u, T: v, W: w})
ascend = append(ascend, concrete.Edge{F: u, T: v, W: w})
}

sort.Sort(byWeight(edges))
sort.Sort(byWeight(ascend))

ds := newDisjointSet()
for _, node := range g.Nodes() {
ds.makeSet(node.ID())
}

for _, e := range edges {
// The disjoint set doesn't really care for which is head and which is tail so this
// should work fine without checking both ways
var w float64
for _, e := range ascend {
if s1, s2 := ds.find(e.From().ID()), ds.find(e.To().ID()); s1 != s2 {
ds.union(s1, s2)
dst.SetEdge(e)
w += e.Weight()
}
}
return w
}

type byWeight []concrete.Edge
Expand Down
Loading

0 comments on commit ad37f28

Please sign in to comment.