Skip to content

Commit

Permalink
Make confirmation window blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
cshuaimin committed Jul 22, 2023
1 parent 3851fec commit da16064
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 77 deletions.
166 changes: 107 additions & 59 deletions lua/ssr.lua
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ local config = {
},
}

-- Set config options.
---@param cfg Config?
function M.setup(cfg)
if cfg then
config = vim.tbl_deep_extend("force", config, cfg)
end
end

---@type table<window, Ui>
local win_uis = {}

---@class Ui
---@field ns number
---@field cur_search_ns number
Expand Down Expand Up @@ -160,9 +164,11 @@ function Ui.new()
api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1)
end
api.nvim_clear_autocmds { group = self.augroup }
win_uis[self.origin_win] = nil
end,
})

win_uis[self.origin_win] = self
return self
end

Expand Down Expand Up @@ -288,74 +294,87 @@ function Ui:replace_confirm()
cfg.title = "Replace?"
cfg.title_pos = "center"
end
confirm_win = api.nvim_open_win(confirm_buf, true, cfg)
return api.nvim_open_win(confirm_buf, true, cfg)
end

local function map(key, func)
keymap.set("n", key, function()
func()
api.nvim_win_close(confirm_win, false)
if match_idx <= #self.matches then
open_confirm_win()
local match_idx = 1
local replaced = 0
local cursor = 1
local _, template = self:get_input()
self:set_status(string.format("Replacing 0/%d", #matches))

while match_idx <= #matches do
local confirm_win = open_confirm_win(match_idx)

local key
while true do
-- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`.
api.nvim_buf_clear_namespace(confirm_buf, self.cur_search_ns, 0, -1)
api.nvim_buf_set_extmark(
confirm_buf,
self.cur_search_ns,
cursor - 1,
0,
{ virt_text = { { "", "Cursor" } }, virt_text_pos = "overlay" }
)
api.nvim_buf_set_extmark(confirm_buf, self.cur_search_ns, cursor - 1, 0, { line_hl_group = "CursorLine" })
vim.cmd.redraw()

local ok, char = pcall(vim.fn.getcharstr)
key = ok and vim.fn.keytrans(char) or ""
if key == "j" then
if cursor == separator_idx - 1 then -- skip separator
cursor = separator_idx + 1
elseif cursor == #choices then -- wrap
cursor = 1
else
cursor = cursor + 1
end
elseif key == "k" then
if cursor == separator_idx + 1 then -- skip separator
cursor = separator_idx - 1
elseif cursor == 1 then -- wrap
cursor = #choices
else
cursor = cursor - 1
end
elseif vim.tbl_contains({ "<C-E>", "<C-Y>", "<C-U>", "<C-D>", "<C-F>", "<C-B>" }, key) then
fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key))
else
api.nvim_buf_delete(confirm_buf, {})
api.nvim_buf_clear_namespace(self.origin_buf, self.cur_match_ns, 0, -1)
break
end
self:set_status(string.format("%d/%d replaced", replaced, #self.matches))
end, { buffer = confirm_buf, nowait = true })
end

map("y", function()
replace(self.origin_buf, self.matches[match_idx], template)
replaced = replaced + 1
match_idx = match_idx + 1
end)

map("n", function()
match_idx = match_idx + 1
end)

map("a", function()
for i = match_idx, #self.matches do
replace(self.origin_buf, self.matches[i], template)
end
replaced = replaced + #self.matches + 1 - match_idx
match_idx = #self.matches + 1
end)

map("q", function()
match_idx = #self.matches + 1
end)

map("<Esc>", function()
match_idx = #self.matches + 1
end)

map("<C-[>", function()
match_idx = #self.matches + 1
end)

map("l", function()
replace(self.origin_buf, self.matches[match_idx], template)
replaced = replaced + 1
match_idx = #self.matches + 1
end)
if key == "<CR>" then
key = ({ "y", "n", "", "a", "q", "l" })[cursor]
end

local function origin_win_map(key)
vim.keymap.set("n", key, function()
fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key))
end, { buffer = confirm_buf })
if key == "y" then
replace(buf, matches[match_idx], template)
replaced = replaced + 1
match_idx = match_idx + 1
elseif key == "n" then
match_idx = match_idx + 1
elseif key == "a" then
for i = match_idx, #matches do
replace(buf, matches[i], template)
end
replaced = replaced + #matches + 1 - match_idx
match_idx = #matches + 1
elseif key == "l" then
replace(buf, matches[match_idx], template)
replaced = replaced + 1
match_idx = #matches + 1
elseif key == "q" or key == "<ESC>" or key == "" then
match_idx = #matches + 1
end
api.nvim_win_close(confirm_win, false)
self:set_status(string.format("Replacing %d/%d", replaced, #matches))
end

origin_win_map "<C-e>"
origin_win_map "<C-y>"
origin_win_map "<C-u>"
origin_win_map "<C-d>"
origin_win_map "<C-f>"
origin_win_map "<C-b>"

self:set_status(string.format("0/%d replaced", #self.matches))
open_confirm_win()
api.nvim_buf_delete(confirm_buf, {})
api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1)
self:set_status(string.format("Replaced %d/%d", replaced, #matches))
end

function Ui:get_input()
Expand All @@ -379,8 +398,37 @@ function Ui:set_status(status)
})
end

---@param win window?
---@return Ui?
function Ui.from_win(win)
if win == nil or win == 0 then
win = api.nvim_get_current_win()
end
local ui = win_uis[win]
if not ui then
return u.notify "No open SSR window"
end
return ui
end

function M.open()
return Ui.new()
end

-- Replace all matches.
function M.replace_all()
local ui = Ui.from_win()
if ui then
ui:replace_all()
end
end

-- Confirm each match.
function M.replace_confirm()
local ui = Ui.from_win()
if ui then
ui:replace_confirm()
end
end

return M
2 changes: 1 addition & 1 deletion lua/ssr/search.lua
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function M.search(buf, node, source, ns)
local parse_query = ts.query.parse or ts.parse_query
local query = parse_query(parsers.get_buf_lang(buf), sexpr)
local matches = {}
local root = u.get_root(buf)
local root = parsers.get_parser(buf):parse()[1]:root()
for _, nodes in query:iter_matches(root, buf, 0, -1) do
local captures = {}
for var, idx in pairs(wildcards) do
Expand Down
44 changes: 35 additions & 9 deletions lua/ssr/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,30 @@ local parsers = require "nvim-treesitter.parsers"

local M = {}

-- Send a notification titled SSR.
---@param msg string
---@return nil
function M.notify(msg)
vim.notify(msg, "error", { title = "SSR" })
vim.notify(msg, vim.log.levels.ERROR, { title = "SSR" })
end

-- Get (0,0)-indexed cursor position.
---@param win window
function M.get_cursor(win)
local cursor = api.nvim_win_get_cursor(win)
return cursor[1] - 1, cursor[2]
end

-- Set (0,0)-indexed cursor position.
---@param win window
---@param row integer
---@param col integer
function M.set_cursor(win, row, col)
api.nvim_win_set_cursor(win, { row + 1, col })
end

-- Get selected region, works in many modes.
---@param win window
---@return number, number, number, number
function M.get_selection(win)
local mode = api.nvim_get_mode().mode
Expand All @@ -40,20 +53,15 @@ function M.get_selection(win)
end
end

---@param buf buffer
---@return TSNode
function M.get_root(buf)
return parsers.get_parser(buf):parse()[1]:root()
end

-- Get smallest node for the range.
---@param buf buffer
---@param start_row number
---@param start_col number
---@param end_row number
---@param end_col number
---@return TSNode
function M.node_for_range(buf, start_row, start_col, end_row, end_col)
return M.get_root(buf):named_descendant_for_range(start_row, start_col, end_row, end_col)
return parsers.get_parser(buf):parse()[1]:root():named_descendant_for_range(start_row, start_col, end_row, end_col)
end

---@param buf buffer
Expand All @@ -80,7 +88,7 @@ function M.remove_indent(lines, indent)
end
end

--- Escape special characters in s and quote it in double quotes.
-- Escape special characters in s and quote it in double quotes.
---@param s string
function M.to_ts_query_str(s)
s = s:gsub([[\]], [[\\]])
Expand All @@ -89,4 +97,22 @@ function M.to_ts_query_str(s)
return '"' .. s .. '"'
end

-- Compute window size to show giving lines.
function M.get_win_size(lines, config)
local function clamp(i, min, max)
return math.min(math.max(i, min), max)
end

local width = 0
for _, line in ipairs(lines) do
if #line > width then
width = #line
end
end

width = clamp(width, config.min_width, config.max_width)
local height = clamp(#lines, config.min_height, config.max_height)
return width, height
end

return M
15 changes: 7 additions & 8 deletions tests/ssr_spec.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
local parsers = require "nvim-treesitter.parsers"
local utils = require "ssr.utils"
local Parser = require("ssr.parse").Parser
local u = require "ssr.utils"
local ParseContext = require("ssr.parse").ParseContext
local search = require("ssr.search").search
local replace = require("ssr.replace").replace
local replace = require("ssr.search").replace

local tests = {}

Expand Down Expand Up @@ -183,11 +182,11 @@ describe("", function()
local buf = vim.api.nvim_create_buf(false, true)
vim.bo[buf].filetype = ft
vim.api.nvim_buf_set_lines(buf, 0, -1, true, content)
local origin_node = utils.node_for_range(buf, start_row, start_col, end_row, end_col)
local origin_node = u.node_for_range(buf, start_row, start_col, end_row, end_col)

local parser = Parser:new(buf, origin_node)
assert(parser)
local node, source = parser:parse(pattern)
local parse_context = ParseContext.new(buf, origin_node)
assert(parse_context)
local node, source = parse_context:parse(pattern)
local matches = search(buf, node, source, ns)

for _, match in ipairs(matches) do
Expand Down

0 comments on commit da16064

Please sign in to comment.