Skip to content
This repository has been archived by the owner on Apr 26, 2019. It is now read-only.

Commit

Permalink
search: add DijkstraAllPaths and DijkstraFrom
Browse files Browse the repository at this point in the history
Also fix copyright dates.
  • Loading branch information
kortschak committed Jun 8, 2015
1 parent 83c55cd commit ac2edf6
Show file tree
Hide file tree
Showing 6 changed files with 1,168 additions and 526 deletions.
187 changes: 187 additions & 0 deletions search/dijkstra.go
@@ -0,0 +1,187 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package search

import (
"container/heap"
"math"

"github.com/gonum/graph"
"github.com/gonum/matrix/mat64"
)

// DijkstraFrom returns a shortest-path tree for a shortest path from u to all nodes in
// the graph g. If weight is nil and the graph does not implement graph.Coster, UniformCost
// is used. DijkstraFrom will panic if g has a u-reachable negative edge weight.
func DijkstraFrom(u graph.Node, g graph.Graph, weight graph.CostFunc) Shortest {
if !g.NodeExists(u) {
return Shortest{from: u}
}
var (
from = g.Neighbors
edgeTo func(graph.Node, graph.Node) graph.Edge
)
switch g := g.(type) {
case graph.DirectedGraph:
from = g.Successors
edgeTo = g.EdgeTo
default:
edgeTo = g.EdgeBetween
}
if weight == nil {
if g, ok := g.(graph.Coster); ok {
weight = g.Cost
} else {
weight = UniformCost
}
}

nodes := g.NodeList()

indexOf := make(map[int]int, len(nodes))
for i, n := range nodes {
indexOf[n.ID()] = i
}

path := Shortest{
from: u,

nodes: nodes,
indexOf: indexOf,

dist: make([]float64, len(nodes)),
next: make([]int, len(nodes)),
}
for i := range path.dist {
path.dist[i] = math.Inf(1)
}

// Dijkstra's algorithm here is implemented essentially as
// described in Function B.2 in figure 6 of UTCS Technical
// Report TR-07-54.
//
// http://www.cs.utexas.edu/ftp/techreports/tr07-54.pdf
Q := priorityQueue{{node: u, dist: 0}}
for Q.Len() != 0 {
mid := heap.Pop(&Q).(distanceNode)
k := path.indexOf[mid.node.ID()]
if mid.dist < path.dist[k] {
path.dist[k] = mid.dist
}
for _, v := range from(mid.node) {
j := path.indexOf[v.ID()]
w := weight(edgeTo(mid.node, v))
if w < 0 {
panic("dijkstra: negative edge weight")
}
joint := path.dist[k] + w
if joint < path.dist[j] {
heap.Push(&Q, distanceNode{node: v, dist: joint})
path.set(j, joint, k)
}
}
}

return path
}

// DijkstraAllPaths returns a shortest-path tree for shortest paths in the graph g.
// If weight is nil and the graph does not implement graph.Coster, UniformCost is used.
// DijkstraAllPaths will panic if g has a negative edge weight.
func DijkstraAllPaths(g graph.Graph, weight graph.CostFunc) (paths ShortestPaths) {
var (
from = g.Neighbors
edgeTo func(graph.Node, graph.Node) graph.Edge
)
switch g := g.(type) {
case graph.DirectedGraph:
from = g.Successors
edgeTo = g.EdgeTo
default:
edgeTo = g.EdgeBetween
}
if weight == nil {
if g, ok := g.(graph.Coster); ok {
weight = g.Cost
} else {
weight = UniformCost
}
}

nodes := g.NodeList()

indexOf := make(map[int]int, len(nodes))
for i, n := range nodes {
indexOf[n.ID()] = i
}

dist := make([]float64, len(nodes)*len(nodes))
for i := range dist {
dist[i] = math.Inf(1)
}
paths = ShortestPaths{
nodes: nodes,
indexOf: indexOf,

dist: mat64.NewDense(len(nodes), len(nodes), dist),
next: make([][]int, len(nodes)*len(nodes)),
forward: false,
}

var Q priorityQueue
for i, u := range nodes {
// Dijkstra's algorithm here is implemented essentially as
// described in Function B.2 in figure 6 of UTCS Technical
// Report TR-07-54 with the addition of handling multiple
// co-equal paths.
//
// http://www.cs.utexas.edu/ftp/techreports/tr07-54.pdf

// Q must be empty at this point.
heap.Push(&Q, distanceNode{node: u, dist: 0})
for Q.Len() != 0 {
mid := heap.Pop(&Q).(distanceNode)
k := paths.indexOf[mid.node.ID()]
if mid.dist < paths.dist.At(i, k) {
paths.dist.Set(i, k, mid.dist)
}
for _, v := range from(mid.node) {
j := paths.indexOf[v.ID()]
w := weight(edgeTo(mid.node, v))
if w < 0 {
panic("dijkstra: negative edge weight")
}
joint := paths.dist.At(i, k) + w
if joint < paths.dist.At(i, j) {
heap.Push(&Q, distanceNode{node: v, dist: joint})
paths.set(i, j, joint, k)
} else if joint == paths.dist.At(i, j) {
paths.add(i, j, k)
}
}
}
}

return paths
}

type distanceNode struct {
node graph.Node
dist float64
}

