Skip to content

Commit

Permalink
add a command for a fast feature extractor.
Browse files Browse the repository at this point in the history
The build_fast_feature_extractor command updates the feature extractor,
so that PaMpeR ignore unrelevant assertions when generating recommendations.

This feature is not tested yet.

If the fast feature extractor is not build, PaMpeR uses the slow version instead,
while emitting a warning message.

I also cleaned the code base.
  • Loading branch information
Yutaka Nagashima committed Nov 20, 2017
1 parent 8abf6cd commit 769569e
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 52 deletions.
13 changes: 7 additions & 6 deletions PaMpeR/Assertions.ML
Expand Up @@ -154,23 +154,24 @@ end;
(*** ASSERTIONS ***)
signature ASSERTIONS =
sig
(*written assertions*)
val eval_assertion_for_ML: Proof.state -> bool list
val eval_assertion: Proof.state -> string
val eval_assertion_gen : Proof.state -> (thm -> Proof.context -> thm list -> bool) list -> bool list;
val eval_assertion_for_ML_real: Proof.state -> real list;
val eval_assertion_for_ML_bool: Proof.state -> bool list;
val eval_assertion: Proof.state -> string;
val assertions: (thm -> Proof.context -> thm list -> bool) list;
end;

(*** Assertions: Implementation of assertions ***)
structure Assertions(*:ASSERTIONS*) =
structure Assertions : ASSERTIONS =
struct

structure AU = Assert_Util;

infix 1 >>= >=> liftM;
infix 1 >>= liftM;

type context = Proof.context;

fun (m >>= f) = Option.mapPartial f m;
fun (f >=> g) = Option.composePartial (g f);
fun (m liftM f) = Option.map f m;

(** Assertions about the existence of certain objects (rules) in the proof state **)
Expand Down
24 changes: 12 additions & 12 deletions PaMpeR/Decision_Tree.thy
Expand Up @@ -43,7 +43,7 @@ sig
val print_final_tree: final_tree -> string;
val parse_printed_tree: string -> final_tree;
val lookup_exp: bool list -> final_tree -> real;
val used_features: final_tree -> feature_name list;
val used_features: final_tree list -> feature_name list;
end;
*}

