Skip to content

Commit

Permalink
feat(lsp): Use argument label code action (#2126)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-snezhko committed Jul 31, 2024
1 parent d34d381 commit 4399387
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 51 deletions.
35 changes: 35 additions & 0 deletions compiler/src/language_server/code_action.re
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@ let explicit_type_annotation = (range, uri, type_str) => {
};
};

let named_arg_label = (range, uri, arg_label) => {
ResponseResult.{
title: "Used named argument label",
kind: "name-argument-label",
edit: {
document_changes: [
{
text_document: {
uri,
version: None,
},
edits: [{range, new_text: arg_label ++ "="}],
},
],
},
};
};

let send_code_actions =
(id: Protocol.message_id, code_actions: list(ResponseResult.code_action)) => {
Protocol.response(~id, ResponseResult.to_yojson(Some(code_actions)));
Expand All @@ -77,6 +95,22 @@ let process_explicit_type_annotation = (uri, results: list(Sourcetree.node)) =>
};
};

let process_named_arg_label = (uri, results: list(Sourcetree.node)) => {
switch (results) {
| [Argument({arg_label, label_specified, loc}), ..._] when !label_specified =>
let loc = {...loc, loc_end: loc.loc_start};
let arg_label =
switch (arg_label) {
| Unlabeled =>
failwith("Impossible: unlabeled argument after typechecking")
| Labeled({txt})
| Default({txt}) => txt
};
Some(named_arg_label(Utils.loc_to_range(loc), uri, arg_label));
| _ => None
};
};

let process =
(
~id: Protocol.message_id,
Expand All @@ -94,6 +128,7 @@ let process =
x => x,
[
process_explicit_type_annotation(params.text_document.uri, results),
process_named_arg_label(params.text_document.uri, results),
],
);

