Skip to content

Commit

Permalink
Search and replace across files with ripgrep
Browse files Browse the repository at this point in the history
  • Loading branch information
cshuaimin committed Oct 15, 2023
1 parent 10d51dd commit 22e141c
Show file tree
Hide file tree
Showing 8 changed files with 392 additions and 278 deletions.
205 changes: 89 additions & 116 deletions lua/ssr.lua

Large diffs are not rendered by default.

83 changes: 83 additions & 0 deletions lua/ssr/file.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
local ts = vim.treesitter
local uv = vim.uv or vim.loop

---@class File
---@field name string
---@field content string
---@field mtime { nsec: integer, sec: integer }
---@field lang_tree LanguageTree
local File = {}

---@type table<string, File>
local cache = {}

---@param name string
---@return File?
function File.new(name)
local fd = uv.fs_open(name, "r", 438)
if not fd then
return
end
local stat = uv.fs_fstat(fd) --[[@as uv.aliases.fs_stat_table]]
local self = cache[name]
if self and stat.mtime.sec == self.mtime.sec and stat.mtime.nsec == self.mtime.nsec then
uv.fs_close(fd)
return self
end

self = setmetatable({ name = name }, { __index = File })
self.mtime = stat.mtime
self.content = uv.fs_read(fd, stat.size, 0) --[[@as string]]
uv.fs_close(fd)
if not name:match "%.ts$" then
return
end
-- local ft = vim.filetype.match { filename = name }
-- if not ft then
-- return
-- end
local lang = ts.language.get_lang "typescript"
if not lang then
return
end
local has_parser, lang_tree = pcall(ts.get_string_parser, self.content, lang)
if not has_parser then
return
end
self.lang_tree = lang_tree
self.lang_tree:parse()

cache[name] = self
return self
end

---@param regex string
---@param on_file fun(file: File)
---@param on_end fun()
---@return nil
function File.grep(regex, on_file, on_end)
vim.system({ "rg", "--line-buffered", "--files-with-matches", regex }, {
text = true,
stdout = vim.schedule_wrap(function(err, files)
if err then
error(files)
end
if not files then
on_end()
return
end
for _, name in ipairs(vim.split(files, "\n", { plain = true })) do
local file = File.new(name)
if file then
on_file(file)
end
end
end),
}, function(obj)
if obj.code ~= 0 then
error(obj.stderr)
end
end)
end

return File
35 changes: 14 additions & 21 deletions lua/ssr/parse.lua → lua/ssr/parse_context.lua
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
local ts = vim.treesitter
local wildcard_prefix = require("ssr.search").wildcard_prefix

local M = {}
local u = require "ssr.utils"

-- The context in which user input will be parsed correctly.
---@class ParseContext
---@field lang string
---@field before string
---@field after string
---@field pad_rows integer
---@field pad_cols integer
local ParseContext = {}
ParseContext.__index = ParseContext
M.ParseContext = ParseContext

-- Create a context in which `origin_node` (and user input) will be parsed correctly.
---@param buf buffer
---@param lang string
---@param origin_node TSNode
---@return ParseContext?
function ParseContext.new(buf, origin_node)
local lang = ts.language.get_lang(vim.bo[buf].filetype)
if not lang then
return
end
function ParseContext.new(buf, lang, origin_node)
local self = setmetatable({ lang = lang }, { __index = ParseContext })

local origin_start_row, origin_start_col, origin_start_byte = origin_node:start()
Expand All @@ -48,8 +41,8 @@ function ParseContext.new(buf, origin_node)
if end_row == start_row then
end_col = end_col + start_col
end
local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col)
if node_in_context and node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then
local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col) --[[@as TSNode]]
if node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then
local context_start_byte
self.start_row, self.start_col, context_start_byte = context_node:start()
self.before = context_text:sub(1, origin_start_byte - context_start_byte)
Expand All @@ -63,17 +56,17 @@ function ParseContext.new(buf, origin_node)
end
end

-- Parse search pattern to syntax tree in proper context.
-- Parse code to TS node.
---@param pattern string
---@return TSNode?, string
---@return TSNode, string
function ParseContext:parse(pattern)
-- Replace named wildcard $name to identifier __ssr_var_name to avoid syntax error.
pattern = pattern:gsub("%$([_%a%d]+)", wildcard_prefix .. "%1")
local context_text = self.before .. pattern .. self.after
local root = ts.get_string_parser(context_text, self.lang):parse()[1]:root()
pattern = pattern:gsub("%$([_%a%d]+)", u.wildcard_prefix .. "%1")
local source = self.before .. pattern .. self.after
local root = ts.get_string_parser(source, self.lang):parse()[1]:root()
local lines = vim.split(pattern, "\n")
local node = root:named_descendant_for_range(self.pad_rows, self.pad_cols, self.pad_rows + #lines - 1, #lines[#lines])
return node, context_text
local node = root:named_descendant_for_range(self.pad_rows, self.pad_cols, self.pad_rows + #lines - 1, #lines[#lines]) --[[@as TSNode]]
return node, source
end

return M
return ParseContext
35 changes: 35 additions & 0 deletions lua/ssr/range.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
---@class Range
---@field start_row number
---@field start_col number
---@field end_row number
---@field end_col number
local Range = {}

---@param node TSNode
---@return Range
function Range.from_node(node)
local start_row, start_col, end_row, end_col = node:range()
return setmetatable({
start_row = start_row,
start_col = start_col,
end_row = end_row,
end_col = end_col,
}, { __index = Range })
end

---@param other Range
---@return boolean
function Range:before(other)
return self.end_row < other.start_row or (self.end_row == other.start_row and self.end_col <= other.start_col)
end

---@param other Range
---@return boolean
function Range:inside(other)
return (
(self.start_row > other.start_row or (self.start_row == other.start_row and self.start_col > other.start_col))
and (self.end_row < other.end_row or (self.end_row == other.end_row and self.end_col <= other.end_col))
)
end

return Range
28 changes: 28 additions & 0 deletions lua/ssr/replace.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
local api = vim.api
local ts = vim.treesitter
local u = require "ssr.utils"

local Replacer = {}

--- Render template and replace one match.
---@param buf buffer
---@param match Match
function Replacer:replace(buf, match)
-- Render templates with captured nodes.
local replace = self.template:gsub("()%$([_%a%d]+)", function(pos, var)
local start_row, start_col, end_row, end_col = match.captures[var]:get()
local lines = api.nvim_buf_get_text(buf, start_row, start_col, end_row, end_col, {})
u.remove_indent(lines, u.get_indent(buf, start_row))
local var_lines = vim.split(self.template:sub(1, pos), "\n")
local var_line = var_lines[#var_lines]
local template_indent = var_line:match "^%s*"
u.add_indent(lines, template_indent)
return table.concat(lines, "\n")
end)
replace = vim.split(replace, "\n")
local start_row, start_col, end_row, end_col = match.range:get()
u.add_indent(replace, u.get_indent(buf, start_row))
api.nvim_buf_set_text(buf, start_row, start_col, end_row, end_col, replace)
end

return Replacer
Loading

0 comments on commit 22e141c

Please sign in to comment.