diff --git a/lang/rust/utils/lsp.go b/lang/rust/utils/lsp.go index 2fb28922..89a83e09 100644 --- a/lang/rust/utils/lsp.go +++ b/lang/rust/utils/lsp.go @@ -16,6 +16,7 @@ package utils import ( "context" + "fmt" "path/filepath" "regexp" "strings" @@ -46,6 +47,23 @@ func GetLSPClient(root string) *lsp.LSPClient { return cli } +// 查找符号对应的代码 +// root: 项目根目录 +// file: 文件路径 +// mod: 命名空间,空表示本文件根 +// name: 符号名 +// receiver: method接收者,为空表示不是method +func GetRawSymbol(root, file, mod, name string, receiver string, caseInsensitive bool) *lsp.DocumentSymbol { + cli := GetLSPClient(root) + sym, err := getSymbol(cli, root, file, mod, name, receiver, caseInsensitive) + if err != nil { + log.Error("get symbol for %s failed, err: %v", name, err) + return nil + } + + return sym +} + // 查找符号对应的代码 // root: 项目根目录 // file: 文件路径 @@ -54,10 +72,23 @@ func GetLSPClient(root string) *lsp.LSPClient { // receiver: method接收者,为空表示不是method func GetSymbol(root, file, mod, name string, receiver string, caseInsensitive bool) string { cli := GetLSPClient(root) - syms, err := cli.FileStructure(context.Background(), lsp.NewURI(file)) + sym, err := getSymbol(cli, root, file, mod, name, receiver, caseInsensitive) if err != nil { + log.Error("get symbol for %s failed, err: %v", name, err) return "" } + if sym == nil { + return "" + } + text, _ := cli.Locate(sym.Location) + return text +} + +func getSymbol(cli *lsp.LSPClient, root, file, mod, name string, receiver string, caseInsensitive bool) (*lsp.DocumentSymbol, error) { + syms, err := cli.FileStructure(context.Background(), lsp.NewURI(file)) + if err != nil { + return nil, err + } var sym *lsp.DocumentSymbol if mod != "" { @@ -91,10 +122,30 @@ func GetSymbol(root, file, mod, name string, receiver string, caseInsensitive bo finally: if sym == nil { - return "" + return nil, fmt.Errorf("can not find symbol for %s", name) + } + + return sym, nil +} + +// 查找符号对应的源码以及文件行号 +// root: 项目根目录 +// file: 文件路径 +// mod: 命名空间,空表示本文件根 +// name: 符号名 +// receiver: method接收者,为空表示不是method +func GetSymbolContentAndLocation(root, file, mod, name string, receiver string, caseInsensitive bool) (string, [2]int) { + cli := GetLSPClient(root) + sym, err := getSymbol(cli, root, file, mod, name, receiver, caseInsensitive) + if err != nil { + log.Error("get symbol for %s failed, err: %v", name, err) + return "", [2]int{} + } + if sym == nil { + return "", [2]int{} } text, _ := cli.Locate(sym.Location) - return text + return text, [2]int{sym.Location.Range.Start.Line, sym.Location.Range.End.Line} } // NOTICE: 为了提供容错率,这里只是简单查找是否包含token,不做严格的标识符检查