Expand Down Expand Up @@ -114,10 +114,10 @@ fun split_database fname data = split_database' fname data ([],[])
Utils.debug_mssg false ("the number of right elements is " ^ Int.toString (length right)) ();
p));
fun get_RSS (fname as (Database.Feature fint):feature_name) (data:database) =
fun get_RSS (fint:feature_name) (data:database) =
let
val _ = Utils.debug_mssg false ("splitting at Feature " ^ Int.toString fint) ();
val (trues, falses) = split_database fname data : (database * database);
val (trues, falses) = split_database fint data : (database * database);
val (t_avrg, f_avrg) = apply2 get_avrg_of_database (trues, falses);
fun residual_square _ ([]:database) (accm:real) = accm
| residual_square average (datum::data:database) (accm:real) =
Expand Down Expand Up @@ -158,9 +158,9 @@ fun get_feat_with_mini_RSS' (_:database) (best_fname:feature_name, _:real)
fun get_feature_with_mini_RSS (data:database) =
let
val fnames = database_to_fname_list data: feature_name list;
val fname as (Database.Feature fint) = if length fnames > 0 then hd fnames else error "get_feature_with_mini_RSS failed!";
val fname = if length fnames > 0 then hd fnames else error "get_feature_with_mini_RSS failed!";
val rss = get_RSS fname data;
val _ = Utils.debug_mssg false ("for " ^ Int.toString fint ^ "th feature: rss is " ^ Real.toString rss) ();
val _ = Utils.debug_mssg false ("for " ^ Int.toString fname ^ "th feature: rss is " ^ Real.toString rss) ();
val fname = get_feat_with_mini_RSS' data (fname, rss) fnames;
val mini_rss = get_RSS fname data;
val result = if Real.== (Real.posInf, mini_rss) then NONE else SOME fname;
Expand Down Expand Up @@ -208,7 +208,7 @@ fun gtree_leaf_map (f:database -> real) (Leaf dtbs:growing_tree) = FLeaf (f dtbs
fun post_process (gtree:growing_tree) = gtree_leaf_map get_avrg_of_database gtree : final_tree;
fun print_feat ((Database.Feature f_index, _):feature) = Int.toString f_index;
fun print_feat ((f_index, _):feature) = Int.toString f_index;
fun print_final_tree (FLeaf real) = "expectation " ^ Real.toString real
| print_final_tree (FBranch {More = ftree1, Feature = feat, Less = ftree2}) =
Expand All @@ -234,7 +234,7 @@ and parse_fbranch _ =
token (parse_ftree ()) >>= (fn less_tree =>
token (symbol ")") >>= K (
result (FBranch {More = more_tree,
Feature = (Database.Feature feat_index, true),
Feature = (feat_index, true),
Less = less_tree})
)))))))
and parse_ftree _ = parse (parse_fleaf () plus parse_fbranch ());
Expand All @@ -247,22 +247,22 @@ type bools = bool list;
fun lookup_exp ([]:bools) _ = error "lookup_one in Decision_Tree failed! Empty list!"
| lookup_exp (_ :bools) (FLeaf expect) = expect
| lookup_exp (bs:bools) (FBranch {More, Feature as (Database.Feature i, _), Less}) =
| lookup_exp (bs:bools) (FBranch {More, Feature as (i, _), Less}) =
if nth bs (i - 1) (*because the numbering of assertions starts from 1.*)
then lookup_exp bs More
else lookup_exp bs Less;
type feature_names = feature_name list;
fun used_features (ftree:final_tree) =
fun used_features (ftrees:final_tree list) =
let
fun used_features' (FLeaf _) = []
| used_features' (FBranch {More, Feature as (Database.Feature i, _), Less}) =
| used_features' (FBranch {More, Feature as (i, _), Less}) =
i :: used_features' More @ used_features' Less;
in
used_features' ftree |> map Database.Feature
map used_features' ftrees |> flat |> duplicates (op =)
end;
end;
*}

ML{* List.tabulate (5, (fn i => i + 1)); *}
end
3 changes: 3 additions & 0 deletions PaMpeR/PaMpeR.thy
Expand Up @@ -30,6 +30,7 @@ keywords "which_method" :: diag
and "print_out_regression_trees" :: thy_decl
and "reset_regression_tree_table" :: thy_decl
and "read_regression_trees" :: thy_decl
and "build_fast_feature_extractor" :: thy_decl
begin

ML_file "./Assertions.ML"
Expand All @@ -40,6 +41,8 @@ reset_regression_tree_table

read_regression_trees

build_fast_feature_extractor

ML_file "./FE_Interface.ML"

end
104 changes: 78 additions & 26 deletions PaMpeR/PaMpeR_Interface.ML
Expand Up @@ -17,6 +17,10 @@ end;
structure PaMpeR_Interface:PAMPER_INTERFACE =
struct

infix 1 liftM;
fun (m liftM f) = Option.map f m;
structure RT = Regression_Tree;

val path = Resources.master_directory @{theory} |> File.platform_path : string;
val path_to_meth_names = path ^ "/method_names": string;
val path_to_rtree = path ^ "/regression_trees";
Expand All @@ -31,10 +35,41 @@ val all_method_names =
dist_meth_names : string list
end;

structure RT = Regression_Tree;
(* Database to store regression trees. *)
structure Regression_Trees = Generic_Data
(
type T = RT.final_tree Symtab.table;
val empty = Symtab.empty : T;
val extend = I;
val merge = Symtab.merge (K true);
);

(* build final trees and register them in a table *)
fun lookup ctxt = (Symtab.lookup o Regression_Trees.get) (Context.Proof ctxt);