// priorityQueue implements a no-dec priority queue.
type priorityQueue []distanceNode

func (q priorityQueue) Len() int { return len(q) }
func (q priorityQueue) Less(i, j int) bool { return q[i].dist < q[j].dist }
func (q priorityQueue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *priorityQueue) Push(n interface{}) { *q = append(*q, n.(distanceNode)) }
func (q *priorityQueue) Pop() interface{} {
t := *q
var n interface{}
n, *q = t[len(t)-1], t[:len(t)-1]
return n
}
154 changes: 154 additions & 0 deletions search/dijkstra_test.go
@@ -0,0 +1,154 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package search_test

import (
"math"
"reflect"
"sort"
"testing"

"github.com/gonum/graph"
"github.com/gonum/graph/internal"
"github.com/gonum/graph/search"
)

func TestDijkstraFrom(t *testing.T) {
for _, test := range positiveWeightTests {
g := test.g()
for _, e := range test.edges {
switch g := g.(type) {
case graph.MutableDirectedGraph:
g.AddDirectedEdge(e, e.Cost)
case graph.MutableGraph:
g.AddUndirectedEdge(e, e.Cost)
default:
panic("dijkstra: bad graph type")
}
}

pt := search.DijkstraFrom(test.query.From(), g.(graph.Graph), nil)

if pt.From().ID() != test.query.From().ID() {
t.Fatalf("%q: unexpected from node ID: got:%d want:%d", pt.From().ID(), test.query.From().ID())
}

p, weight := pt.To(test.query.To())
if weight != test.weight {
t.Errorf("%q: unexpected weight from Between: got:%f want:%f",
test.name, weight, test.weight)
}
if weight := pt.WeightTo(test.query.To()); weight != test.weight {
t.Errorf("%q: unexpected weight from Weight: got:%f want:%f",
test.name, weight, test.weight)
}

var got []int
for _, n := range p {
got = append(got, n.ID())
}
ok := len(got) == 0 && len(test.want) == 0
for _, sp := range test.want {
if reflect.DeepEqual(got, sp) {
ok = true
break
}
}
if !ok {
t.Errorf("%q: unexpected shortest path:\ngot: %v\nwant from:%v",
test.name, p, test.want)
}

np, weight := pt.To(test.none.To())
if pt.From().ID() == test.none.From().ID() && (np != nil || !math.IsInf(weight, 1)) {
t.Errorf("%q: unexpected path:\ngot: path=%v weight=%f\nwant:path=<nil> weight=+Inf",
test.name, np, weight)
}
}
}

func TestDijkstraAllPaths(t *testing.T) {
for _, test := range positiveWeightTests {
g := test.g()
for _, e := range test.edges {
switch g := g.(type) {
case graph.MutableDirectedGraph:
g.AddDirectedEdge(e, e.Cost)
case graph.MutableGraph:
g.AddUndirectedEdge(e, e.Cost)
default:
panic("dijkstra: bad graph type")
}
}

pt := search.DijkstraAllPaths(g.(graph.Graph), nil)

// Check all random paths returned are OK.
for i := 0; i < 10; i++ {
p, weight, unique := pt.Between(test.query.From(), test.query.To())
if weight != test.weight {
t.Errorf("%q: unexpected weight from Between: got:%f want:%f",
test.name, weight, test.weight)
}
if weight := pt.Weight(test.query.From(), test.query.To()); weight != test.weight {
t.Errorf("%q: unexpected weight from Weight: got:%f want:%f",
test.name, weight, test.weight)
}
if unique != test.unique {
t.Errorf("%q: unexpected number of paths: got: unique=%t want: unique=%t",
test.name, unique, test.unique)
}

var got []int
for _, n := range p {
got = append(got, n.ID())
}
ok := len(got) == 0 && len(test.want) == 0
for _, sp := range test.want {
if reflect.DeepEqual(got, sp) {
ok = true
break
}
}
if !ok {
t.Errorf("%q: unexpected shortest path:\ngot: %v\nwant from:%v",
test.name, p, test.want)
}
}

np, weight, unique := pt.Between(test.none.From(), test.none.To())
if np != nil || !math.IsInf(weight, 1) || unique != false {
t.Errorf("%q: unexpected path:\ngot: path=%v weight=%f unique=%t\nwant:path=<nil> weight=+Inf unique=false",
test.name, np, weight, unique)
}

paths, weight := pt.AllBetween(test.query.From(), test.query.To())
if weight != test.weight {
t.Errorf("%q: unexpected weight from Between: got:%f want:%f",
test.name, weight, test.weight)
}

var got [][]int
if len(paths) != 0 {
got = make([][]int, len(paths))
}
for i, p := range paths {
for _, v := range p {
got[i] = append(got[i], v.ID())
}
}
sort.Sort(internal.BySliceValues(got))
if !reflect.DeepEqual(got, test.want) {
t.Errorf("testing %q: unexpected shortest paths:\ngot: %v\nwant:%v",
test.name, got, test.want)
}

nps, weight := pt.AllBetween(test.none.From(), test.none.To())
if nps != nil || !math.IsInf(weight, 1) {
t.Errorf("%q: unexpected path:\ngot: paths=%v weight=%f\nwant:path=<nil> weight=+Inf",
test.name, nps, weight)
}
}
}

0 comments on commit ac2edf6

Please sign in to comment.