diff --git a/graph/path/a_star.go b/graph/path/a_star.go index a91da5f61..8fb8e1ebd 100644 --- a/graph/path/a_star.go +++ b/graph/path/a_star.go @@ -21,14 +21,14 @@ import ( // the path from n to t is less than or equal to the true cost of that path. // // If h is nil, AStar will use the g.HeuristicCost method if g implements HeuristicCoster, -// falling back to NullHeuristic otherwise. If the graph does not implement graph.Weighter, +// falling back to NullHeuristic otherwise. If the graph does not implement Weighted, // UniformCost is used. AStar will panic if g has an A*-reachable negative edge weight. func AStar(s, t graph.Node, g graph.Graph, h Heuristic) (path Shortest, expanded int) { if !g.Has(s.ID()) || !g.Has(t.ID()) { return Shortest{from: s}, 0 } var weight Weighting - if wg, ok := g.(graph.Weighted); ok { + if wg, ok := g.(Weighted); ok { weight = wg.Weight } else { weight = UniformCost(g) diff --git a/graph/path/bellman_ford_moore.go b/graph/path/bellman_ford_moore.go index 9786175b0..45d637391 100644 --- a/graph/path/bellman_ford_moore.go +++ b/graph/path/bellman_ford_moore.go @@ -8,7 +8,7 @@ import "gonum.org/v1/gonum/graph" // BellmanFordFrom returns a shortest-path tree for a shortest path from u to all nodes in // the graph g, or false indicating that a negative cycle exists in the graph. If the graph -// does not implement graph.Weighter, UniformCost is used. +// does not implement Weighted, UniformCost is used. // // The time complexity of BellmanFordFrom is O(|V|.|E|). func BellmanFordFrom(u graph.Node, g graph.Graph) (path Shortest, ok bool) { @@ -16,7 +16,7 @@ func BellmanFordFrom(u graph.Node, g graph.Graph) (path Shortest, ok bool) { return Shortest{from: u}, true } var weight Weighting - if wg, ok := g.(graph.Weighted); ok { + if wg, ok := g.(Weighted); ok { weight = wg.Weight } else { weight = UniformCost(g) diff --git a/graph/path/dijkstra.go b/graph/path/dijkstra.go index 6a8a37ebe..dcacd6b02 100644 --- a/graph/path/dijkstra.go +++ b/graph/path/dijkstra.go @@ -8,27 +8,38 @@ import ( "container/heap" "gonum.org/v1/gonum/graph" + "gonum.org/v1/gonum/graph/traverse" ) // DijkstraFrom returns a shortest-path tree for a shortest path from u to all nodes in -// the graph g. If the graph does not implement graph.Weighter, UniformCost is used. +// the graph g. If the graph does not implement Weighted, UniformCost is used. // DijkstraFrom will panic if g has a u-reachable negative edge weight. // +// If g is a graph.Graph, all nodes of the graph will be stored in the shortest-path +// tree, otherwise only nodes reachable from u will be stored. +// // The time complexity of DijkstrFrom is O(|E|.log|V|). -func DijkstraFrom(u graph.Node, g graph.Graph) Shortest { - if !g.Has(u.ID()) { - return Shortest{from: u} +func DijkstraFrom(u graph.Node, g traverse.Graph) Shortest { + var path Shortest + if h, ok := g.(graph.Graph); ok { + if !h.Has(u.ID()) { + return Shortest{from: u} + } + path = newShortestFrom(u, h.Nodes()) + } else { + if g.From(u.ID()) == nil { + return Shortest{from: u} + } + path = newShortestFrom(u, []graph.Node{u}) } + var weight Weighting - if wg, ok := g.(graph.Weighted); ok { + if wg, ok := g.(Weighted); ok { weight = wg.Weight } else { weight = UniformCost(g) } - nodes := g.Nodes() - path := newShortestFrom(u, nodes) - // Dijkstra's algorithm here is implemented essentially as // described in Function B.2 in figure 6 of UTCS Technical // Report TR-07-54. @@ -49,7 +60,10 @@ func DijkstraFrom(u graph.Node, g graph.Graph) Shortest { mnid := mid.node.ID() for _, v := range g.From(mnid) { vid := v.ID() - j := path.indexOf[vid] + j, ok := path.indexOf[vid] + if !ok { + j = path.add(v) + } w, ok := weight(mnid, vid) if !ok { panic("dijkstra: unexpected invalid weight") diff --git a/graph/path/dijkstra_test.go b/graph/path/dijkstra_test.go index fe6450e21..e695a1233 100644 --- a/graph/path/dijkstra_test.go +++ b/graph/path/dijkstra_test.go @@ -13,6 +13,7 @@ import ( "gonum.org/v1/gonum/graph" "gonum.org/v1/gonum/graph/internal/ordered" "gonum.org/v1/gonum/graph/path/internal/testgraphs" + "gonum.org/v1/gonum/graph/traverse" ) func TestDijkstraFrom(t *testing.T) { @@ -22,65 +23,82 @@ func TestDijkstraFrom(t *testing.T) { g.SetWeightedEdge(e) } - var ( - pt Shortest - - panicked bool - ) - func() { - defer func() { - panicked = recover() != nil + for _, tg := range []struct { + typ string + g traverse.Graph + }{ + {"complete", g.(graph.Graph)}, + {"incremental", incremental{g.(graph.Weighted)}}, + } { + var ( + pt Shortest + + panicked bool + ) + func() { + defer func() { + panicked = recover() != nil + }() + pt = DijkstraFrom(test.Query.From(), tg.g) }() - pt = DijkstraFrom(test.Query.From(), g.(graph.Graph)) - }() - if panicked || test.HasNegativeWeight { - if !test.HasNegativeWeight { - t.Errorf("%q: unexpected panic", test.Name) - } - if !panicked { - t.Errorf("%q: expected panic for negative edge weight", test.Name) + if panicked || test.HasNegativeWeight { + if !test.HasNegativeWeight { + t.Errorf("%q %s: unexpected panic", test.Name, tg.typ) + } + if !panicked { + t.Errorf("%q %s: expected panic for negative edge weight", test.Name, tg.typ) + } + continue } - continue - } - if pt.From().ID() != test.Query.From().ID() { - t.Fatalf("%q: unexpected from node ID: got:%d want:%d", test.Name, pt.From().ID(), test.Query.From().ID()) - } + if pt.From().ID() != test.Query.From().ID() { + t.Fatalf("%q %s: unexpected from node ID: got:%d want:%d", test.Name, tg.typ, pt.From().ID(), test.Query.From().ID()) + } - p, weight := pt.To(test.Query.To().ID()) - if weight != test.Weight { - t.Errorf("%q: unexpected weight from Between: got:%f want:%f", - test.Name, weight, test.Weight) - } - if weight := pt.WeightTo(test.Query.To().ID()); weight != test.Weight { - t.Errorf("%q: unexpected weight from Weight: got:%f want:%f", - test.Name, weight, test.Weight) - } + p, weight := pt.To(test.Query.To().ID()) + if weight != test.Weight { + t.Errorf("%q %s: unexpected weight from Between: got:%f want:%f", + test.Name, tg.typ, weight, test.Weight) + } + if weight := pt.WeightTo(test.Query.To().ID()); weight != test.Weight { + t.Errorf("%q %s: unexpected weight from Weight: got:%f want:%f", + test.Name, tg.typ, weight, test.Weight) + } - var got []int64 - for _, n := range p { - got = append(got, n.ID()) - } - ok := len(got) == 0 && len(test.WantPaths) == 0 - for _, sp := range test.WantPaths { - if reflect.DeepEqual(got, sp) { - ok = true - break + var got []int64 + for _, n := range p { + got = append(got, n.ID()) + } + ok := len(got) == 0 && len(test.WantPaths) == 0 + for _, sp := range test.WantPaths { + if reflect.DeepEqual(got, sp) { + ok = true + break + } + } + if !ok { + t.Errorf("%q %s: unexpected shortest path:\ngot: %v\nwant from:%v", + test.Name, tg.typ, p, test.WantPaths) } - } - if !ok { - t.Errorf("%q: unexpected shortest path:\ngot: %v\nwant from:%v", - test.Name, p, test.WantPaths) - } - np, weight := pt.To(test.NoPathFor.To().ID()) - if pt.From().ID() == test.NoPathFor.From().ID() && (np != nil || !math.IsInf(weight, 1)) { - t.Errorf("%q: unexpected path:\ngot: path=%v weight=%f\nwant:path= weight=+Inf", - test.Name, np, weight) + np, weight := pt.To(test.NoPathFor.To().ID()) + if pt.From().ID() == test.NoPathFor.From().ID() && (np != nil || !math.IsInf(weight, 1)) { + t.Errorf("%q %s: unexpected path:\ngot: path=%v weight=%f\nwant:path= weight=+Inf", + test.Name, tg.typ, np, weight) + } } } } +type weightedTraverseGraph interface { + traverse.Graph + Weighted +} + +type incremental struct { + weightedTraverseGraph +} + func TestDijkstraAllPaths(t *testing.T) { for _, test := range testgraphs.ShortestPathTests { g := test.Graph() diff --git a/graph/path/floydwarshall.go b/graph/path/floydwarshall.go index db9d91486..8b41c79c5 100644 --- a/graph/path/floydwarshall.go +++ b/graph/path/floydwarshall.go @@ -8,12 +8,12 @@ import "gonum.org/v1/gonum/graph" // FloydWarshall returns a shortest-path tree for the graph g or false indicating // that a negative cycle exists in the graph. If the graph does not implement -// graph.Weighter, UniformCost is used. +// Weighted, UniformCost is used. // // The time complexity of FloydWarshall is O(|V|^3). func FloydWarshall(g graph.Graph) (paths AllShortest, ok bool) { var weight Weighting - if wg, ok := g.(graph.Weighted); ok { + if wg, ok := g.(Weighted); ok { weight = wg.Weight } else { weight = UniformCost(g) diff --git a/graph/path/johnson_apsp.go b/graph/path/johnson_apsp.go index ccabbbeff..557c2310f 100644 --- a/graph/path/johnson_apsp.go +++ b/graph/path/johnson_apsp.go @@ -14,7 +14,7 @@ import ( ) // JohnsonAllPaths returns a shortest-path tree for shortest paths in the graph g. -// If the graph does not implement graph.Weighter, UniformCost is used. +// If the graph does not implement Weighted, UniformCost is used. // // The time complexity of JohnsonAllPaths is O(|V|.|E|+|V|^2.log|V|). func JohnsonAllPaths(g graph.Graph) (paths AllShortest, ok bool) { @@ -23,7 +23,7 @@ func JohnsonAllPaths(g graph.Graph) (paths AllShortest, ok bool) { from: g.From, edgeTo: g.Edge, } - if wg, ok := g.(graph.Weighted); ok { + if wg, ok := g.(Weighted); ok { jg.weight = wg.Weight } else { jg.weight = UniformCost(g) diff --git a/graph/path/shortest.go b/graph/path/shortest.go index c2d120d44..2a37f37f7 100644 --- a/graph/path/shortest.go +++ b/graph/path/shortest.go @@ -83,6 +83,22 @@ func newShortestFrom(u graph.Node, nodes []graph.Node) Shortest { return p } +// add adds a node to the Shortest, initialising its stored index and returning, and +// setting the distance and position as unconnected. add will panic if the node is +// already present. +func (p *Shortest) add(u graph.Node) int { + uid := u.ID() + if _, exists := p.indexOf[uid]; exists { + panic("shortest: adding existing node") + } + idx := len(p.nodes) + p.indexOf[uid] = idx + p.nodes = append(p.nodes, u) + p.dist = append(p.dist, math.Inf(1)) + p.next = append(p.next, -1) + return idx +} + func (p Shortest) set(to int, weight float64, mid int) { p.dist[to] = weight p.next[to] = mid diff --git a/graph/path/weight.go b/graph/path/weight.go index d5e416e9b..625dda494 100644 --- a/graph/path/weight.go +++ b/graph/path/weight.go @@ -8,15 +8,30 @@ import ( "math" "gonum.org/v1/gonum/graph" + "gonum.org/v1/gonum/graph/traverse" ) +// Weighted is a weighted graph. It is a subset of graph.Weighted. +type Weighted interface { + // Weight returns the weight for the edge between + // x and y with IDs xid and yid if Edge(xid, yid) + // returns a non-nil Edge. + // If x and y are the same node or there is no + // joining edge between the two nodes the weight + // value returned is implementation dependent. + // Weight returns true if an edge exists between + // x and y or if x and y have the same ID, false + // otherwise. + Weight(xid, yid int64) (w float64, ok bool) +} + // Weighting is a mapping between a pair of nodes and a weight. It follows the // semantics of the Weighter interface. type Weighting func(xid, yid int64) (w float64, ok bool) // UniformCost returns a Weighting that returns an edge cost of 1 for existing // edges, zero for node identity and Inf for otherwise absent edges. -func UniformCost(g graph.Graph) Weighting { +func UniformCost(g traverse.Graph) Weighting { return func(xid, yid int64) (w float64, ok bool) { if xid == yid { return 0, true