From 826d080d3c37c2c031db58b849790832e2faf604 Mon Sep 17 00:00:00 2001 From: Chen Shuaimin Date: Sat, 14 Oct 2023 16:20:51 +0800 Subject: [PATCH] Search across multiple files with ripgrep --- lua/ssr.lua | 471 ----------------------- lua/ssr/config.lua | 24 ++ lua/ssr/file.lua | 85 ++++ lua/ssr/init.lua | 16 + lua/ssr/{parse.lua => parse_context.lua} | 35 +- lua/ssr/range.lua | 40 ++ lua/ssr/replace.lua | 28 ++ lua/ssr/search.lua | 230 +++++------ lua/ssr/ui/confirm_win.lua | 139 +++++++ lua/ssr/ui/init.lua | 162 ++++++++ lua/ssr/ui/main_win.lua | 240 ++++++++++++ lua/ssr/ui/result_list.lua | 129 +++++++ lua/ssr/utils.lua | 79 +++- tests/ssr_spec.lua | 18 +- 14 files changed, 1058 insertions(+), 638 deletions(-) delete mode 100644 lua/ssr.lua create mode 100644 lua/ssr/config.lua create mode 100644 lua/ssr/file.lua create mode 100644 lua/ssr/init.lua rename lua/ssr/{parse.lua => parse_context.lua} (70%) create mode 100644 lua/ssr/range.lua create mode 100644 lua/ssr/replace.lua create mode 100644 lua/ssr/ui/confirm_win.lua create mode 100644 lua/ssr/ui/init.lua create mode 100644 lua/ssr/ui/main_win.lua create mode 100644 lua/ssr/ui/result_list.lua diff --git a/lua/ssr.lua b/lua/ssr.lua deleted file mode 100644 index a32cfe4..0000000 --- a/lua/ssr.lua +++ /dev/null @@ -1,471 +0,0 @@ -local api = vim.api -local ts = vim.treesitter -local fn = vim.fn -local keymap = vim.keymap -local highlight = vim.highlight -local ParseContext = require("ssr.parse").ParseContext -local search = require("ssr.search").search -local replace = require("ssr.search").replace -local u = require "ssr.utils" - -local M = {} - ----@class Config -local config = { - border = "rounded", - min_width = 50, - min_height = 5, - max_width = 120, - max_height = 25, - adjust_window = true, - keymaps = { - close = "q", - next_match = "n", - prev_match = "N", - replace_confirm = "", - replace_all = "", - }, -} - --- Set config options. ----@param cfg Config? -function M.setup(cfg) - if cfg then - config = vim.tbl_deep_extend("force", config, cfg) - end -end - ----@type table -local win_uis = {} - ----@class Ui ----@field ns number ----@field cur_search_ns number ----@field augroup number ----@field ui_buf buffer ----@field extmarks {status: number, search: number, replace: number} ----@field origin_win window ----@field lang string ----@field parse_context ParseContext ----@field buf_matches table -local Ui = {} - ----@return Ui? -function Ui.new() - local self = setmetatable({}, { __index = Ui }) - - self.origin_win = api.nvim_get_current_win() - local origin_buf = api.nvim_win_get_buf(self.origin_win) - local lang = ts.language.get_lang(vim.bo[origin_buf].filetype) - if not lang then - return u.notify("Treesitter language not found") - end - self.lang = lang - - local origin_node = u.node_for_range(origin_buf, self.lang, u.get_selection(self.origin_win)) - if not origin_node then - return u.notify("Treesitter parser not found, please try to install it with :TSInstall " .. self.lang) - end - if origin_node:has_error() then - return u.notify "You have syntax errors in the selected node" - end - local parse_context = ParseContext.new(origin_buf, origin_node) - if not parse_context then - return u.notify "Can't find a proper context to parse the pattern" - end - self.parse_context = parse_context - - self.buf_matches = {} - self.ns = api.nvim_create_namespace("ssr_" .. self.origin_win) -- TODO - self.cur_search_ns = api.nvim_create_namespace("ssr_cur_match_" .. self.origin_win) - self.augroup = api.nvim_create_augroup("ssr_augroup_" .. self.origin_win, {}) - - -- Init ui buffer - self.ui_buf = api.nvim_create_buf(false, true) - vim.bo[self.ui_buf].filetype = "ssr" - - local placeholder = ts.get_node_text(origin_node, origin_buf) - placeholder = "\n\n" .. placeholder .. "\n\n" - placeholder = vim.split(placeholder, "\n") - u.remove_indent(placeholder, u.get_indent(origin_buf, origin_node:start())) - api.nvim_buf_set_lines(self.ui_buf, 0, -1, true, placeholder) - -- Enable syntax highlights - ts.start(self.ui_buf, self.lang) - - local function virt_text(row, text) - return api.nvim_buf_set_extmark(self.ui_buf, self.ns, row, 0, { virt_text = text, virt_text_pos = "overlay" }) - end - self.extmarks = { - status = virt_text(0, { { "[SSR]", "Comment" }, { " (Press ? for help)", "Comment" } }), - search = virt_text(1, { { "SEARCH:", "String" } }), - replace = virt_text(#placeholder - 2, { { "REPLACE:", "String" } }), - } - - local function map(key, func) - keymap.set("n", key, function() - func(self) - end, { buffer = self.ui_buf, nowait = true }) - end - map(config.keymaps.replace_confirm, self.replace_confirm) - map(config.keymaps.replace_all, self.replace_all) - map(config.keymaps.next_match, function() - self:goto_match(self:next_match_idx()) - end) - map(config.keymaps.prev_match, function() - self:goto_match(self:prev_match_idx()) - end) - - -- Open float window - local width, height = u.get_win_size(placeholder, config) - local ui_win = api.nvim_open_win(self.ui_buf, true, { - relative = "win", - anchor = "NE", - row = 1, - col = api.nvim_win_get_width(0) - 1, - style = "minimal", - border = config.border, - width = width, - height = height, - }) - u.set_cursor(ui_win, 2, 0) - fn.matchadd("Title", [[$\w\+]]) - - map(config.keymaps.close, function() - api.nvim_win_close(ui_win, false) - end) - - api.nvim_create_autocmd({ "TextChanged", "TextChangedI" }, { - group = self.augroup, - buffer = self.ui_buf, - callback = function() - if config.adjust_window then - local lines = api.nvim_buf_get_lines(self.ui_buf, 0, -1, true) - local width, height = u.get_win_size(lines, config) - if api.nvim_win_get_width(ui_win) ~= width then - api.nvim_win_set_width(ui_win, width) - end - if api.nvim_win_get_height(ui_win) ~= height then - api.nvim_win_set_height(ui_win, height) - end - end - self:search() - end, - }) - - -- SSR window is bound to the original window (not buffer!), which is the same behavior as IDEs and browsers. - api.nvim_create_autocmd("BufWinEnter", { - group = self.augroup, - callback = function(event) - if event.buf == self.ui_buf then - return - end - - local win = api.nvim_get_current_win() - if win == ui_win then - -- Prevent accidentally opening another file in the ssr window. - -- Adapted from neo-tree.nvim. - vim.schedule(function() - api.nvim_win_set_buf(ui_win, self.ui_buf) - local name = api.nvim_buf_get_name(event.buf) - api.nvim_win_call(self.origin_win, function() - pcall(api.nvim_buf_delete, event.buf, {}) - if name ~= "" then - vim.cmd.edit(name) - end - end) - api.nvim_set_current_win(self.origin_win) - end) - return - elseif win ~= self.origin_win then - return - end - - if ts.language.get_lang(vim.bo[event.buf].filetype) ~= self.lang then - return self:set_status "N/A" - end - self:search() - end, - }) - - api.nvim_create_autocmd("WinClosed", { - group = self.augroup, - buffer = self.ui_buf, - callback = function() - win_uis[self.origin_win] = nil - api.nvim_clear_autocmds { group = self.augroup } - api.nvim_buf_delete(self.ui_buf, {}) - for buf in pairs(self.buf_matches) do - api.nvim_buf_clear_namespace(buf, self.ns, 0, -1) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - end - end, - }) - - win_uis[self.origin_win] = self - return self -end - -function Ui:search() - local pattern = self:get_input() - local buf = api.nvim_win_get_buf(self.origin_win) - self.buf_matches[buf] = {} - api.nvim_buf_clear_namespace(buf, self.ns, 0, -1) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - - local start = vim.loop.hrtime() - local node, source = self.parse_context:parse(pattern) - if node:has_error() then - return self:set_status "Error" - end - self.buf_matches[buf] = search(buf, node, source, self.ns) - local elapsed = (vim.loop.hrtime() - start) / 1E6 - for _, match in ipairs(self.buf_matches[buf]) do - local start_row, start_col, end_row, end_col = match.range:get() - highlight.range(buf, self.ns, "Search", { start_row, start_col }, { end_row, end_col }, {}) - end - self:set_status(string.format("%d found in %dms", #self.buf_matches[buf], elapsed)) -end - -function Ui:next_match_idx() - local cursor_row, cursor_col = u.get_cursor(self.origin_win) - local buf = api.nvim_win_get_buf(self.origin_win) - for idx, matches in pairs(self.buf_matches[buf]) do - local start_row, start_col = matches.range:get() - if start_row > cursor_row or (start_row == cursor_row and start_col > cursor_col) then - return idx - end - end - return 1 -end - -function Ui:prev_match_idx() - local cursor_row, cursor_col = u.get_cursor(self.origin_win) - local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] - for idx = #matches, 1, -1 do - local start_row, start_col = matches[idx].range:get() - if start_row < cursor_row or (start_row == cursor_row and start_col < cursor_col) then - return idx - end - end - return #matches -end - -function Ui:goto_match(match_idx) - local buf = api.nvim_win_get_buf(self.origin_win) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - local matches = self.buf_matches[buf] - local start_row, start_col, end_row, end_col = matches[match_idx].range:get() - u.set_cursor(self.origin_win, start_row, start_col) - highlight.range( - buf, - self.cur_search_ns, - "CurSearch", - { start_row, start_col }, - { end_row, end_col }, - { priority = vim.highlight.priorities.user + 100 } - ) - api.nvim_buf_set_extmark(buf, self.cur_search_ns, start_row, start_col, { - virt_text_pos = "eol", - virt_text = { { string.format("[%d/%d]", match_idx, #matches), "DiagnosticVirtualTextInfo" } }, - }) -end - -function Ui:replace_all() - self:search() - local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] - if #matches == 0 then - return self:set_status "pattern not found" - end - local _, template = self:get_input() - local start = vim.loop.hrtime() - for _, match in ipairs(matches) do - replace(buf, match, template) - end - local elapsed = (vim.loop.hrtime() - start) / 1E6 - self:set_status(string.format("%d replaced in %dms", #matches, elapsed)) -end - -function Ui:replace_confirm() - self:search() - local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] - if #matches == 0 then - return self:set_status "pattern not found" - end - - local confirm_buf = api.nvim_create_buf(false, true) - vim.bo[confirm_buf].filetype = "ssr_confirm" - local choices = { - "• Yes", - "• No", - "──────────────", - "• All", - "• Quit", - "• Last replace", - } - local separator_idx = 3 - api.nvim_buf_set_lines(confirm_buf, 0, -1, true, choices) - for idx = 0, #choices - 1 do - if idx + 1 ~= separator_idx then - api.nvim_buf_set_extmark(confirm_buf, self.ns, idx, 4, { hl_group = "Underlined", end_row = idx, end_col = 5 }) - end - end - - local function open_confirm_win(match_idx) - self:goto_match(match_idx) - local _, _, end_row, end_col = matches[match_idx].range:get() - local cfg = { - relative = "win", - win = self.origin_win, - bufpos = { end_row, end_col }, - style = "minimal", - border = config.border, - width = 14, - height = 6, - } - if vim.fn.has "nvim-0.9" == 1 then - cfg.title = "Replace?" - cfg.title_pos = "center" - end - return api.nvim_open_win(confirm_buf, true, cfg) - end - - local match_idx = 1 - local replaced = 0 - local cursor = 1 - local _, template = self:get_input() - self:set_status(string.format("replacing 0/%d", #matches)) - - while match_idx <= #matches do - local confirm_win = open_confirm_win(match_idx) - - ---@type string - local key - while true do - -- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`. - api.nvim_buf_clear_namespace(confirm_buf, self.cur_search_ns, 0, -1) - api.nvim_buf_set_extmark( - confirm_buf, - self.cur_search_ns, - cursor - 1, - 0, - { virt_text = { { "•", "Cursor" } }, virt_text_pos = "overlay" } - ) - api.nvim_buf_set_extmark(confirm_buf, self.cur_search_ns, cursor - 1, 0, { line_hl_group = "CursorLine" }) - vim.cmd.redraw() - - local ok, char = pcall(vim.fn.getcharstr) - key = ok and vim.fn.keytrans(char) or "" - if key == "j" then - if cursor == separator_idx - 1 then -- skip separator - cursor = separator_idx + 1 - elseif cursor == #choices then -- wrap - cursor = 1 - else - cursor = cursor + 1 - end - elseif key == "k" then - if cursor == separator_idx + 1 then -- skip separator - cursor = separator_idx - 1 - elseif cursor == 1 then -- wrap - cursor = #choices - else - cursor = cursor - 1 - end - elseif vim.tbl_contains({ "", "", "", "", "", "" }, key) then - fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key)) - else - break - end - end - - if key == "" then - key = ({ "y", "n", "", "a", "q", "l" })[cursor] - end - - if key == "y" then - replace(buf, matches[match_idx], template) - replaced = replaced + 1 - match_idx = match_idx + 1 - elseif key == "n" then - match_idx = match_idx + 1 - elseif key == "a" then - for i = match_idx, #matches do - replace(buf, matches[i], template) - end - replaced = replaced + #matches + 1 - match_idx - match_idx = #matches + 1 - elseif key == "l" then - replace(buf, matches[match_idx], template) - replaced = replaced + 1 - match_idx = #matches + 1 - elseif key == "q" or key == "" or key == "" then - match_idx = #matches + 1 - end - api.nvim_win_close(confirm_win, false) - self:set_status(string.format("replacing %d/%d", replaced, #matches)) - end - - api.nvim_buf_delete(confirm_buf, {}) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - self:set_status(string.format("%d/%d replaced", replaced, #matches)) -end - -function Ui:get_input() - local lines = api.nvim_buf_get_lines(self.ui_buf, 0, -1, true) - local pattern_pos = api.nvim_buf_get_extmark_by_id(self.ui_buf, self.ns, self.extmarks.search, {})[1] - local template_pos = api.nvim_buf_get_extmark_by_id(self.ui_buf, self.ns, self.extmarks.replace, {})[1] - local pattern = vim.trim(table.concat(lines, "\n", pattern_pos + 2, template_pos)) - local template = vim.trim(table.concat(lines, "\n", template_pos + 1, #lines)) - return pattern, template -end - ----@param status string -function Ui:set_status(status) - api.nvim_buf_set_extmark(self.ui_buf, self.ns, 0, 0, { - id = self.extmarks.status, - virt_text = { - { "[SSR] ", "Comment" }, - { status }, - { " (Press ? for help)", "Comment" }, - }, - virt_text_pos = "overlay", - }) -end - ----@param win window? ----@return Ui? -function Ui.from_win(win) - if win == nil or win == 0 then - win = api.nvim_get_current_win() - end - local ui = win_uis[win] - if not ui then - return u.notify "No open SSR window" - end - return ui -end - -function M.open() - return Ui.new() -end - --- Replace all matches. -function M.replace_all() - local ui = Ui.from_win() - if ui then - ui:replace_all() - end -end - --- Confirm each match. -function M.replace_confirm() - local ui = Ui.from_win() - if ui then - ui:replace_confirm() - end -end - -return M diff --git a/lua/ssr/config.lua b/lua/ssr/config.lua new file mode 100644 index 0000000..d9fa4f7 --- /dev/null +++ b/lua/ssr/config.lua @@ -0,0 +1,24 @@ +local M = {} + +---@class Config +M.opts = { + border = "rounded", + min_width = 70, + max_width = 70, + min_height = 15, + max_height = 15, + adjust_window = false, + keymaps = { + close = "q", + next_match = "n", + prev_match = "N", + replace_confirm = "", + replace_all = "", + }, +} + +function M.set(config) + M.opts = vim.tbl_deep_extend("force", M.opts, config) +end + +return M diff --git a/lua/ssr/file.lua b/lua/ssr/file.lua new file mode 100644 index 0000000..9f9d66e --- /dev/null +++ b/lua/ssr/file.lua @@ -0,0 +1,85 @@ +local ts = vim.treesitter +local uv = vim.uv or vim.loop + +---@class File +---@field path string +---@field content string +---@field mtime { nsec: integer, sec: integer } +---@field lang_tree LanguageTree +local File = {} + +---@type table +local cache = {} + +---@param path string +---@return File? +function File.new(path) + local fd = uv.fs_open(path, "r", 438) + if not fd then + return + end + local stat = uv.fs_fstat(fd) --[[@as uv.aliases.fs_stat_table]] + local self = cache[path] + if self then + if stat.mtime.sec == self.mtime.sec and stat.mtime.nsec == self.mtime.nsec then + uv.fs_close(fd) + return self + else + cache[path] = nil + end + end + + self = setmetatable({ path = path }, { __index = File }) + self.mtime = stat.mtime + self.content = uv.fs_read(fd, stat.size, 0) --[[@as string]] + uv.fs_close(fd) + local first_line = self.content:match ".*\n" + local ft = vim.filetype.match { filename = path, contents = vim.split(self.content, "\n", { plain = true }) } -- not work for .ts + if not ft then + return + end + local lang = ts.language.get_lang(ft) + if not lang then + return + end + local has_parser, lang_tree = pcall(ts.get_string_parser, self.content, lang) + if not has_parser then + return + end + self.lang_tree = lang_tree + self.lang_tree:parse(true) + + cache[path] = self + return self +end + +---@param regex string +---@param on_file fun(file: File) +---@param on_end fun() +---@return nil +function File.grep(regex, on_file, on_end) + vim.system({ "rg", "--line-buffered", "--files-with-matches", regex }, { + text = true, + stdout = vim.schedule_wrap(function(err, files) + if err then + error(files) + end + if not files then + on_end() + return + end + for _, path in ipairs(vim.split(files, "\n", { plain = true, trimempty = true })) do + local file = File.new(path) + if file then + on_file(file) + end + end + end), + }, function(obj) + if obj.code ~= 0 then + error(obj.stderr) + end + end) +end + +return File diff --git a/lua/ssr/init.lua b/lua/ssr/init.lua new file mode 100644 index 0000000..73eb83c --- /dev/null +++ b/lua/ssr/init.lua @@ -0,0 +1,16 @@ +local Ui = require "ssr.ui" +local M = {} + +--- Set config options. Optional. +---@param config Config? +function M.setup(config) + if config then + require("ssr.config").set(config) + end +end + +function M.open() + Ui.new() +end + +return M diff --git a/lua/ssr/parse.lua b/lua/ssr/parse_context.lua similarity index 70% rename from lua/ssr/parse.lua rename to lua/ssr/parse_context.lua index 2e78f40..f889b24 100644 --- a/lua/ssr/parse.lua +++ b/lua/ssr/parse_context.lua @@ -1,8 +1,7 @@ local ts = vim.treesitter -local wildcard_prefix = require("ssr.search").wildcard_prefix - -local M = {} +local u = require "ssr.utils" +-- The context in which user input will be parsed correctly. ---@class ParseContext ---@field lang string ---@field before string @@ -10,18 +9,12 @@ local M = {} ---@field pad_rows integer ---@field pad_cols integer local ParseContext = {} -ParseContext.__index = ParseContext -M.ParseContext = ParseContext --- Create a context in which `origin_node` (and user input) will be parsed correctly. ---@param buf buffer +---@param lang string ---@param origin_node TSNode ---@return ParseContext? -function ParseContext.new(buf, origin_node) - local lang = ts.language.get_lang(vim.bo[buf].filetype) - if not lang then - return - end +function ParseContext.new(buf, lang, origin_node) local self = setmetatable({ lang = lang }, { __index = ParseContext }) local origin_start_row, origin_start_col, origin_start_byte = origin_node:start() @@ -48,8 +41,8 @@ function ParseContext.new(buf, origin_node) if end_row == start_row then end_col = end_col + start_col end - local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col) - if node_in_context and node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then + local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col) --[[@as TSNode]] + if node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then local context_start_byte self.start_row, self.start_col, context_start_byte = context_node:start() self.before = context_text:sub(1, origin_start_byte - context_start_byte) @@ -63,17 +56,17 @@ function ParseContext.new(buf, origin_node) end end --- Parse search pattern to syntax tree in proper context. +-- Parse code to TS node. ---@param pattern string ----@return TSNode?, string +---@return TSNode, string function ParseContext:parse(pattern) -- Replace named wildcard $name to identifier __ssr_var_name to avoid syntax error. - pattern = pattern:gsub("%$([_%a%d]+)", wildcard_prefix .. "%1") - local context_text = self.before .. pattern .. self.after - local root = ts.get_string_parser(context_text, self.lang):parse()[1]:root() + pattern = pattern:gsub("%$([_%a%d]+)", u.wildcard_prefix .. "%1") + local source = self.before .. pattern .. self.after + local root = ts.get_string_parser(source, self.lang):parse()[1]:root() local lines = vim.split(pattern, "\n") - local node = root:named_descendant_for_range(self.pad_rows, self.pad_cols, self.pad_rows + #lines - 1, #lines[#lines]) - return node, context_text + local node = root:named_descendant_for_range(self.pad_rows, self.pad_cols, self.pad_rows + #lines - 1, #lines[#lines]) --[[@as TSNode]] + return node, source end -return M +return ParseContext diff --git a/lua/ssr/range.lua b/lua/ssr/range.lua new file mode 100644 index 0000000..6a27e1b --- /dev/null +++ b/lua/ssr/range.lua @@ -0,0 +1,40 @@ +---@class Range +---@field start_row number +---@field start_col number +---@field start_byte number +---@field end_row number +---@field end_col number +---@field end_byte number +local Range = {} + +---@param node TSNode +---@return Range +function Range.from_node(node) + local start_row, start_col, start_byte = node:start() + local end_row, end_col, end_byte = node:end_() + return setmetatable({ + start_row = start_row, + start_col = start_col, + start_byte = start_byte, + end_row = end_row, + end_col = end_col, + end_byte = end_byte, + }, { __index = Range }) +end + +---@param other Range +---@return boolean +function Range:before(other) + return self.end_row < other.start_row or (self.end_row == other.start_row and self.end_col <= other.start_col) +end + +---@param other Range +---@return boolean +function Range:inside(other) + return ( + (self.start_row > other.start_row or (self.start_row == other.start_row and self.start_col > other.start_col)) + and (self.end_row < other.end_row or (self.end_row == other.end_row and self.end_col <= other.end_col)) + ) +end + +return Range diff --git a/lua/ssr/replace.lua b/lua/ssr/replace.lua new file mode 100644 index 0000000..475e714 --- /dev/null +++ b/lua/ssr/replace.lua @@ -0,0 +1,28 @@ +local api = vim.api +local ts = vim.treesitter +local u = require "ssr.utils" + +local Replacer = {} + +--- Render template and replace one match. +---@param buf buffer +---@param match Match +function Replacer:replace(buf, match) + -- Render templates with captured nodes. + local replace = self.template:gsub("()%$([_%a%d]+)", function(pos, var) + local start_row, start_col, end_row, end_col = match.captures[var]:get() + local lines = api.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {}) + u.remove_indent(lines, u.get_indent(buf, start_row)) + local var_lines = vim.split(self.template:sub(1, pos), "\n") + local var_line = var_lines[#var_lines] + local template_indent = var_line:match "^%s*" + u.add_indent(lines, template_indent) + return table.concat(lines, "\n") + end) + replace = vim.split(replace, "\n") + local start_row, start_col, end_row, end_col = match.range:get() + u.add_indent(replace, u.get_indent(buf, start_row)) + api.nvim_buf_set_text(buf, start_row, start_col, end_row, end_col, replace) +end + +return Replacer diff --git a/lua/ssr/search.lua b/lua/ssr/search.lua index ec6cdcc..250f050 100644 --- a/lua/ssr/search.lua +++ b/lua/ssr/search.lua @@ -1,86 +1,88 @@ local api = vim.api local ts = vim.treesitter +local Range = require "ssr.range" local u = require "ssr.utils" -local M = {} - -M.wildcard_prefix = "__ssr_var_" +local H = {} ---@class Match ----@field range ExtmarkRange ----@field captures ExtmarkRange[] - ----@class ExtmarkRange ----@field ns number ----@field buf buffer ----@field extmark number -local ExtmarkRange = {} -M.ExtmarkRange = ExtmarkRange - ----@param ns number ----@param buf buffer ----@param node TSNode ----@return ExtmarkRange -function ExtmarkRange.new(ns, buf, node) - local start_row, start_col, end_row, end_col = node:range() +---@field range Range +---@field captures table + +---@field captures +---@class Searcher +---@field lang string +---@field query Query +---@field wildcards table +---@field rough_regex string +local Searcher = {} + +---@param lang string +---@param pattern string +---@param parse_context ParseContext +---@return Searcher? +function Searcher.new(lang, pattern, parse_context) + local node, source = parse_context:parse(pattern) + if node:has_error() then + return + end + local sexpr, wildcards, rough_regex = H.build_sexpr(node, source) + local parse_query = ts.query.parse or ts.parse_query + local query = parse_query(lang, sexpr) return setmetatable({ - ns = ns, - buf = buf, - extmark = api.nvim_buf_set_extmark(buf, ns, start_row, start_col, { - end_row = end_row, - end_col = end_col, - right_gravity = false, - end_right_gravity = true, - }), - }, { __index = ExtmarkRange }) -end - ----@return number, number, number, number -function ExtmarkRange:get() - local extmark = api.nvim_buf_get_extmark_by_id(self.buf, self.ns, self.extmark, { details = true }) - return extmark[1], extmark[2], extmark[3].end_row, extmark[3].end_col + lang = lang, + query = query, + wildcards = wildcards, + rough_regex = rough_regex, + }, { __index = Searcher }) end --- Compare if two captured trees can match. --- The check is loose because users want to match different types of node. --- e.g. converting `{ foo: foo }` to shorthand `{ foo }`. -ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred) - ---@param node1 TSNode? - ---@param node2 TSNode? - ---@return boolean - local function tree_match(node1, node2) - if not node1 or not node2 then - return false - end - if node1:named() ~= node2:named() then - return false - end - if node1:child_count() == 0 or node2:child_count() == 0 then - return ts.get_node_text(node1, buf) == ts.get_node_text(node2, buf) - end - if node1:child_count() ~= node2:child_count() then - return false +---@param file File +---@return Match[] +function Searcher:search(file) + local matches = {} + file.lang_tree:for_each_tree(function(tree, lang_tree) -- must called :parse(true) + if lang_tree:lang() ~= self.lang then + return end - for i = 0, node1:child_count() - 1 do - if not tree_match(node1:child(i), node2:child(i)) then - return false + for _, nodes in self.query:iter_matches(tree:root(), file.content, 0, -1) do + local range = Range.from_node(nodes[#nodes]) + local captures = {} + for var, idx in pairs(self.wildcards) do + captures[var] = Range.from_node(nodes[idx]) end + table.insert(matches, { range = range, captures = captures }) end - return true - end - return tree_match(match[pred[2]], match[pred[3]]) -end, true) + end) + + -- Sort matches from + -- buffer top to bottom, to make goto next/prev match intuitive + -- inner to outer for recursive matches, to make replacing correct + ---@param match1 Match + ---@param match2 Match + ---@return boolean + table.sort(matches, function(match1, match2) + if match1.range:before(match2.range) then + return true + end + return match1.range:inside(match2.range) + end) + return matches +end -- Build a TS sexpr represting the node. +-- This function is more strict than `TSNode:sexpr()` by also requiring leaf nodes to match text. ---@param node TSNode ---@param source string ----@return string, table -local function build_sexpr(node, source) +---@return string sexpr +---@return table wildcards +---@return string rough_regex +function H.build_sexpr(node, source) ---@type table local wildcards = {} + local rough_regex = "" local next_idx = 1 - -- This function is more strict than `tsnode:sexpr()` by also requiring leaf nodes to match text. ---@param node TSNode ---@return string local function build(node) @@ -88,7 +90,7 @@ local function build_sexpr(node, source) -- Special identifier __ssr_var_name is a named wildcard. -- Handle this early to make sure wildcard captures largest node. - local var = text:match("^" .. M.wildcard_prefix .. "([_%a%d]+)$") + local var = text:match("^" .. u.wildcard_prefix .. "([_%a%d]+)$") if var then if not wildcards[var] then wildcards[var] = next_idx @@ -104,6 +106,9 @@ local function build_sexpr(node, source) -- Leaf nodes (keyword, identifier, literal and symbol) should match text. if node:named_child_count() == 0 then + if #text > #rough_regex then + rough_regex = text + end local sexpr = string.format("(%s) @_%d (#eq? @_%d %s)", node:type(), next_idx, next_idx, u.to_ts_query_str(text)) next_idx = next_idx + 1 return sexpr @@ -134,76 +139,37 @@ local function build_sexpr(node, source) end local sexpr = string.format("(%s) @all", build(node)) - return sexpr, wildcards + rough_regex = u.regex_escape(rough_regex) + return sexpr, wildcards, rough_regex end ----@param buf buffer ----@param node TSNode ----@param source string ----@return Match[] -function M.search(buf, node, source, ns) - local sexpr, wildcards = build_sexpr(node, source) - local parse_query = ts.query.parse or ts.parse_query - local lang = ts.language.get_lang(vim.bo[buf].filetype) - if not lang then - return {} - end - local query = parse_query(lang, sexpr) - local matches = {} - local has_parser, parser = pcall(ts.get_parser, buf, lang) - if not has_parser then - return {} - end - local root = parser:parse(true)[1]:root() - for _, nodes in query:iter_matches(root, buf, 0, -1) do - ---@type table - local captures = {} - for var, idx in pairs(wildcards) do - captures[var] = ExtmarkRange.new(ns, buf, nodes[idx]) - end - local match = { range = ExtmarkRange.new(ns, buf, nodes[#nodes]), captures = captures } - table.insert(matches, match) - end - - -- Sort matches from - -- buffer top to bottom, to make goto next/prev match intuitive - -- inner to outer for recursive matches, to make replacing correct - ---@param match1 { range: ExtmarkRange, captures: table} - ---@param match2 { range: ExtmarkRange, captures: table} +-- Compare if two captured trees can match. +-- The check is loose because we want to match different types of node. +-- e.g. converting `{ foo: foo }` to shorthand `{ foo }`. +ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred) + ---@param node1 TSNode + ---@param node2 TSNode ---@return boolean - table.sort(matches, function(match1, match2) - local start_row1, start_col1, end_row1, end_col1 = match1.range:get() - local start_row2, start_col2, end_row2, end_col2 = match2.range:get() - if end_row1 < start_row2 or (end_row1 == start_row2 and end_col1 <= start_col2) then - return true + local function tree_match(node1, node2) + if node1:named() ~= node2:named() then + return false end - return (start_row1 > start_row2 or (start_row1 == start_row2 and start_col1 > start_col2)) - and (end_row1 < end_row2 or (end_row1 == end_row2 and end_col1 <= end_col2)) - end) - - return matches -end - ---- Render template and replace one match. ----@param buf buffer ----@param match Match ----@param template string -function M.replace(buf, match, template) - -- Render templates with captured nodes. - local replace = template:gsub("()%$([_%a%d]+)", function(pos, var) - local start_row, start_col, end_row, end_col = match.captures[var]:get() - local lines = api.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {}) - u.remove_indent(lines, u.get_indent(buf, start_row)) - local var_lines = vim.split(template:sub(1, pos), "\n") - local var_line = var_lines[#var_lines] - local template_indent = var_line:match "^%s*" - u.add_indent(lines, template_indent) - return table.concat(lines, "\n") - end) - replace = vim.split(replace, "\n") - local start_row, start_col, end_row, end_col = match.range:get() - u.add_indent(replace, u.get_indent(buf, start_row)) - api.nvim_buf_set_text(buf, start_row, start_col, end_row, end_col, replace) -end + if node1:child_count() == 0 or node2:child_count() == 0 then + return ts.get_node_text(node1, buf) == ts.get_node_text(node2, buf) + end + if node1:child_count() ~= node2:child_count() then + return false + end + for i = 0, node1:child_count() - 1 do + if + not tree_match(node1:child(i) --[[@as TSNode]], node2:child(i) --[[@as TSNode]]) + then + return false + end + end + return true + end + return tree_match(match[pred[2]], match[pred[3]]) +end, true) -return M +return Searcher diff --git a/lua/ssr/ui/confirm_win.lua b/lua/ssr/ui/confirm_win.lua new file mode 100644 index 0000000..e02fa19 --- /dev/null +++ b/lua/ssr/ui/confirm_win.lua @@ -0,0 +1,139 @@ +local api = vim.api + +---@class ConfirmWin +local ConfirmWin = {} + +function ConfirmWin.new() end + +function ConfirmWin:open() + local buf = api.nvim_win_get_buf(self.origin_win) + local matches = self.matches[buf] + if #matches == 0 then + return self:set_status "pattern not found" + end + + local confirm_buf = api.nvim_create_buf(false, true) + vim.bo[confirm_buf].filetype = "ssr_confirm" + local choices = { + "• Yes", + "• No", + "──────────────", + "• All", + "• Quit", + "• Last replace", + } + local separator_idx = 3 + api.nvim_buf_set_lines(confirm_buf, 0, -1, true, choices) + for idx = 0, #choices - 1 do + if idx + 1 ~= separator_idx then + api.nvim_buf_set_extmark( + confirm_buf, + u.namespace, + idx, + 4, + { hl_group = "Underlined", end_row = idx, end_col = 5 } + ) + end + end + + local function open_confirm_win(match_idx) + self:goto_match(match_idx) + local _, _, end_row, end_col = matches[match_idx].range:get() + local cfg = { + relative = "win", + win = self.origin_win, + bufpos = { end_row, end_col }, + style = "minimal", + border = config.options.border, + width = 14, + height = 6, + } + if vim.fn.has "nvim-0.9" == 1 then + cfg.title = "Replace?" + cfg.title_pos = "center" + end + return api.nvim_open_win(confirm_buf, true, cfg) + end + + local match_idx = 1 + local replaced = 0 + local cursor = 1 + local _, template = self:get_input() + self:set_status(string.format("replacing 0/%d", #matches)) + + while match_idx <= #matches do + local confirm_win = open_confirm_win(match_idx) + + ---@type string + local key + while true do + -- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`. + api.nvim_buf_clear_namespace(confirm_buf, u.cur_search_ns, 0, -1) + api.nvim_buf_set_extmark( + confirm_buf, + u.cur_search_ns, + cursor - 1, + 0, + { virt_text = { { "•", "Cursor" } }, virt_text_pos = "overlay" } + ) + api.nvim_buf_set_extmark(confirm_buf, u.cur_search_ns, cursor - 1, 0, { line_hl_group = "CursorLine" }) + vim.cmd.redraw() + + local ok, char = pcall(vim.fn.getcharstr) + key = ok and vim.fn.keytrans(char) or "" + if key == "j" then + if cursor == separator_idx - 1 then -- skip separator + cursor = separator_idx + 1 + elseif cursor == #choices then -- wrap + cursor = 1 + else + cursor = cursor + 1 + end + elseif key == "k" then + if cursor == separator_idx + 1 then -- skip separator + cursor = separator_idx - 1 + elseif cursor == 1 then -- wrap + cursor = #choices + else + cursor = cursor - 1 + end + elseif vim.tbl_contains({ "", "", "", "", "", "" }, key) then + vim.fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key)) + else + break + end + end + + if key == "" then + key = ({ "y", "n", "", "a", "q", "l" })[cursor] + end + + if key == "y" then + replace(buf, matches[match_idx], template) + replaced = replaced + 1 + match_idx = match_idx + 1 + elseif key == "n" then + match_idx = match_idx + 1 + elseif key == "a" then + for i = match_idx, #matches do + replace(buf, matches[i], template) + end + replaced = replaced + #matches + 1 - match_idx + match_idx = #matches + 1 + elseif key == "l" then + replace(buf, matches[match_idx], template) + replaced = replaced + 1 + match_idx = #matches + 1 + elseif key == "q" or key == "" or key == "" then + match_idx = #matches + 1 + end + api.nvim_win_close(confirm_win, false) + self:set_status(string.format("replacing %d/%d", replaced, #matches)) + end + + api.nvim_buf_delete(confirm_buf, {}) + api.nvim_buf_clear_namespace(buf, u.cur_search_ns, 0, -1) + self:set_status(string.format("%d/%d replaced", replaced, #matches)) +end + +return ConfirmWin diff --git a/lua/ssr/ui/init.lua b/lua/ssr/ui/init.lua new file mode 100644 index 0000000..0647012 --- /dev/null +++ b/lua/ssr/ui/init.lua @@ -0,0 +1,162 @@ +local api = vim.api +local ts = vim.treesitter +local ParseContext = require "ssr.parse_context" +local Searcher = require "ssr.search" +local Replacer = require "ssr.replace" +local File = require "ssr.file" +local MainWin = require "ssr.ui.main_win" +local u = require "ssr.utils" + +---@class Ui +---@field lang string +---@field parse_context ParseContext +---@field results { file: File, matches: Match[] }[] +---@field last_pattern string +---@field main_win MainWin +local Ui = {} + +---@return Ui? +function Ui.new() + local self = setmetatable({ matches = {} }, { __index = Ui }) + + -- Pre-checks + local origin_win = api.nvim_get_current_win() + local origin_buf = api.nvim_win_get_buf(origin_win) + local lang = ts.language.get_lang(vim.bo[origin_buf].filetype) + if not lang then + return u.notify "Treesitter language not found" + end + self.lang = lang + local origin_node = u.node_for_range(origin_buf, self.lang, u.get_selection(origin_win)) + if not origin_node then + return u.notify("Treesitter parser not found, please try to install it with :TSInstall " .. self.lang) + end + if origin_node:has_error() then + return u.notify "You have syntax errors in the selected node" + end + local parse_context = ParseContext.new(origin_buf, self.lang, origin_node) + if not parse_context then + return u.notify "Can't find a proper context to parse the pattern" + end + self.parse_context = parse_context + + local placeholder = vim.split(ts.get_node_text(origin_node, origin_buf), "\n", { plain = true }) + u.remove_indent(placeholder, u.get_indent(origin_buf, origin_node:start())) + + self.main_win = MainWin.new(lang, placeholder, { "" }, origin_win) + + self.main_win:on({ "InsertLeave", "TextChanged" }, function() + self:search() + end) + + self:search() + return self +end + +function Ui:search() + local pattern = self.main_win:get_input() + if pattern == self.last_pattern then + return + end + self.last_pattern = pattern + + self.results = {} + local found = 0 + local matched_files = 0 + local start = vim.loop.hrtime() + local searcher = Searcher.new(self.lang, pattern, self.parse_context) + if not searcher then + return self:set_status "Error" + end + + File.grep(searcher.rough_regex, function(file) + local matches = searcher:search(file) + if #matches == 0 then + return + end + found = found + #matches + matched_files = matched_files + 1 + table.insert(self.results, { file = file, matches = matches }) + end, function() + local elapsed = (vim.loop.hrtime() - start) / 1E6 + self.main_win.result_list:set(self.results) + self:set_status(string.format("%d found in %d files (%dms)", found, matched_files, elapsed)) + end) +end + +function Ui:next_match_idx() + local cursor_row, cursor_col = u.get_cursor(origin_win) + local buf = api.nvim_win_get_buf(origin_win) + for idx, matches in pairs(self.matches[buf]) do + local start_row, start_col = matches.range:get() + if start_row > cursor_row or (start_row == cursor_row and start_col > cursor_col) then + return idx + end + end + return 1 +end + +function Ui:prev_match_idx() + local cursor_row, cursor_col = u.get_cursor(origin_win) + local buf = api.nvim_win_get_buf(origin_win) + local matches = self.matches[buf] + for idx = #matches, 1, -1 do + local start_row, start_col = matches[idx].range:get() + if start_row < cursor_row or (start_row == cursor_row and start_col < cursor_col) then + return idx + end + end + return #matches +end + +function Ui:goto_match(match_idx) + local buf = api.nvim_win_get_buf(origin_win) + api.nvim_buf_clear_namespace(buf, u.cur_search_ns, 0, -1) + local matches = self.matches[buf] + local start_row, start_col, end_row, end_col = matches[match_idx].range:get() + u.set_cursor(origin_win, start_row, start_col) + vim.highlight.range( + buf, + u.cur_search_ns, + "CurSearch", + { start_row, start_col }, + { end_row, end_col }, + { priority = vim.highlight.priorities.user + 100 } + ) + api.nvim_buf_set_extmark(buf, u.cur_search_ns, start_row, start_col, { + virt_text_pos = "eol", + virt_text = { { string.format("[%d/%d]", match_idx, #matches), "DiagnosticVirtualTextInfo" } }, + }) +end + +function Ui:replace_all() + self:search() + local buf = api.nvim_win_get_buf(origin_win) + local matches = self.matches[buf] + if #matches == 0 then + return self:set_status "pattern not found" + end + local _, template = self:get_input() + local start = vim.loop.hrtime() + for _, match in ipairs(matches) do + replace(buf, match, template) + end + local elapsed = (vim.loop.hrtime() - start) / 1E6 + self:set_status(string.format("%d replaced in %dms", #matches, elapsed)) +end + +---@param status string +---@return nil +function Ui:set_status(status) + api.nvim_buf_set_extmark(self.main_win.buf, u.namespace, 0, 0, { + id = self.main_win.extmarks.status, + virt_text = { + { "[SSR] ", "Comment" }, + { status }, + { " (Press ? for help)", "Comment" }, + }, + virt_text_pos = "overlay", + }) +end + +return Ui diff --git a/lua/ssr/ui/main_win.lua b/lua/ssr/ui/main_win.lua new file mode 100644 index 0000000..b0e3495 --- /dev/null +++ b/lua/ssr/ui/main_win.lua @@ -0,0 +1,240 @@ +local api = vim.api +local ts = vim.treesitter +local config = require "ssr.config" +local ResultList = require "ssr.ui.result_list" +local u = require "ssr.utils" + +---@class MainWin +---@field buf buffer +---@field win window +---@field origin_win window +---@field lang string +---@field last_pattern string[] +---@field last_template string[] +---@field result_list ResultList +local MainWin = {} + +function MainWin.new(lang, pattern, template, origin_win) + local self = setmetatable({ + lang = lang, + last_pattern = pattern, + last_template = template, + origin_win = origin_win, + }, { __index = MainWin }) + + self.buf = api.nvim_create_buf(false, true) + vim.bo[self.buf].filetype = "ssr" + + local lines = self:render() + self:open_win(u.get_win_size(lines)) + self.result_list = ResultList.new(self.buf, self.win, self.extmarks.results) + + self:setup_autocmds() + self:setup_keymaps() + + return self +end + +function MainWin:render() + ts.stop(self.buf) + api.nvim_buf_clear_namespace(self.buf, u.namespace, 0, -1) + + local lines = { + "", -- [SSR] + "```" .. self.lang, -- SEARCH: + } + vim.list_extend(lines, self.last_pattern) + table.insert(lines, "") -- REPLACE: + vim.list_extend(lines, self.last_template) + table.insert(lines, "```") -- RESULTS: + api.nvim_buf_set_lines(self.buf, 0, -1, true, lines) + + -- Enable syntax highlights for input area. + local parser = ts.get_parser(self.buf, "markdown") + parser:parse(true) + parser:for_each_tree(function(tree, lang_tree) + if tree:root():start() == 2 then + ts.highlighter.new(lang_tree) + end + end) + + local function virt_text(row, text) + return api.nvim_buf_set_extmark(self.buf, u.namespace, row, 0, { virt_text = text, virt_text_pos = "overlay" }) + end + self.extmarks = { + status = virt_text(0, { { "[SSR]", "Comment" }, { " (Press ? for help)", "Comment" } }), + search = virt_text(1, { { "SEARCH:", "String" } }), + replace = virt_text(#lines - 3, { { "REPLACE:", "String" } }), + results = virt_text(#lines - 1, { { "RESULTS:", "String" } }), + } + + -- RESULTS extmark is re-created + -- self.result_list.extmark = self.extmarks.results + + return lines +end + +function MainWin:check(lines) + if #lines < 6 then + return false + end + + local function get_index(extmark) + return api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, extmark, {})[1] + 1 + end + + return get_index(self.extmarks.status) == 1 + and lines[1] == "" + and get_index(self.extmarks.search) == 2 + and lines[2] == "```" .. self.lang + and lines[get_index(self.extmarks.replace)] == "" + and lines[get_index(self.extmarks.results)] == "```" +end + +function MainWin:open_win(width, height) + self.win = api.nvim_open_win(self.buf, true, { + relative = "editor", + anchor = "NE", + row = 0, + col = vim.o.columns - 1, + style = "minimal", + border = config.opts.border, + width = width, + height = height, + }) + vim.wo[self.win].wrap = false + u.set_cursor(self.win, 2, 0) + vim.fn.matchadd("Title", [[$\w\+]]) +end + +function MainWin:on(event, func) + api.nvim_create_autocmd(event, { + group = u.augroup, + buffer = self.buf, + callback = func, + }) +end + +function MainWin:setup_autocmds() + self:on({ "TextChanged", "TextChangedI" }, function() + local lines = api.nvim_buf_get_lines(self.buf, 0, -1, true) + if not self:check(lines) then + self:render() + end + if not config.opts.adjust_window then + return + end + local width, height = u.get_win_size(lines) + if api.nvim_win_get_width(self.win) ~= width then + api.nvim_win_set_width(self.win, width) + end + if api.nvim_win_get_height(self.win) ~= height then + api.nvim_win_set_height(self.win, height) + end + end) + + self:on("BufWinEnter", function(event) + if event.buf == self.buf then + return + end + local win = api.nvim_get_current_win() + if win ~= self.win then + return + end + -- Prevent accidentally opening another file in the ssr window. + -- Adapted from neo-tree.nvim. + vim.schedule(function() + api.nvim_win_set_buf(self.win, self.buf) + local name = api.nvim_buf_get_name(event.buf) + api.nvim_win_call(self.origin_win, function() + pcall(api.nvim_buf_delete, event.buf, {}) + if name ~= "" then + vim.cmd.edit(name) + end + end) + api.nvim_set_current_win(self.origin_win) + end) + end) + + self:on("WinClosed", function() + api.nvim_clear_autocmds { group = u.augroup } + api.nvim_buf_delete(self.buf, {}) + end) +end + +function MainWin:on_key(key, func) + vim.keymap.set("n", key, func, { buffer = self.buf, nowait = true }) +end + +function MainWin:setup_keymaps() + self:on_key(config.opts.keymaps.close, function() + api.nvim_win_close(self.win, false) + end) + + self:on_key("gg", function() + u.set_cursor(self.win, 2, 0) + end) + + self:on_key("j", function() + local cursor = u.get_cursor(self.win) + for _, extmark in ipairs { self.extmarks.replace, self.extmarks.results } do + local skip_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, extmark, {})[1] + if cursor == skip_pos - 1 then + return u.set_cursor(self.win, skip_pos + 1, 0) + end + end + vim.fn.feedkeys("j", "n") + end) + + self:on_key("k", function() + local cursor = u.get_cursor(self.win) + if cursor <= 2 then + return u.set_cursor(self.win, 2, 0) + end + for _, extmark in ipairs { self.extmarks.replace, self.extmarks.results } do + local skip_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, extmark, {})[1] + if cursor == skip_pos + 1 then + return u.set_cursor(self.win, skip_pos - 1, 0) + end + end + vim.fn.feedkeys("k", "n") + end) + + self:on_key("l", function() + local cursor = u.get_cursor(self.win) + if cursor < self.result_list:get_start() then + return vim.fn.feedkeys("l", "n") + end + self.result_list:set_folded(false) + end) + + self:on_key("h", function() + local cursor = u.get_cursor(self.win) + if cursor < self.result_list:get_start() then + return vim.fn.feedkeys("h", "n") + end + self.result_list:set_folded(true) + end) + + self:on_key(config.opts.keymaps.next_match, function() + local cursor = u.get_cursor(self.win) + local first = self.result_list:get_start() + 1 + if cursor < first then + u.set_cursor(self.win, first, 0) + else + u.set_cursor(self.win, cursor + 1, 0) + end + end) +end + +function MainWin:get_input() + local pattern_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmarks.search, {})[1] + local template_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmarks.replace, {})[1] + local results_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmarks.results, {})[1] + local lines = api.nvim_buf_get_lines(self.buf, 0, results_pos, true) + local pattern = table.concat(lines, "\n", pattern_pos + 2, template_pos) + local template = table.concat(lines, "\n", template_pos + 2) + return pattern, template +end + +return MainWin diff --git a/lua/ssr/ui/result_list.lua b/lua/ssr/ui/result_list.lua new file mode 100644 index 0000000..7473fc1 --- /dev/null +++ b/lua/ssr/ui/result_list.lua @@ -0,0 +1,129 @@ +local api = vim.api +local u = require "ssr.utils" + +---@class Item list item per line +---@field fold_idx number which fold this line belongs to, 1-based +---@field match_idx number which match this line belongs to, 0-based, 0 for filename + +---@class Fold a foldable region that may span multiple lines +---@field folded boolean +---@field filename string +---@field path string +---@field preview_lines string[] +local Fold = {} + +function Fold.new(folded, file, matches) + local preview_lines = {} + for _, match in ipairs(matches) do + local line = vim.split(file.content, "\n", { plain = true })[match.range.start_row + 1] + line = line:gsub("^%s*", "") + table.insert(preview_lines, " " .. line) + end + return setmetatable({ + folded = folded, + filename = vim.fn.fnamemodify(file.path, ":t"), + path = vim.fn.fnamemodify(file.path, ":~:.:h"), + preview_lines = preview_lines, + }, { __index = Fold }) +end + +function Fold:len() + if self.folded then + return 1 + end + return 1 + #self.preview_lines +end + +function Fold:get_lines() + if self.folded then + return { string.format(" %s %s %d", self.filename, self.path, #self.preview_lines) } + end + local lines = { string.format(" %s %s %d", self.filename, self.path, #self.preview_lines) } + vim.list_extend(lines, self.preview_lines) + return lines +end + +---@class ResultList +---@field buf buffer +---@field win window +---@field extmark number +---@field folds Fold[] +---@field items Item[] +local ResultList = {} + +function ResultList.new(buf, win, extmark) + return setmetatable({ + buf = buf, + win = win, + extmark = extmark, + folds = {}, + items = {}, + }, { __index = ResultList }) +end + +function ResultList:get_start() + return api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmark, {})[1] + 1 +end + +---@params results { file: File, matches: Match[] }[] +function ResultList:set(results) + self.folds = {} + self.items = {} + local start = self:get_start() + api.nvim_buf_clear_namespace(self.buf, u.namespace, start, -1) + + local lines = {} + for fold_idx, result in ipairs(results) do + local fold = Fold.new(fold_idx ~= 1, result.file, result.matches) + table.insert(self.folds, fold) + for match_idx, line in ipairs(fold:get_lines()) do + table.insert(lines, line) + table.insert(self.items, { fold_idx = fold_idx, match_idx = match_idx - 1 }) + end + end + api.nvim_buf_set_lines(self.buf, start, -1, true, lines) + + for _, fold in ipairs(self.folds) do + self:highlight_fold(start, fold) + start = start + fold:len() + end +end + +function ResultList:highlight_fold(row, fold) + local col = 4 -- "" is 3 bytes, plus 1 space + api.nvim_buf_add_highlight(self.buf, u.namespace, "Directory", row, col, col + #fold.filename) + col = col + #fold.filename + 1 + api.nvim_buf_add_highlight(self.buf, u.namespace, "Comment", row, col, col + #fold.path) + col = col + #fold.path + 1 + api.nvim_buf_add_highlight(self.buf, u.namespace, "Number", row, col, col + #(tostring(fold.preview_lines))) +end + +function ResultList:set_folded(folded) + local result_start = self:get_start() + local cursor = u.get_cursor(self.win) - result_start + local item = self.items[cursor + 1] -- +1 beacause `cursor` is 0-based + local fold = self.folds[item.fold_idx] + if fold.folded == folded then + return + end + + local start = cursor - item.match_idx -- like C macro `container_of` + local end_ = start + fold:len() + fold.folded = folded + local lines = fold:get_lines() + local items = {} + for i = 0, #lines - 1 do + table.insert(items, { fold_idx = item.fold_idx, match_idx = i }) + end + u.list_replace(self.items, start, end_, items) + start = result_start + start + end_ = result_start + end_ + api.nvim_buf_set_lines(self.buf, start, end_, true, lines) + self:highlight_fold(start, fold) + if folded then + u.set_cursor(self.win, start, 0) + vim.fn.feedkeys("zb", "n") + end +end + +return ResultList diff --git a/lua/ssr/utils.lua b/lua/ssr/utils.lua index 099ed92..a603381 100644 --- a/lua/ssr/utils.lua +++ b/lua/ssr/utils.lua @@ -1,8 +1,13 @@ local api = vim.api local ts = vim.treesitter - +local config = require "ssr.config" local M = {} +M.wildcard_prefix = "__ssr_var_" +M.namespace = api.nvim_create_namespace "ssr_ns" +M.cur_search_ns = api.nvim_create_namespace "ssr_cur_search_ns" +M.augroup = api.nvim_create_augroup("ssr_augroup", {}) + -- Send a notification titled SSR. ---@param msg string ---@return nil @@ -103,10 +108,9 @@ end -- Compute window size to show giving lines. ---@param lines string[] ----@param config Config ---@return number ---@return number -function M.get_win_size(lines, config) +function M.get_win_size(lines) ---@param i number ---@param min number ---@param max number @@ -122,9 +126,74 @@ function M.get_win_size(lines, config) end end - width = clamp(width, config.min_width, config.max_width) - local height = clamp(#lines, config.min_height, config.max_height) + width = clamp(width, config.opts.min_width, config.opts.max_width) + local height = clamp(#lines, config.opts.min_height, config.opts.max_height) return width, height end +-- https://github.com/rust-lang/regex/blob/17284451f10aa06c6c42e622e3529b98513901a8/regex-syntax/src/lib.rs#L272 +local regex_meta_chars = { + ["\\"] = true, + ["."] = true, + ["+"] = true, + ["*"] = true, + ["?"] = true, + ["("] = true, + [")"] = true, + ["|"] = true, + ["["] = true, + ["]"] = true, + ["{"] = true, + ["}"] = true, + ["^"] = true, + ["$"] = true, + ["#"] = true, + ["&"] = true, + ["-"] = true, + ["~"] = true, +} + +---@param s string +---@return string +function M.regex_escape(s) + local escaped = s:gsub(".", function(ch) + return regex_meta_chars[ch] and "\\" .. ch or ch + end) + return escaped +end + +---@generic T +---@param list T[] +---@param f fun(T): -1 | 0 | 1 +---@return number? +function M.binary_search_by(list, f) + local left = 1 + local right = #list + 1 + while left < right do + local mid = math.floor((left + right) / 2) + local cmp = f(list[mid]) + if cmp < 0 then + left = mid + 1 + elseif cmp > 0 then + right = mid + else + return mid + end + end +end + +---@generic T +---@param list table +---@param start number 0-based +---@param end_ number exclusive +---@param replacement table +function M.list_replace(list, start, end_, replacement) + for _ = start + 1, end_ do + table.remove(list, start + 1) + end + for i = start + 1, start + #replacement do + table.insert(list, i, replacement[i - start]) + end +end + return M diff --git a/tests/ssr_spec.lua b/tests/ssr_spec.lua index f6fcd17..bc8c242 100644 --- a/tests/ssr_spec.lua +++ b/tests/ssr_spec.lua @@ -1,8 +1,7 @@ -local u = require "ssr.utils" -local ParseContext = require("ssr.parse").ParseContext local ts = vim.treesitter -local search = require("ssr.search").search -local replace = require("ssr.search").replace +local s = require "ssr.search" +local ParseContext, Ssr = s.ParseContext, s.Ssr +local u = require "ssr.utils" local tests = {} @@ -257,14 +256,15 @@ describe("", function() local lang = ts.language.get_lang(vim.bo[buf].filetype) assert(lang, "language not found") local origin_node = u.node_for_range(buf, lang, start_row, start_col, end_row, end_col) - - local parse_context = ParseContext.new(buf, origin_node) + assert(origin_node, 'treesitter parser not installed') + local parse_context = ParseContext.new(buf, lang, origin_node) assert(parse_context) - local node, source = parse_context:parse(pattern) - local matches = search(buf, node, source, ns) + local rule = Ssr.new(lang, pattern, template, parse_context) + assert(rule) + local matches = rule:search(buf) for _, match in ipairs(matches) do - replace(buf, match, template) + rule:replace(buf, match) end local actual = vim.api.nvim_buf_get_lines(buf, 0, -1, true)