Skip to content

Commit

Permalink
feat(lsp): Goto type definition (#2129)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-snezhko committed Jul 28, 2024
1 parent d4a99f9 commit 4bb8fae
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 70 deletions.
4 changes: 2 additions & 2 deletions compiler/src/language_server/driver.re
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ let process = msg => {
| Formatting(id, params) when is_initialized^ =>
Formatting.process(~id, ~compiled_code, ~documents, params);
Reading;
| Definition(id, params) when is_initialized^ =>
Definition.process(~id, ~compiled_code, ~documents, params);
| Goto(id, goto_request_type, params) when is_initialized^ =>
Goto.process(~id, ~compiled_code, ~documents, goto_request_type, params);
Reading;
| SetTrace(trace_value) =>
Trace.set_level(trace_value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,27 @@ open Grain_diagnostics;
open Sourcetree;
open Lsp_types;

type goto_request_type =
| Definition
| TypeDefinition;

// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#definitionParams
module RequestParams = {
[@deriving yojson({strict: false})]
type t = {
[@key "textDocument"]
text_document: Protocol.text_document_identifier,
position: Protocol.position,
};
type t = Protocol.text_document_position_params;
};

// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#locationLink
module ResponseResult = {
[@deriving yojson]
type t = {
[@key "originSelectionRange"]
origin_selection_range: Protocol.range,
[@key "targetUri"]
target_uri: Protocol.uri,
[@key "targetRange"]
target_range: Protocol.range,
[@key "targetSelectionRange"]
target_selection_range: Protocol.range,
};
type t = Protocol.location_link;
};

let send_no_result = (~id: Protocol.message_id) => {
Protocol.response(~id, `Null);
};

let send_definition =
let send_location_link =
(
~id: Protocol.message_id,
~range: Protocol.range,
Expand All @@ -45,48 +36,43 @@ let send_definition =
) => {
Protocol.response(
~id,
ResponseResult.to_yojson({
Protocol.location_link_to_yojson({
origin_selection_range: range,
target_uri,
target_range,
target_selection_range: target_range,
}),
);
};

type check_position =
| Forward
| Backward;
let rec find_definition =

let rec find_location =
(
~check_position=Forward,
get_location: list(Sourcetree.node) => option(Location.t),
sourcetree: Sourcetree.sourcetree,
position: Protocol.position,
) => {
let results = Sourcetree.query(position, sourcetree);

let result =
switch (results) {
| [Value({definition}), ..._]
| [Pattern({definition}), ..._]
| [Type({definition}), ..._]
| [Declaration({definition}), ..._]
| [Exception({definition}), ..._]
| [Module({definition}), ..._] =>
switch (definition) {
| None => None
| Some(loc) =>
let uri = Utils.filename_to_uri(loc.loc_start.pos_fname);
Some((loc, uri));
}
| _ => None
switch (get_location(results)) {
| None => None
| Some(loc) =>
let uri = Utils.filename_to_uri(loc.loc_start.pos_fname);
Some((loc, uri));
};
switch (result) {
| None =>
if (check_position == Forward && position.character > 0) {
// If a user selects from left to right, their pointer ends up after the identifier
// this tries to check if the identifier was selected.
find_definition(
find_location(
~check_position=Backward,
get_location,
sourcetree,
{line: position.line, character: position.character - 1},
);
Expand All @@ -102,16 +88,44 @@ let process =
~id: Protocol.message_id,
~compiled_code: Hashtbl.t(Protocol.uri, Lsp_types.code),
~documents: Hashtbl.t(Protocol.uri, string),
goto_request_type: goto_request_type,
params: RequestParams.t,
) => {
switch (Hashtbl.find_opt(compiled_code, params.text_document.uri)) {
| None => send_no_result(~id)
| Some({program, sourcetree}) =>
let result = find_definition(sourcetree, params.position);
let get_location =
switch (goto_request_type) {
| Definition => (
results => {
switch (results) {
| [Sourcetree.Value({definition}), ..._]
| [Pattern({definition}), ..._]
| [Type({definition}), ..._]
| [Declaration({definition}), ..._]
| [Exception({definition}), ..._]
| [Module({definition}), ..._] => definition
| _ => None
};
}
)
| TypeDefinition => (
results => {
switch (results) {
| [Value({env, value_type: type_expr}), ..._] =>
Env.get_type_definition_loc(type_expr, env)
| [Pattern({definition}), ..._] => definition
| _ => None
};
}
)
};

let result = find_location(get_location, sourcetree, params.position);
switch (result) {
| None => send_no_result(~id)
| Some((loc, uri)) =>
send_definition(
send_location_link(
~id,
~range=Utils.loc_to_range(loc),
~target_uri=uri,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
open Grain_typed;

// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#definitionParams
type goto_request_type =
| Definition
| TypeDefinition;

module RequestParams: {
[@deriving yojson({strict: false})]
type t;
Expand All @@ -17,6 +20,7 @@ let process:
~id: Protocol.message_id,
~compiled_code: Hashtbl.t(Protocol.uri, Lsp_types.code),
~documents: Hashtbl.t(Protocol.uri, string),
goto_request_type,
RequestParams.t
) =>
unit;
6 changes: 1 addition & 5 deletions compiler/src/language_server/hover.re
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@ open Lsp_types;
// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#hoverParams
module RequestParams = {
[@deriving yojson({strict: false})]
type t = {
[@key "textDocument"]
text_document: Protocol.text_document_identifier,
position: Protocol.position,
};
type t = Protocol.text_document_position_params;
};

// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#hover
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/language_server/initialize.re
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ module ResponseResult = {
definition_provider: {
link_support: true,
},
type_definition_provider: false,
type_definition_provider: true,
references_provider: false,
document_symbol_provider: false,
code_action_provider: false,
Expand Down
15 changes: 12 additions & 3 deletions compiler/src/language_server/message.re
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ type t =
| TextDocumentDidChange(Protocol.uri, Code_file.DidChange.RequestParams.t)
| TextDocumentInlayHint(Protocol.message_id, Inlayhint.RequestParams.t)
| Formatting(Protocol.message_id, Formatting.RequestParams.t)
| Definition(Protocol.message_id, Definition.RequestParams.t)
| Goto(Protocol.message_id, Goto.goto_request_type, Goto.RequestParams.t)
| SetTrace(Protocol.trace_value)
| Unsupported
| Error(string);
Expand Down Expand Up @@ -63,8 +63,17 @@ let of_request = (msg: Protocol.request_message): t => {
| Error(msg) => Error(msg)
}
| {method: "textDocument/definition", id: Some(id), params: Some(params)} =>
switch (Definition.RequestParams.of_yojson(params)) {
| Ok(params) => Definition(id, params)
switch (Goto.RequestParams.of_yojson(params)) {
| Ok(params) => Goto(id, Definition, params)
| Error(msg) => Error(msg)
}
| {
method: "textDocument/typeDefinition",
id: Some(id),
params: Some(params),
} =>
switch (Goto.RequestParams.of_yojson(params)) {
| Ok(params) => Goto(id, TypeDefinition, params)
| Error(msg) => Error(msg)
}
| {method: "$/setTrace", params: Some(params)} =>
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/language_server/message.rei
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type t =
| TextDocumentDidChange(Protocol.uri, Code_file.DidChange.RequestParams.t)
| TextDocumentInlayHint(Protocol.message_id, Inlayhint.RequestParams.t)
| Formatting(Protocol.message_id, Formatting.RequestParams.t)
| Definition(Protocol.message_id, Definition.RequestParams.t)
| Goto(Protocol.message_id, Goto.goto_request_type, Goto.RequestParams.t)
| SetTrace(Protocol.trace_value)
| Unsupported
| Error(string);
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/language_server/protocol.re
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ type trace_value = string; // 'off' | 'messages' | 'verbose';
[@deriving yojson]
type text_document_identifier = {uri};

// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocumentPositionParams
[@deriving yojson({strict: false})]
type text_document_position_params = {
[@key "textDocument"]
text_document: text_document_identifier,
position,
};

// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocumentSyncKind
[@deriving (enum, yojson)]
type text_document_sync_kind =
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/language_server/protocol.rei
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ type trace_value = string; // 'off' | 'messages' | 'verbose';
[@deriving yojson]
type text_document_identifier = {uri};

// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocumentPositionParams
[@deriving yojson({strict: false})]
type text_document_position_params = {
[@key "textDocument"]
text_document: text_document_identifier,
position,
};

// https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocumentSyncKind
[@deriving (enum, yojson)]
type text_document_sync_kind =
Expand Down
27 changes: 5 additions & 22 deletions compiler/src/language_server/sourcetree.re
Original file line number Diff line number Diff line change
Expand Up @@ -439,32 +439,15 @@ module Sourcetree: Sourcetree = {
);
};
let enter_pattern = pat => {
let rec get_type_path = pat_type => {
Types.(
switch (pat_type.desc) {
| TTyConstr(path, _, _) => Some(path)
| TTyLink(inner)
| TTySubst(inner) => get_type_path(inner)
| _ => None
}
);
};
let definition =
switch (get_type_path(pat.pat_type)) {
| Some(path) =>
let decl = Env.find_type(path, pat.pat_env);
if (decl.type_loc == Location.dummy_loc) {
None;
} else {
Some(decl.type_loc);
};
| _ => None
};
segments :=
[
(
loc_to_interval(pat.pat_loc),
Pattern({pattern: pat, definition}),
Pattern({
pattern: pat,
definition:
Env.get_type_definition_loc(pat.pat_type, pat.pat_env),
}),
),
...segments^,
];
Expand Down
24 changes: 24 additions & 0 deletions compiler/src/typed/env.re
Original file line number Diff line number Diff line change
Expand Up @@ -2467,6 +2467,30 @@ let format_dependency_chain = (ppf, depchain: dependency_chain) => {
fprintf(ppf, "@]");
};

let rec get_type_path = type_expr => {
Types.(
switch (type_expr.desc) {
| TTyConstr(path, _, _) => Some(path)
| TTyLink(inner)
| TTySubst(inner) => get_type_path(inner)
| _ => None
}
);
};

let get_type_definition_loc = (type_expr, env) => {
switch (get_type_path(type_expr)) {
| Some(path) =>
let decl = find_type(path, env);
if (decl.type_loc == Location.dummy_loc) {
None;
} else {
Some(decl.type_loc);
};
| _ => None
};
};

let report_error = ppf =>
fun
| Illegal_renaming(modname, ps_name, filename) =>
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/typed/env.rei
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ let fold_modtypes:
let scrape_alias: (t, module_type) => module_type;
let check_value_name: (string, Location.t) => unit;

let get_type_definition_loc: (type_expr, t) => option(Location.t);

module Persistent_signature: {
type t = {
/** Name of the file containing the signature. */
Expand Down

0 comments on commit 4bb8fae

Please sign in to comment.