Skip to content

Commit

Permalink
Make confirmation window blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
cshuaimin committed Jul 22, 2023
1 parent 3851fec commit b6a072b
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 67 deletions.
131 changes: 72 additions & 59 deletions lua/ssr.lua
Original file line number Diff line number Diff line change
Expand Up @@ -288,74 +288,87 @@ function Ui:replace_confirm()
cfg.title = "Replace?"
cfg.title_pos = "center"
end
confirm_win = api.nvim_open_win(confirm_buf, true, cfg)
return api.nvim_open_win(confirm_buf, true, cfg)
end

local function map(key, func)
keymap.set("n", key, function()
func()
api.nvim_win_close(confirm_win, false)
if match_idx <= #self.matches then
open_confirm_win()
local match_idx = 1
local replaced = 0
local cursor = 1
local _, template = self:get_input()
self:set_status(string.format("Replacing 0/%d", #matches))

while match_idx <= #matches do
local confirm_win = open_confirm_win(match_idx)

local key
while true do
-- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`.
api.nvim_buf_clear_namespace(confirm_buf, self.cur_search_ns, 0, -1)
api.nvim_buf_set_extmark(
confirm_buf,
self.cur_search_ns,
cursor - 1,
0,
{ virt_text = { { "", "Cursor" } }, virt_text_pos = "overlay" }
)
api.nvim_buf_set_extmark(confirm_buf, self.cur_search_ns, cursor - 1, 0, { line_hl_group = "CursorLine" })
vim.cmd.redraw()

local ok, char = pcall(vim.fn.getcharstr)
key = ok and vim.fn.keytrans(char) or ""
if key == "j" then
if cursor == separator_idx - 1 then -- skip separator
cursor = separator_idx + 1
elseif cursor == #choices then -- wrap
cursor = 1
else
cursor = cursor + 1
end
elseif key == "k" then
if cursor == separator_idx + 1 then -- skip separator
cursor = separator_idx - 1
elseif cursor == 1 then -- wrap
cursor = #choices
else
cursor = cursor - 1
end
elseif vim.tbl_contains({ "<C-E>", "<C-Y>", "<C-U>", "<C-D>", "<C-F>", "<C-B>" }, key) then
fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key))
else
api.nvim_buf_delete(confirm_buf, {})
api.nvim_buf_clear_namespace(self.origin_buf, self.cur_match_ns, 0, -1)
break
end
self:set_status(string.format("%d/%d replaced", replaced, #self.matches))
end, { buffer = confirm_buf, nowait = true })
end

map("y", function()
replace(self.origin_buf, self.matches[match_idx], template)
replaced = replaced + 1
match_idx = match_idx + 1
end)

map("n", function()
match_idx = match_idx + 1
end)

map("a", function()
for i = match_idx, #self.matches do
replace(self.origin_buf, self.matches[i], template)
end
replaced = replaced + #self.matches + 1 - match_idx
match_idx = #self.matches + 1
end)

map("q", function()
match_idx = #self.matches + 1
end)

map("<Esc>", function()
match_idx = #self.matches + 1
end)

map("<C-[>", function()
match_idx = #self.matches + 1
end)

map("l", function()
replace(self.origin_buf, self.matches[match_idx], template)
replaced = replaced + 1
match_idx = #self.matches + 1
end)
if key == "<CR>" then
key = ({ "y", "n", "", "a", "q", "l" })[cursor]
end

local function origin_win_map(key)
vim.keymap.set("n", key, function()
fn.win_execute(self.origin_win, string.format('execute "normal! \\%s"', key))
end, { buffer = confirm_buf })
if key == "y" then
replace(buf, matches[match_idx], template)
replaced = replaced + 1
match_idx = match_idx + 1
elseif key == "n" then
match_idx = match_idx + 1
elseif key == "a" then
for i = match_idx, #matches do
replace(buf, matches[i], template)
end
replaced = replaced + #matches + 1 - match_idx
match_idx = #matches + 1
elseif key == "l" then
replace(buf, matches[match_idx], template)
replaced = replaced + 1
match_idx = #matches + 1
elseif key == "q" or key == "<ESC>" or key == "" then
match_idx = #matches + 1
end
api.nvim_win_close(confirm_win, false)
self:set_status(string.format("Replacing %d/%d", replaced, #matches))
end

origin_win_map "<C-e>"
origin_win_map "<C-y>"
origin_win_map "<C-u>"
origin_win_map "<C-d>"
origin_win_map "<C-f>"
origin_win_map "<C-b>"

self:set_status(string.format("0/%d replaced", #self.matches))
open_confirm_win()
api.nvim_buf_delete(confirm_buf, {})
api.nvim_buf_clear_namespace(buf, self.cur_search_ns, 0, -1)
self:set_status(string.format("Replaced %d/%d", replaced, #matches))
end

function Ui:get_input()
Expand Down
21 changes: 21 additions & 0 deletions lua/ssr/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ function M.get_cursor(win)
return cursor[1] - 1, cursor[2]
end

function M.set_cursor(win, row, col)
api.nvim_win_set_cursor(win, { row + 1, col })
end

---@return number, number, number, number
function M.get_selection(win)
local mode = api.nvim_get_mode().mode
Expand Down Expand Up @@ -89,4 +93,21 @@ function M.to_ts_query_str(s)
return '"' .. s .. '"'
end

function M.get_win_size(lines, config)
local function clamp(i, min, max)
return math.min(math.max(i, min), max)
end

local width = 0
for _, line in ipairs(lines) do
if #line > width then
width = #line
end
end

width = clamp(width, config.min_width, config.max_width)
local height = clamp(#lines, config.min_height, config.max_height)
return width, height
end

return M
15 changes: 7 additions & 8 deletions tests/ssr_spec.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
local parsers = require "nvim-treesitter.parsers"
local utils = require "ssr.utils"
local Parser = require("ssr.parse").Parser
local u = require "ssr.utils"
local ParseContext = require("ssr.parse").ParseContext
local search = require("ssr.search").search
local replace = require("ssr.replace").replace
local replace = require("ssr.search").replace

local tests = {}

Expand Down Expand Up @@ -183,11 +182,11 @@ describe("", function()
local buf = vim.api.nvim_create_buf(false, true)
vim.bo[buf].filetype = ft
vim.api.nvim_buf_set_lines(buf, 0, -1, true, content)
local origin_node = utils.node_for_range(buf, start_row, start_col, end_row, end_col)
local origin_node = u.node_for_range(buf, start_row, start_col, end_row, end_col)

local parser = Parser:new(buf, origin_node)
assert(parser)
local node, source = parser:parse(pattern)
local parse_context = ParseContext.new(buf, origin_node)
assert(parse_context)
local node, source = parse_context:parse(pattern)
local matches = search(buf, node, source, ns)

for _, match in ipairs(matches) do
Expand Down

0 comments on commit b6a072b

Please sign in to comment.