Skip to content

Commit

Permalink
Merge pull request #643 from ejgallego/require_memo
Browse files Browse the repository at this point in the history
[memo] Add memo table for Require.
  • Loading branch information
ejgallego committed Jan 20, 2024
2 parents 1344b35 + 2fe6527 commit 1176337
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 81 deletions.
2 changes: 1 addition & 1 deletion fleche/doc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ module Util = struct

let print_stats () =
(if !Config.v.mem_stats then
let size = Memo.Interp.size () in
let size = Memo.all_size () in
Io.Log.trace "stats" (string_of_int size));
let stats = Stats.dump () in
Io.Log.trace "cache" (Stats.to_string stats);
Expand Down
218 changes: 139 additions & 79 deletions fleche/memo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -66,44 +66,14 @@ module MemoTable = struct
end
end

module Interp = struct
(* Loc-independent command evalution and caching. *)
module VernacInput = struct
type t = Coq.State.t * Coq.Ast.t

(* This crutially relies on our ppx to ignore the CAst location *)
let equal (st1, v1) (st2, v2) =
if Coq.Ast.compare v1 v2 = 0 then
if Coq.State.compare st1 st2 = 0 then true else false
else false

let hash (st, v) = Hashtbl.hash (Coq.Ast.hash v, st)
end

type t = VernacInput.t

let input_info (st, v) =
Format.asprintf "stm: %d | st %d" (Coq.Ast.hash v) (Hashtbl.hash st)

module HC = MemoTable.Make (VernacInput)

module Result = struct
(* We store the location as to compute an offset for cached results *)
type t = Loc.t * (Coq.State.t, Loc.t) Coq.Protect.E.t
end

type cache = Result.t HC.t

let cache : cache ref = ref (HC.create 1000)

(* This is very expensive *)
let size () = Obj.reachable_words (Obj.magic cache)

let in_cache st stm =
let kind = CS.Kind.Hashing in
CS.record ~kind ~f:(HC.find_opt !cache) (st, stm)

