diff --git a/lib/Frenetic_NetKAT_Equivalence.ml b/lib/Frenetic_NetKAT_Equivalence.ml index c3c9bdcb1..d8bb069d3 100644 --- a/lib/Frenetic_NetKAT_Equivalence.ml +++ b/lib/Frenetic_NetKAT_Equivalence.ml @@ -1,6 +1,6 @@ open Core.Std -module A = Frenetic_NetKAT_Compiler.Automaton +module Automaton = Frenetic_NetKAT_Compiler.Automaton module FDD = Frenetic_NetKAT_Compiler.FDD type state = FDD.t * FDD.t @@ -11,12 +11,20 @@ module type UPTO = sig end module Upto_Sym () : UPTO = struct - (* FIXME: avoid polymorphic hash/max/min *) + (* FIXME: avoid polymorphic hash/max/min/equality *) let cache = Hash_set.Poly.create () - let equiv s1 s2 = Hash_set.mem cache (min s1 s2, max s1 s2) + let equiv s1 s2 = (s1 = s2) || Hash_set.mem cache (min s1 s2, max s1 s2) let add_equiv s1 s2 = Hash_set.add cache (min s1 s2, max s1 s2) end +module Upto_Trans () : UPTO = struct + (* FIXME: avoid polymorphic hash/max/min/equality *) + let cache = Hashtbl.Poly.create () + let find = Hashtbl.find_or_add cache ~default:Union_find.create + let equiv s1 s2 = (s1 = s2) || Union_find.same_class (find s1) (find s2) + let add_equiv s1 s2 = Union_find.union (find s1) (find s2) +end + module Make_Naive(Upto : UPTO) = struct module SymPkt = struct @@ -41,7 +49,7 @@ module Make_Naive(Upto : UPTO) = struct - let equiv ?(pk=SymPkt.all) (a1 : A.t) (a2 : A.t) = + let equiv ?(pk=SymPkt.all) (a1 : Automaton.t) (a2 : Automaton.t) = let rec eq_states pk (s1 : int) (s2 : int) = let mask = SymPkt.to_alist pk in