fun update (k, v) = Regression_Trees.map (Symtab.update_new (k, v))
|> Context.theory_map
|> Local_Theory.background_theory;

val reset = Regression_Trees.map (fn _ => Symtab.empty)
|> Context.theory_map
|> Local_Theory.background_theory;

(* Database to store quick assertions. *)
structure Dynamic_Feature_Extractor = Generic_Data
(
type T = (thm -> Proof.context -> thm list -> bool) list Symtab.table;
val empty = Symtab.empty : T;
val extend = I;
val merge = Symtab.merge (K true);
);

fun dfe_lookup ctxt = (Symtab.lookup o Dynamic_Feature_Extractor.get) (Context.Proof ctxt);

fun dfe_update (k, v) = Dynamic_Feature_Extractor.map (Symtab.update_new (k, v))
|> Context.theory_map
|> Local_Theory.background_theory;

(* build final trees and register them in a table *)
fun build_ftree (meth_name:string) =
let
fun did_success sth = if is_some sth then " successfully " else " NOT really ";
Expand All @@ -55,30 +90,11 @@ fun build_final_trees (meth_names:string list) = Par_List.map build_ftree meth_n
|> filter is_some
|> map the;

structure Data = Generic_Data
(
type T = RT.final_tree Symtab.table;
val empty = Symtab.empty : T;
val extend = I;
val merge = Symtab.merge (K true);
);

fun lookup ctxt = (Symtab.lookup o Data.get) (Context.Proof ctxt);

fun update (k, v) = Data.map (Symtab.update_new (k, v))
|> Context.theory_map
|> Local_Theory.background_theory;

val reset = Data.map (fn _ => Symtab.empty)
|> Context.theory_map
|> Local_Theory.background_theory;

fun register_final_trees (lthy:local_theory) = fold update (build_final_trees all_method_names) lthy;

fun mk_parser func = fn (tkns:Token.T list) => (func, tkns);

(* print out regression trees in PaMpeR/regression_trees *)

fun print_out_ftree (ctxt:Proof.context) (meth_name:string) =
let
val final_tree = lookup ctxt meth_name: RT.final_tree option;
Expand All @@ -102,8 +118,6 @@ fun print_out_all_ftrees (meth_names:string list) (lthy:local_theory) =
(* read regression trees printed in in PaMpeR/regression_trees *)
val read_regression_trees =
let
infix 1 liftM;
fun (m liftM f) = Option.map f m;
val lines = try TextIO.openIn path_to_rtree
liftM TextIO.inputAll
liftM split_lines
Expand All @@ -125,10 +139,41 @@ val read_regression_trees =
fun register_final_trees (lthy:local_theory) = fold (update o read_final_tree) lines lthy;
in register_final_trees end;

(* build and register fast assertions *)
local

fun get_fast_assertions (lthy:local_theory) =
let
fun get_ftrees ctxt = map (the_list o lookup ctxt) all_method_names |> flat: RT.final_tree list;
val features = lthy |> get_ftrees |> RT.used_features: int list;
val leng = length Assertions.assertions: int;
val dummys = List.tabulate (leng, (fn _ => fn _ => fn _ => fn _ => true (*dummy value*)));
fun swap_nth (n:int) (xs:(thm -> Proof.context -> thm list -> bool) list) =
nth_map (n - 1) (fn _ => nth Assertions.assertions (n - 1)) xs;
val fast_asserts = fold swap_nth features dummys;
in
fast_asserts: (thm -> Proof.context -> thm list -> bool) list
end;

in

fun register_fast_assertions lthy = dfe_update ("fast_assetions", (get_fast_assertions lthy)) lthy

end;

(* print out recommendation in the output panel. *)
local

