diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2a0df4a..dcf5dcb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,34 +9,41 @@ jobs: strategy: matrix: - nvim: [v0.9.1, nightly] + nvim: [nightly] + + env: + RIPGREP_VERSION: "14.1.0" + VIM: ~/.local/share/nvim/share/nvim/runtime steps: - uses: actions/checkout@v3 - - name: Set Envs + - name: Add PATH run: | - echo "VIM=~/.local/share/nvim/share/nvim/runtime" >> $GITHUB_ENV - echo "PATH=~/.local/share/nvim/bin:$PATH" >> $GITHUB_ENV + echo "$HOME/.local/share/nvim/bin" >> $GITHUB_PATH + echo "$HOME/.local/share/ripgrep" >> $GITHUB_PATH - name: Cache Dependencies id: cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: - path: ~/.local/share/nvim - key: ${{ runner.os }}-nvim-${{ matrix.nvim }} + key: ${{ runner.os }}-nvim-${{ matrix.nvim }}-rg-${{ env.RIPGREP_VERSION }} + path: | + ~/.local/share/nvim + ~/.local/share/ripgrep - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' run: | - mkdir -p ~/.local/share/nvim/ + mkdir -p ~/.local/share/{nvim,ripgrep} + curl -sL "https://github.com/BurntSushi/ripgrep/releases/download/$RIPGREP_VERSION/ripgrep-$RIPGREP_VERSION-x86_64-unknown-linux-musl.tar.gz" | tar xzf - --strip-components=1 -C ~/.local/share/ripgrep curl -sL "https://github.com/neovim/neovim/releases/download/${{ matrix.nvim }}/nvim-linux64.tar.gz" | tar xzf - --strip-components=1 -C ~/.local/share/nvim/ git clone --depth 1 https://github.com/nvim-treesitter/nvim-treesitter.git ~/.local/share/nvim/site/pack/vendor/start/nvim-treesitter git clone --depth 1 https://github.com/nvim-lua/plenary.nvim ~/.local/share/nvim/site/pack/vendor/start/plenary.nvim - ln -s $(pwd) ~/.local/share/nvim/site/pack/vendor/start - nvim --headless -c 'TSInstallSync python javascript lua rust go' -c 'q' + ln -s $PWD ~/.local/share/nvim/site/pack/vendor/start + nvim --headless '+TSInstallSync python javascript lua rust go' +q - name: Run tests run: | - nvim --version - nvim --headless -c 'PlenaryBustedDirectory tests/' + nvim --version | head -1 && rg --version | head -1 + nvim --headless '+PlenaryBustedDirectory tests/' diff --git a/.neoconf.json b/.neoconf.json new file mode 100644 index 0000000..39b80bd --- /dev/null +++ b/.neoconf.json @@ -0,0 +1,20 @@ +{ + "neodev": { + "library": { + "enabled": true, + "plugins": ["plenary.nvim"] + } + }, + "neoconf": { + "plugins": { + "lua_ls": { + "enabled": true + } + } + }, + "lspconfig": { + "lua_ls": { + "Lua.format.enable": false + } + } +} diff --git a/lua/ssr.lua b/lua/ssr.lua deleted file mode 100644 index a32cfe4..0000000 --- a/lua/ssr.lua +++ /dev/null @@ -1,471 +0,0 @@ -local api = vim.api -local ts = vim.treesitter -local fn = vim.fn -local keymap = vim.keymap -local highlight = vim.highlight -local ParseContext = require("ssr.parse").ParseContext -local search = require("ssr.search").search -local replace = require("ssr.search").replace -local u = require "ssr.utils" - -local M = {} - ----@class Config -local config = { - border = "rounded", - min_width = 50, - min_height = 5, - max_width = 120, - max_height = 25, - adjust_window = true, - keymaps = { - close = "q", - next_match = "n", - prev_match = "N", - replace_confirm = "", - replace_all = "", - }, -} - --- Set config options. ----@param cfg Config? -function M.setup(cfg) - if cfg then - config = vim.tbl_deep_extend("force", config, cfg) - end -end - ----@type table -local win_uis = {} - ----@class Ui ----@field ns number ----@field cur_search_ns number ----@field augroup number ----@field ui_buf buffer ----@field extmarks {status: number, search: number, replace: number} ----@field origin_win window ----@field lang string ----@field parse_context ParseContext ----@field buf_matches table -local Ui = {} - ----@return Ui? -function Ui.new() - local self = setmetatable({}, { __index = Ui }) - - self.origin_win = api.nvim_get_current_win() - local origin_buf = api.nvim_win_get_buf(self.origin_win) - local lang = ts.language.get_lang(vim.bo[origin_buf].filetype) - if not lang then - return u.notify("Treesitter language not found") - end - self.lang = lang - - local origin_node = u.node_for_range(origin_buf, self.lang, u.get_selection(self.origin_win)) - if not origin_node then - return u.notify("Treesitter parser not found, please try to install it with :TSInstall " .. self.lang) - end - if origin_node:has_error() then - return u.notify "You have syntax errors in the selected node" - end - local parse_context = ParseContext.new(origin_buf, origin_node) - if not parse_context then - return u.notify "Can't find a proper context to parse the pattern" - end - self.parse_context = parse_context - - self.buf_matches = {} - self.ns = api.nvim_create_namespace("ssr_" .. self.origin_win) -- TODO - self.cur_search_ns = api.nvim_create_namespace("ssr_cur_match_" .. self.origin_win) - self.augroup = api.nvim_create_augroup("ssr_augroup_" .. self.origin_win, {}) - - -- Init ui buffer - self.ui_buf = api.nvim_create_buf(false, true) - vim.bo[self.ui_buf].filetype = "ssr" - - local placeholder = ts.get_node_text(origin_node, origin_buf) - placeholder = "\n\n" .. placeholder .. "\n\n" - placeholder = vim.split(placeholder, "\n") - u.remove_indent(placeholder, u.get_indent(origin_buf, origin_node:start())) - api.nvim_buf_set_lines(self.ui_buf, 0, -1, true, placeholder) - -- Enable syntax highlights - ts.start(self.ui_buf, self.lang) - - local function virt_text(row, text) - return api.nvim_buf_set_extmark(self.ui_buf, self.ns, row, 0, { virt_text = text, virt_text_pos = "overlay" }) - end - self.extmarks = { - status = virt_text(0, { { "[SSR]", "Comment" }, { " (Press ? for help)", "Comment" } }), - search = virt_text(1, { { "SEARCH:", "String" } }), - replace = virt_text(#placeholder - 2, { { "REPLACE:", "String" } }), - } - - local function map(key, func) - keymap.set("n", key, function() - func(self) - end, { buffer = self.ui_buf, nowait = true }) - end - map(config.keymaps.replace_confirm, self.replace_confirm) - map(config.keymaps.replace_all, self.replace_all) - map(config.keymaps.next_match, function() - self:goto_match(self:next_match_idx()) - end) - map(config.keymaps.prev_match, function() - self:goto_match(self:prev_match_idx()) - end) - - -- Open float window - local width, height = u.get_win_size(placeholder, config) - local ui_win = api.nvim_open_win(self.ui_buf, true, { - relative = "win", - anchor = "NE", - row = 1, - col = api.nvim_win_get_width(0) - 1, - style = "minimal", - border = config.border, - width = width, - height = height, - }) - u.set_cursor(ui_win, 2, 0) - fn.matchadd("Title", [[$\w\+]]) - - map(config.keymaps.close, function() - api.nvim_win_close(ui_win, false) - end) - - api.nvim_create_autocmd({ "TextChanged", "TextChangedI" }, { - group = self.augroup, - buffer = self.ui_buf, - callback = function() - if config.adjust_window then - local lines = api.nvim_buf_get_lines(self.ui_buf, 0, -1, true) - local width, height = u.get_win_size(lines, config) - if api.nvim_win_get_width(ui_win) ~= width then - api.nvim_win_set_width(ui_win, width) - end - if api.nvim_win_get_height(ui_win) ~= height then - api.nvim_win_set_height(ui_win, height) - end - end - self:search() - end, - }) - - -- SSR window is bound to the original window (not buffer!), which is the same behavior as IDEs and browsers. - api.nvim_create_autocmd("BufWinEnter", { - group = self.augroup, - callback = function(event) - if event.buf == self.ui_buf then - return - end - - local win = api.nvim_get_current_win() - if win == ui_win then - -- Prevent accidentally opening another file in the ssr window. - -- Adapted from neo-tree.nvim. - vim.schedule(function() - api.nvim_win_set_buf(ui_win, self.ui_buf) - local name = api.nvim_buf_get_name(event.buf) - api.nvim_win_call(self.origin_win, function() - pcall(api.nvim_buf_delete, event.buf, {}) - if name ~= "" then - vim.cmd.edit(name) - end - end) - api.nvim_set_current_win(self.origin_win) - end) - return - elseif win ~= self.origin_win then - return - end - - if ts.language.get_lang(vim.bo[event.buf].filetype) ~= self.lang then - return self:set_status "N/A" - end - self:search() - end, - }) - - api.nvim_create_autocmd("WinClosed", { - group = self.augroup, - buffer = self.ui_buf, - callback = function() - win_uis[self.origin_win] = nil - api.nvim_clear_autocmds { group = self.augroup } - api.nvim_buf_delete(self.ui_buf, {}) - for buf in pairs(self.buf_matches) do - api.nvim_buf_clear_namespace(buf, self.ns, 0, -1) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - end - end, - }) - - win_uis[self.origin_win] = self - return self -end - -function Ui:search() - local pattern = self:get_input() - local buf = api.nvim_win_get_buf(self.origin_win) - self.buf_matches[buf] = {} - api.nvim_buf_clear_namespace(buf, self.ns, 0, -1) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - - local start = vim.loop.hrtime() - local node, source = self.parse_context:parse(pattern) - if node:has_error() then - return self:set_status "Error" - end - self.buf_matches[buf] = search(buf, node, source, self.ns) - local elapsed = (vim.loop.hrtime() - start) / 1E6 - for _, match in ipairs(self.buf_matches[buf]) do - local start_row, start_col, end_row, end_col = match.range:get() - highlight.range(buf, self.ns, "Search", { start_row, start_col }, { end_row, end_col }, {}) - end - self:set_status(string.format("%d found in %dms", #self.buf_matches[buf], elapsed)) -end - -function Ui:next_match_idx() - local cursor_row, cursor_col = u.get_cursor(self.origin_win) - local buf = api.nvim_win_get_buf(self.origin_win) - for idx, matches in pairs(self.buf_matches[buf]) do - local start_row, start_col = matches.range:get() - if start_row > cursor_row or (start_row == cursor_row and start_col > cursor_col) then - return idx - end - end - return 1 -end - -function Ui:prev_match_idx() - local cursor_row, cursor_col = u.get_cursor(self.origin_win) - local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] - for idx = #matches, 1, -1 do - local start_row, start_col = matches[idx].range:get() - if start_row < cursor_row or (start_row == cursor_row and start_col < cursor_col) then - return idx - end - end - return #matches -end - -function Ui:goto_match(match_idx) - local buf = api.nvim_win_get_buf(self.origin_win) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - local matches = self.buf_matches[buf] - local start_row, start_col, end_row, end_col = matches[match_idx].range:get() - u.set_cursor(self.origin_win, start_row, start_col) - highlight.range( - buf, - self.cur_search_ns, - "CurSearch", - { start_row, start_col }, - { end_row, end_col }, - { priority = vim.highlight.priorities.user + 100 } - ) - api.nvim_buf_set_extmark(buf, self.cur_search_ns, start_row, start_col, { - virt_text_pos = "eol", - virt_text = { { string.format("[%d/%d]", match_idx, #matches), "DiagnosticVirtualTextInfo" } }, - }) -end - -function Ui:replace_all() - self:search() - local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] - if #matches == 0 then - return self:set_status "pattern not found" - end - local _, template = self:get_input() - local start = vim.loop.hrtime() - for _, match in ipairs(matches) do - replace(buf, match, template) - end - local elapsed = (vim.loop.hrtime() - start) / 1E6 - self:set_status(string.format("%d replaced in %dms", #matches, elapsed)) -end - -function Ui:replace_confirm() - self:search() - local buf = api.nvim_win_get_buf(self.origin_win) - local matches = self.buf_matches[buf] - if #matches == 0 then - return self:set_status "pattern not found" - end - - local confirm_buf = api.nvim_create_buf(false, true) - vim.bo[confirm_buf].filetype = "ssr_confirm" - local choices = { - "• Yes", - "• No", - "──────────────", - "• All", - "• Quit", - "• Last replace", - } - local separator_idx = 3 - api.nvim_buf_set_lines(confirm_buf, 0, -1, true, choices) - for idx = 0, #choices - 1 do - if idx + 1 ~= separator_idx then - api.nvim_buf_set_extmark(confirm_buf, self.ns, idx, 4, { hl_group = "Underlined", end_row = idx, end_col = 5 }) - end - end - - local function open_confirm_win(match_idx) - self:goto_match(match_idx) - local _, _, end_row, end_col = matches[match_idx].range:get() - local cfg = { - relative = "win", - win = self.origin_win, - bufpos = { end_row, end_col }, - style = "minimal", - border = config.border, - width = 14, - height = 6, - } - if vim.fn.has "nvim-0.9" == 1 then - cfg.title = "Replace?" - cfg.title_pos = "center" - end - return api.nvim_open_win(confirm_buf, true, cfg) - end - - local match_idx = 1 - local replaced = 0 - local cursor = 1 - local _, template = self:get_input() - self:set_status(string.format("replacing 0/%d", #matches)) - - while match_idx <= #matches do - local confirm_win = open_confirm_win(match_idx) - - ---@type string - local key - while true do - -- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`. - api.nvim_buf_clear_namespace(confirm_buf, self.cur_search_ns, 0, -1) - api.nvim_buf_set_extmark( - confirm_buf, - self.cur_search_ns, - cursor - 1, - 0, - { virt_text = { { "•", "Cursor" } }, virt_text_pos = "overlay" } - ) - api.nvim_buf_set_extmark(confirm_buf, self.cur_search_ns, cursor - 1, 0, { line_hl_group = "CursorLine" }) - vim.cmd.redraw() - - local ok, char = pcall(vim.fn.getcharstr) - key = ok and vim.fn.keytrans(char) or "" - if key == "j" then - if cursor == separator_idx - 1 then -- skip separator - cursor = separator_idx + 1 - elseif cursor == #choices then -- wrap - cursor = 1 - else - cursor = cursor + 1 - end - elseif key == "k" then - if cursor == separator_idx + 1 then -- skip separator - cursor = separator_idx - 1 - elseif cursor == 1 then -- wrap - cursor = #choices - else - cursor = cursor - 1 - end - elseif vim.tbl_contains({ "", "", "", "", "", "" }, key) then - fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key)) - else - break - end - end - - if key == "" then - key = ({ "y", "n", "", "a", "q", "l" })[cursor] - end - - if key == "y" then - replace(buf, matches[match_idx], template) - replaced = replaced + 1 - match_idx = match_idx + 1 - elseif key == "n" then - match_idx = match_idx + 1 - elseif key == "a" then - for i = match_idx, #matches do - replace(buf, matches[i], template) - end - replaced = replaced + #matches + 1 - match_idx - match_idx = #matches + 1 - elseif key == "l" then - replace(buf, matches[match_idx], template) - replaced = replaced + 1 - match_idx = #matches + 1 - elseif key == "q" or key == "" or key == "" then - match_idx = #matches + 1 - end - api.nvim_win_close(confirm_win, false) - self:set_status(string.format("replacing %d/%d", replaced, #matches)) - end - - api.nvim_buf_delete(confirm_buf, {}) - api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1) - self:set_status(string.format("%d/%d replaced", replaced, #matches)) -end - -function Ui:get_input() - local lines = api.nvim_buf_get_lines(self.ui_buf, 0, -1, true) - local pattern_pos = api.nvim_buf_get_extmark_by_id(self.ui_buf, self.ns, self.extmarks.search, {})[1] - local template_pos = api.nvim_buf_get_extmark_by_id(self.ui_buf, self.ns, self.extmarks.replace, {})[1] - local pattern = vim.trim(table.concat(lines, "\n", pattern_pos + 2, template_pos)) - local template = vim.trim(table.concat(lines, "\n", template_pos + 1, #lines)) - return pattern, template -end - ----@param status string -function Ui:set_status(status) - api.nvim_buf_set_extmark(self.ui_buf, self.ns, 0, 0, { - id = self.extmarks.status, - virt_text = { - { "[SSR] ", "Comment" }, - { status }, - { " (Press ? for help)", "Comment" }, - }, - virt_text_pos = "overlay", - }) -end - ----@param win window? ----@return Ui? -function Ui.from_win(win) - if win == nil or win == 0 then - win = api.nvim_get_current_win() - end - local ui = win_uis[win] - if not ui then - return u.notify "No open SSR window" - end - return ui -end - -function M.open() - return Ui.new() -end - --- Replace all matches. -function M.replace_all() - local ui = Ui.from_win() - if ui then - ui:replace_all() - end -end - --- Confirm each match. -function M.replace_confirm() - local ui = Ui.from_win() - if ui then - ui:replace_confirm() - end -end - -return M diff --git a/lua/ssr/config.lua b/lua/ssr/config.lua new file mode 100644 index 0000000..7326c92 --- /dev/null +++ b/lua/ssr/config.lua @@ -0,0 +1,24 @@ +local M = {} + +---@class Config +M.opts = { + border = "rounded", + min_width = 50, + max_width = 120, + min_height = 6, + max_height = 25, + adjust_window = true, + keymaps = { + close = "q", + next_match = "n", + prev_match = "N", + replace_confirm = "", + replace_all = "", + }, +} + +function M.set(config) + M.opts = vim.tbl_deep_extend("force", M.opts, config) +end + +return M diff --git a/lua/ssr/file.lua b/lua/ssr/file.lua new file mode 100644 index 0000000..d43b416 --- /dev/null +++ b/lua/ssr/file.lua @@ -0,0 +1,147 @@ +local api = vim.api +local ts = vim.treesitter +local uv = vim.uv or vim.loop + +-- File contents and it's parsed tree +-- Unloaded buffers are read with libuv because loading a vim buffer can be up to 100x slower. +---@class ssr.File +---@field path string +---@field source string | buffer +---@field tree vim.treesitter.LanguageTree +-- Only if `source` is file content +---@field lines? string[] +---@field mtime? { nsec: integer, sec: integer } +local File = {} + +---@type table +local cache = {} + +---@param path string +---@return ssr.File? +function File.new(path) + -- First check if the file is already opened as a buffer. + local buf = vim.fn.bufnr(path) + if buf ~= -1 then + cache[path] = nil + if vim.bo[buf].filetype == "" then + local ft = vim.filetype.match { buf = buf } + api.nvim_buf_call(buf, function() + vim.cmd("noautocmd setlocal filetype=" .. ft) + end) + end + return setmetatable({ + path = path, + source = buf, + tree = ts.get_parser(buf), + }, { __index = File }) + end + + local fd = uv.fs_open(path, "r", 438) + if not fd then + return + end + local stat = uv.fs_fstat(fd) ---@cast stat -? + local self = cache[path] + if self then + if stat.mtime.sec == self.mtime.sec and stat.mtime.nsec == self.mtime.nsec then + uv.fs_close(fd) + return self + else + cache[path] = nil + end + end + local source = uv.fs_read(fd, stat.size, 0) --[[@as string]] + uv.fs_close(fd) + local lines = vim.split(source, "\n", { plain = true }) + local ft = vim.filetype.match { filename = path, contents = lines } -- not work for .ts + if not ft then + return + end + local lang = ts.language.get_lang(ft) + if not lang then + return + end + local has_parser, tree = pcall(ts.get_string_parser, source, lang) + if not has_parser then + return + end + tree:parse(true) + self = setmetatable({ + path = path, + source = source, + tree = tree, + filetype = ft, + lines = lines, + mtime = stat.mtime, + }, { __index = File }) + cache[path] = self + return self +end + +---@param line integer +---@return string +function File:get_line(line) + if type(self.source) == "number" then + return api.nvim_buf_get_lines(self.source --[[@as integer]], line, line + 1, true)[1] + end + return self.lines[line + 1] +end + +---@return integer +function File:load_buf() + if type(self.source) == "integer" then + return self.source --[[@as integer]] + end + local buf = vim.fn.bufadd(self.path) + self.source = buf + vim.fn.bufload(buf) + -- api.nvim_buf_call(buf, function() + -- vim.cmd("noautocmd setlocal filetype=" .. self.filetype) + -- end) + self.lines = nil + self.mtime = nil + cache[self.path] = nil + return buf +end + +---@param dir string +---@param regex string +---@param on_file fun(file: ssr.File) +---@param on_end fun() +---@return nil +function File.grep(dir, regex, on_file, on_end) + vim.system( + { "rg", "--line-buffered", "--files-with-matches", "--multiline", regex, dir }, + { + text = true, + stdout = vim.schedule_wrap(function(err, files) + if err then + error(files) + end + if not files then + on_end() + return + end + for _, path in ipairs(vim.split(files, "\n", { plain = true, trimempty = true })) do + local file = File.new(path) + if file then + on_file(file) + end + end + end), + }, + vim.schedule_wrap(function(obj) + if obj.code == 1 then -- no match was found + on_end() + elseif obj.code ~= 0 then + error(obj.stderr) + end + end) + ) +end + +function File.clear_cache() + cache = {} +end + +return File diff --git a/lua/ssr/init.lua b/lua/ssr/init.lua new file mode 100644 index 0000000..63b3880 --- /dev/null +++ b/lua/ssr/init.lua @@ -0,0 +1,15 @@ +local M = {} + +--- Set config options. Optional. +---@param config Config? +function M.setup(config) + if config then + require("ssr.config").set(config) + end +end + +function M.open() + require("ssr.ui").new() +end + +return M diff --git a/lua/ssr/parse.lua b/lua/ssr/parse_context.lua similarity index 65% rename from lua/ssr/parse.lua rename to lua/ssr/parse_context.lua index 2e78f40..9e341f3 100644 --- a/lua/ssr/parse.lua +++ b/lua/ssr/parse_context.lua @@ -1,8 +1,7 @@ local ts = vim.treesitter -local wildcard_prefix = require("ssr.search").wildcard_prefix - -local M = {} +local u = require "ssr.utils" +-- The context in which user input (pattern or template) will be parsed correctly. ---@class ParseContext ---@field lang string ---@field before string @@ -10,18 +9,24 @@ local M = {} ---@field pad_rows integer ---@field pad_cols integer local ParseContext = {} -ParseContext.__index = ParseContext -M.ParseContext = ParseContext --- Create a context in which `origin_node` (and user input) will be parsed correctly. ----@param buf buffer +---@param lang string +---@return ParseContext +function ParseContext.empty(lang) + return setmetatable({ + lang = lang, + before = "", + after = "", + pad_rows = 0, + pad_cols = 0, + }, { __index = ParseContext }) +end + +---@param buf integer +---@param lang string ---@param origin_node TSNode ---@return ParseContext? -function ParseContext.new(buf, origin_node) - local lang = ts.language.get_lang(vim.bo[buf].filetype) - if not lang then - return - end +function ParseContext.new(buf, lang, origin_node) local self = setmetatable({ lang = lang }, { __index = ParseContext }) local origin_start_row, origin_start_col, origin_start_byte = origin_node:start() @@ -48,8 +53,8 @@ function ParseContext.new(buf, origin_node) if end_row == start_row then end_col = end_col + start_col end - local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col) - if node_in_context and node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then + local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col) --[[@as TSNode]] + if node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then local context_start_byte self.start_row, self.start_col, context_start_byte = context_node:start() self.before = context_text:sub(1, origin_start_byte - context_start_byte) @@ -63,17 +68,17 @@ function ParseContext.new(buf, origin_node) end end --- Parse search pattern to syntax tree in proper context. ----@param pattern string ----@return TSNode?, string -function ParseContext:parse(pattern) +-- Parse code to TS node. +---@param code string +---@return TSNode, string +function ParseContext:parse(code) -- Replace named wildcard $name to identifier __ssr_var_name to avoid syntax error. - pattern = pattern:gsub("%$([_%a%d]+)", wildcard_prefix .. "%1") - local context_text = self.before .. pattern .. self.after - local root = ts.get_string_parser(context_text, self.lang):parse()[1]:root() - local lines = vim.split(pattern, "\n") - local node = root:named_descendant_for_range(self.pad_rows, self.pad_cols, self.pad_rows + #lines - 1, #lines[#lines]) - return node, context_text + code = code:gsub("%$([_%a%d]+)", u.wildcard_prefix .. "%1") + local source = self.before .. code .. self.after + local root = ts.get_string_parser(source, self.lang):parse()[1]:root() + local lines = vim.split(code, "\n") + local node = root:named_descendant_for_range(self.pad_rows, self.pad_cols, self.pad_rows + #lines - 1, #lines[#lines]) --[[@as TSNode]] + return node, source end -return M +return ParseContext diff --git a/lua/ssr/range.lua b/lua/ssr/range.lua new file mode 100644 index 0000000..cacd0d0 --- /dev/null +++ b/lua/ssr/range.lua @@ -0,0 +1,50 @@ +local api = vim.api +local u = require "ssr.utils" + +---@class ssr.Range +---@field start_row integer +---@field start_col integer +---@field end_row integer +---@field end_col integer +local Range = {} + +---@param node TSNode +---@return ssr.Range +function Range.from_node(node) + local start_row, start_col, end_row, end_col = node:range() + return setmetatable({ + start_row = start_row, + start_col = start_col, + end_row = end_row, + end_col = end_col, + }, { __index = Range }) +end + +---@param other ssr.Range +---@return boolean +function Range:before(other) + return self.end_row < other.start_row or (self.end_row == other.start_row and self.end_col <= other.start_col) +end + +---@param other ssr.Range +---@return boolean +function Range:inside(other) + return ( + (self.start_row > other.start_row or (self.start_row == other.start_row and self.start_col > other.start_col)) + and (self.end_row < other.end_row or (self.end_row == other.end_row and self.end_col <= other.end_col)) + ) +end + +-- Extmark-based ranges automatically adjust as buffer contents change. +---@param buf integer +---@return integer +function Range:to_extmark(buf) + return api.nvim_buf_set_extmark(buf, u.namespace, self.start_row, self.start_col, { + end_row = self.end_row, + end_col = self.end_col, + right_gravity = false, + end_right_gravity = true, + }) +end + +return Range diff --git a/lua/ssr/replace.lua b/lua/ssr/replace.lua new file mode 100644 index 0000000..9f96f4d --- /dev/null +++ b/lua/ssr/replace.lua @@ -0,0 +1,59 @@ +local api = vim.api +local u = require "ssr.utils" + +local M = {} + +---@class ssr.PinnedMatch +---@field buf integer +---@field range integer +---@field captures integer[] +M.PinnedMatch = {} + +-- Convert `ssr.SearchResults` to extmark-based version. +---@param matches ssr.Matches +---@return ssr.PinnedMatch[] +function M.pin_matches(matches) + local res = {} + for _, row in ipairs(matches) do + local buf = row.file:load_buf() + for _, match in ipairs(row.matches) do + local pinned = { buf = buf, range = match.range:to_extmark(buf), captures = {} } + for var, range in pairs(match.captures) do + pinned.captures[var] = range:to_extmark(buf) + end + table.insert(res, pinned) + end + end + return res +end + +---@param buf integer +---@param id integer +---@return integer, number, number, number +local function get_extmark_range(buf, id) + local extmark = api.nvim_buf_get_extmark_by_id(buf, u.namespace, id, { details = true }) + return extmark[1], extmark[2], extmark[3].end_row, extmark[3].end_col +end + +--- Render template and replace one match. +---@param match ssr.PinnedMatch +---@param template string +function M.replace(match, template) + -- Render templates with captured nodes. + local replacement = template:gsub("()%$([_%a%d]+)", function(pos, var) + local start_row, start_col, end_row, end_col = get_extmark_range(match.buf, match.captures[var]) + local capture_lines = api.nvim_buf_get_text(match.buf, start_row, start_col, end_row, end_col, {}) + u.remove_indent(capture_lines, u.get_indent(match.buf, start_row)) + local var_lines = vim.split(template:sub(1, pos), "\n") + local var_line = var_lines[#var_lines] + local template_indent = var_line:match "^%s*" + u.add_indent(capture_lines, template_indent) + return table.concat(capture_lines, "\n") + end) + replacement = vim.split(replacement, "\n") + local start_row, start_col, end_row, end_col = get_extmark_range(match.buf, match.range) + u.add_indent(replacement, u.get_indent(match.buf, start_row)) + api.nvim_buf_set_text(match.buf, start_row, start_col, end_row, end_col, replacement) +end + +return M diff --git a/lua/ssr/search.lua b/lua/ssr/search.lua index ec6cdcc..6cd986a 100644 --- a/lua/ssr/search.lua +++ b/lua/ssr/search.lua @@ -1,57 +1,15 @@ -local api = vim.api local ts = vim.treesitter +local Range = require "ssr.range" local u = require "ssr.utils" -local M = {} - -M.wildcard_prefix = "__ssr_var_" - ----@class Match ----@field range ExtmarkRange ----@field captures ExtmarkRange[] - ----@class ExtmarkRange ----@field ns number ----@field buf buffer ----@field extmark number -local ExtmarkRange = {} -M.ExtmarkRange = ExtmarkRange - ----@param ns number ----@param buf buffer ----@param node TSNode ----@return ExtmarkRange -function ExtmarkRange.new(ns, buf, node) - local start_row, start_col, end_row, end_col = node:range() - return setmetatable({ - ns = ns, - buf = buf, - extmark = api.nvim_buf_set_extmark(buf, ns, start_row, start_col, { - end_row = end_row, - end_col = end_col, - right_gravity = false, - end_right_gravity = true, - }), - }, { __index = ExtmarkRange }) -end - ----@return number, number, number, number -function ExtmarkRange:get() - local extmark = api.nvim_buf_get_extmark_by_id(self.buf, self.ns, self.extmark, { details = true }) - return extmark[1], extmark[2], extmark[3].end_row, extmark[3].end_col -end - -- Compare if two captured trees can match. --- The check is loose because users want to match different types of node. +-- The check is loose because we want to match different types of node. -- e.g. converting `{ foo: foo }` to shorthand `{ foo }`. ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred) - ---@param node1 TSNode? - ---@param node2 TSNode? + ---@param node1 TSNode + ---@param node2 TSNode ---@return boolean local function tree_match(node1, node2) - if not node1 or not node2 then - return false - end if node1:named() ~= node2:named() then return false end @@ -62,33 +20,38 @@ ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred) return false end for i = 0, node1:child_count() - 1 do - if not tree_match(node1:child(i), node2:child(i)) then + if + not tree_match(node1:child(i) --[[@as TSNode]], node2:child(i) --[[@as TSNode]]) + then return false end end return true end return tree_match(match[pred[2]], match[pred[3]]) -end, true) +end, { force = true }) -- Build a TS sexpr represting the node. +-- This function is more strict than `TSNode:sexpr()` by also requiring leaf nodes to match text. ---@param node TSNode ---@param source string ----@return string, table +---@return string sexpr +---@return table wildcards +---@return string rough_regex local function build_sexpr(node, source) - ---@type table + ---@type table local wildcards = {} + local rough_regex = "" local next_idx = 1 - -- This function is more strict than `tsnode:sexpr()` by also requiring leaf nodes to match text. ---@param node TSNode ---@return string local function build(node) local text = ts.get_node_text(node, source) -- Special identifier __ssr_var_name is a named wildcard. - -- Handle this early to make sure wildcard captures largest node. - local var = text:match("^" .. M.wildcard_prefix .. "([_%a%d]+)$") + -- Handle this early to make sure wildcard captures the largest node. + local var = text:match("^" .. u.wildcard_prefix .. "([_%a%d]+)$") if var then if not wildcards[var] then wildcards[var] = next_idx @@ -104,6 +67,9 @@ local function build_sexpr(node, source) -- Leaf nodes (keyword, identifier, literal and symbol) should match text. if node:named_child_count() == 0 then + if #text > #rough_regex then -- TODO build an actual regex + rough_regex = text + end local sexpr = string.format("(%s) @_%d (#eq? @_%d %s)", node:type(), next_idx, next_idx, u.to_ts_query_str(text)) next_idx = next_idx + 1 return sexpr @@ -134,76 +100,74 @@ local function build_sexpr(node, source) end local sexpr = string.format("(%s) @all", build(node)) - return sexpr, wildcards + rough_regex = u.regex_escape(rough_regex) + return sexpr, wildcards, rough_regex end ----@param buf buffer ----@param node TSNode ----@param source string ----@return Match[] -function M.search(buf, node, source, ns) - local sexpr, wildcards = build_sexpr(node, source) - local parse_query = ts.query.parse or ts.parse_query - local lang = ts.language.get_lang(vim.bo[buf].filetype) - if not lang then - return {} +---@class ssr.Searcher +---@field lang string +---@field query vim.treesitter.Query +---@field wildcards table +---@field rough_regex string +local Searcher = {} + +---@param lang string +---@param pattern string +---@param parse_context ParseContext +---@return ssr.Searcher? +function Searcher.new(lang, pattern, parse_context) + local node, source = parse_context:parse(pattern) + if node:has_error() then + return end + local sexpr, wildcards, rough_regex = build_sexpr(node, source) + local parse_query = ts.query.parse or ts.parse_query local query = parse_query(lang, sexpr) + return setmetatable({ + lang = lang, + query = query, + wildcards = wildcards, + rough_regex = rough_regex, + }, { __index = Searcher }) +end + +-- A single match, including its captures. +---@class ssr.Match +---@field range ssr.Range +---@field captures table + +---@param file ssr.File +---@return ssr.Match[] +function Searcher:search(file) + ---@type ssr.Match[] local matches = {} - local has_parser, parser = pcall(ts.get_parser, buf, lang) - if not has_parser then - return {} - end - local root = parser:parse(true)[1]:root() - for _, nodes in query:iter_matches(root, buf, 0, -1) do - ---@type table - local captures = {} - for var, idx in pairs(wildcards) do - captures[var] = ExtmarkRange.new(ns, buf, nodes[idx]) + file.tree:for_each_tree(function(tree, lang_tree) -- must called :parse(true) + if lang_tree:lang() ~= self.lang then + return end - local match = { range = ExtmarkRange.new(ns, buf, nodes[#nodes]), captures = captures } - table.insert(matches, match) - end + for _, nodes in self.query:iter_matches(tree:root(), file.source, 0, -1) do + local range = Range.from_node(nodes[#nodes]) + local captures = {} + for var, idx in pairs(self.wildcards) do + captures[var] = Range.from_node(nodes[idx]) + end + table.insert(matches, { range = range, captures = captures }) + end + end) -- Sort matches from -- buffer top to bottom, to make goto next/prev match intuitive -- inner to outer for recursive matches, to make replacing correct - ---@param match1 { range: ExtmarkRange, captures: table} - ---@param match2 { range: ExtmarkRange, captures: table} + ---@param match1 ssr.Match + ---@param match2 ssr.Match ---@return boolean table.sort(matches, function(match1, match2) - local start_row1, start_col1, end_row1, end_col1 = match1.range:get() - local start_row2, start_col2, end_row2, end_col2 = match2.range:get() - if end_row1 < start_row2 or (end_row1 == start_row2 and end_col1 <= start_col2) then + if match1.range:before(match2.range) then return true end - return (start_row1 > start_row2 or (start_row1 == start_row2 and start_col1 > start_col2)) - and (end_row1 < end_row2 or (end_row1 == end_row2 and end_col1 <= end_col2)) + return match1.range:inside(match2.range) end) - return matches end ---- Render template and replace one match. ----@param buf buffer ----@param match Match ----@param template string -function M.replace(buf, match, template) - -- Render templates with captured nodes. - local replace = template:gsub("()%$([_%a%d]+)", function(pos, var) - local start_row, start_col, end_row, end_col = match.captures[var]:get() - local lines = api.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {}) - u.remove_indent(lines, u.get_indent(buf, start_row)) - local var_lines = vim.split(template:sub(1, pos), "\n") - local var_line = var_lines[#var_lines] - local template_indent = var_line:match "^%s*" - u.add_indent(lines, template_indent) - return table.concat(lines, "\n") - end) - replace = vim.split(replace, "\n") - local start_row, start_col, end_row, end_col = match.range:get() - u.add_indent(replace, u.get_indent(buf, start_row)) - api.nvim_buf_set_text(buf, start_row, start_col, end_row, end_col, replace) -end - -return M +return Searcher diff --git a/lua/ssr/ui/confirm_win.lua b/lua/ssr/ui/confirm_win.lua new file mode 100644 index 0000000..e02fa19 --- /dev/null +++ b/lua/ssr/ui/confirm_win.lua @@ -0,0 +1,139 @@ +local api = vim.api + +---@class ConfirmWin +local ConfirmWin = {} + +function ConfirmWin.new() end + +function ConfirmWin:open() + local buf = api.nvim_win_get_buf(self.origin_win) + local matches = self.matches[buf] + if #matches == 0 then + return self:set_status "pattern not found" + end + + local confirm_buf = api.nvim_create_buf(false, true) + vim.bo[confirm_buf].filetype = "ssr_confirm" + local choices = { + "• Yes", + "• No", + "──────────────", + "• All", + "• Quit", + "• Last replace", + } + local separator_idx = 3 + api.nvim_buf_set_lines(confirm_buf, 0, -1, true, choices) + for idx = 0, #choices - 1 do + if idx + 1 ~= separator_idx then + api.nvim_buf_set_extmark( + confirm_buf, + u.namespace, + idx, + 4, + { hl_group = "Underlined", end_row = idx, end_col = 5 } + ) + end + end + + local function open_confirm_win(match_idx) + self:goto_match(match_idx) + local _, _, end_row, end_col = matches[match_idx].range:get() + local cfg = { + relative = "win", + win = self.origin_win, + bufpos = { end_row, end_col }, + style = "minimal", + border = config.options.border, + width = 14, + height = 6, + } + if vim.fn.has "nvim-0.9" == 1 then + cfg.title = "Replace?" + cfg.title_pos = "center" + end + return api.nvim_open_win(confirm_buf, true, cfg) + end + + local match_idx = 1 + local replaced = 0 + local cursor = 1 + local _, template = self:get_input() + self:set_status(string.format("replacing 0/%d", #matches)) + + while match_idx <= #matches do + local confirm_win = open_confirm_win(match_idx) + + ---@type string + local key + while true do + -- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`. + api.nvim_buf_clear_namespace(confirm_buf, u.cur_search_ns, 0, -1) + api.nvim_buf_set_extmark( + confirm_buf, + u.cur_search_ns, + cursor - 1, + 0, + { virt_text = { { "•", "Cursor" } }, virt_text_pos = "overlay" } + ) + api.nvim_buf_set_extmark(confirm_buf, u.cur_search_ns, cursor - 1, 0, { line_hl_group = "CursorLine" }) + vim.cmd.redraw() + + local ok, char = pcall(vim.fn.getcharstr) + key = ok and vim.fn.keytrans(char) or "" + if key == "j" then + if cursor == separator_idx - 1 then -- skip separator + cursor = separator_idx + 1 + elseif cursor == #choices then -- wrap + cursor = 1 + else + cursor = cursor + 1 + end + elseif key == "k" then + if cursor == separator_idx + 1 then -- skip separator + cursor = separator_idx - 1 + elseif cursor == 1 then -- wrap + cursor = #choices + else + cursor = cursor - 1 + end + elseif vim.tbl_contains({ "", "", "", "", "", "" }, key) then + vim.fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key)) + else + break + end + end + + if key == "" then + key = ({ "y", "n", "", "a", "q", "l" })[cursor] + end + + if key == "y" then + replace(buf, matches[match_idx], template) + replaced = replaced + 1 + match_idx = match_idx + 1 + elseif key == "n" then + match_idx = match_idx + 1 + elseif key == "a" then + for i = match_idx, #matches do + replace(buf, matches[i], template) + end + replaced = replaced + #matches + 1 - match_idx + match_idx = #matches + 1 + elseif key == "l" then + replace(buf, matches[match_idx], template) + replaced = replaced + 1 + match_idx = #matches + 1 + elseif key == "q" or key == "" or key == "" then + match_idx = #matches + 1 + end + api.nvim_win_close(confirm_win, false) + self:set_status(string.format("replacing %d/%d", replaced, #matches)) + end + + api.nvim_buf_delete(confirm_buf, {}) + api.nvim_buf_clear_namespace(buf, u.cur_search_ns, 0, -1) + self:set_status(string.format("%d/%d replaced", replaced, #matches)) +end + +return ConfirmWin diff --git a/lua/ssr/ui/init.lua b/lua/ssr/ui/init.lua new file mode 100644 index 0000000..a5ad74b --- /dev/null +++ b/lua/ssr/ui/init.lua @@ -0,0 +1,123 @@ +local api = vim.api +local ts = vim.treesitter +local config = require "ssr.config" +local ParseContext = require "ssr.parse_context" +local Searcher = require "ssr.search" +local replace = require("ssr.replace").replace +local pin_matches = require("ssr.replace").pin_matches +local File = require "ssr.file" +local MainWin = require "ssr.ui.main_win" +local u = require "ssr.utils" + +---@alias ssr.Matches { file: ssr.File, matches: ssr.Match[] }[] + +---@class Ui +---@field lang string +---@field parse_context ParseContext +---@field matches ssr.Matches +---@field last_pattern string +---@field main_win MainWin +local Ui = {} + +---@return Ui? +function Ui.new() + local self = setmetatable({ matches = {} }, { __index = Ui }) + + -- Pre-checks + local origin_win = api.nvim_get_current_win() + local origin_buf = api.nvim_win_get_buf(origin_win) + local lang = ts.language.get_lang(vim.bo[origin_buf].filetype) + if not lang then + return u.notify(string.format("Treesitter language not found for filetype '%s'", vim.bo[origin_buf].filetype)) + end + self.lang = lang + local origin_node = u.node_for_range(origin_buf, self.lang, u.get_selection(origin_win)) + if not origin_node then + return u.notify("Treesitter parser not found, please try to install it with :TSInstall " .. self.lang) + end + if origin_node:has_error() then + return u.notify "You have syntax errors in the selected node" + end + local parse_context = ParseContext.new(origin_buf, self.lang, origin_node) + if not parse_context then + return u.notify "Can't find a proper context to parse the pattern" + end + self.parse_context = parse_context + + local placeholder = vim.split(ts.get_node_text(origin_node, origin_buf), "\n", { plain = true }) + u.remove_indent(placeholder, u.get_indent(origin_buf, origin_node:start())) + + self.main_win = MainWin.new(lang, placeholder, { "" }, origin_win) + + self.main_win:on({ "InsertLeave", "TextChanged" }, function() + self:search() + end) + + self.main_win:on_key(config.opts.keymaps.replace_all, function() + self:replace_all() + end) + + self:search() + return self +end + +function Ui:search() + local pattern = self.main_win:get_input() + if pattern == self.last_pattern then + return + end + self.last_pattern = pattern + + self.matches = {} + local found = 0 + local matched_files = 0 + local start = vim.loop.hrtime() + local searcher = Searcher.new(self.lang, pattern, self.parse_context) + if not searcher then + return self:set_status "Error" + end + + File.grep(vim.loop.cwd(), searcher.rough_regex, function(file) + local matches = searcher:search(file) + if #matches == 0 then + return + end + found = found + #matches + matched_files = matched_files + 1 + table.insert(self.matches, { file = file, matches = matches }) + end, function() + local elapsed = (vim.loop.hrtime() - start) / 1E6 + self.main_win.result_list:set(self.matches) + self:set_status(string.format("%d found in %d files (%dms)", found, matched_files, elapsed)) + end) +end + +function Ui:replace_all() + if #self.matches == 0 then + return self:set_status "pattern not found" + end + local _, template = self.main_win:get_input() + local start = vim.loop.hrtime() + local pinned = pin_matches(self.matches) + for _, match in ipairs(pinned) do + replace(match, template) + end + local elapsed = (vim.loop.hrtime() - start) / 1E6 + self:set_status(string.format("%d replaced in %d files (%dms)", #self.matches, 0, elapsed)) +end + +---@param status string +---@return nil +function Ui:set_status(status) + api.nvim_buf_set_extmark(self.main_win.buf, u.namespace, 0, 0, { + id = self.main_win.extmarks.status, + virt_text = { + { "[SSR] ", "Comment" }, + { status }, + { " (Press ? for help)", "Comment" }, + }, + virt_text_pos = "overlay", + }) +end + +return Ui diff --git a/lua/ssr/ui/main_win.lua b/lua/ssr/ui/main_win.lua new file mode 100644 index 0000000..b81a830 --- /dev/null +++ b/lua/ssr/ui/main_win.lua @@ -0,0 +1,241 @@ +local api = vim.api +local ts = vim.treesitter +local config = require "ssr.config" +local ResultList = require "ssr.ui.result_list" +local u = require "ssr.utils" + +---@class MainWin +---@field buf integer +---@field win integer +---@field origin_win integer +---@field lang string +---@field last_pattern string[] +---@field last_template string[] +---@field result_list ResultList +local MainWin = {} + +function MainWin.new(lang, pattern, template, origin_win) + local self = setmetatable({ + lang = lang, + last_pattern = pattern, + last_template = template, + origin_win = origin_win, + }, { __index = MainWin }) + + self.buf = api.nvim_create_buf(false, true) + vim.bo[self.buf].filetype = "ssr" + + local lines = self:render() + self:open_win(u.get_win_size(lines)) + + self.result_list = ResultList.new(self.buf, self.win, self.extmarks.results) + + self:setup_autocmds() + self:setup_keymaps() + + return self +end + +---@private +function MainWin:render() + ts.stop(self.buf) + api.nvim_buf_clear_namespace(self.buf, u.namespace, 0, -1) + + local lines = { + "", -- [SSR] + "```" .. self.lang, -- SEARCH: + } + vim.list_extend(lines, self.last_pattern) + table.insert(lines, "") -- REPLACE: + vim.list_extend(lines, self.last_template) + table.insert(lines, "```") -- RESULTS: + api.nvim_buf_set_lines(self.buf, 0, -1, true, lines) + + -- Enable syntax highlights for input area. + local parser = ts.get_parser(self.buf, "markdown") + parser:parse(true) + parser:for_each_tree(function(tree, lang_tree) + if tree:root():start() == 2 then + ts.highlighter.new(lang_tree) + end + end) + + local function virt_text(row, text) + return api.nvim_buf_set_extmark(self.buf, u.namespace, row, 0, { virt_text = text, virt_text_pos = "overlay" }) + end + self.extmarks = { + status = virt_text(0, { { "[SSR]", "Comment" }, { " (Press ? for help)", "Comment" } }), + search = virt_text(1, { { "SEARCH: ", "String" } }), -- Extra spaces to cover too long language name. + replace = virt_text(#lines - 3, { { "REPLACE:", "String" } }), + results = virt_text(#lines - 1, { { "RESULTS:", "String" } }), + } + + return lines +end + +---@private +function MainWin:check(lines) + if #lines < 6 then + return false + end + + local function get_index(extmark) + return api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, extmark, {})[1] + 1 + end + + return get_index(self.extmarks.status) == 1 + and lines[1] == "" + and get_index(self.extmarks.search) == 2 + and lines[2] == "```" .. self.lang + and lines[get_index(self.extmarks.replace)] == "" + and lines[get_index(self.extmarks.results)] == "```" +end + +---@private +function MainWin:open_win(width, height) + self.win = api.nvim_open_win(self.buf, true, { + relative = "editor", + anchor = "NE", + row = 0, + col = vim.o.columns - 1, + style = "minimal", + border = config.opts.border, + width = width, + height = height, + }) + vim.wo[self.win].wrap = false + if vim.fn.has "nvim-0.10" == 1 then + vim.wo[self.win].winfixbuf = true + end + u.set_cursor(self.win, 2, 0) + vim.fn.matchadd("Title", [[$\w\+]]) +end + +function MainWin:on(event, func) + api.nvim_create_autocmd(event, { + group = u.augroup, + buffer = self.buf, + callback = func, + }) +end + +---@private +function MainWin:setup_autocmds() + self:on({ "TextChanged", "TextChangedI" }, function() + local lines = api.nvim_buf_get_lines(self.buf, 0, -1, true) + if not self:check(lines) then + self:render() + self.result_list.extmark = self.extmarks.results + self.result_list:set {} + u.set_cursor(self.win, 2, 0) + end + if not config.opts.adjust_window then + return + end + local width, height = u.get_win_size(lines) + if api.nvim_win_get_width(self.win) ~= width then + api.nvim_win_set_width(self.win, width) + end + if api.nvim_win_get_height(self.win) ~= height then + api.nvim_win_set_height(self.win, height) + end + end) + + self:on("BufWinEnter", function(event) + if event.buf == self.buf then + return + end + local win = api.nvim_get_current_win() + if win ~= self.win then + return + end + -- Prevent accidentally opening another file in the ssr window. + -- Adapted from neo-tree.nvim. + vim.schedule(function() + api.nvim_win_set_buf(self.win, self.buf) + local name = api.nvim_buf_get_name(event.buf) + api.nvim_win_call(self.origin_win, function() + pcall(api.nvim_buf_delete, event.buf, {}) + if name ~= "" then + vim.cmd.edit(name) + end + end) + api.nvim_set_current_win(self.origin_win) + end) + end) + + self:on("WinClosed", function() + api.nvim_clear_autocmds { group = u.augroup } + api.nvim_buf_delete(self.buf, {}) + end) +end + +function MainWin:on_key(key, func) + vim.keymap.set("n", key, func, { buffer = self.buf, nowait = true }) +end + +---@private +function MainWin:setup_keymaps() + self:on_key(config.opts.keymaps.close, function() + api.nvim_win_close(self.win, false) + end) + + self:on_key("gg", function() + u.set_cursor(self.win, 2, 0) + end) + + self:on_key("j", function() + local cursor = u.get_cursor(self.win) + for _, extmark in ipairs { self.extmarks.replace, self.extmarks.results } do + local skip_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, extmark, {})[1] + if cursor == skip_pos - 1 then + return pcall(u.set_cursor, self.win, skip_pos + 1, 0) + end + end + vim.fn.feedkeys("j", "n") + end) + + self:on_key("k", function() + local cursor = u.get_cursor(self.win) + if cursor <= 2 then + return u.set_cursor(self.win, 2, 0) + end + for _, extmark in ipairs { self.extmarks.replace, self.extmarks.results } do + local skip_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, extmark, {})[1] + if cursor == skip_pos + 1 then + return pcall(u.set_cursor, self.win, skip_pos - 1, 0) + end + end + vim.fn.feedkeys("k", "n") + end) + + self:on_key("l", function() + local cursor = u.get_cursor(self.win) + if cursor < self.result_list:get_start() then + return vim.fn.feedkeys("l", "n") + end + self.result_list:set_folded(false) + end) + + self:on_key("h", function() + local cursor = u.get_cursor(self.win) + if cursor < self.result_list:get_start() then + return vim.fn.feedkeys("h", "n") + end + self.result_list:set_folded(true) + end) +end + +function MainWin:get_input() + local pattern_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmarks.search, {})[1] + local template_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmarks.replace, {})[1] + local results_pos = api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmarks.results, {})[1] + local lines = api.nvim_buf_get_lines(self.buf, 0, results_pos, true) + local pattern = vim.list_slice(lines, pattern_pos + 2, template_pos) + local template = vim.list_slice(lines, template_pos + 2) + self.last_pattern = pattern + self.last_template = template + return vim.trim(table.concat(pattern, "\n")), vim.trim(table.concat(template, "\n")) +end + +return MainWin diff --git a/lua/ssr/ui/result_list.lua b/lua/ssr/ui/result_list.lua new file mode 100644 index 0000000..815e6b5 --- /dev/null +++ b/lua/ssr/ui/result_list.lua @@ -0,0 +1,205 @@ +local api = vim.api +local config = require "ssr.config" +local u = require "ssr.utils" + +-- List item per line +---@class Item +---@field fold_idx integer which fold this line belongs to, 1-based +---@field match_idx integer which match this line belongs to, 0-based, 0 for filename + +-- A foldable region that may span multiple lines +---@class Fold +---@field folded boolean +---@field filename string +---@field path string +---@field preview_lines string[] +local Fold = {} + +---@param folded boolean +---@param file ssr.File +---@param matches ssr.Match[] +---@return Fold +function Fold.new(folded, file, matches) + local preview_lines = {} + for _, match in ipairs(matches) do + local line = file:get_line(match.range.start_row) + line = line:gsub("^%s*", "") + table.insert(preview_lines, "│ " .. line) + end + return setmetatable({ + folded = folded, + filename = vim.fn.fnamemodify(file.path, ":t"), + path = vim.fn.fnamemodify(file.path, ":~:.:h"), + preview_lines = preview_lines, + }, { __index = Fold }) +end + +function Fold:len() + if self.folded then + return 1 + end + return 1 + #self.preview_lines +end + +---@private +function Fold:get_lines() + if self.folded then + return { string.format(" %s %s %d", self.filename, self.path, #self.preview_lines) } + end + local lines = { string.format(" %s %s %d", self.filename, self.path, #self.preview_lines) } + vim.list_extend(lines, self.preview_lines) + return lines +end + +function Fold:highlight(buf, row) + local col = 4 -- "" is 3 bytes, plus 1 space + api.nvim_buf_add_highlight(buf, u.namespace, "Directory", row, col, col + #self.filename) + col = col + #self.filename + 1 + api.nvim_buf_add_highlight(buf, u.namespace, "Comment", row, col, col + #self.path) + col = col + #self.path + 1 + api.nvim_buf_add_highlight(buf, u.namespace, "Number", row, col, col + #(tostring(self.preview_lines))) +end + +---@class ResultList +---@field buf integer +---@field win integer +---@field extmark integer +---@field folds Fold[] +---@field items Item[] +local ResultList = {} + +function ResultList.new(buf, win, extmark) + local self = setmetatable({ + buf = buf, + win = win, + extmark = extmark, + folds = {}, + items = {}, + }, { __index = ResultList }) + + vim.keymap.set("n", config.opts.keymaps.next_match, function() + self:next_match() + end, { buffer = self.buf, nowait = true }) + vim.keymap.set("n", config.opts.keymaps.prev_match, function() + self:prev_match() + end, { buffer = self.buf, nowait = true }) + + return self +end + +---@private +function ResultList:get_start() + return api.nvim_buf_get_extmark_by_id(self.buf, u.namespace, self.extmark, {})[1] + 1 +end + +---@params matches ssr.Matches +function ResultList:set(matches) + self.folds = {} + self.items = {} + local start = self:get_start() + api.nvim_buf_clear_namespace(self.buf, u.namespace, start, -1) + + local lines = {} + for fold_idx, row in ipairs(matches) do + local fold = Fold.new(fold_idx ~= 1, row.file, row.matches) + table.insert(self.folds, fold) + for match_idx, line in ipairs(fold:get_lines()) do + table.insert(lines, line) + table.insert(self.items, { fold_idx = fold_idx, match_idx = match_idx - 1 }) + end + end + api.nvim_buf_set_lines(self.buf, start, -1, true, lines) + + for _, fold in ipairs(self.folds) do + fold:highlight(self.buf, start) + start = start + fold:len() + end +end + +---@param folded boolean +---@param cursor integer? +function ResultList:set_folded(folded, cursor) + local result_start = self:get_start() + cursor = cursor or u.get_cursor(self.win) - result_start + local item = self.items[cursor + 1] -- +1 beacause `cursor` is 0-based + local fold = self.folds[item.fold_idx] + if fold.folded == folded then + return + end + + local start = cursor - item.match_idx -- like C macro `container_of` + local end_ = start + fold:len() + fold.folded = folded + local lines = fold:get_lines() + local items = {} + for i = 0, #lines - 1 do + table.insert(items, { fold_idx = item.fold_idx, match_idx = i }) + end + u.list_replace(self.items, start, end_, items) + start = result_start + start + end_ = result_start + end_ + api.nvim_buf_set_lines(self.buf, start, end_, true, lines) + fold:highlight(self.buf, start) + if folded then + u.set_cursor(self.win, start, 0) + end +end + +function ResultList:next_match() + local cursor = u.get_cursor(self.win) + local result_start = self:get_start() + cursor = cursor > result_start and cursor - result_start or 0 + local item = self.items[cursor + 1] -- +1: lua index + if not item then + return + end + if item.match_idx == 0 then + self:set_folded(false, cursor) + end + cursor = cursor + 1 + item = self.items[cursor + 1] + if not item then + return + end + if item.match_idx == 0 then + self:set_folded(false, cursor) + cursor = cursor + 1 + end + return u.set_cursor(self.win, cursor + result_start, 0) +end + +function ResultList:prev_match() + local cursor = u.get_cursor(self.win) + local result_start = self:get_start() + if cursor <= result_start then + if #self.items == 0 then + return + end + self:set_folded(false, #self.items - 1) + return u.set_cursor(self.win, result_start + #self.items - 1, 0) + end + + cursor = cursor - result_start + local item = self.items[cursor + 1] + if not item then + return + end + if item.match_idx <= 1 then + cursor = cursor - item.match_idx - 1 + item = self.items[cursor + 1] + if not item then + return + end + local fold = self.folds[item.fold_idx] + if fold.folded then + self:set_folded(false, cursor) + cursor = cursor + #fold.preview_lines + end + return u.set_cursor(self.win, result_start + cursor, 0) + end + + cursor = cursor - 1 + return u.set_cursor(self.win, cursor + result_start, 0) +end + +return ResultList diff --git a/lua/ssr/utils.lua b/lua/ssr/utils.lua index 099ed92..fa0efb1 100644 --- a/lua/ssr/utils.lua +++ b/lua/ssr/utils.lua @@ -1,8 +1,14 @@ local api = vim.api local ts = vim.treesitter +local config = require "ssr.config" local M = {} +M.wildcard_prefix = "__ssr_var_" +M.namespace = api.nvim_create_namespace "ssr_ns" +M.cur_search_ns = api.nvim_create_namespace "ssr_cur_search_ns" +M.augroup = api.nvim_create_augroup("ssr_augroup", {}) + -- Send a notification titled SSR. ---@param msg string ---@return nil @@ -11,14 +17,14 @@ function M.notify(msg) end -- Get (0,0)-indexed cursor position. ----@param win window +---@param win integer function M.get_cursor(win) local cursor = api.nvim_win_get_cursor(win) return cursor[1] - 1, cursor[2] end -- Set (0,0)-indexed cursor position. ----@param win window +---@param win integer ---@param row integer ---@param col integer function M.set_cursor(win, row, col) @@ -26,8 +32,8 @@ function M.set_cursor(win, row, col) end -- Get selected region, works in many modes. ----@param win window ----@return number, number, number, number +---@param win integer +---@return integer, number, number, number function M.get_selection(win) local mode = api.nvim_get_mode().mode local cursor_row, cursor_col = M.get_cursor(win) @@ -54,12 +60,12 @@ function M.get_selection(win) end -- Get smallest node for the range. ----@param buf buffer +---@param buf integer ---@param lang string ----@param start_row number ----@param start_col number ----@param end_row number ----@param end_col number +---@param start_row integer +---@param start_col integer +---@param end_row integer +---@param end_col integer ---@return TSNode? function M.node_for_range(buf, lang, start_row, start_col, end_row, end_col) local has_parser, parser = pcall(ts.get_parser, buf, lang) @@ -68,8 +74,8 @@ function M.node_for_range(buf, lang, start_row, start_col, end_row, end_col) end end ----@param buf buffer ----@param row number +---@param buf integer +---@param row integer function M.get_indent(buf, row) local line = api.nvim_buf_get_lines(buf, row, row + 1, true)[1] return line:match "^%s*" @@ -103,14 +109,13 @@ end -- Compute window size to show giving lines. ---@param lines string[] ----@param config Config ----@return number ----@return number -function M.get_win_size(lines, config) - ---@param i number - ---@param min number - ---@param max number - ---@return number +---@return integer +---@return integer +function M.get_win_size(lines) + ---@param i integer + ---@param min integer + ---@param max integer + ---@return integer local function clamp(i, min, max) return math.min(math.max(i, min), max) end @@ -122,9 +127,74 @@ function M.get_win_size(lines, config) end end - width = clamp(width, config.min_width, config.max_width) - local height = clamp(#lines, config.min_height, config.max_height) + width = clamp(width, config.opts.min_width, config.opts.max_width) + local height = clamp(#lines, config.opts.min_height, config.opts.max_height) return width, height end +-- https://github.com/rust-lang/regex/blob/17284451f10aa06c6c42e622e3529b98513901a8/regex-syntax/src/lib.rs#L272 +local regex_meta_chars = { + ["\\"] = true, + ["."] = true, + ["+"] = true, + ["*"] = true, + ["?"] = true, + ["("] = true, + [")"] = true, + ["|"] = true, + ["["] = true, + ["]"] = true, + ["{"] = true, + ["}"] = true, + ["^"] = true, + ["$"] = true, + ["#"] = true, + ["&"] = true, + ["-"] = true, + ["~"] = true, +} + +---@param s string +---@return string +function M.regex_escape(s) + local escaped = s:gsub(".", function(ch) + return regex_meta_chars[ch] and "\\" .. ch or ch + end) + return escaped +end + +---@generic T +---@param list T[] +---@param f fun(T): -1 | 0 | 1 +---@return integer? +function M.binary_search_by(list, f) + local left = 1 + local right = #list + 1 + while left < right do + local mid = math.floor((left + right) / 2) + local cmp = f(list[mid]) + if cmp < 0 then + left = mid + 1 + elseif cmp > 0 then + right = mid + else + return mid + end + end +end + +---@generic T +---@param list table +---@param start integer 0-based +---@param end_ integer exclusive +---@param replacement table +function M.list_replace(list, start, end_, replacement) + for _ = start + 1, end_ do + table.remove(list, start + 1) + end + for i = start + 1, start + #replacement do + table.insert(list, i, replacement[i - start]) + end +end + return M diff --git a/tests/ssr_spec.lua b/tests/ssr_spec.lua index f6fcd17..bbe5683 100644 --- a/tests/ssr_spec.lua +++ b/tests/ssr_spec.lua @@ -1,99 +1,101 @@ -local u = require "ssr.utils" -local ParseContext = require("ssr.parse").ParseContext local ts = vim.treesitter -local search = require("ssr.search").search -local replace = require("ssr.search").replace +local uv = vim.uv or vim.loop +local ParseContext = require "ssr.parse_context" +local Searcher = require "ssr.search" +local pin_matches = require("ssr.replace").pin_matches +local replace = require("ssr.replace").replace +local File = require "ssr.file" +---@type string[] local tests = {} +---@param s string local function t(s) table.insert(tests, s) end -t [[ python operators - -a - b -==== +t [[ operators a + b ==>> (+ a b) +==== t.py +a + b +a - b ==== (+ a b) a - b ]] -t [[ python complex string -<""" +t [[ complex string +""" line 1 \r\n\a\?\\ 'a'"'"'b' -"""> -==== +""" ==>> x +==== t.py """ line 1 \r\n\a\?\\ 'a'"'"'b' """ -==>> -x ==== x ]] -t [[ javascript keywords - -const a = 1 -==== +t [[ keywords let a = 1 ==>> x +==== t.js +let a = 1 +const a = 1 ==== x const a = 1 ]] -t [[ lua func args - -f(1, 3) -==== +t [[ func args f(1, 3) ==>> x +==== t.lua +f(1, 2, 3) +f(1, 3) ==== f(1, 2, 3) x ]] -t [[ lua recursive 1 - -==== +t [[ recursive 1 f($a) ==>> $a.f() +==== recursive.lua +f(f(f(0))) ==== 0.f().f().f() ]] -t [[ rust recursive 2 -f(f(, 2), 3) -==== +t [[ recursive 2 f($a, $b) ==>> $a.f($b) +==== t.rs +f(f(f(0, 1), 2), 3) ==== 0.f(1).f(2).f(3) ]] -t [[ rust recursive 3 -f(3, f(2, )) -==== +t [[ recursive 3 f($a, $b) ==>> $a.f($b) +==== t.rs +f(3, f(2, f(1, 0))) ==== 3.f(2.f(1.f(0))) ]] -t [[ python indent 1 -def f(): - -==== +t [[ indent 1 if $a: $b ==>> if $a: if True: $b +==== t.py +def f(): + if foo: + if bar: + pass ==== def f(): if foo: @@ -103,18 +105,18 @@ def f(): pass ]] -t [[ python indent 2 -def f(): - if len(a) != 0: - do_a(a) - -==== +t [[ indent 2 if len($a) != 0: $b ==>> if $a: $b +==== t.py +def f(): + if len(a) != 0: + do_a(a) + if len(b) != 0: + do_b(b) ==== def f(): if a: @@ -123,59 +125,45 @@ def f(): do_b(b) ]] -t [[ rust question mark -let foo = ; -==== +t [[ question mark $a? ==>> try!($a) +==== t.rs +let foo = bar().await?; ==== let foo = try!(bar().await); ]] -t [[ rust rust-analyzer ssr example -String::from() -==== +t [[ rust-analyzer ssr example foo($a, $b) ==>> ($a).foo($b) +==== t.rs +String::from(foo(y + 5, z)) ==== String::from((y + 5).foo(z)) ]] -t [[ go parse Go := in function -func main() { - -} -==== -$a, _ := os.LookupEnv($b) -==>> -$a := os.Getenv($b) -==== -func main() { - commit := os.Getenv("GITHUB_SHA") -} -]] - -t [[ go match Go if err +t [[ match Go if err +if err != nil { panic(err) } ==>> x +==== t.go fn main() { - + } } ==== -if err != nil { panic(err) } ==>> x -==== fn main() { x } ]] -t [[ rust reused wildcard: compound assignments -; +t [[ reused wildcard: compound assignments +$a = $a + $b; ==>> $a += $b; +==== t.rs +idx = idx + 1; bar = foo + idx; *foo.bar() = * foo . bar () + 1; (foo + bar) = (foo + bar) + 1; (foo + bar) = (foo - bar) + 1; ==== -$a = $a + $b ==>> $a += $b -==== idx += 1; bar = foo + idx; *foo.bar() += 1; @@ -183,18 +171,18 @@ bar = foo + idx; (foo + bar) = (foo - bar) + 1; ]] -t [[ python reused wildcard: indent -def f(): - -==== +t [[ reused wildcard: indent if $foo: if $foo: $body ==>> if $foo: $body +==== t.py +def f(): + if await foo.bar(baz): + if await foo.bar(baz): + pass ==== def f(): if await foo.bar(baz): @@ -202,74 +190,102 @@ def f(): ]] -- two `foo`s have different type: `property_identifier` and `identifier` -t [[ javascript reused wildcard: match different node types 1 -<{ foo: foo }> -{ foo: bar } -==== +t [[ reused wildcard: match different node types 1 { $a: $a } ==>> { $a } +==== t.js +{ foo: foo } +{ foo: bar } ==== { foo } { foo: bar } ]] -t [[ lua reused wildcard: match different node types 2 - -local a = vim.api -==== +t [[ reused wildcard: match different node types 2 local $a = vim.$a ==>> x +==== t.lua +local api = vim.api +local a = vim.api ==== x local a = vim.api ]] +t [[ multiple files +local $a = vim.$a ==>> _G.g_$a = vim.$a + +==== t.lua +local api = vim.api +local fn = vim.fn +==== +_G.g_api = vim.api +_G.g_fn = vim.fn + +==== README.md +# Example +```lua +local F = vim.F +local uv = vim.uv +``` +==== +# Example +```lua +_G.g_F = vim.F +_G.g_uv = vim.uv +``` +]] + describe("", function() -- Plenary runs nvim with `--noplugin` argument. - -- Make sure nvim-treesitter is loaded, which populates vim.treesitter's ft_to_lang table. + -- Load nvim-treesitter to make `ts.language.get_lang()` work. require "nvim-treesitter" for _, s in ipairs(tests) do - local ft, desc, content, pattern, template, expected = - s:match "^ (%a-) (.-)\n(.-)%s?====%s?(.-)%s?==>>%s?(.-)%s?====%s?(.-)%s?$" - content = vim.split(content, "\n") - expected = vim.split(expected, "\n") - local start_row, start_col, end_row, end_col - for idx, line in ipairs(content) do - local col = line:find "<" - if col then - start_row = idx - 1 - start_col = col - 1 - end - line = line:gsub("<", "") - col = line:find ">" - if col then - end_row = idx - 1 - end_col = col - 1 - end - line = line:gsub(">", "") - content[idx] = line - end - + local desc, pattern, template, rest = s:match "^ (.-)\n(.-)%s?==>>%s?(.-)\n%s*==(.-)$" it(desc, function() - local ns = vim.api.nvim_create_namespace "" - local buf = vim.api.nvim_create_buf(false, true) - vim.bo[buf].filetype = ft - vim.api.nvim_buf_set_lines(buf, 0, -1, true, content) - local lang = ts.language.get_lang(vim.bo[buf].filetype) - assert(lang, "language not found") - local origin_node = u.node_for_range(buf, lang, start_row, start_col, end_row, end_col) + local dir = vim.fn.tempname() + assert(uv.fs_mkdir(dir, 448)) + + local expected_files = {} + local lang + for fname, before, after in (rest .. "=="):gmatch "== (.-)\n(.-)====\n(.-)%s*==" do + after = after .. "\n" -- Vim always adds a \n to files. + fname = vim.fs.joinpath(dir, fname) + local fd = assert(uv.fs_open(fname, "w", 438)) + assert(uv.fs_write(fd, before) > 0) + assert(uv.fs_close(fd)) + expected_files[fname] = after + lang = lang or assert(ts.language.get_lang(vim.filetype.match { filename = fname })) + end - local parse_context = ParseContext.new(buf, origin_node) - assert(parse_context) - local node, source = parse_context:parse(pattern) - local matches = search(buf, node, source, ns) + local empty_context = ParseContext.empty(lang) + local searcher = assert(Searcher.new(lang, pattern, empty_context)) + ---@type ssr.SearchResults + local results = {} + local done = false + File.grep(dir, searcher.rough_regex, function(file) + local matches = searcher:search(file) + assert(#matches > 0) + table.insert(results, { file = file, matches = matches }) + end, function() + done = true + end) + vim.wait(1000, function() + return done + end) - for _, match in ipairs(matches) do - replace(buf, match, template) + local pinned_matches = pin_matches(results) + for _, match in ipairs(pinned_matches) do + replace(match, template) end - local actual = vim.api.nvim_buf_get_lines(buf, 0, -1, true) - vim.api.nvim_buf_delete(buf, {}) - assert.are.same(expected, actual) + vim.cmd "silent wa" + for fname, expected in pairs(expected_files) do + local fd = assert(uv.fs_open(fname, "r", 438)) + local stat = assert(uv.fs_fstat(fd)) + local actual = uv.fs_read(fd, stat.size, 0) + uv.fs_close(fd) + assert.are.same(expected, actual) + end end) end end)