diff --git a/lua/ssr.lua b/lua/ssr.lua index 8aece99..bcabc94 100644 --- a/lua/ssr.lua +++ b/lua/ssr.lua @@ -1,11 +1,9 @@ local api = vim.api local ts = vim.treesitter -local fn = vim.fn -local keymap = vim.keymap -local highlight = vim.highlight -local ParseContext = require("ssr.parse").ParseContext -local search = require("ssr.search").search -local replace = require("ssr.search").replace +local ParseContext = require "ssr.parse_context" +local Searcher = require "ssr.search" +local Replacer = require "ssr.replace" +local File = require "ssr.file" local u = require "ssr.utils" local M = {} @@ -34,33 +32,26 @@ function M.setup(cfg) end end ----@type table -local win_uis = {} - ---@class Ui ----@field ns number ----@field cur_search_ns number ----@field augroup number ---@field ui_buf buffer ---@field extmarks {status: number, search: number, replace: number} ---@field origin_win window ---@field lang string ---@field parse_context ParseContext ----@field buf_matches table +---@field matches table local Ui = {} ---@return Ui? function Ui.new() - local self = setmetatable({}, { __index = Ui }) + local self = setmetatable({ matches = {} }, { __index = Ui }) self.origin_win = api.nvim_get_current_win() local origin_buf = api.nvim_win_get_buf(self.origin_win) local lang = ts.language.get_lang(vim.bo[origin_buf].filetype) if not lang then - return u.notify("Treesitter language not found") + return u.notify "Treesitter language not found" end self.lang = lang - local origin_node = u.node_for_range(origin_buf, self.lang, u.get_selection(self.origin_win)) if not origin_node then return u.notify("Treesitter parser not found, please try to install it with :TSInstall " .. self.lang) @@ -68,23 +59,18 @@ function Ui.new() if origin_node:has_error() then return u.notify "You have syntax errors in the selected node" end - local parse_context = ParseContext.new(origin_buf, origin_node) + local parse_context = ParseContext.new(origin_buf, self.lang, origin_node) if not parse_context then return u.notify "Can't find a proper context to parse the pattern" end self.parse_context = parse_context - self.buf_matches = {} - self.ns = api.nvim_create_namespace("ssr_" .. self.origin_win) -- TODO - self.cur_search_ns = api.nvim_create_namespace("ssr_cur_match_" .. self.origin_win) - self.augroup = api.nvim_create_augroup("ssr_augroup_" .. self.origin_win, {}) - -- Init ui buffer self.ui_buf = api.nvim_create_buf(false, true) vim.bo[self.ui_buf].filetype = "ssr" local placeholder = ts.get_node_text(origin_node, origin_buf) - placeholder = "\n\n" .. placeholder .. "\n\n" + placeholder = "\n\n" .. placeholder .. "\n\n\n\n" placeholder = vim.split(placeholder, "\n") u.remove_indent(placeholder, u.get_indent(origin_buf, origin_node:start())) api.nvim_buf_set_lines(self.ui_buf, 0, -1, true, placeholder) @@ -92,16 +78,17 @@ function Ui.new() ts.start(self.ui_buf, self.lang) local function virt_text(row, text) - return api.nvim_buf_set_extmark(self.ui_buf, self.ns, row, 0, { virt_text = text, virt_text_pos = "overlay" }) + return api.nvim_buf_set_extmark(self.ui_buf, u.namespace, row, 0, { virt_text = text, virt_text_pos = "overlay" }) end self.extmarks = { status = virt_text(0, { { "[SSR]", "Comment" }, { " (Press ? for help)", "Comment" } }), search = virt_text(1, { { "SEARCH:", "String" } }), - replace = virt_text(#placeholder - 2, { { "REPLACE:", "String" } }), + replace = virt_text(#placeholder - 4, { { "REPLACE:", "String" } }), + results = virt_text(#placeholder - 2, { { "RESULTS:", "String" } }), } local function map(key, func) - keymap.set("n", key, function() + vim.keymap.set("n", key, function() func(self) end, { buffer = self.ui_buf, nowait = true }) end @@ -127,7 +114,7 @@ function Ui.new() height = height, }) u.set_cursor(ui_win, 2, 0) - fn.matchadd("Title", [[$\w\+]]) + vim.fn.matchadd("Title", [[$\w\+]]) map(config.keymaps.close, function() api.nvim_win_close(ui_win, false) @@ -149,38 +136,29 @@ function Ui.new() end, }) - -- SSR window is bound to the original window (not buffer!), which is the same behavior as IDEs and browsers. api.nvim_create_autocmd("BufWinEnter", { group = self.augroup, callback = function(event) if event.buf == self.ui_buf then return end - local win = api.nvim_get_current_win() - if win == ui_win then - -- Prevent accidentally opening another file in the ssr window. - -- Adapted from neo-tree.nvim. - vim.schedule(function() - api.nvim_win_set_buf(ui_win, self.ui_buf) - local name = api.nvim_buf_get_name(event.buf) - api.nvim_win_call(self.origin_win, function() - pcall(api.nvim_buf_delete, event.buf, {}) - if name ~= "" then - vim.cmd.edit(name) - end - end) - api.nvim_set_current_win(self.origin_win) - end) - return - elseif win ~= self.origin_win then + if win ~= ui_win then return end - - if ts.language.get_lang(vim.bo[event.buf].filetype) ~= self.lang then - return self:set_status "N/A" - end - self:search() + -- Prevent accidentally opening another file in the ssr window. + -- Adapted from neo-tree.nvim. + vim.schedule(function() + api.nvim_win_set_buf(ui_win, self.ui_buf) + local name = api.nvim_buf_get_name(event.buf) + api.nvim_win_call(self.origin_win, function() + pcall(api.nvim_buf_delete, event.buf, {}) + if name ~= "" then + vim.cmd.edit(name) + end + end) + api.nvim_set_current_win(self.origin_win) + end) end, }) @@ -188,45 +166,62 @@ function Ui.new() group = self.augroup, buffer = self.ui_buf, callback = function() - win_uis[self.origin_win] = nil api.nvim_clear_autocmds { group = self.augroup } api.nvim_buf_delete(self.ui_buf, {}) - for buf in pairs(self.buf_matches) do - api.nvim_buf_clear_namespace(buf, self.ns, 0, -1) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) + for buf in pairs(self.matches) do + api.nvim_buf_clear_namespace(buf, u.namespace, 0, -1) + api.nvim_buf_clear_namespace(buf, u.cur_search_ns, 0, -1) end end, }) - win_uis[self.origin_win] = self return self end function Ui:search() - local pattern = self:get_input() - local buf = api.nvim_win_get_buf(self.origin_win) - self.buf_matches[buf] = {} - api.nvim_buf_clear_namespace(buf, self.ns, 0, -1) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) + local pattern, template = self:get_input() + if pattern == self.last_pattern then + return + end + self.last_pattern = pattern + -- for buf in pairs(self.matches) do + -- api.nvim_buf_clear_namespace(buf, u.namespace, 0, -1) + -- api.nvim_buf_clear_namespace(buf, u.cur_search_ns, 0, -1) + -- end + self.matches = {} + + local found = 0 + local matched_files = 0 local start = vim.loop.hrtime() - local node, source = self.parse_context:parse(pattern) - if node:has_error() then + local searcher = Searcher.new(self.lang, pattern, self.parse_context) + if not searcher then return self:set_status "Error" end - self.buf_matches[buf] = search(buf, node, source, self.ns) - local elapsed = (vim.loop.hrtime() - start) / 1E6 - for _, match in ipairs(self.buf_matches[buf]) do - local start_row, start_col, end_row, end_col = match.range:get() - highlight.range(buf, self.ns, "Search", { start_row, start_col }, { end_row, end_col }, {}) - end - self:set_status(string.format("%d found in %dms", #self.buf_matches[buf], elapsed)) + + File.grep(searcher.rough_regex, function(file) + local matches = searcher:search(file) + if #matches == 0 then + return + end + -- self.matches[buf] = matches + found = found + #matches + matched_files = matched_files + 1 + -- for _, match in ipairs(matches) do + -- local start_row, start_col, end_row, end_col = match.range:get() + -- vim.highlight.range(buf, u.namespace, "Search", { start_row, start_col }, { end_row, end_col }, {}) + -- end + self:set_status(string.format("%d found in %d files (searching)", found, matched_files)) + end, function() + local elapsed = (vim.loop.hrtime() - start) / 1E6 + return self:set_status(string.format("%d found in %d files (%dms)", found, matched_files, elapsed)) + end) end function Ui:next_match_idx() local cursor_row, cursor_col = u.get_cursor(self.origin_win) local buf = api.nvim_win_get_buf(self.origin_win) - for idx, matches in pairs(self.buf_matches[buf]) do + for idx, matches in pairs(self.matches[buf]) do local start_row, start_col = matches.range:get() if start_row > cursor_row or (start_row == cursor_row and start_col > cursor_col) then return idx @@ -238,7 +233,7 @@ end function Ui:prev_match_idx() local cursor_row, cursor_col = u.get_cursor(self.origin_win) local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] + local matches = self.matches[buf] for idx = #matches, 1, -1 do local start_row, start_col = matches[idx].range:get() if start_row < cursor_row or (start_row == cursor_row and start_col < cursor_col) then @@ -250,19 +245,19 @@ end function Ui:goto_match(match_idx) local buf = api.nvim_win_get_buf(self.origin_win) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - local matches = self.buf_matches[buf] + api.nvim_buf_clear_namespace(buf, u.cur_search_ns, 0, -1) + local matches = self.matches[buf] local start_row, start_col, end_row, end_col = matches[match_idx].range:get() u.set_cursor(self.origin_win, start_row, start_col) - highlight.range( + vim.highlight.range( buf, - self.cur_search_ns, + u.cur_search_ns, "CurSearch", { start_row, start_col }, { end_row, end_col }, { priority = vim.highlight.priorities.user + 100 } ) - api.nvim_buf_set_extmark(buf, self.cur_search_ns, start_row, start_col, { + api.nvim_buf_set_extmark(buf, u.cur_search_ns, start_row, start_col, { virt_text_pos = "eol", virt_text = { { string.format("[%d/%d]", match_idx, #matches), "DiagnosticVirtualTextInfo" } }, }) @@ -271,7 +266,7 @@ end function Ui:replace_all() self:search() local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] + local matches = self.matches[buf] if #matches == 0 then return self:set_status "pattern not found" end @@ -287,7 +282,7 @@ end function Ui:replace_confirm() self:search() local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] + local matches = self.matches[buf] if #matches == 0 then return self:set_status "pattern not found" end @@ -306,7 +301,13 @@ function Ui:replace_confirm() api.nvim_buf_set_lines(confirm_buf, 0, -1, true, choices) for idx = 0, #choices - 1 do if idx + 1 ~= separator_idx then - api.nvim_buf_set_extmark(confirm_buf, self.ns, idx, 4, { hl_group = "Underlined", end_row = idx, end_col = 5 }) + api.nvim_buf_set_extmark( + confirm_buf, + u.namespace, + idx, + 4, + { hl_group = "Underlined", end_row = idx, end_col = 5 } + ) end end @@ -342,15 +343,15 @@ function Ui:replace_confirm() local key while true do -- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`. - api.nvim_buf_clear_namespace(confirm_buf, self.cur_search_ns, 0, -1) + api.nvim_buf_clear_namespace(confirm_buf, u.cur_search_ns, 0, -1) api.nvim_buf_set_extmark( confirm_buf, - self.cur_search_ns, + u.cur_search_ns, cursor - 1, 0, { virt_text = { { "•", "Cursor" } }, virt_text_pos = "overlay" } ) - api.nvim_buf_set_extmark(confirm_buf, self.cur_search_ns, cursor - 1, 0, { line_hl_group = "CursorLine" }) + api.nvim_buf_set_extmark(confirm_buf, u.cur_search_ns, cursor - 1, 0, { line_hl_group = "CursorLine" }) vim.cmd.redraw() local ok, char = pcall(vim.fn.getcharstr) @@ -372,7 +373,7 @@ function Ui:replace_confirm() cursor = cursor - 1 end elseif vim.tbl_contains({ "", "", "", "", "", "" }, key) then - fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key)) + vim.fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key)) else break end @@ -406,22 +407,23 @@ function Ui:replace_confirm() end api.nvim_buf_delete(confirm_buf, {}) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) + api.nvim_buf_clear_namespace(buf, u.cur_search_ns, 0, -1) self:set_status(string.format("%d/%d replaced", replaced, #matches)) end function Ui:get_input() local lines = api.nvim_buf_get_lines(self.ui_buf, 0, -1, true) - local pattern_pos = api.nvim_buf_get_extmark_by_id(self.ui_buf, self.ns, self.extmarks.search, {})[1] - local template_pos = api.nvim_buf_get_extmark_by_id(self.ui_buf, self.ns, self.extmarks.replace, {})[1] + local pattern_pos = api.nvim_buf_get_extmark_by_id(self.ui_buf, u.namespace, self.extmarks.search, {})[1] + local template_pos = api.nvim_buf_get_extmark_by_id(self.ui_buf, u.namespace, self.extmarks.replace, {})[1] local pattern = vim.trim(table.concat(lines, "\n", pattern_pos + 2, template_pos)) local template = vim.trim(table.concat(lines, "\n", template_pos + 1, #lines)) return pattern, template end ---@param status string +---@return nil function Ui:set_status(status) - api.nvim_buf_set_extmark(self.ui_buf, self.ns, 0, 0, { + api.nvim_buf_set_extmark(self.ui_buf, u.namespace, 0, 0, { id = self.extmarks.status, virt_text = { { "[SSR] ", "Comment" }, @@ -432,37 +434,8 @@ function Ui:set_status(status) }) end ----@param win window? ----@return Ui? -function Ui.from_win(win) - if win == nil or win == 0 then - win = api.nvim_get_current_win() - end - local ui = win_uis[win] - if not ui then - return u.notify "No open SSR window" - end - return ui -end - function M.open() return Ui.new() end --- Replace all matches. -function M.replace_all() - local ui = Ui.from_win() - if ui then - ui:replace_all() - end -end - --- Confirm each match. -function M.replace_confirm() - local ui = Ui.from_win() - if ui then - ui:replace_confirm() - end -end - return M diff --git a/lua/ssr/file.lua b/lua/ssr/file.lua new file mode 100644 index 0000000..b43d787 --- /dev/null +++ b/lua/ssr/file.lua @@ -0,0 +1,83 @@ +local ts = vim.treesitter +local uv = vim.uv or vim.loop + +---@class File +---@field name string +---@field content string +---@field mtime { nsec: integer, sec: integer } +---@field lang_tree LanguageTree +local File = {} + +---@type table +local cache = {} + +---@param name string +---@return File? +function File.new(name) + local fd = uv.fs_open(name, "r", 438) + if not fd then + return + end + local stat = uv.fs_fstat(fd) --[[@as uv.aliases.fs_stat_table]] + local self = cache[name] + if self and stat.mtime.sec == self.mtime.sec and stat.mtime.nsec == self.mtime.nsec then + uv.fs_close(fd) + return self + end + + self = setmetatable({ name = name }, { __index = File }) + self.mtime = stat.mtime + self.content = uv.fs_read(fd, stat.size, 0) --[[@as string]] + uv.fs_close(fd) + if not name:match "%.ts$" then + return + end + -- local ft = vim.filetype.match { filename = name } + -- if not ft then + -- return + -- end + local lang = ts.language.get_lang "typescript" + if not lang then + return + end + local has_parser, lang_tree = pcall(ts.get_string_parser, self.content, lang) + if not has_parser then + return + end + self.lang_tree = lang_tree + self.lang_tree:parse() + + cache[name] = self + return self +end + +---@param regex string +---@param on_file fun(file: File) +---@param on_end fun() +---@return nil +function File.grep(regex, on_file, on_end) + vim.system({ "rg", "--line-buffered", "--files-with-matches", regex }, { + text = true, + stdout = vim.schedule_wrap(function(err, files) + if err then + error(files) + end + if not files then + on_end() + return + end + for _, name in ipairs(vim.split(files, "\n", { plain = true })) do + local file = File.new(name) + if file then + on_file(file) + end + end + end), + }, function(obj) + if obj.code ~= 0 then + error(obj.stderr) + end + end) +end + +return File diff --git a/lua/ssr/parse.lua b/lua/ssr/parse_context.lua similarity index 70% rename from lua/ssr/parse.lua rename to lua/ssr/parse_context.lua index 2e78f40..f889b24 100644 --- a/lua/ssr/parse.lua +++ b/lua/ssr/parse_context.lua @@ -1,8 +1,7 @@ local ts = vim.treesitter -local wildcard_prefix = require("ssr.search").wildcard_prefix - -local M = {} +local u = require "ssr.utils" +-- The context in which user input will be parsed correctly. ---@class ParseContext ---@field lang string ---@field before string @@ -10,18 +9,12 @@ local M = {} ---@field pad_rows integer ---@field pad_cols integer local ParseContext = {} -ParseContext.__index = ParseContext -M.ParseContext = ParseContext --- Create a context in which `origin_node` (and user input) will be parsed correctly. ---@param buf buffer +---@param lang string ---@param origin_node TSNode ---@return ParseContext? -function ParseContext.new(buf, origin_node) - local lang = ts.language.get_lang(vim.bo[buf].filetype) - if not lang then - return - end +function ParseContext.new(buf, lang, origin_node) local self = setmetatable({ lang = lang }, { __index = ParseContext }) local origin_start_row, origin_start_col, origin_start_byte = origin_node:start() @@ -48,8 +41,8 @@ function ParseContext.new(buf, origin_node) if end_row == start_row then end_col = end_col + start_col end - local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col) - if node_in_context and node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then + local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col) --[[@as TSNode]] + if node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then local context_start_byte self.start_row, self.start_col, context_start_byte = context_node:start() self.before = context_text:sub(1, origin_start_byte - context_start_byte) @@ -63,17 +56,17 @@ function ParseContext.new(buf, origin_node) end end --- Parse search pattern to syntax tree in proper context. +-- Parse code to TS node. ---@param pattern string ----@return TSNode?, string +---@return TSNode, string function ParseContext:parse(pattern) -- Replace named wildcard $name to identifier __ssr_var_name to avoid syntax error. - pattern = pattern:gsub("%$([_%a%d]+)", wildcard_prefix .. "%1") - local context_text = self.before .. pattern .. self.after - local root = ts.get_string_parser(context_text, self.lang):parse()[1]:root() + pattern = pattern:gsub("%$([_%a%d]+)", u.wildcard_prefix .. "%1") + local source = self.before .. pattern .. self.after + local root = ts.get_string_parser(source, self.lang):parse()[1]:root() local lines = vim.split(pattern, "\n") - local node = root:named_descendant_for_range(self.pad_rows, self.pad_cols, self.pad_rows + #lines - 1, #lines[#lines]) - return node, context_text + local node = root:named_descendant_for_range(self.pad_rows, self.pad_cols, self.pad_rows + #lines - 1, #lines[#lines]) --[[@as TSNode]] + return node, source end -return M +return ParseContext diff --git a/lua/ssr/range.lua b/lua/ssr/range.lua new file mode 100644 index 0000000..0112abc --- /dev/null +++ b/lua/ssr/range.lua @@ -0,0 +1,35 @@ +---@class Range +---@field start_row number +---@field start_col number +---@field end_row number +---@field end_col number +local Range = {} + +---@param node TSNode +---@return Range +function Range.from_node(node) + local start_row, start_col, end_row, end_col = node:range() + return setmetatable({ + start_row = start_row, + start_col = start_col, + end_row = end_row, + end_col = end_col, + }, { __index = Range }) +end + +---@param other Range +---@return boolean +function Range:before(other) + return self.end_row < other.start_row or (self.end_row == other.start_row and self.end_col <= other.start_col) +end + +---@param other Range +---@return boolean +function Range:inside(other) + return ( + (self.start_row > other.start_row or (self.start_row == other.start_row and self.start_col > other.start_col)) + and (self.end_row < other.end_row or (self.end_row == other.end_row and self.end_col <= other.end_col)) + ) +end + +return Range diff --git a/lua/ssr/replace.lua b/lua/ssr/replace.lua new file mode 100644 index 0000000..475e714 --- /dev/null +++ b/lua/ssr/replace.lua @@ -0,0 +1,28 @@ +local api = vim.api +local ts = vim.treesitter +local u = require "ssr.utils" + +local Replacer = {} + +--- Render template and replace one match. +---@param buf buffer +---@param match Match +function Replacer:replace(buf, match) + -- Render templates with captured nodes. + local replace = self.template:gsub("()%$([_%a%d]+)", function(pos, var) + local start_row, start_col, end_row, end_col = match.captures[var]:get() + local lines = api.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {}) + u.remove_indent(lines, u.get_indent(buf, start_row)) + local var_lines = vim.split(self.template:sub(1, pos), "\n") + local var_line = var_lines[#var_lines] + local template_indent = var_line:match "^%s*" + u.add_indent(lines, template_indent) + return table.concat(lines, "\n") + end) + replace = vim.split(replace, "\n") + local start_row, start_col, end_row, end_col = match.range:get() + u.add_indent(replace, u.get_indent(buf, start_row)) + api.nvim_buf_set_text(buf, start_row, start_col, end_row, end_col, replace) +end + +return Replacer diff --git a/lua/ssr/search.lua b/lua/ssr/search.lua index ec6cdcc..250f050 100644 --- a/lua/ssr/search.lua +++ b/lua/ssr/search.lua @@ -1,86 +1,88 @@ local api = vim.api local ts = vim.treesitter +local Range = require "ssr.range" local u = require "ssr.utils" -local M = {} - -M.wildcard_prefix = "__ssr_var_" +local H = {} ---@class Match ----@field range ExtmarkRange ----@field captures ExtmarkRange[] - ----@class ExtmarkRange ----@field ns number ----@field buf buffer ----@field extmark number -local ExtmarkRange = {} -M.ExtmarkRange = ExtmarkRange - ----@param ns number ----@param buf buffer ----@param node TSNode ----@return ExtmarkRange -function ExtmarkRange.new(ns, buf, node) - local start_row, start_col, end_row, end_col = node:range() +---@field range Range +---@field captures table + +---@field captures +---@class Searcher +---@field lang string +---@field query Query +---@field wildcards table +---@field rough_regex string +local Searcher = {} + +---@param lang string +---@param pattern string +---@param parse_context ParseContext +---@return Searcher? +function Searcher.new(lang, pattern, parse_context) + local node, source = parse_context:parse(pattern) + if node:has_error() then + return + end + local sexpr, wildcards, rough_regex = H.build_sexpr(node, source) + local parse_query = ts.query.parse or ts.parse_query + local query = parse_query(lang, sexpr) return setmetatable({ - ns = ns, - buf = buf, - extmark = api.nvim_buf_set_extmark(buf, ns, start_row, start_col, { - end_row = end_row, - end_col = end_col, - right_gravity = false, - end_right_gravity = true, - }), - }, { __index = ExtmarkRange }) -end - ----@return number, number, number, number -function ExtmarkRange:get() - local extmark = api.nvim_buf_get_extmark_by_id(self.buf, self.ns, self.extmark, { details = true }) - return extmark[1], extmark[2], extmark[3].end_row, extmark[3].end_col + lang = lang, + query = query, + wildcards = wildcards, + rough_regex = rough_regex, + }, { __index = Searcher }) end --- Compare if two captured trees can match. --- The check is loose because users want to match different types of node. --- e.g. converting `{ foo: foo }` to shorthand `{ foo }`. -ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred) - ---@param node1 TSNode? - ---@param node2 TSNode? - ---@return boolean - local function tree_match(node1, node2) - if not node1 or not node2 then - return false - end - if node1:named() ~= node2:named() then - return false - end - if node1:child_count() == 0 or node2:child_count() == 0 then - return ts.get_node_text(node1, buf) == ts.get_node_text(node2, buf) - end - if node1:child_count() ~= node2:child_count() then - return false +---@param file File +---@return Match[] +function Searcher:search(file) + local matches = {} + file.lang_tree:for_each_tree(function(tree, lang_tree) -- must called :parse(true) + if lang_tree:lang() ~= self.lang then + return end - for i = 0, node1:child_count() - 1 do - if not tree_match(node1:child(i), node2:child(i)) then - return false + for _, nodes in self.query:iter_matches(tree:root(), file.content, 0, -1) do + local range = Range.from_node(nodes[#nodes]) + local captures = {} + for var, idx in pairs(self.wildcards) do + captures[var] = Range.from_node(nodes[idx]) end + table.insert(matches, { range = range, captures = captures }) end - return true - end - return tree_match(match[pred[2]], match[pred[3]]) -end, true) + end) + + -- Sort matches from + -- buffer top to bottom, to make goto next/prev match intuitive + -- inner to outer for recursive matches, to make replacing correct + ---@param match1 Match + ---@param match2 Match + ---@return boolean + table.sort(matches, function(match1, match2) + if match1.range:before(match2.range) then + return true + end + return match1.range:inside(match2.range) + end) + return matches +end -- Build a TS sexpr represting the node. +-- This function is more strict than `TSNode:sexpr()` by also requiring leaf nodes to match text. ---@param node TSNode ---@param source string ----@return string, table -local function build_sexpr(node, source) +---@return string sexpr +---@return table wildcards +---@return string rough_regex +function H.build_sexpr(node, source) ---@type table local wildcards = {} + local rough_regex = "" local next_idx = 1 - -- This function is more strict than `tsnode:sexpr()` by also requiring leaf nodes to match text. ---@param node TSNode ---@return string local function build(node) @@ -88,7 +90,7 @@ local function build_sexpr(node, source) -- Special identifier __ssr_var_name is a named wildcard. -- Handle this early to make sure wildcard captures largest node. - local var = text:match("^" .. M.wildcard_prefix .. "([_%a%d]+)$") + local var = text:match("^" .. u.wildcard_prefix .. "([_%a%d]+)$") if var then if not wildcards[var] then wildcards[var] = next_idx @@ -104,6 +106,9 @@ local function build_sexpr(node, source) -- Leaf nodes (keyword, identifier, literal and symbol) should match text. if node:named_child_count() == 0 then + if #text > #rough_regex then + rough_regex = text + end local sexpr = string.format("(%s) @_%d (#eq? @_%d %s)", node:type(), next_idx, next_idx, u.to_ts_query_str(text)) next_idx = next_idx + 1 return sexpr @@ -134,76 +139,37 @@ local function build_sexpr(node, source) end local sexpr = string.format("(%s) @all", build(node)) - return sexpr, wildcards + rough_regex = u.regex_escape(rough_regex) + return sexpr, wildcards, rough_regex end ----@param buf buffer ----@param node TSNode ----@param source string ----@return Match[] -function M.search(buf, node, source, ns) - local sexpr, wildcards = build_sexpr(node, source) - local parse_query = ts.query.parse or ts.parse_query - local lang = ts.language.get_lang(vim.bo[buf].filetype) - if not lang then - return {} - end - local query = parse_query(lang, sexpr) - local matches = {} - local has_parser, parser = pcall(ts.get_parser, buf, lang) - if not has_parser then - return {} - end - local root = parser:parse(true)[1]:root() - for _, nodes in query:iter_matches(root, buf, 0, -1) do - ---@type table - local captures = {} - for var, idx in pairs(wildcards) do - captures[var] = ExtmarkRange.new(ns, buf, nodes[idx]) - end - local match = { range = ExtmarkRange.new(ns, buf, nodes[#nodes]), captures = captures } - table.insert(matches, match) - end - - -- Sort matches from - -- buffer top to bottom, to make goto next/prev match intuitive - -- inner to outer for recursive matches, to make replacing correct - ---@param match1 { range: ExtmarkRange, captures: table} - ---@param match2 { range: ExtmarkRange, captures: table} +-- Compare if two captured trees can match. +-- The check is loose because we want to match different types of node. +-- e.g. converting `{ foo: foo }` to shorthand `{ foo }`. +ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred) + ---@param node1 TSNode + ---@param node2 TSNode ---@return boolean - table.sort(matches, function(match1, match2) - local start_row1, start_col1, end_row1, end_col1 = match1.range:get() - local start_row2, start_col2, end_row2, end_col2 = match2.range:get() - if end_row1 < start_row2 or (end_row1 == start_row2 and end_col1 <= start_col2) then - return true + local function tree_match(node1, node2) + if node1:named() ~= node2:named() then + return false end - return (start_row1 > start_row2 or (start_row1 == start_row2 and start_col1 > start_col2)) - and (end_row1 < end_row2 or (end_row1 == end_row2 and end_col1 <= end_col2)) - end) - - return matches -end - ---- Render template and replace one match. ----@param buf buffer ----@param match Match ----@param template string -function M.replace(buf, match, template) - -- Render templates with captured nodes. - local replace = template:gsub("()%$([_%a%d]+)", function(pos, var) - local start_row, start_col, end_row, end_col = match.captures[var]:get() - local lines = api.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {}) - u.remove_indent(lines, u.get_indent(buf, start_row)) - local var_lines = vim.split(template:sub(1, pos), "\n") - local var_line = var_lines[#var_lines] - local template_indent = var_line:match "^%s*" - u.add_indent(lines, template_indent) - return table.concat(lines, "\n") - end) - replace = vim.split(replace, "\n") - local start_row, start_col, end_row, end_col = match.range:get() - u.add_indent(replace, u.get_indent(buf, start_row)) - api.nvim_buf_set_text(buf, start_row, start_col, end_row, end_col, replace) -end + if node1:child_count() == 0 or node2:child_count() == 0 then + return ts.get_node_text(node1, buf) == ts.get_node_text(node2, buf) + end + if node1:child_count() ~= node2:child_count() then + return false + end + for i = 0, node1:child_count() - 1 do + if + not tree_match(node1:child(i) --[[@as TSNode]], node2:child(i) --[[@as TSNode]]) + then + return false + end + end + return true + end + return tree_match(match[pred[2]], match[pred[3]]) +end, true) -return M +return Searcher diff --git a/lua/ssr/utils.lua b/lua/ssr/utils.lua index 099ed92..a5ad06b 100644 --- a/lua/ssr/utils.lua +++ b/lua/ssr/utils.lua @@ -3,6 +3,11 @@ local ts = vim.treesitter local M = {} +M.wildcard_prefix = "__ssr_var_" +M.namespace = api.nvim_create_namespace "ssr_ns" +M.cur_search_ns = api.nvim_create_namespace "ssr_cur_search_ns" +M.augroup = api.nvim_create_augroup("ssr_augroup", {}) + -- Send a notification titled SSR. ---@param msg string ---@return nil @@ -127,4 +132,35 @@ function M.get_win_size(lines, config) return width, height end +-- https://github.com/rust-lang/regex/blob/17284451f10aa06c6c42e622e3529b98513901a8/regex-syntax/src/lib.rs#L272 +local regex_meta_chars = { + ["\\"] = true, + ["."] = true, + ["+"] = true, + ["*"] = true, + ["?"] = true, + ["("] = true, + [")"] = true, + ["|"] = true, + ["["] = true, + ["]"] = true, + ["{"] = true, + ["}"] = true, + ["^"] = true, + ["$"] = true, + ["#"] = true, + ["&"] = true, + ["-"] = true, + ["~"] = true, +} + +---@param s string +---@return string +function M.regex_escape(s) + local escaped = s:gsub(".", function(ch) + return regex_meta_chars[ch] and "\\" .. ch or ch + end) + return escaped +end + return M diff --git a/tests/ssr_spec.lua b/tests/ssr_spec.lua index f6fcd17..bc8c242 100644 --- a/tests/ssr_spec.lua +++ b/tests/ssr_spec.lua @@ -1,8 +1,7 @@ -local u = require "ssr.utils" -local ParseContext = require("ssr.parse").ParseContext local ts = vim.treesitter -local search = require("ssr.search").search -local replace = require("ssr.search").replace +local s = require "ssr.search" +local ParseContext, Ssr = s.ParseContext, s.Ssr +local u = require "ssr.utils" local tests = {} @@ -257,14 +256,15 @@ describe("", function() local lang = ts.language.get_lang(vim.bo[buf].filetype) assert(lang, "language not found") local origin_node = u.node_for_range(buf, lang, start_row, start_col, end_row, end_col) - - local parse_context = ParseContext.new(buf, origin_node) + assert(origin_node, 'treesitter parser not installed') + local parse_context = ParseContext.new(buf, lang, origin_node) assert(parse_context) - local node, source = parse_context:parse(pattern) - local matches = search(buf, node, source, ns) + local rule = Ssr.new(lang, pattern, template, parse_context) + assert(rule) + local matches = rule:search(buf) for _, match in ipairs(matches) do - replace(buf, match, template) + rule:replace(buf, match) end local actual = vim.api.nvim_buf_get_lines(buf, 0, -1, true)