Skip to content

Commit

Permalink
pointer to nominal type coercions
Browse files Browse the repository at this point in the history
  • Loading branch information
skaller committed Aug 18, 2022
1 parent 3fe3ee7 commit 8598d4a
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 15 deletions.
16 changes: 9 additions & 7 deletions src/compiler/flx_bind/flx_bbind.ml
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,6 @@ print_endline ("[flx_bbind] bind_symbol " ^ sym.Flx_sym.id ^ "??");
let vs = fst sym.Flx_sym.vs in
begin match sym.Flx_sym.symdef with
| SYMDEF_function (params,ret,effect,props,_) when List.mem `Subtype props ->
(*
if List.length vs <> 0 then
clierr sr (" Improper subtype, no type variables allowed, got " ^
string_of_int (List.length vs))
else
*)
let dom, cod =
let ps = fst params in
begin match ps with
Expand Down Expand Up @@ -289,7 +283,11 @@ print_endline ("[flx_bbind] bind_symbol " ^ sym.Flx_sym.id ^ "??");
*)
Flx_bsym_table.add_supertype bsym_table ((cod,dom),i)

| _ -> clierr sr ("Subtype specification requires function from and to nominal type")
| BTYP_function (BTYP_ptr (`RW, BTYP_inst (`Nominal _, dom,_,_), []),BTYP_ptr(`RW, BTYP_inst (`Nominal _, cod,_,_),[])) ->
print_endline ("Pointer coercion: Domain index = " ^ string_of_int dom ^ " codomain index = " ^ string_of_int cod);

Flx_bsym_table.add_pointer_supertype bsym_table ((cod,dom),i)
| _ -> clierr sr ("Subtype specification requires function from and to nominal type or pointers thereto")
end

| SYMDEF_fun (props,ps,ret,_,_,_) when List.mem `Subtype props ->
Expand Down Expand Up @@ -334,6 +332,10 @@ print_endline ("[flx_bbind] bind_symbol " ^ sym.Flx_sym.id ^ "??");
print_endline ("Domain index = " ^ string_of_int dom ^ " codomain index = " ^ string_of_int cod);
*)
Flx_bsym_table.add_supertype bsym_table ((cod,dom),i)
| BTYP_function (BTYP_ptr (`RW, BTYP_inst (`Nominal _, dom,_,_), []),BTYP_ptr(`RW, BTYP_inst (`Nominal _, cod,_,_),[])) ->
print_endline ("Pointer coercion: Domain index = " ^ string_of_int dom ^ " codomain index = " ^ string_of_int cod);

Flx_bsym_table.add_pointer_supertype bsym_table ((cod,dom),i)
| _ -> clierr sr ("Subtype specification requires function from and to nominal types")
end
| _ -> ()
Expand Down
121 changes: 119 additions & 2 deletions src/compiler/flx_core/flx_bsym_table.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type t = {
table: (bid_t, elt) Hashtbl.t;
childmap: (int, BidSet.t) Hashtbl.t; (** All of this bsym table's roots. *)
mutable subtype_map: ((int * int) * int) list; (* super=cod:param, sub=dom:arg -> coercion *)
mutable pointer_subtype_map: ((int * int) * int) list; (* super=cod:param, sub=dom:arg -> coercion *)
mutable reductions: Flx_mtypes2.reduction_t list;
}

Expand All @@ -33,13 +34,15 @@ let create_fresh () =
{ table=Hashtbl.create 97;
childmap=Hashtbl.create 97;
subtype_map = [];
pointer_subtype_map = [];
reductions = []
}

let create_from bsym_table =
{ table=Hashtbl.create 97;
childmap=Hashtbl.create 97;
subtype_map = bsym_table.subtype_map;
pointer_subtype_map = bsym_table.pointer_subtype_map;
reductions = []
}

Expand All @@ -55,6 +58,8 @@ so the codomain is actually first in the pair.
*)
type coercion_t = (bid_t * bid_t) * bid_t

(* ----------- START VALUE COERCIONS --------------- *)

(* temporary hackery *)
let add_supertype bsym_table x =
bsym_table.subtype_map <- x :: bsym_table.subtype_map
Expand Down Expand Up @@ -181,6 +186,116 @@ let greatest_subtype bsym_table ls : int option =
let fold_coercions bsym_table f init =
List.fold_left f init bsym_table.subtype_map

(* ----------- END VALUE COERCIONS --------------- *)


(* ----------- START POINTER COERCIONS --------------- *)

(* temporary hackery *)
let add_pointer_supertype bsym_table x =
bsym_table.pointer_subtype_map <- x :: bsym_table.pointer_subtype_map

