From 2babe18bd4e9576951e502e7cde2c7bf19b9ed2c Mon Sep 17 00:00:00 2001 From: kortschak Date: Fri, 5 Jun 2015 10:13:26 +0930 Subject: [PATCH] search: add Johnson all pair shortest path function --- search/johnson_apsp.go | 149 ++++++++++++++++++++++++++++++++++++ search/johnson_apsp_test.go | 109 ++++++++++++++++++++++++++ 2 files changed, 258 insertions(+) create mode 100644 search/johnson_apsp.go create mode 100644 search/johnson_apsp_test.go diff --git a/search/johnson_apsp.go b/search/johnson_apsp.go new file mode 100644 index 00000000..63bdb666 --- /dev/null +++ b/search/johnson_apsp.go @@ -0,0 +1,149 @@ +// Copyright ©2015 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package search + +import ( + "math" + "math/rand" + + "github.com/gonum/graph" + "github.com/gonum/graph/concrete" +) + +// JohnsonAllPaths returns a shortest-path tree for shortest paths in the graph g. +// If weight is nil and the graph does not implement graph.Coster, UniformCost is used. +func JohnsonAllPaths(g graph.Graph, weight graph.CostFunc) (paths ShortestPaths, ok bool) { + jg := johnsonWeightAdjuster{ + g: g, + from: g.Neighbors, + to: g.Neighbors, + weight: weight, + } + switch g := g.(type) { + case graph.DirectedGraph: + jg.from = g.Successors + jg.to = g.Predecessors + jg.edgeTo = g.EdgeTo + default: + jg.edgeTo = g.EdgeBetween + } + if jg.weight == nil { + if g, ok := g.(graph.Coster); ok { + jg.weight = g.Cost + } else { + jg.weight = UniformCost + } + } + + nodes := g.NodeList() + indexOf := make(map[int]int, len(nodes)) + for i, n := range nodes { + indexOf[n.ID()] = i + } + sign := -1 + for { + // Choose a random node ID until we find + // one that is not in g. + jg.q = sign * rand.Int() + if _, exists := indexOf[jg.q]; !exists { + break + } + sign *= -1 + } + + jg.bellmanFord = true + jg.adjustBy, ok = BellmanFordFrom(johnsonGraphNode(jg.q), jg, nil) + if !ok { + return paths, false + } + + jg.bellmanFord = false + paths = DijkstraAllPaths(jg, nil) + + for i, u := range paths.nodes { + hu := jg.adjustBy.WeightTo(u) + for j, v := range paths.nodes { + if i == j { + continue + } + hv := jg.adjustBy.WeightTo(v) + paths.dist.Set(i, j, paths.dist.At(i, j)-hu+hv) + } + } + + return paths, ok +} + +type johnsonWeightAdjuster struct { + q int + g graph.Graph + + from, to func(graph.Node) []graph.Node + edgeTo func(graph.Node, graph.Node) graph.Edge + weight graph.CostFunc + + bellmanFord bool + adjustBy Shortest +} + +var ( + _ graph.DirectedGraph = johnsonWeightAdjuster{} + _ graph.Coster = johnsonWeightAdjuster{} +) + +func (g johnsonWeightAdjuster) NodeExists(n graph.Node) bool { + if g.bellmanFord && n.ID() == g.q { + return true + } + return g.g.NodeExists(n) + +} + +func (g johnsonWeightAdjuster) NodeList() []graph.Node { + if g.bellmanFord { + return append(g.g.NodeList(), johnsonGraphNode(g.q)) + } + return g.g.NodeList() +} + +func (g johnsonWeightAdjuster) Neighbors(n graph.Node) []graph.Node { + panic("search: unintended use of johnsonWeightAdjuster") +} + +func (g johnsonWeightAdjuster) EdgeBetween(u, v graph.Node) graph.Edge { + panic("search: unintended use of johnsonWeightAdjuster") +} + +func (g johnsonWeightAdjuster) Successors(n graph.Node) []graph.Node { + if g.bellmanFord && n.ID() == g.q { + return g.g.NodeList() + } + return g.from(n) +} + +func (g johnsonWeightAdjuster) Predecessors(n graph.Node) []graph.Node { + panic("search: unintended use of johnsonWeightAdjuster") +} + +func (g johnsonWeightAdjuster) EdgeTo(u, v graph.Node) graph.Edge { + if g.bellmanFord && u.ID() == g.q && g.g.NodeExists(v) { + return concrete.Edge{johnsonGraphNode(g.q), v} + } + return g.edgeTo(u, v) +} + +func (g johnsonWeightAdjuster) Cost(e graph.Edge) float64 { + if g.bellmanFord { + switch g.q { + case e.From().ID(): + return 0 + case e.To().ID(): + return math.Inf(1) + default: + return g.weight(e) + } + } + return g.weight(e) + g.adjustBy.WeightTo(e.From()) - g.adjustBy.WeightTo(e.To()) +} diff --git a/search/johnson_apsp_test.go b/search/johnson_apsp_test.go new file mode 100644 index 00000000..9a6d912a --- /dev/null +++ b/search/johnson_apsp_test.go @@ -0,0 +1,109 @@ +// Copyright ©2015 The gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package search_test + +import ( + "math" + "reflect" + "sort" + "testing" + + "github.com/gonum/graph" + "github.com/gonum/graph/internal" + "github.com/gonum/graph/search" +) + +func TestJohnsonAllPaths(t *testing.T) { + for _, test := range shortestPathTests { + g := test.g() + for _, e := range test.edges { + switch g := g.(type) { + case graph.MutableDirectedGraph: + g.AddDirectedEdge(e, e.Cost) + case graph.MutableGraph: + g.AddUndirectedEdge(e, e.Cost) + default: + panic("johnson: bad graph type") + } + } + + pt, ok := search.JohnsonAllPaths(g.(graph.Graph), nil) + if test.hasNegativeCycle { + if ok { + t.Errorf("%q: expected negative cycle", test.name) + } + continue + } + if !ok { + t.Fatalf("%q: unexpected negative cycle", test.name) + } + + // Check all random paths returned are OK. + for i := 0; i < 10; i++ { + p, weight, unique := pt.Between(test.query.From(), test.query.To()) + if weight != test.weight { + t.Errorf("%q: unexpected weight from Between: got:%f want:%f", + test.name, weight, test.weight) + } + if weight := pt.Weight(test.query.From(), test.query.To()); weight != test.weight { + t.Errorf("%q: unexpected weight from Weight: got:%f want:%f", + test.name, weight, test.weight) + } + if unique != test.unique { + t.Errorf("%q: unexpected number of paths: got: unique=%t want: unique=%t", + test.name, unique, test.unique) + } + + var got []int + for _, n := range p { + got = append(got, n.ID()) + } + ok := len(got) == 0 && len(test.want) == 0 + for _, sp := range test.want { + if reflect.DeepEqual(got, sp) { + ok = true + break + } + } + if !ok { + t.Errorf("%q: unexpected shortest path:\ngot: %v\nwant from:%v", + test.name, p, test.want) + } + } + + np, weight, unique := pt.Between(test.none.From(), test.none.To()) + if np != nil || !math.IsInf(weight, 1) || unique != false { + t.Errorf("%q: unexpected path:\ngot: path=%v weight=%f unique=%t\nwant:path= weight=+Inf unique=false", + test.name, np, weight, unique) + } + + paths, weight := pt.AllBetween(test.query.From(), test.query.To()) + if weight != test.weight { + t.Errorf("%q: unexpected weight from Between: got:%f want:%f", + test.name, weight, test.weight) + } + + var got [][]int + if len(paths) != 0 { + got = make([][]int, len(paths)) + } + for i, p := range paths { + for _, v := range p { + got[i] = append(got[i], v.ID()) + } + } + sort.Sort(internal.BySliceValues(got)) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("testing %q: unexpected shortest paths:\ngot: %v\nwant:%v", + test.name, got, test.want) + } + + nps, weight := pt.AllBetween(test.none.From(), test.none.To()) + if nps != nil || !math.IsInf(weight, 1) { + t.Errorf("%q: unexpected path:\ngot: paths=%v weight=%f\nwant:path= weight=+Inf", + test.name, nps, weight) + } + } +}