fun get_top_5 (pstate:Proof.state) =
let
val ass_results = Assertions.eval_assertion_for_ML_bool pstate;
val ctxt = Proof.context_of pstate;
val fast_assertions = (flat o the_list oo dfe_lookup) ctxt "fast_assetions": (thm -> Proof.context -> thm list -> bool) list;
val _ = if null fast_assertions
then tracing "fast_assertions is empty. Did you forget to call build_fast_feature_extractor?"
else ();
val ass_results = if null fast_assertions
then Assertions.eval_assertion_for_ML_bool pstate
else Assertions.eval_assertion_gen pstate fast_assertions;
fun get_ftree (meth_name:string) = lookup ctxt meth_name: RT.final_tree option;
val get_exp = RT.lookup_exp ass_results;
fun get_top_result' (best_meth, best_exp) (meth_name::names:string list) =
Expand All @@ -149,7 +194,7 @@ fun get_top_5 (pstate:Proof.state) =
remove (op =) (fst top) meth_names
end;

(*TODO: better if I use Lazy Sequence here?*)
(*TODO: better if I use Lazy Sequence here?*)
val _ =
get_top_result all_method_names |> (fn (wo_1) =>
get_top_result wo_1 |> (fn (wo_2) =>
Expand All @@ -168,6 +213,7 @@ fun get_top_5 (pstate:Proof.state) =
get_top_result wo_14 |> (fn (wo_15) =>
get_top_result wo_15)))))))))))))))
in () end;
in

val which_method_cmd = Toplevel.keep_proof (fn Tstate =>
let
Expand All @@ -176,6 +222,8 @@ val which_method_cmd = Toplevel.keep_proof (fn Tstate =>
get_top_5 state
end);

end;

(* Register Isar commands. *)
fun PaMpeR_activate _ =
let
Expand All @@ -193,8 +241,12 @@ fun PaMpeR_activate _ =
"read regression trees in PaMpeR/regression_trees and register them from databases."
(mk_parser read_regression_trees);

val _ = Outer_Syntax.command @{command_keyword which_method} "Method recommender"
val _ = Outer_Syntax.command @{command_keyword which_method} "recommend which method to use"
(Scan.succeed which_method_cmd);

val _ = Outer_Syntax.local_theory @{command_keyword build_fast_feature_extractor}
"build a quick version of feature extractor by ignoring unrelevant assertions"
(mk_parser register_fast_assertions)
in
()
end;
Expand Down
15 changes: 7 additions & 8 deletions PaMpeR/Read_Databases.thy
Expand Up @@ -6,12 +6,12 @@ ML_file "../src/Utils.ML"

ML{* signature DATABASE =
sig
type used = bool;
datatype feature_name = Feature of int;
type feature_value = bool;
type meth_name = string;
type one_line = used * (feature_name * feature_value) list;
type database = one_line list;
type used = bool;
type feature_name = int;
type feature_value = bool;
type meth_name = string;
type one_line = used * (feature_name * feature_value) list;
type database = one_line list;
(*get_Database_names has to be called only *after* generating pre-processed databases.*)
val get_meth_names : unit -> meth_name list;
val parse_database : meth_name -> database;
Expand All @@ -23,7 +23,7 @@ ML{* structure Database:DATABASE =
struct
type used = bool;
datatype feature_name = Feature of int;
type feature_name = int;
type feature_value = bool;
type meth_name = string;
type one_line = used * (feature_name * feature_value) list;
Expand Down Expand Up @@ -59,7 +59,6 @@ fun parse_database (meth_name:string) =
| int_to_bool _ = error "int_to_bool failed.";
fun get_feature_vector (line:int list) = tl line
|> Utils.index
|> map (apfst Feature)
|> map (apsnd int_to_bool);
val raw_line = get_raw_line line;
val one_line = (is_used raw_line, get_feature_vector raw_line);
Expand Down
4 changes: 4 additions & 0 deletions PaMpeR/Test/Assertion_Checker.thy
Expand Up @@ -89,6 +89,10 @@ end;

ML{* Assertion_Checker.activate_assertion_checker ();*}

lemma "True"
which_method
oops

lemma "[1] = [1]"
assert_nth_true 20
oops
Expand Down

0 comments on commit 769569e

Please sign in to comment.