Skip to content
This repository has been archived by the owner on Apr 26, 2019. It is now read-only.

Commit

Permalink
search: add Johnson all pair shortest path function
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Jun 8, 2015
1 parent ecf7a0f commit 2babe18
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 0 deletions.
149 changes: 149 additions & 0 deletions search/johnson_apsp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package search

import (
"math"
"math/rand"

"github.com/gonum/graph"
"github.com/gonum/graph/concrete"
)

// JohnsonAllPaths returns a shortest-path tree for shortest paths in the graph g.
// If weight is nil and the graph does not implement graph.Coster, UniformCost is used.
func JohnsonAllPaths(g graph.Graph, weight graph.CostFunc) (paths ShortestPaths, ok bool) {
jg := johnsonWeightAdjuster{
g: g,
from: g.Neighbors,
to: g.Neighbors,
weight: weight,
}
switch g := g.(type) {
case graph.DirectedGraph:
jg.from = g.Successors
jg.to = g.Predecessors
jg.edgeTo = g.EdgeTo
default:
jg.edgeTo = g.EdgeBetween
}
if jg.weight == nil {
if g, ok := g.(graph.Coster); ok {
jg.weight = g.Cost
} else {
jg.weight = UniformCost
}
}

nodes := g.NodeList()
indexOf := make(map[int]int, len(nodes))
for i, n := range nodes {
indexOf[n.ID()] = i
}
sign := -1
for {
// Choose a random node ID until we find
// one that is not in g.
jg.q = sign * rand.Int()
if _, exists := indexOf[jg.q]; !exists {
break
}
sign *= -1
}

jg.bellmanFord = true
jg.adjustBy, ok = BellmanFordFrom(johnsonGraphNode(jg.q), jg, nil)
if !ok {
return paths, false
}

jg.bellmanFord = false
paths = DijkstraAllPaths(jg, nil)

for i, u := range paths.nodes {
hu := jg.adjustBy.WeightTo(u)
for j, v := range paths.nodes {
if i == j {
continue
}
hv := jg.adjustBy.WeightTo(v)
paths.dist.Set(i, j, paths.dist.At(i, j)-hu+hv)
}
}

return paths, ok
}

type johnsonWeightAdjuster struct {
q int
g graph.Graph

from, to func(graph.Node) []graph.Node
edgeTo func(graph.Node, graph.Node) graph.Edge
weight graph.CostFunc

bellmanFord bool
adjustBy Shortest
}

var (
_ graph.DirectedGraph = johnsonWeightAdjuster{}
_ graph.Coster = johnsonWeightAdjuster{}
)

func (g johnsonWeightAdjuster) NodeExists(n graph.Node) bool {
if g.bellmanFord && n.ID() == g.q {
return true
}
return g.g.NodeExists(n)

}

func (g johnsonWeightAdjuster) NodeList() []graph.Node {
if g.bellmanFord {
return append(g.g.NodeList(), johnsonGraphNode(g.q))
}
return g.g.NodeList()
}

func (g johnsonWeightAdjuster) Neighbors(n graph.Node) []graph.Node {
panic("search: unintended use of johnsonWeightAdjuster")
}

func (g johnsonWeightAdjuster) EdgeBetween(u, v graph.Node) graph.Edge {
panic("search: unintended use of johnsonWeightAdjuster")
}

func (g johnsonWeightAdjuster) Successors(n graph.Node) []graph.Node {
if g.bellmanFord && n.ID() == g.q {
return g.g.NodeList()
}
return g.from(n)
}

func (g johnsonWeightAdjuster) Predecessors(n graph.Node) []graph.Node {
panic("search: unintended use of johnsonWeightAdjuster")
}

func (g johnsonWeightAdjuster) EdgeTo(u, v graph.Node) graph.Edge {
if g.bellmanFord && u.ID() == g.q && g.g.NodeExists(v) {
return concrete.Edge{johnsonGraphNode(g.q), v}
}
return g.edgeTo(u, v)
}

func (g johnsonWeightAdjuster) Cost(e graph.Edge) float64 {
if g.bellmanFord {
switch g.q {
case e.From().ID():
return 0
case e.To().ID():
return math.Inf(1)
default:
return g.weight(e)
}
}
return g.weight(e) + g.adjustBy.WeightTo(e.From()) - g.adjustBy.WeightTo(e.To())
}
109 changes: 109 additions & 0 deletions search/johnson_apsp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package search_test

import (
"math"
"reflect"
"sort"
"testing"

"github.com/gonum/graph"
"github.com/gonum/graph/internal"
"github.com/gonum/graph/search"
)

func TestJohnsonAllPaths(t *testing.T) {
for _, test := range shortestPathTests {
g := test.g()
for _, e := range test.edges {
switch g := g.(type) {
case graph.MutableDirectedGraph:
g.AddDirectedEdge(e, e.Cost)
case graph.MutableGraph:
g.AddUndirectedEdge(e, e.Cost)
default:
panic("johnson: bad graph type")
}
}

pt, ok := search.JohnsonAllPaths(g.(graph.Graph), nil)
if test.hasNegativeCycle {
if ok {
t.Errorf("%q: expected negative cycle", test.name)
}
continue
}
if !ok {
t.Fatalf("%q: unexpected negative cycle", test.name)
}

// Check all random paths returned are OK.
for i := 0; i < 10; i++ {
p, weight, unique := pt.Between(test.query.From(), test.query.To())
if weight != test.weight {
t.Errorf("%q: unexpected weight from Between: got:%f want:%f",
test.name, weight, test.weight)
}
if weight := pt.Weight(test.query.From(), test.query.To()); weight != test.weight {
t.Errorf("%q: unexpected weight from Weight: got:%f want:%f",
test.name, weight, test.weight)
}
if unique != test.unique {
t.Errorf("%q: unexpected number of paths: got: unique=%t want: unique=%t",
test.name, unique, test.unique)
}

var got []int
for _, n := range p {
got = append(got, n.ID())
}
ok := len(got) == 0 && len(test.want) == 0
for _, sp := range test.want {
if reflect.DeepEqual(got, sp) {
ok = true
break
}
}
if !ok {
t.Errorf("%q: unexpected shortest path:\ngot: %v\nwant from:%v",
test.name, p, test.want)
}
}

np, weight, unique := pt.Between(test.none.From(), test.none.To())
if np != nil || !math.IsInf(weight, 1) || unique != false {
t.Errorf("%q: unexpected path:\ngot: path=%v weight=%f unique=%t\nwant:path=<nil> weight=+Inf unique=false",
test.name, np, weight, unique)
}

paths, weight := pt.AllBetween(test.query.From(), test.query.To())
if weight != test.weight {
t.Errorf("%q: unexpected weight from Between: got:%f want:%f",
test.name, weight, test.weight)
}

var got [][]int
if len(paths) != 0 {
got = make([][]int, len(paths))
}
for i, p := range paths {
for _, v := range p {
got[i] = append(got[i], v.ID())
}
}
sort.Sort(internal.BySliceValues(got))
if !reflect.DeepEqual(got, test.want) {
t.Errorf("testing %q: unexpected shortest paths:\ngot: %v\nwant:%v",
test.name, got, test.want)
}

nps, weight := pt.AllBetween(test.none.From(), test.none.To())
if nps != nil || !math.IsInf(weight, 1) {
t.Errorf("%q: unexpected path:\ngot: paths=%v weight=%f\nwant:path=<nil> weight=+Inf",
test.name, nps, weight)
}
}
}

0 comments on commit 2babe18

Please sign in to comment.