Expand Down
25 changes: 25 additions & 0 deletions compiler/src/language_server/sourcetree.re
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ module type Sourcetree = {
loc: Location.t,
definition: option(Location.t),
})
| Argument({
loc: Location.t,
arg_label: Typedtree.argument_label,
label_specified: bool,
})
| Type({
core_type: Typedtree.core_type,
definition: option(Location.t),
Expand Down Expand Up @@ -232,6 +237,11 @@ module Sourcetree: Sourcetree = {
loc: Location.t,
definition: option(Location.t),
})
| Argument({
loc: Location.t,
arg_label: Typedtree.argument_label,
label_specified: bool,
})
| Type({
core_type: Typedtree.core_type,
definition: option(Location.t),
Expand Down Expand Up @@ -420,6 +430,21 @@ module Sourcetree: Sourcetree = {
),
...segments^,
]
| TExpApp(_, _, args) =>
segments :=
List.map(
({arg_label, arg_label_specified, arg_expr}) =>
(
loc_to_interval(arg_expr.exp_loc),
Argument({
loc: arg_expr.exp_loc,
arg_label,
label_specified: arg_label_specified,
}),
),
args,
)
@ segments^
| _ =>
segments :=
[
Expand Down
63 changes: 39 additions & 24 deletions compiler/src/middle_end/linearize.re
Original file line number Diff line number Diff line change
Expand Up @@ -665,12 +665,18 @@ let rec transl_imm =
switch (PrimMap.find_opt(prim_map, prim), args) {
| (Some(Primitive0(prim)), []) =>
transl_imm({...e, exp_desc: TExpPrim0(prim)})
| (Some(Primitive1(prim)), [(_, arg)]) =>
transl_imm({...e, exp_desc: TExpPrim1(prim, arg)})
| (Some(Primitive2(prim)), [(_, arg1), (_, arg2)]) =>
transl_imm({...e, exp_desc: TExpPrim2(prim, arg1, arg2)})
| (Some(Primitive1(prim)), [{arg_expr}]) =>
transl_imm({...e, exp_desc: TExpPrim1(prim, arg_expr)})
| (
Some(Primitive2(prim)),
[{arg_expr: arg_expr1}, {arg_expr: arg_expr2}],
) =>
transl_imm({...e, exp_desc: TExpPrim2(prim, arg_expr1, arg_expr2)})
| (Some(PrimitiveN(prim)), args) =>
transl_imm({...e, exp_desc: TExpPrimN(prim, List.map(snd, args))})
transl_imm({
...e,
exp_desc: TExpPrimN(prim, List.map(x => x.arg_expr, args)),
})
| (Some(_), _) => failwith("transl_imm: invalid primitive arity")
| (None, _) => failwith("transl_imm: unknown primitive")
}
Expand All @@ -689,9 +695,9 @@ let rec transl_imm =
let (new_args, new_setup) =
List.split(
List.map(
((l, arg)) => {
let (arg, setup) = transl_imm(arg);
((l, arg), setup);
({arg_label, arg_expr}) => {
let (arg, setup) = transl_imm(arg_expr);
((arg_label, arg), setup);
},
args,
),
Expand Down Expand Up @@ -1497,9 +1503,9 @@ and transl_comp_expression =
let (new_args, new_setup) =
List.split(
List.map(
((l, arg)) => {
let (arg, setup) = transl_imm(arg);
((l, arg), setup);
({arg_label, arg_expr}) => {
let (arg, setup) = transl_imm(arg_expr);
((arg_label, arg), setup);
},
args,
),
Expand Down Expand Up @@ -1538,14 +1544,20 @@ and transl_comp_expression =
switch (PrimMap.find_opt(prim_map, prim), args) {
| (Some(Primitive0(prim)), []) =>
transl_imm({...e, exp_desc: TExpPrim0(prim)})
| (Some(Primitive1(prim)), [(_, arg)]) =>
transl_imm({...e, exp_desc: TExpPrim1(prim, arg)})
| (Some(Primitive2(prim)), [(_, arg1), (_, arg2)]) =>
transl_imm({...e, exp_desc: TExpPrim2(prim, arg1, arg2)})
| (Some(Primitive1(prim)), [{arg_expr}]) =>
transl_imm({...e, exp_desc: TExpPrim1(prim, arg_expr)})
| (
Some(Primitive2(prim)),
[{arg_expr: arg_expr1}, {arg_expr: arg_expr2}],
) =>
transl_imm({
...e,
exp_desc: TExpPrim2(prim, arg_expr1, arg_expr2),
})
| (Some(PrimitiveN(prim)), args) =>
transl_imm({
...e,
exp_desc: TExpPrimN(prim, List.map(snd, args)),
exp_desc: TExpPrimN(prim, List.map(x => x.arg_expr, args)),
})
| (Some(_), _) =>
failwith("transl_comp_expression: invalid primitive arity")
Expand All @@ -1559,17 +1571,20 @@ and transl_comp_expression =
switch (PrimMap.find_opt(prim_map, prim), args) {
| (Some(Primitive0(prim)), []) =>
transl_comp_expression({...e, exp_desc: TExpPrim0(prim)})
| (Some(Primitive1(prim)), [(_, arg)]) =>
transl_comp_expression({...e, exp_desc: TExpPrim1(prim, arg)})
| (Some(Primitive2(prim)), [(_, arg1), (_, arg2)]) =>
| (Some(Primitive1(prim)), [{arg_expr}]) =>
transl_comp_expression({...e, exp_desc: TExpPrim1(prim, arg_expr)})
| (
Some(Primitive2(prim)),
[{arg_expr: arg_expr1}, {arg_expr: arg_expr2}],
) =>
transl_comp_expression({
...e,
exp_desc: TExpPrim2(prim, arg1, arg2),
exp_desc: TExpPrim2(prim, arg_expr1, arg_expr2),
})
| (Some(PrimitiveN(prim)), args) =>
transl_comp_expression({
...e,
exp_desc: TExpPrimN(prim, List.map(snd, args)),
exp_desc: TExpPrimN(prim, List.map(x => x.arg_expr, args)),
})
| (Some(_), _) =>
failwith("transl_comp_expression: invalid primitive arity")
Expand All @@ -1582,9 +1597,9 @@ and transl_comp_expression =
let (new_args, new_setup) =
List.split(
List.map(
((l, arg)) => {
let (arg, setup) = transl_imm(arg);
((l, arg), setup);
({arg_label, arg_expr}) => {
let (arg, setup) = transl_imm(arg_expr);
((arg_label, arg), setup);
},
args,
),
Expand Down
23 changes: 18 additions & 5 deletions compiler/src/typed/typecore.re
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,7 @@ and type_application = (~in_function=?, ~loc, env, funct, sargs) => {
| [sarg, ...remaining_sargs] =>
let (
corresponding_tyarg,
arg_label_specified,
remaining_used_labeled_tyargs,
remaining_unused_tyargs,
) =
Expand All @@ -1952,6 +1953,7 @@ and type_application = (~in_function=?, ~loc, env, funct, sargs) => {
extract_label(sarg.paa_label, remaining_used_labeled_tyargs);
(
corresponding_tyarg,
true,
remaining_used_labeled_tyargs,
remaining_unused_tyargs,
);
Expand All @@ -1960,6 +1962,7 @@ and type_application = (~in_function=?, ~loc, env, funct, sargs) => {
next_tyarg(remaining_unused_tyargs);
(
corresponding_tyarg,
false,
remaining_used_labeled_tyargs,
remaining_unused_tyargs,
);
Expand Down Expand Up @@ -1994,7 +1997,7 @@ and type_application = (~in_function=?, ~loc, env, funct, sargs) => {
);
};
type_args(
[(l, arg), ...args],
[(l, arg_label_specified, arg), ...args],
remaining_sargs,
remaining_used_labeled_tyargs,
remaining_unused_tyargs,
Expand Down Expand Up @@ -2041,11 +2044,16 @@ and type_application = (~in_function=?, ~loc, env, funct, sargs) => {

let omitted_args =
List.map(
((l, ty)) => {
switch (l) {
((arg_label, ty)) => {
switch (arg_label) {
| Default(_) =>
// omitted optional argument
(l, option_none(env, instance(env, ty), Location.dummy_loc))
{
arg_label,
arg_label_specified: true,
arg_expr:
option_none(env, instance(env, ty), Location.dummy_loc),
}
| _ =>
let missing_args =
List.filter(((l, _)) => !is_optional(l), remaining_tyargs);
Expand All @@ -2057,7 +2065,12 @@ and type_application = (~in_function=?, ~loc, env, funct, sargs) => {

// Typecheck all arguments.
// Order here is important; rev_map would be incorrect.
let typed_args = List.map(((l, argf)) => (l, argf()), List.rev(args));
let typed_args =
List.map(
((arg_label, arg_label_specified, argf)) =>
{arg_label, arg_label_specified, arg_expr: argf()},
List.rev(args),
);

(ordered_labels, omitted_args @ typed_args, instance(env, ty_ret));
}
Expand Down
24 changes: 14 additions & 10 deletions compiler/src/typed/typed_well_formedness.re
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,11 @@ module WellFormednessArg: TypedtreeIter.IteratorArgument = {
args,
)
when func == "==" || func == "!=" =>
if (List.exists(((_, arg)) => exp_is_wasm_unsafe(arg), args)) {
if (List.exists(({arg_expr}) => exp_is_wasm_unsafe(arg_expr), args)) {
let typeName =
switch (args) {
| [(_, arg), _] when exp_is_wasm_unsafe(arg) =>
"Wasm" ++ resolve_unsafe_type(arg)
| [{arg_expr}, _] when exp_is_wasm_unsafe(arg_expr) =>
"Wasm" ++ resolve_unsafe_type(arg_expr)
| _ => "(WasmI32 | WasmI64 | WasmF32 | WasmF64)"
};
let warning =
Expand All @@ -290,9 +290,11 @@ module WellFormednessArg: TypedtreeIter.IteratorArgument = {
_,
args,
) =>
switch (List.find_opt(((_, arg)) => exp_is_wasm_unsafe(arg), args)) {
| Some((_, arg)) =>
let typeName = resolve_unsafe_type(arg);
switch (
List.find_opt(({arg_expr}) => exp_is_wasm_unsafe(arg_expr), args)
) {
| Some({arg_expr}) =>
let typeName = resolve_unsafe_type(arg_expr);
let warning = Grain_utils.Warnings.PrintUnsafe(typeName);
if (Grain_utils.Warnings.is_active(warning)) {
Grain_parsing.Location.prerr_warning(exp_loc, warning);
Expand All @@ -315,9 +317,11 @@ module WellFormednessArg: TypedtreeIter.IteratorArgument = {
_,
args,
) =>
switch (List.find_opt(((_, arg)) => exp_is_wasm_unsafe(arg), args)) {
| Some((_, arg)) =>
let typeName = resolve_unsafe_type(arg);
switch (
List.find_opt(({arg_expr}) => exp_is_wasm_unsafe(arg_expr), args)
) {
| Some({arg_expr}) =>
let typeName = resolve_unsafe_type(arg_expr);
let warning = Grain_utils.Warnings.ToStringUnsafe(typeName);
if (Grain_utils.Warnings.is_active(warning)) {
Grain_parsing.Location.prerr_warning(exp_loc, warning);
Expand All @@ -335,7 +339,7 @@ module WellFormednessArg: TypedtreeIter.IteratorArgument = {
),
},
_,
[(_, {exp_desc: TExpConstant(Const_number(n))})],
[{arg_expr: {exp_desc: TExpConstant(Const_number(n))}}],
)
when
modname == "Int8"
Expand Down
13 changes: 8 additions & 5 deletions compiler/src/typed/typedtree.re
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,7 @@ and expression_desc =
| TExpBreak
| TExpReturn(option(expression))
| TExpLambda(list(match_branch), partial)
| TExpApp(
expression,
list(argument_label),
list((argument_label, expression)),
)
| TExpApp(expression, list(argument_label), list(argument_value))
| TExpConstruct(
loc(Identifier.t),
constructor_description,
Expand Down Expand Up @@ -542,6 +538,13 @@ and match_branch = {
mb_guard: option(expression),
[@sexp_drop_if sexp_locs_disabled]
mb_loc: Location.t,
}

[@deriving sexp]
and argument_value = {
arg_label: argument_label,
arg_label_specified: bool,
arg_expr: expression,
};

[@deriving sexp]
Expand Down
13 changes: 8 additions & 5 deletions compiler/src/typed/typedtree.rei
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,7 @@ and expression_desc =
| TExpBreak
| TExpReturn(option(expression))
| TExpLambda(list(match_branch), partial)
| TExpApp(
expression,
list(argument_label),
list((argument_label, expression)),
)
| TExpApp(expression, list(argument_label), list(argument_value))
| TExpConstruct(
loc(Identifier.t),
constructor_description,
Expand Down Expand Up @@ -505,6 +501,13 @@ and match_branch = {
mb_body: expression,
mb_guard: option(expression),
mb_loc: Location.t,
}

[@deriving sexp]
and argument_value = {
arg_label: argument_label,
arg_label_specified: bool,
arg_expr: expression,
};

[@deriving sexp]
Expand Down
Loading

0 comments on commit 4399387

Please sign in to comment.