This repository has been archived by the owner on Apr 26, 2019. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
search: add DijkstraAllPaths and DijkstraFrom
Also fix copyright dates.
- Loading branch information
Showing
6 changed files
with
1,168 additions
and
526 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
// Copyright ©2015 The gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package search | ||
|
||
import ( | ||
"container/heap" | ||
"math" | ||
|
||
"github.com/gonum/graph" | ||
"github.com/gonum/matrix/mat64" | ||
) | ||
|
||
// DijkstraFrom returns a shortest-path tree for a shortest path from u to all nodes in | ||
// the graph g. If weight is nil and the graph does not implement graph.Coster, UniformCost | ||
// is used. DijkstraFrom will panic if g has a u-reachable negative edge weight. | ||
func DijkstraFrom(u graph.Node, g graph.Graph, weight graph.CostFunc) Shortest { | ||
if !g.NodeExists(u) { | ||
return Shortest{from: u} | ||
} | ||
var ( | ||
from = g.Neighbors | ||
edgeTo func(graph.Node, graph.Node) graph.Edge | ||
) | ||
switch g := g.(type) { | ||
case graph.DirectedGraph: | ||
from = g.Successors | ||
edgeTo = g.EdgeTo | ||
default: | ||
edgeTo = g.EdgeBetween | ||
} | ||
if weight == nil { | ||
if g, ok := g.(graph.Coster); ok { | ||
weight = g.Cost | ||
} else { | ||
weight = UniformCost | ||
} | ||
} | ||
|
||
nodes := g.NodeList() | ||
|
||
indexOf := make(map[int]int, len(nodes)) | ||
for i, n := range nodes { | ||
indexOf[n.ID()] = i | ||
} | ||
|
||
path := Shortest{ | ||
from: u, | ||
|
||
nodes: nodes, | ||
indexOf: indexOf, | ||
|
||
dist: make([]float64, len(nodes)), | ||
next: make([]int, len(nodes)), | ||
} | ||
for i := range path.dist { | ||
path.dist[i] = math.Inf(1) | ||
} | ||
|
||
// Dijkstra's algorithm here is implemented essentially as | ||
// described in Function B.2 in figure 6 of UTCS Technical | ||
// Report TR-07-54. | ||
// | ||
// http://www.cs.utexas.edu/ftp/techreports/tr07-54.pdf | ||
Q := priorityQueue{{node: u, dist: 0}} | ||
for Q.Len() != 0 { | ||
mid := heap.Pop(&Q).(distanceNode) | ||
k := path.indexOf[mid.node.ID()] | ||
if mid.dist < path.dist[k] { | ||
path.dist[k] = mid.dist | ||
} | ||
for _, v := range from(mid.node) { | ||
j := path.indexOf[v.ID()] | ||
w := weight(edgeTo(mid.node, v)) | ||
if w < 0 { | ||
panic("dijkstra: negative edge weight") | ||
} | ||
joint := path.dist[k] + w | ||
if joint < path.dist[j] { | ||
heap.Push(&Q, distanceNode{node: v, dist: joint}) | ||
path.set(j, joint, k) | ||
} | ||
} | ||
} | ||
|
||
return path | ||
} | ||
|
||
// DijkstraAllPaths returns a shortest-path tree for shortest paths in the graph g. | ||
// If weight is nil and the graph does not implement graph.Coster, UniformCost is used. | ||
// DijkstraAllPaths will panic if g has a negative edge weight. | ||
func DijkstraAllPaths(g graph.Graph, weight graph.CostFunc) (paths ShortestPaths) { | ||
var ( | ||
from = g.Neighbors | ||
edgeTo func(graph.Node, graph.Node) graph.Edge | ||
) | ||
switch g := g.(type) { | ||
case graph.DirectedGraph: | ||
from = g.Successors | ||
edgeTo = g.EdgeTo | ||
default: | ||
edgeTo = g.EdgeBetween | ||
} | ||
if weight == nil { | ||
if g, ok := g.(graph.Coster); ok { | ||
weight = g.Cost | ||
} else { | ||
weight = UniformCost | ||
} | ||
} | ||
|
||
nodes := g.NodeList() | ||
|
||
indexOf := make(map[int]int, len(nodes)) | ||
for i, n := range nodes { | ||
indexOf[n.ID()] = i | ||
} | ||
|
||
dist := make([]float64, len(nodes)*len(nodes)) | ||
for i := range dist { | ||
dist[i] = math.Inf(1) | ||
} | ||
paths = ShortestPaths{ | ||
nodes: nodes, | ||
indexOf: indexOf, | ||
|
||
dist: mat64.NewDense(len(nodes), len(nodes), dist), | ||
next: make([][]int, len(nodes)*len(nodes)), | ||
forward: false, | ||
} | ||
|
||
var Q priorityQueue | ||
for i, u := range nodes { | ||
// Dijkstra's algorithm here is implemented essentially as | ||
// described in Function B.2 in figure 6 of UTCS Technical | ||
// Report TR-07-54 with the addition of handling multiple | ||
// co-equal paths. | ||
// | ||
// http://www.cs.utexas.edu/ftp/techreports/tr07-54.pdf | ||
|
||
// Q must be empty at this point. | ||
heap.Push(&Q, distanceNode{node: u, dist: 0}) | ||
for Q.Len() != 0 { | ||
mid := heap.Pop(&Q).(distanceNode) | ||
k := paths.indexOf[mid.node.ID()] | ||
if mid.dist < paths.dist.At(i, k) { | ||
paths.dist.Set(i, k, mid.dist) | ||
} | ||
for _, v := range from(mid.node) { | ||
j := paths.indexOf[v.ID()] | ||
w := weight(edgeTo(mid.node, v)) | ||
if w < 0 { | ||
panic("dijkstra: negative edge weight") | ||
} | ||
joint := paths.dist.At(i, k) + w | ||
if joint < paths.dist.At(i, j) { | ||
heap.Push(&Q, distanceNode{node: v, dist: joint}) | ||
paths.set(i, j, joint, k) | ||
} else if joint == paths.dist.At(i, j) { | ||
paths.add(i, j, k) | ||
} | ||
} | ||
} | ||
} | ||
|
||
return paths | ||
} | ||
|
||
type distanceNode struct { | ||
node graph.Node | ||
dist float64 | ||
} | ||
|
||
// priorityQueue implements a no-dec priority queue. | ||
type priorityQueue []distanceNode | ||
|
||
func (q priorityQueue) Len() int { return len(q) } | ||
func (q priorityQueue) Less(i, j int) bool { return q[i].dist < q[j].dist } | ||
func (q priorityQueue) Swap(i, j int) { q[i], q[j] = q[j], q[i] } | ||
func (q *priorityQueue) Push(n interface{}) { *q = append(*q, n.(distanceNode)) } | ||
func (q *priorityQueue) Pop() interface{} { | ||
t := *q | ||
var n interface{} | ||
n, *q = t[len(t)-1], t[:len(t)-1] | ||
return n | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
// Copyright ©2015 The gonum Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package search_test | ||
|
||
import ( | ||
"math" | ||
"reflect" | ||
"sort" | ||
"testing" | ||
|
||
"github.com/gonum/graph" | ||
"github.com/gonum/graph/internal" | ||
"github.com/gonum/graph/search" | ||
) | ||
|
||
func TestDijkstraFrom(t *testing.T) { | ||
for _, test := range positiveWeightTests { | ||
g := test.g() | ||
for _, e := range test.edges { | ||
switch g := g.(type) { | ||
case graph.MutableDirectedGraph: | ||
g.AddDirectedEdge(e, e.Cost) | ||
case graph.MutableGraph: | ||
g.AddUndirectedEdge(e, e.Cost) | ||
default: | ||
panic("dijkstra: bad graph type") | ||
} | ||
} | ||
|
||
pt := search.DijkstraFrom(test.query.From(), g.(graph.Graph), nil) | ||
|
||
if pt.From().ID() != test.query.From().ID() { | ||
t.Fatalf("%q: unexpected from node ID: got:%d want:%d", pt.From().ID(), test.query.From().ID()) | ||
} | ||
|
||
p, weight := pt.To(test.query.To()) | ||
if weight != test.weight { | ||
t.Errorf("%q: unexpected weight from Between: got:%f want:%f", | ||
test.name, weight, test.weight) | ||
} | ||
if weight := pt.WeightTo(test.query.To()); weight != test.weight { | ||
t.Errorf("%q: unexpected weight from Weight: got:%f want:%f", | ||
test.name, weight, test.weight) | ||
} | ||
|
||
var got []int | ||
for _, n := range p { | ||
got = append(got, n.ID()) | ||
} | ||
ok := len(got) == 0 && len(test.want) == 0 | ||
for _, sp := range test.want { | ||
if reflect.DeepEqual(got, sp) { | ||
ok = true | ||
break | ||
} | ||
} | ||
if !ok { | ||
t.Errorf("%q: unexpected shortest path:\ngot: %v\nwant from:%v", | ||
test.name, p, test.want) | ||
} | ||
|
||
np, weight := pt.To(test.none.To()) | ||
if pt.From().ID() == test.none.From().ID() && (np != nil || !math.IsInf(weight, 1)) { | ||
t.Errorf("%q: unexpected path:\ngot: path=%v weight=%f\nwant:path=<nil> weight=+Inf", | ||
test.name, np, weight) | ||
} | ||
} | ||
} | ||
|
||
func TestDijkstraAllPaths(t *testing.T) { | ||
for _, test := range positiveWeightTests { | ||
g := test.g() | ||
for _, e := range test.edges { | ||
switch g := g.(type) { | ||
case graph.MutableDirectedGraph: | ||
g.AddDirectedEdge(e, e.Cost) | ||
case graph.MutableGraph: | ||
g.AddUndirectedEdge(e, e.Cost) | ||
default: | ||
panic("dijkstra: bad graph type") | ||
} | ||
} | ||
|
||
pt := search.DijkstraAllPaths(g.(graph.Graph), nil) | ||
|
||
// Check all random paths returned are OK. | ||
for i := 0; i < 10; i++ { | ||
p, weight, unique := pt.Between(test.query.From(), test.query.To()) | ||
if weight != test.weight { | ||
t.Errorf("%q: unexpected weight from Between: got:%f want:%f", | ||
test.name, weight, test.weight) | ||
} | ||
if weight := pt.Weight(test.query.From(), test.query.To()); weight != test.weight { | ||
t.Errorf("%q: unexpected weight from Weight: got:%f want:%f", | ||
test.name, weight, test.weight) | ||
} | ||
if unique != test.unique { | ||
t.Errorf("%q: unexpected number of paths: got: unique=%t want: unique=%t", | ||
test.name, unique, test.unique) | ||
} | ||
|
||
var got []int | ||
for _, n := range p { | ||
got = append(got, n.ID()) | ||
} | ||
ok := len(got) == 0 && len(test.want) == 0 | ||
for _, sp := range test.want { | ||
if reflect.DeepEqual(got, sp) { | ||
ok = true | ||
break | ||
} | ||
} | ||
if !ok { | ||
t.Errorf("%q: unexpected shortest path:\ngot: %v\nwant from:%v", | ||
test.name, p, test.want) | ||
} | ||
} | ||
|
||
np, weight, unique := pt.Between(test.none.From(), test.none.To()) | ||
if np != nil || !math.IsInf(weight, 1) || unique != false { | ||
t.Errorf("%q: unexpected path:\ngot: path=%v weight=%f unique=%t\nwant:path=<nil> weight=+Inf unique=false", | ||
test.name, np, weight, unique) | ||
} | ||
|
||
paths, weight := pt.AllBetween(test.query.From(), test.query.To()) | ||
if weight != test.weight { | ||
t.Errorf("%q: unexpected weight from Between: got:%f want:%f", | ||
test.name, weight, test.weight) | ||
} | ||
|
||
var got [][]int | ||
if len(paths) != 0 { | ||
got = make([][]int, len(paths)) | ||
} | ||
for i, p := range paths { | ||
for _, v := range p { | ||
got[i] = append(got[i], v.ID()) | ||
} | ||
} | ||
sort.Sort(internal.BySliceValues(got)) | ||
if !reflect.DeepEqual(got, test.want) { | ||
t.Errorf("testing %q: unexpected shortest paths:\ngot: %v\nwant:%v", | ||
test.name, got, test.want) | ||
} | ||
|
||
nps, weight := pt.AllBetween(test.none.From(), test.none.To()) | ||
if nps != nil || !math.IsInf(weight, 1) { | ||
t.Errorf("%q: unexpected path:\ngot: paths=%v weight=%f\nwant:path=<nil> weight=+Inf", | ||
test.name, nps, weight) | ||
} | ||
} | ||
} |
Oops, something went wrong.