From 10d51ddaaba0cd03ccf9ae51620d6e8cad27f1e4 Mon Sep 17 00:00:00 2001 From: Amaan Qureshi Date: Sat, 14 Oct 2023 00:24:10 -0400 Subject: [PATCH] refactor!: remove nvim-treesitter dependency (#35) * refactor!: remove nvim-treesitter dependency * chore: lua diagnostics * Check if parser installed with pcall `ts.get_parser` * Load nvim-treesitter in CI before running tests --------- Co-authored-by: Chen Shuaimin --- .github/workflows/ci.yml | 2 +- lua/ssr.lua | 17 +++++++++++------ lua/ssr/parse.lua | 12 ++++++++---- lua/ssr/search.lua | 27 ++++++++++++++++++++++----- lua/ssr/utils.lua | 24 ++++++++++++++++++------ tests/ssr_spec.lua | 9 ++++++++- 6 files changed, 68 insertions(+), 23 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b7e6c1c..2a0df4a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: git clone --depth 1 https://github.com/nvim-treesitter/nvim-treesitter.git ~/.local/share/nvim/site/pack/vendor/start/nvim-treesitter git clone --depth 1 https://github.com/nvim-lua/plenary.nvim ~/.local/share/nvim/site/pack/vendor/start/plenary.nvim ln -s $(pwd) ~/.local/share/nvim/site/pack/vendor/start - nvim --headless -c "lua require('nvim-treesitter').setup {}" -c 'TSInstallSync python javascript lua rust go' -c 'q' + nvim --headless -c 'TSInstallSync python javascript lua rust go' -c 'q' - name: Run tests run: | diff --git a/lua/ssr.lua b/lua/ssr.lua index 30a3651..8aece99 100644 --- a/lua/ssr.lua +++ b/lua/ssr.lua @@ -3,7 +3,6 @@ local ts = vim.treesitter local fn = vim.fn local keymap = vim.keymap local highlight = vim.highlight -local parsers = require "nvim-treesitter.parsers" local ParseContext = require("ssr.parse").ParseContext local search = require("ssr.search").search local replace = require("ssr.search").replace @@ -56,13 +55,18 @@ function Ui.new() self.origin_win = api.nvim_get_current_win() local origin_buf = api.nvim_win_get_buf(self.origin_win) - self.lang = parsers.get_buf_lang(origin_buf) - if not parsers.has_parser(self.lang) then + local lang = ts.language.get_lang(vim.bo[origin_buf].filetype) + if not lang then + return u.notify("Treesitter language not found") + end + self.lang = lang + + local origin_node = u.node_for_range(origin_buf, self.lang, u.get_selection(self.origin_win)) + if not origin_node then return u.notify("Treesitter parser not found, please try to install it with :TSInstall " .. self.lang) end - local origin_node = u.node_for_range(origin_buf, u.get_selection(self.origin_win)) if origin_node:has_error() then - return u.notify "You have syntax errors in selected node" + return u.notify "You have syntax errors in the selected node" end local parse_context = ParseContext.new(origin_buf, origin_node) if not parse_context then @@ -173,7 +177,7 @@ function Ui.new() return end - if parsers.get_buf_lang(event.buf) ~= self.lang then + if ts.language.get_lang(vim.bo[event.buf].filetype) ~= self.lang then return self:set_status "N/A" end self:search() @@ -334,6 +338,7 @@ function Ui:replace_confirm() while match_idx <= #matches do local confirm_win = open_confirm_win(match_idx) + ---@type string local key while true do -- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`. diff --git a/lua/ssr/parse.lua b/lua/ssr/parse.lua index 1050d08..2e78f40 100644 --- a/lua/ssr/parse.lua +++ b/lua/ssr/parse.lua @@ -1,5 +1,4 @@ local ts = vim.treesitter -local parsers = require "nvim-treesitter.parsers" local wildcard_prefix = require("ssr.search").wildcard_prefix local M = {} @@ -19,12 +18,17 @@ M.ParseContext = ParseContext ---@param origin_node TSNode ---@return ParseContext? function ParseContext.new(buf, origin_node) - local self = setmetatable({ lang = parsers.get_buf_lang(buf) }, { __index = ParseContext }) + local lang = ts.language.get_lang(vim.bo[buf].filetype) + if not lang then + return + end + local self = setmetatable({ lang = lang }, { __index = ParseContext }) local origin_start_row, origin_start_col, origin_start_byte = origin_node:start() local _, _, origin_end_byte = origin_node:end_() local origin_lines = vim.split(ts.get_node_text(origin_node, buf), "\n") local origin_sexpr = origin_node:sexpr() + ---@type TSNode? local context_node = origin_node -- Find an ancestor of `origin_node` @@ -45,7 +49,7 @@ function ParseContext.new(buf, origin_node) end_col = end_col + start_col end local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col) - if node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then + if node_in_context and node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then local context_start_byte self.start_row, self.start_col, context_start_byte = context_node:start() self.before = context_text:sub(1, origin_start_byte - context_start_byte) @@ -61,7 +65,7 @@ end -- Parse search pattern to syntax tree in proper context. ---@param pattern string ----@return TSNode, string +---@return TSNode?, string function ParseContext:parse(pattern) -- Replace named wildcard $name to identifier __ssr_var_name to avoid syntax error. pattern = pattern:gsub("%$([_%a%d]+)", wildcard_prefix .. "%1") diff --git a/lua/ssr/search.lua b/lua/ssr/search.lua index 766dd75..ec6cdcc 100644 --- a/lua/ssr/search.lua +++ b/lua/ssr/search.lua @@ -1,6 +1,5 @@ local api = vim.api local ts = vim.treesitter -local parsers = require "nvim-treesitter.parsers" local u = require "ssr.utils" local M = {} @@ -46,10 +45,13 @@ end -- The check is loose because users want to match different types of node. -- e.g. converting `{ foo: foo }` to shorthand `{ foo }`. ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred) - ---@param node1 TSNode - ---@param node2 TSNode + ---@param node1 TSNode? + ---@param node2 TSNode? ---@return boolean local function tree_match(node1, node2) + if not node1 or not node2 then + return false + end if node1:named() ~= node2:named() then return false end @@ -74,10 +76,13 @@ end, true) ---@param source string ---@return string, table local function build_sexpr(node, source) + ---@type table local wildcards = {} local next_idx = 1 -- This function is more strict than `tsnode:sexpr()` by also requiring leaf nodes to match text. + ---@param node TSNode + ---@return string local function build(node) local text = ts.get_node_text(node, source) @@ -139,10 +144,19 @@ end function M.search(buf, node, source, ns) local sexpr, wildcards = build_sexpr(node, source) local parse_query = ts.query.parse or ts.parse_query - local query = parse_query(parsers.get_buf_lang(buf), sexpr) + local lang = ts.language.get_lang(vim.bo[buf].filetype) + if not lang then + return {} + end + local query = parse_query(lang, sexpr) local matches = {} - local root = parsers.get_parser(buf):parse()[1]:root() + local has_parser, parser = pcall(ts.get_parser, buf, lang) + if not has_parser then + return {} + end + local root = parser:parse(true)[1]:root() for _, nodes in query:iter_matches(root, buf, 0, -1) do + ---@type table local captures = {} for var, idx in pairs(wildcards) do captures[var] = ExtmarkRange.new(ns, buf, nodes[idx]) @@ -154,6 +168,9 @@ function M.search(buf, node, source, ns) -- Sort matches from -- buffer top to bottom, to make goto next/prev match intuitive -- inner to outer for recursive matches, to make replacing correct + ---@param match1 { range: ExtmarkRange, captures: table} + ---@param match2 { range: ExtmarkRange, captures: table} + ---@return boolean table.sort(matches, function(match1, match2) local start_row1, start_col1, end_row1, end_col1 = match1.range:get() local start_row2, start_col2, end_row2, end_col2 = match2.range:get() diff --git a/lua/ssr/utils.lua b/lua/ssr/utils.lua index 86c3625..099ed92 100644 --- a/lua/ssr/utils.lua +++ b/lua/ssr/utils.lua @@ -1,5 +1,5 @@ local api = vim.api -local parsers = require "nvim-treesitter.parsers" +local ts = vim.treesitter local M = {} @@ -55,13 +55,17 @@ end -- Get smallest node for the range. ---@param buf buffer +---@param lang string ---@param start_row number ---@param start_col number ---@param end_row number ---@param end_col number ----@return TSNode -function M.node_for_range(buf, start_row, start_col, end_row, end_col) - return parsers.get_parser(buf):parse()[1]:root():named_descendant_for_range(start_row, start_col, end_row, end_col) +---@return TSNode? +function M.node_for_range(buf, lang, start_row, start_col, end_row, end_col) + local has_parser, parser = pcall(ts.get_parser, buf, lang) + if has_parser then + return parser:parse()[1]:root():named_descendant_for_range(start_row, start_col, end_row, end_col) + end end ---@param buf buffer @@ -71,7 +75,7 @@ function M.get_indent(buf, row) return line:match "^%s*" end ----@param lines table +---@param lines string[] ---@param indent string function M.add_indent(lines, indent) for i = 2, #lines do @@ -79,7 +83,7 @@ function M.add_indent(lines, indent) end end ----@param lines table +---@param lines string[] ---@param indent string function M.remove_indent(lines, indent) indent = "^" .. indent @@ -98,7 +102,15 @@ function M.to_ts_query_str(s) end -- Compute window size to show giving lines. +---@param lines string[] +---@param config Config +---@return number +---@return number function M.get_win_size(lines, config) + ---@param i number + ---@param min number + ---@param max number + ---@return number local function clamp(i, min, max) return math.min(math.max(i, min), max) end diff --git a/tests/ssr_spec.lua b/tests/ssr_spec.lua index 2cd3315..f6fcd17 100644 --- a/tests/ssr_spec.lua +++ b/tests/ssr_spec.lua @@ -1,5 +1,6 @@ local u = require "ssr.utils" local ParseContext = require("ssr.parse").ParseContext +local ts = vim.treesitter local search = require("ssr.search").search local replace = require("ssr.search").replace @@ -222,6 +223,10 @@ local a = vim.api ]] describe("", function() + -- Plenary runs nvim with `--noplugin` argument. + -- Make sure nvim-treesitter is loaded, which populates vim.treesitter's ft_to_lang table. + require "nvim-treesitter" + for _, s in ipairs(tests) do local ft, desc, content, pattern, template, expected = s:match "^ (%a-) (.-)\n(.-)%s?====%s?(.-)%s?==>>%s?(.-)%s?====%s?(.-)%s?$" @@ -249,7 +254,9 @@ describe("", function() local buf = vim.api.nvim_create_buf(false, true) vim.bo[buf].filetype = ft vim.api.nvim_buf_set_lines(buf, 0, -1, true, content) - local origin_node = u.node_for_range(buf, start_row, start_col, end_row, end_col) + local lang = ts.language.get_lang(vim.bo[buf].filetype) + assert(lang, "language not found") + local origin_node = u.node_for_range(buf, lang, start_row, start_col, end_row, end_col) local parse_context = ParseContext.new(buf, origin_node) assert(parse_context)