(* XXX: Move elsewhere *)
(* XXX: Move elsewhere *)
module Loc_utils : sig
val adjust_offset :
stm_loc:Loc.t
-> cached_loc:Loc.t
-> ('a, Loc.t) Coq.Protect.E.t
-> ('a, Loc.t) Coq.Protect.E.t
end = struct
let loc_offset (l1 : Loc.t) (l2 : Loc.t) =
let line_offset = l2.line_nb - l1.line_nb in
let bol_offset = l2.bol_pos - l1.bol_pos in
Expand Down Expand Up @@ -138,69 +108,159 @@ module Interp = struct
let offset = loc_offset cached_loc stm_loc in
let f = loc_apply_offset offset in
Coq.Protect.E.map_loc ~f res
end

module type EvalType = sig
include Hashtbl.HashedType

type output

val eval : t -> (output, Loc.t) Coq.Protect.E.t
end

module SEval (E : EvalType) = struct
type t = E.t

module HC = MemoTable.Make (E)

let cache = HC.create 1000
let size () = Obj.reachable_words (Obj.magic cache)

let eval v =
match HC.find_opt cache v with
| None ->
let admitted_st = E.eval v in
HC.add_execution cache v admitted_st;
admitted_st
| Some admitted_st -> admitted_st
end

module type LocEvalType = sig
include EvalType

val loc_of_input : t -> Loc.t
val input_info : t -> string
end

module CEval (E : LocEvalType) = struct
type t = E.t

let eval (st, stm) : _ Stats.t =
let stm_loc = Coq.Ast.loc stm |> Option.get in
match in_cache st stm with
module HC = MemoTable.Make (E)

module Result = struct
(* We store the location as to compute an offset for cached results *)
type t = Loc.t * (E.output, Loc.t) Coq.Protect.E.t
end

type cache = Result.t HC.t

let cache : cache ref = ref (HC.create 1000)

(* This is very expensive *)
let size () = Obj.reachable_words (Obj.magic cache)
let input_info = E.input_info

let in_cache i =
let kind = CS.Kind.Hashing in
CS.record ~kind ~f:(HC.find_opt !cache) i

let eval i : _ Stats.t =
let stm_loc = E.loc_of_input i in
match in_cache i with
| Some (cached_loc, res), time ->
if Debug.cache then Io.Log.trace "memo" "cache hit";
CacheStats.hit ();
let res = adjust_offset ~stm_loc ~cached_loc res in
let res = Loc_utils.adjust_offset ~stm_loc ~cached_loc res in
Stats.make ~cache_hit:true ~time res
| None, time_hash ->
if Debug.cache then Io.Log.trace "memo" "cache miss";
CacheStats.miss ();
let kind = CS.Kind.Exec in
let res, time_interp = CS.record ~kind ~f:(Coq.Interp.interp ~st) stm in
let () = HC.add_execution_loc !cache (st, stm) (stm_loc, res) in
let res, time_interp = CS.record ~kind ~f:E.eval i in
let () = HC.add_execution_loc !cache i (stm_loc, res) in
let time = time_hash +. time_interp in
Stats.make ~time res
end

module Admit = struct
type t = Coq.State.t
module VernacEval = struct
type t = Coq.State.t * Coq.Ast.t

module C = MemoTable.Make (Coq.State)
(* This crutially relies on our ppx to ignore the CAst location *)
let equal (st1, v1) (st2, v2) =
if Coq.Ast.compare v1 v2 = 0 then
if Coq.State.compare st1 st2 = 0 then true else false
else false

let cache = C.create 1000
let hash (st, v) = Hashtbl.hash (Coq.Ast.hash v, Coq.State.hash st)
let loc_of_input (_, stm) = Coq.Ast.loc stm |> Option.get

let eval v =
match C.find_opt cache v with
| None ->
let admitted_st = Coq.State.admit ~st:v in
C.add_execution cache v admitted_st;
admitted_st
| Some admitted_st -> admitted_st
let input_info (st, v) =
Format.asprintf "stm: %d | st %d" (Coq.Ast.hash v) (Hashtbl.hash st)

type output = Coq.State.t

let eval (st, stm) = Coq.Interp.interp ~st stm
end

module Init = struct
module S = struct
type t = Coq.State.t * Coq.Workspace.t * Lang.LUri.File.t
module Interp = CEval (VernacEval)

let equal (s1, w1, u1) (s2, w2, u2) : bool =
if Lang.LUri.File.compare u1 u2 = 0 then
if Coq.Workspace.compare w1 w2 = 0 then
if Coq.State.compare s1 s2 = 0 then true else false
else false
else false
module RequireEval = struct
type t = Coq.State.t * Coq.Files.t * Coq.Ast.Require.t

let hash (st, w, uri) =
Hashtbl.hash
(Coq.State.hash st, Coq.Workspace.hash w, Lang.LUri.File.hash uri)
end
(* This crutially relies on our ppx to ignore the CAst location *)
let equal (st1, f1, r1) (st2, f2, r2) =
if
Coq.Ast.Require.compare r1 r2 = 0
&& Coq.Files.compare f1 f2 = 0
&& Coq.State.compare st1 st2 = 0
then true
else false

type t = S.t
let hash (st, f, v) =
Hashtbl.hash (Coq.Ast.Require.hash v, Coq.Files.hash f, Coq.State.hash st)

module C = MemoTable.Make (S)
let input_info (st, f, v) =
Format.asprintf "stm: %d | st %d | f %d" (Coq.Ast.Require.hash v)
(Hashtbl.hash st) (Coq.Files.hash f)

let cache = C.create 1000
let loc_of_input (_, _, stm) = Option.get stm.Coq.Ast.Require.loc

let eval v =
match C.find_opt cache v with
| None ->
let root_state, workspace, uri = v in
let admitted_st = Coq.Init.doc_init ~root_state ~workspace ~uri in
C.add_execution cache v admitted_st;
admitted_st
| Some res -> res
type output = Coq.State.t

let eval (st, files, stm) = Coq.Interp.Require.interp ~st files stm
end

module Require = CEval (RequireEval)

module Admit = SEval (struct
include Coq.State

type output = Coq.State.t

let eval st = Coq.State.admit ~st
end)

module InitEval = struct
type t = Coq.State.t * Coq.Workspace.t * Lang.LUri.File.t

let equal (s1, w1, u1) (s2, w2, u2) : bool =
if Lang.LUri.File.compare u1 u2 = 0 then
if Coq.Workspace.compare w1 w2 = 0 then
if Coq.State.compare s1 s2 = 0 then true else false
else false
else false

let hash (st, w, uri) =
Hashtbl.hash
(Coq.State.hash st, Coq.Workspace.hash w, Lang.LUri.File.hash uri)

type output = Coq.State.t

let eval (root_state, workspace, uri) =
Coq.Init.doc_init ~root_state ~workspace ~uri
end

module Init = SEval (InitEval)

let all_size () =
Init.size () + Interp.size () + Require.size () + Admit.size ()
21 changes: 21 additions & 0 deletions fleche/memo.mli
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ end
module Init : sig
type t = Coq.State.t * Coq.Workspace.t * Lang.LUri.File.t

(** [size ()] Return the size in words, expensive *)
val size : unit -> int

val eval : t -> (Coq.State.t, Loc.t) Coq.Protect.E.t
end

Expand All @@ -26,9 +29,25 @@ module Interp : sig
val input_info : t -> string
end

module Require : sig
type t = Coq.State.t * Coq.Files.t * Coq.Ast.Require.t

(** Interpret a require, possibly memoizing it *)
val eval : t -> (Coq.State.t, Loc.t) Coq.Protect.E.t Stats.t

(** [size ()] Return the size in words, expensive *)
val size : unit -> int

(** debug *)
val input_info : t -> string
end

module Admit : sig
type t = Coq.State.t

(** [size ()] Return the size in words, expensive *)
val size : unit -> int

val eval : t -> (Coq.State.t, Loc.t) Coq.Protect.E.t
end

Expand All @@ -38,3 +57,5 @@ module CacheStats : sig
(** Returns the hit ratio of the cache *)
val stats : unit -> string
end

val all_size : unit -> int
2 changes: 1 addition & 1 deletion fleche/perf_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ let make (doc : Doc.t) =
let n_stm = List.length doc.nodes in
let stats = get_stats ~doc in
let cache_size =
if display_cache_size then Memo.Interp.size () |> float_of_int else 0.0
if display_cache_size then Memo.all_size () |> float_of_int else 0.0
in
let summary =
Format.asprintf "{ num sentences: %d@\n; stats: %s; cache: %a@\n}" n_stm
Expand Down

0 comments on commit 1176337

Please sign in to comment.