let set_pointer_coercions bsym_table x =
bsym_table.pointer_subtype_map <- x

let get_pointer_coercions bsym_table = bsym_table.pointer_subtype_map

let iter_pointer_coercions bsym_table f =
List.iter f bsym_table.pointer_subtype_map


let maybe_pointer_coercion bsym_table param arg =
try Some (List.assoc (param,arg) bsym_table.pointer_subtype_map)
with Not_found -> None

let is_direct_pointer_supertype bsym_table param arg =
List.mem_assoc (param, arg) bsym_table.pointer_subtype_map

let find_pointer_coercion_chains bsym_table param arg : int list list =
let limit = 10 in
let chains = ref [] in
let rec iis counter chain a =
if counter > limit then failwith ("circular subtype definition, chain limit " ^ string_of_int limit ^ ", exceeded");

(* find all the types to which the argument can be coerced *)
let cands = List.fold_left (fun acc ((p',a'),j) -> if a = a' then (p',j)::acc else acc) [] bsym_table.pointer_subtype_map in
if List.mem_assoc param cands
then chains := (List.assoc param cands :: chain) :: !chains
else
if cands = [] then ()
else
List.iter (fun (p',j) -> iis (counter + 1) (j::chain) p') cands
in
iis 0 [] arg;
(*
print_endline (string_of_int (List.length !chains) ^ " coercion chains to parameter " ^ string_of_int param ^ " from argument " ^ string_of_int arg);
List.iter (fun chain -> print_endline ("Chain=" ^ String.concat "," (List.map string_of_int chain))) !chains;
*)
!chains

let is_indirect_pointer_supertype bsym_table param arg : bool =
let limit = 10 in
let rec iis counter a =
if counter > limit then failwith ("circular subtype definition, chain limit " ^ string_of_int limit ^ ", exceeded");
(* find all the types to which the argument can be coerced *)
let cands = List.fold_left (fun acc ((p',a'),_) -> if a = a' then p'::acc else acc) [] bsym_table.pointer_subtype_map in
if List.mem param cands then true
else
if cands = [] then false
else
List.fold_left (fun acc p' -> acc || iis (counter + 1) p') false cands
in
let result = iis 0 arg in
result

let is_indirect_pointer_subtype bsym_table arg param : bool = is_indirect_pointer_supertype bsym_table param arg

(* These sets are inclusive *)
let pointer_supertypes_of bsym_table a : BidSet.t =
List.fold_left (fun acc ((p',a'),j) -> if a = a' then BidSet.add p' acc else acc) (BidSet.singleton a) bsym_table.pointer_subtype_map

let pointer_subtypes_of bsym_table p : BidSet.t =
List.fold_left (fun acc ((p',a'),j) -> if p = p' then BidSet.add a' acc else acc) (BidSet.singleton p) bsym_table.pointer_subtype_map

let least_pointer_supertype bsym_table ls : int option =
match ls with
| [] -> None
| [x] -> Some x
| h :: tail ->
let cands = List.fold_left
(fun acc elt -> BidSet.inter acc (pointer_supertypes_of bsym_table elt))
(pointer_supertypes_of bsym_table h)
tail
in
if BidSet.is_empty cands then None else
let chosen = BidSet.choose cands in
let cands = BidSet.remove chosen cands in
Some (BidSet.fold (fun cand current ->
if is_indirect_pointer_supertype bsym_table current cand then cand else current
) cands chosen)

let greatest_pointer_subtype bsym_table ls : int option =
match ls with
| [] -> None
| [x] -> Some x
| h :: tail ->
let cands = List.fold_left
(fun acc elt -> BidSet.inter acc (pointer_subtypes_of bsym_table elt))
(pointer_subtypes_of bsym_table h)
tail
in
if BidSet.is_empty cands then None else
let chosen = BidSet.choose cands in
let cands = BidSet.remove chosen cands in
Some (BidSet.fold (fun cand current ->
if is_indirect_pointer_subtype bsym_table current cand then cand else current
) cands chosen)


let fold_pointer_coercions bsym_table f init =
List.fold_left f init bsym_table.pointer_subtype_map

(* ----------- END POINTER COERCIONS --------------- *)

(** Returns how many items are in the bound symbol table. *)
let length bsym_table = Hashtbl.length bsym_table.table
Expand Down Expand Up @@ -318,7 +433,8 @@ let copy bsym_table =
{ bsym_table with
table=Hashtbl.copy bsym_table.table;
childmap=Hashtbl.copy bsym_table.childmap;
subtype_map=bsym_table.subtype_map
subtype_map=bsym_table.subtype_map;
pointer_subtype_map=bsym_table.pointer_subtype_map
}

(** Set's a symbol's parent. *)
Expand Down Expand Up @@ -429,7 +545,8 @@ let validate msg bsym_table =
let f_bid index = if not (index = 0 || mem bsym_table index)
then raise (IncompleteBsymTable (0,index,"subtype table"))
in
List.iter (fun ((a,b),c) -> f_bid a; f_bid b; f_bid c) bsym_table.subtype_map
List.iter (fun ((a,b),c) -> f_bid a; f_bid b; f_bid c) bsym_table.subtype_map;
List.iter (fun ((a,b),c) -> f_bid a; f_bid b; f_bid c) bsym_table.pointer_subtype_map

let validate_types f_btype bsym_table =
iter begin fun bid _ bsym ->
Expand Down
27 changes: 21 additions & 6 deletions src/compiler/flx_core/flx_bsym_table.mli
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,28 @@ val iter_coercions: t -> (coercion_t -> unit) -> unit
val fold_coercions: t -> ('a -> coercion_t -> 'a) -> 'a -> 'a
val set_coercions: t -> coercion_t list -> unit
val get_coercions: t -> coercion_t list
val get_fun_type: t -> bid_t -> Flx_btype.t
val least_supertype: t -> int list -> int option
val greatest_subtype: t -> int list -> int option
val subtypes_of: t -> bid_t -> BidSet.t
val supertypes_of: t -> bid_t -> BidSet.t

val add_pointer_supertype: t -> coercion_t -> unit
val is_direct_pointer_supertype: t -> bid_t -> bid_t -> bool
val is_indirect_pointer_supertype: t -> bid_t -> bid_t -> bool
val is_indirect_pointer_subtype: t -> bid_t -> bid_t -> bool
val find_pointer_coercion_chains : t -> bid_t -> bid_t -> bid_t list list
val maybe_pointer_coercion: t -> bid_t -> bid_t -> bid_t option
val iter_pointer_coercions: t -> (coercion_t -> unit) -> unit
val fold_pointer_coercions: t -> ('a -> coercion_t -> 'a) -> 'a -> 'a
val set_pointer_coercions: t -> coercion_t list -> unit
val get_pointer_coercions: t -> coercion_t list
val least_pointer_supertype: t -> int list -> int option
val greatest_pointer_subtype: t -> int list -> int option
val pointer_subtypes_of: t -> bid_t -> BidSet.t
val pointer_supertypes_of: t -> bid_t -> BidSet.t


val get_fun_type: t -> bid_t -> Flx_btype.t
val get_reductions: t -> Flx_mtypes2.reduction_t list
val set_reductions: t -> Flx_mtypes2.reduction_t list -> unit
val add_reduction_case: t -> string -> Flx_mtypes2.reduction_case_t -> unit
Expand Down Expand Up @@ -134,8 +154,3 @@ val validate : string -> t -> unit
val validate_types: (Flx_btype.t -> unit) -> t -> unit
val is_prim: t -> bid_t -> bool

val least_supertype: t -> int list -> int option
val greatest_subtype: t -> int list -> int option
val subtypes_of: t -> bid_t -> BidSet.t
val supertypes_of: t -> bid_t -> BidSet.t

54 changes: 54 additions & 0 deletions src/compiler/flx_core/flx_unify.ml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,60 @@ print_endline ("Adding inequality " ^ Flx_btype.st lhs ^ " > " ^ Flx_btype.st x)
end


| BTYP_ptr (`RW,BTYP_inst (`Nominal variance,l,lts,knd),[]),BTYP_ptr(`RW,BTYP_inst(`Nominal _,r,rts,_),[]) when l <> r -> (* distinct polymorphic nominal type *)
let chains = Flx_bsym_table.find_pointer_coercion_chains bsym_table l r in
let n = List.length chains in
(*
if n > 0 then
print_endline ("Unify Found " ^ string_of_int n ^ " chains");
*)
begin match chains with
| [] -> raise Not_found (* not a subtype *)
| chain :: _ ->
(*
print_endline ("Unify using chain length " ^ string_of_int (List.length chain));
print_endline ("Chain= " ^ Flx_util.catmap "," string_of_int chain);
*)
let ts = List.fold_left
(fun ats f ->
(*
print_endline ("Input argument ats = " ^ Flx_util.catmap "," Flx_btype.st ats);
*)
let bsym = Flx_bsym_table.find bsym_table f in
let dom,cod,bvs = match bsym.bbdcl with
| BBDCL_external_fun (_,bvs,params,ret,_,_,_ ) -> btyp_tuple params, ret, bvs
| BBDCL_fun (_,bvs,bparams,ret,_,_) -> Flx_bparams.get_btype bparams, ret, bvs
| _ -> assert false
in
(*
print_endline ("Coercion " ^ Flx_btype.st (btyp_function (dom,cod)));
*)
match dom,cod with
(* Dom=Derived[vs]->Cod=Base[ts(vs)] *)
| BTYP_ptr (`RW,BTYP_inst (`Nominal _, d,dts,_),[]), BTYP_ptr (`RW, BTYP_inst (`Nominal _, c, cts,_),[]) ->
(* the pts MUST be the sequence of type variables in bvs, the subtype *)
(*
print_endline ("[derived] ats = " ^ Flx_util.catmap "," Flx_btype.st ats);
*)
let vmap = Flx_btype_subst.mk_varmap bsym.sr bvs ats in
let mapped_cts = List.map (Flx_btype_subst.varmap_subst vmap) cts in
(*
print_endline ("[base] cts = " ^ Flx_util.catmap "," Flx_btype.st mapped_cts);
*)
mapped_cts

