diff --git a/alg.go b/alg.go index eb80e9f..5bc333c 100644 --- a/alg.go +++ b/alg.go @@ -14,6 +14,10 @@ package lpg +import ( + "context" +) + // Sources finds all the source nodes in the graph func SourcesItr(graph *Graph) NodeIterator { return nodeIterator{ @@ -73,9 +77,13 @@ func EdgesBetweenNodes(from, to *Node) []*Edge { // nodeEquivalenceFunction will be called for all pairs of nodes. The // edgeEquivalenceFunction will be called for edges connecting // equivalent nodes. -func CheckIsomorphism(g1, g2 *Graph, nodeEquivalenceFunc func(n1, n2 *Node) bool, edgeEquivalenceFunc func(e1, e2 *Edge) bool) bool { +// +// This is a potentially long running function. Cancel the context to +// stop. If the function returns because of context cancellation, +// error will be ctx.Err() +func CheckIsomorphism(ctx context.Context, g1, g2 *Graph, nodeEquivalenceFunc func(n1, n2 *Node) bool, edgeEquivalenceFunc func(e1, e2 *Edge) bool) (bool, error) { if g1.NumNodes() != g2.NumNodes() || g1.NumEdges() != g2.NumEdges() { - return false + return false, nil } // Slice of all nodes of g1 @@ -92,7 +100,10 @@ func CheckIsomorphism(g1, g2 *Graph, nodeEquivalenceFunc func(n1, n2 *Node) bool } } if len(equivalences[i]) == 0 { - return false + return false, nil + } + if err := ctx.Err(); err != nil { + return false, err } } @@ -153,13 +164,16 @@ func CheckIsomorphism(g1, g2 *Graph, nodeEquivalenceFunc func(n1, n2 *Node) bool for { nodeMapping := buildNodeEquivalence() if isIsomorphism(nodeMapping) { - return true + return true, nil } if !next() { break } + if err := ctx.Err(); err != nil { + return false, err + } } - return false + return false, nil } // ForEachNode iterates through all the nodes of g until predicate diff --git a/clone_test.go b/clone_test.go index b8572fa..07fb552 100644 --- a/clone_test.go +++ b/clone_test.go @@ -15,6 +15,7 @@ package lpg import ( + "context" "reflect" "testing" ) @@ -52,14 +53,14 @@ func TestClone(t *testing.T) { return value }) - if !CheckIsomorphism(source, target, func(n1, n2 *Node) bool { + if ok, _ := CheckIsomorphism(context.Background(), source, target, func(n1, n2 *Node) bool { result := n1.GetLabels().HasAll(n2.GetLabels().Slice()...) && reflect.DeepEqual(n1.properties, n2.properties) return result }, func(e1, e2 *Edge) bool { result := e1.label == e2.label && reflect.DeepEqual(e1.properties, e2.properties) return result - }) { + }); !ok { t.Errorf("Clone result not isomorphic") } }