Skip to content

Commit

Permalink
Handle variance in structs and cstructs.
Browse files Browse the repository at this point in the history
  • Loading branch information
skaller committed Jul 27, 2022
1 parent 8e59cee commit 88116b6
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 30 deletions.
66 changes: 39 additions & 27 deletions src/compiler/flx_core/flx_unify.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,41 @@ let check_recursion bsym_table t =
sbt bsym_table t);
raise Bad_recursion

let nominal_subtype bsym_table lhs rhs =
match lhs, rhs with
| BTYP_inst (`Nominal,l,[],_),BTYP_inst(`Nominal,r,[],_) ->
(* meta types have to agree if types do? *)
if l <> r && not (Flx_bsym_table.is_indirect_supertype bsym_table l r)
then raise Not_found
| _ -> raise Not_found


(* LHS ge RHS, parameter supertype of argument *)
let rec solve_subtypes nominal_subtype counter lhs rhs dvars (s:vassign_t option ref) (add_eq:reladd_t) (add_ge:reladd_t) =
let rec solve_subtypes bsym_table counter lhs rhs dvars (s:vassign_t option ref) (add_eq:reladd_t) (add_ge:reladd_t) =
(*
print_endline ("Solve subtypes " ^ Flx_btype.str_of_btype lhs ^ " >=? " ^ Flx_btype.str_of_btype rhs);
*)
try nominal_subtype lhs rhs
with Not_found ->

match lhs, rhs with
| BTYP_inst (`Nominal,l,[],_),BTYP_inst(`Nominal,r,[],_) -> (* distinct monomorphic nominal types *)
if l <> r && not (Flx_bsym_table.is_indirect_supertype bsym_table l r)
then raise Not_found

| BTYP_inst (`Nominal,l,lts,_),BTYP_inst(`Nominal,r,rts,_) when l = r -> (* same polymorphic nominal type *)
let bsym = Flx_bsym_table.find bsym_table l in
let bbdcl = Flx_bsym.bbdcl bsym in
begin match bbdcl with
| BBDCL_external_type (_,_,_,_,variance)
| BBDCL_union (_,_,variance)
| BBDCL_cstruct (_,_,_,variance)
| BBDCL_struct (_,_,variance) ->
assert (List.length lts = List.length rts);
let t2 = List.combine lts rts in
assert(List.length variance <= List.length lts);
let variance = (Flx_list.repeat `invariant (List.length lts - List.length variance)) @ variance in
List.iter2 (fun (l, r) variance ->
match variance with
| `covariant -> add_ge (l, r)
| `invariant -> add_eq (l, r)
| `contravariant -> add_ge (r, l)
) t2 variance

| BBDCL_newtype _ (* FIXME: newtype should have variance too *)
| BBDCL_instance_type _ -> List.iter2 (fun l r -> add_eq (l,r)) lts rts

| _ -> assert false
end