| _ -> assert false
)
rts (List.rev chain)
in
let x = btyp_ptr `RW (btyp_inst (`Nominal variance, l, ts, knd)) [] in
(*
print_endline ("Adding inequality " ^ Flx_btype.st lhs ^ " > " ^ Flx_btype.st x);
*)
add_ge (lhs, x)
end



(* a non-uniq parameter accepts a uniq one, uniq T is a subtype of T,
also, covariant ???????
Expand Down
51 changes: 51 additions & 0 deletions src/packages/grammar.fdoc
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,27 @@ syntax cbind {
`(ast_fun_decl ,_sr ,name ,vs ,(mktylist argt) ,ret ,ct ,reqs ,prec)
)
""";
stmt := "supertype" stvarlist "&" squalified_name ":" stypeexpr sopt_cstring sopt_prec srequires_clause ";" =>#
"""
(let*
(
(name (string-append "_supertype_" (base_of_qualified_name _4)))
(vs _2)
(ret `(typ_ref ,_sr ,_4))
(argt _6)
(ct
(if (eq? 'none _7)
`(StrTemplate ,(string-append "::" (base_of_qualified_name _4) "($a)"))
(second _6)
)
)
(prec _8)
(xreqs _9)
(reqs `(rreq_and (rreq_atom (Subtype_req)) ,xreqs))
)
`(ast_fun_decl ,_sr ,name ,vs ,(mktylist argt) ,ret ,ct ,reqs ,prec)
)
""";
cbind_stmt:= "virtual" "type" sname ";" =>#
"`(ast_virtual_type ,_sr ,_3)"
;
Expand Down Expand Up @@ -2486,6 +2507,21 @@ syntax functions {
)
`(ast_curry_effects ,_sr ,name ,vs ,args (,ret ,traint) ,effects Function (Subtype) ,body))
""";
sfunction := "supertype" stvarlist "&" squalified_name sfun_arg+ sopt_traint_eq scompound =>#
"""
(let*
(
(name (string-append "_supertype_" (base_of_qualified_name _4)))
(vs _2)
(ret `(typ_ref ,_sr ,_4))
(traint (first _6))
(effects (second _6))
(body _7)
(args _5)
)
`(ast_curry_effects ,_sr ,name ,vs ,args (,ret ,traint) ,effects Function (Subtype) ,body))
""";


//$ Short form constructor function.
//$ The name of the function must be a type name.
Expand Down Expand Up @@ -2518,6 +2554,21 @@ syntax functions {
)
`(ast_curry_effects ,_sr ,name ,vs ,args (,ret ,traint) ,effects Function (Subtype) ,body))
""";
sfunction := "supertype" stvarlist "&" squalified_name sfun_arg+ sopt_traint "=>" sexpr ";" =>#
"""
(let*
(
(name (string-append "_supertype_" (base_of_qualified_name _4)))
(vs _2)
(ret `(typ_ref ,_sr ,_4))
(traint (first _6))
(effects (second _6))
(body `((ast_fun_return ,_sr ,_8)))
(args _5)
)
`(ast_curry_effects ,_sr ,name ,vs ,args (,ret ,traint) ,effects Function (Subtype) ,body))
""";



//$ Procedure definition, general form.
Expand Down

0 comments on commit 8598d4a

Please sign in to comment.