(* a non-uniq parameter accepts a uniq one, uniq T is a subtype of T,
also, covariant ???????
*)
Expand Down Expand Up @@ -220,9 +237,9 @@ print_endline ("Solve subtypes " ^ Flx_btype.str_of_btype lhs ^ " >=? " ^ Flx_bt


| _ ->
solve_subsumption nominal_subtype counter lhs rhs dvars s add_eq
solve_subsumption bsym_table counter lhs rhs dvars s add_eq

and solve_subsumption nominal_subtype counter lhs rhs dvars (s:vassign_t option ref) (add_eqn:reladd_t) =
and solve_subsumption bsym_table counter lhs rhs dvars (s:vassign_t option ref) (add_eqn:reladd_t) =
begin match lhs,rhs with
| BTYP_instancetype _, BTYP_instancetype _ -> () (* weirdo but we have to do it *)
| BTYP_rev t1, BTYP_rev t2 ->
Expand Down Expand Up @@ -569,7 +586,7 @@ print_endline "Trying to unify type map";
raise Not_found
end

let unif nominal_subtype counter (inrels: rels_t) (dvars:dvars_t) =
let unif bsym_table counter (inrels: rels_t) (dvars:dvars_t) =
(*
print_endline ("Unif:");
print_endline ( "Dvars = { " ^ catmap ", " si (BidSet.elements dvars) ^ "}");
Expand Down Expand Up @@ -603,8 +620,8 @@ print_endline ("Unif:");
print_endline ("Trying " ^ sbt bsym_table lhs ^ " " ^ string_of_relmode_t mode ^ " " ^ sbt bsym_table rhs);
*)
begin match mode with
| `Eq -> solve_subsumption nominal_subtype counter lhs rhs dvars s add_eq
| `Ge -> solve_subtypes nominal_subtype counter lhs rhs dvars s add_eq add_ge
| `Eq -> solve_subsumption bsym_table counter lhs rhs dvars s add_eq
| `Ge -> solve_subtypes bsym_table counter lhs rhs dvars s add_eq add_ge
end
;
begin match !s with
Expand Down Expand Up @@ -649,32 +666,28 @@ let find_vars_eqns eqns =
!lhs_vars,!rhs_vars

let unification bsym_table counter eqns dvars =
let nominal_subtype lhs rhs = nominal_subtype bsym_table lhs rhs in
let eqns = List.map (fun x -> `Eq, x) eqns in
unif nominal_subtype counter eqns dvars
unif bsym_table counter eqns dvars

let maybe_unification bsym_table counter eqns =
let nominal_subtype lhs rhs = nominal_subtype bsym_table lhs rhs in
let l,r = find_vars_eqns eqns in
let dvars = BidSet.union l r in
let eqns = List.map (fun x -> `Eq, x) eqns in
try Some (unif nominal_subtype counter eqns dvars)
try Some (unif bsym_table counter eqns dvars)
with Not_found -> None

(* same as unifies so why is this here? *)
let maybe_matches bsym_table counter eqns =
let nominal_subtype lhs rhs = nominal_subtype bsym_table lhs rhs in
let l,r = find_vars_eqns eqns in
let dvars = BidSet.union l r in
let eqns = List.map (fun x -> `Eq, x) eqns in
try Some (unif nominal_subtype counter eqns dvars)
try Some (unif bsym_table counter eqns dvars)
with Not_found -> None

(* LHS is parameter, RHS is argument, we require LHS >= RHS *)
let maybe_specialisation_with_dvars bsym_table counter eqns dvars =
let nominal_subtype lhs rhs = nominal_subtype bsym_table lhs rhs in
let eqns = List.map (fun x -> `Ge, x) eqns in
try Some (unif nominal_subtype counter eqns dvars)
try Some (unif bsym_table counter eqns dvars)
with Not_found ->
None

Expand All @@ -684,7 +697,6 @@ let maybe_specialisation bsym_table counter eqns =
maybe_specialisation_with_dvars bsym_table counter eqns l

let unifies bsym_table counter t1 t2 =
let nominal_subtype lhs rhs = nominal_subtype bsym_table lhs rhs in
let eqns = [t1,t2] in
match maybe_unification bsym_table counter eqns with
| None -> false
Expand Down
34 changes: 33 additions & 1 deletion src/compiler/flx_frontend/flx_xcoerce.ml
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,38 @@ print_endline ("Dst type " ^ Flx_print.sbt bsym_table dstt);
let srct = unfold "expand_coercion srct" srct in
let dstt = unfold "expand_coercion dstt" dstt in
match srct,dstt with
| BTYP_inst (`Nominal, src,lts,_), BTYP_inst (`Nominal, dst,rts,_) when src = dst ->
let bsym = Flx_bsym_table.find bsym_table src in
let bbdcl = Flx_bsym.bbdcl bsym in
begin match bbdcl with
| BBDCL_cstruct (bvs,flds,_,variance)
| BBDCL_struct (bvs,flds,variance) ->
(* each member must be covariant: it's the same as a tuple *)

(* get list of src values *)
let lvarmap = Flx_btype_subst.mk_varmap sr bvs lts in
let sflds = List.map (fun (_, t) -> Flx_btype_subst.varmap_subst lvarmap t) flds in
let prjs = List.map2 (fun pos t -> bexpr_get_n t pos srce) (Flx_list.nlist (List.length flds)) sflds in

(* get list of dst types *)
let rvarmap = Flx_btype_subst.mk_varmap sr bvs rts in
let dtyps = List.map (fun (_, t) -> Flx_btype_subst.varmap_subst rvarmap t) flds in
let dtyp = btyp_tuple dtyps in

(* get list of coerced values *)
let vals = List.map2 (fun p t -> bexpr_coerce (p,t)) prjs dtyps in
(* build argument tuple *)
let arg = bexpr_tuple dtyp vals in

(* we cannot use a apply_struct yet so we have to form a closure instead *)
let ctor_typ = btyp_function (dtyp, dstt) in
let ctor = bexpr_closure ctor_typ (dst, rts) in
remap parent (bexpr_apply dstt (ctor, arg))

| x ->
Flx_exceptions.clierr sr ("Flx_xcoerce: NOT IMPLEMENTED: polymorphic nominal type coercions for " ^ Flx_print.sbt bsym_table srct)
end

| BTYP_inst (`Nominal, src,[],_), BTYP_inst (`Nominal, dst,[],_) ->
if debug then
print_endline ("Searching for nominal type conversion from " ^
Expand All @@ -286,7 +318,7 @@ print_endline ("Dst type " ^ Flx_print.sbt bsym_table dstt);

let srcid = Flx_bsym.id (Flx_bsym_table.find bsym_table src) in
let dstid = Flx_bsym.id (Flx_bsym_table.find bsym_table dst) in
print_endline ("Unable to find supertype coercion from " ^
print_endline ("Flx_xcoerce: Unable to find supertype coercion from " ^
srcid ^ "<" ^ si src ^ "> to " ^ dstid ^ "<" ^ si dst ^ ">");
Flx_bsym_table.iter_coercions bsym_table
(fun ((a,b),c) -> print_endline (" " ^ si c ^ ":" ^ si b ^ "->" ^ si a))
Expand Down
6 changes: 4 additions & 2 deletions src/compiler/flxg/flxg.ml
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,10 @@ let handle_bind state main_prog module_name =

let t0 = Unix.gettimeofday () in

let ntsc lhs rhs = Flx_unify.nominal_subtype bsym_table lhs rhs in
Flx_btype.set_unif_thunk (Flx_unify.unif ntsc state.syms.counter);
(* This cannot work because the bsym_yable is not unique, new ones get
created all the time.
*)
Flx_btype.set_unif_thunk (Flx_unify.unif bsym_table state.syms.counter);


(* Do the binding here *)
Expand Down

0 comments on commit 88116b6

Please sign in to comment.