Skip to content

Commit

Permalink
refactor!: remove nvim-treesitter dependency (#35)
Browse files Browse the repository at this point in the history
* refactor!: remove nvim-treesitter dependency

* chore: lua diagnostics

* Check if parser installed with pcall `ts.get_parser`

* Load nvim-treesitter in CI before running tests

---------

Co-authored-by: Chen Shuaimin <chen_shuaimin@outlook.com>
  • Loading branch information
amaanq and cshuaimin committed Oct 14, 2023
1 parent b2f35df commit 10d51dd
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
git clone --depth 1 https://github.com/nvim-treesitter/nvim-treesitter.git ~/.local/share/nvim/site/pack/vendor/start/nvim-treesitter
git clone --depth 1 https://github.com/nvim-lua/plenary.nvim ~/.local/share/nvim/site/pack/vendor/start/plenary.nvim
ln -s $(pwd) ~/.local/share/nvim/site/pack/vendor/start
nvim --headless -c "lua require('nvim-treesitter').setup {}" -c 'TSInstallSync python javascript lua rust go' -c 'q'
nvim --headless -c 'TSInstallSync python javascript lua rust go' -c 'q'
- name: Run tests
run: |
Expand Down
17 changes: 11 additions & 6 deletions lua/ssr.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ local ts = vim.treesitter
local fn = vim.fn
local keymap = vim.keymap
local highlight = vim.highlight
local parsers = require "nvim-treesitter.parsers"
local ParseContext = require("ssr.parse").ParseContext
local search = require("ssr.search").search
local replace = require("ssr.search").replace
Expand Down Expand Up @@ -56,13 +55,18 @@ function Ui.new()

self.origin_win = api.nvim_get_current_win()
local origin_buf = api.nvim_win_get_buf(self.origin_win)
self.lang = parsers.get_buf_lang(origin_buf)
if not parsers.has_parser(self.lang) then
local lang = ts.language.get_lang(vim.bo[origin_buf].filetype)
if not lang then
return u.notify("Treesitter language not found")
end
self.lang = lang

local origin_node = u.node_for_range(origin_buf, self.lang, u.get_selection(self.origin_win))
if not origin_node then
return u.notify("Treesitter parser not found, please try to install it with :TSInstall " .. self.lang)
end
local origin_node = u.node_for_range(origin_buf, u.get_selection(self.origin_win))
if origin_node:has_error() then
return u.notify "You have syntax errors in selected node"
return u.notify "You have syntax errors in the selected node"
end
local parse_context = ParseContext.new(origin_buf, origin_node)
if not parse_context then
Expand Down Expand Up @@ -173,7 +177,7 @@ function Ui.new()
return
end

if parsers.get_buf_lang(event.buf) ~= self.lang then
if ts.language.get_lang(vim.bo[event.buf].filetype) ~= self.lang then
return self:set_status "N/A"
end
self:search()
Expand Down Expand Up @@ -334,6 +338,7 @@ function Ui:replace_confirm()
while match_idx <= #matches do
local confirm_win = open_confirm_win(match_idx)

---@type string
local key
while true do
-- Draw a fake cursor because cursor is not shown correctly when blocking on `getchar()`.
Expand Down
12 changes: 8 additions & 4 deletions lua/ssr/parse.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
local ts = vim.treesitter
local parsers = require "nvim-treesitter.parsers"
local wildcard_prefix = require("ssr.search").wildcard_prefix

local M = {}
Expand All @@ -19,12 +18,17 @@ M.ParseContext = ParseContext
---@param origin_node TSNode
---@return ParseContext?
function ParseContext.new(buf, origin_node)
local self = setmetatable({ lang = parsers.get_buf_lang(buf) }, { __index = ParseContext })
local lang = ts.language.get_lang(vim.bo[buf].filetype)
if not lang then
return
end
local self = setmetatable({ lang = lang }, { __index = ParseContext })

local origin_start_row, origin_start_col, origin_start_byte = origin_node:start()
local _, _, origin_end_byte = origin_node:end_()
local origin_lines = vim.split(ts.get_node_text(origin_node, buf), "\n")
local origin_sexpr = origin_node:sexpr()
---@type TSNode?
local context_node = origin_node

-- Find an ancestor of `origin_node`
Expand All @@ -45,7 +49,7 @@ function ParseContext.new(buf, origin_node)
end_col = end_col + start_col
end
local node_in_context = root:named_descendant_for_range(start_row, start_col, end_row, end_col)
if node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then
if node_in_context and node_in_context:type() == origin_node:type() and node_in_context:sexpr() == origin_sexpr then
local context_start_byte
self.start_row, self.start_col, context_start_byte = context_node:start()
self.before = context_text:sub(1, origin_start_byte - context_start_byte)
Expand All @@ -61,7 +65,7 @@ end

-- Parse search pattern to syntax tree in proper context.
---@param pattern string
---@return TSNode, string
---@return TSNode?, string
function ParseContext:parse(pattern)
-- Replace named wildcard $name to identifier __ssr_var_name to avoid syntax error.
pattern = pattern:gsub("%$([_%a%d]+)", wildcard_prefix .. "%1")
Expand Down
27 changes: 22 additions & 5 deletions lua/ssr/search.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
local api = vim.api
local ts = vim.treesitter
local parsers = require "nvim-treesitter.parsers"
local u = require "ssr.utils"

local M = {}
Expand Down Expand Up @@ -46,10 +45,13 @@ end
-- The check is loose because users want to match different types of node.
-- e.g. converting `{ foo: foo }` to shorthand `{ foo }`.
ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred)
---@param node1 TSNode
---@param node2 TSNode
---@param node1 TSNode?
---@param node2 TSNode?
---@return boolean
local function tree_match(node1, node2)
if not node1 or not node2 then
return false
end
if node1:named() ~= node2:named() then
return false
end
Expand All @@ -74,10 +76,13 @@ end, true)
---@param source string
---@return string, table<string, number>
local function build_sexpr(node, source)
---@type table<string, number>
local wildcards = {}
local next_idx = 1

-- This function is more strict than `tsnode:sexpr()` by also requiring leaf nodes to match text.
---@param node TSNode
---@return string
local function build(node)
local text = ts.get_node_text(node, source)

Expand Down Expand Up @@ -139,10 +144,19 @@ end
function M.search(buf, node, source, ns)
local sexpr, wildcards = build_sexpr(node, source)
local parse_query = ts.query.parse or ts.parse_query
local query = parse_query(parsers.get_buf_lang(buf), sexpr)
local lang = ts.language.get_lang(vim.bo[buf].filetype)
if not lang then
return {}
end
local query = parse_query(lang, sexpr)
local matches = {}
local root = parsers.get_parser(buf):parse()[1]:root()
local has_parser, parser = pcall(ts.get_parser, buf, lang)
if not has_parser then
return {}
end
local root = parser:parse(true)[1]:root()
for _, nodes in query:iter_matches(root, buf, 0, -1) do
---@type table<string, ExtmarkRange>
local captures = {}
for var, idx in pairs(wildcards) do
captures[var] = ExtmarkRange.new(ns, buf, nodes[idx])
Expand All @@ -154,6 +168,9 @@ function M.search(buf, node, source, ns)
-- Sort matches from
-- buffer top to bottom, to make goto next/prev match intuitive
-- inner to outer for recursive matches, to make replacing correct
---@param match1 { range: ExtmarkRange, captures: table<string, ExtmarkRange>}
---@param match2 { range: ExtmarkRange, captures: table<string, ExtmarkRange>}
---@return boolean
table.sort(matches, function(match1, match2)
local start_row1, start_col1, end_row1, end_col1 = match1.range:get()
local start_row2, start_col2, end_row2, end_col2 = match2.range:get()
Expand Down
24 changes: 18 additions & 6 deletions lua/ssr/utils.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
local api = vim.api
local parsers = require "nvim-treesitter.parsers"
local ts = vim.treesitter

local M = {}

Expand Down Expand Up @@ -55,13 +55,17 @@ end

-- Get smallest node for the range.
---@param buf buffer
---@param lang string
---@param start_row number
---@param start_col number
---@param end_row number
---@param end_col number
---@return TSNode
function M.node_for_range(buf, start_row, start_col, end_row, end_col)
return parsers.get_parser(buf):parse()[1]:root():named_descendant_for_range(start_row, start_col, end_row, end_col)
---@return TSNode?
function M.node_for_range(buf, lang, start_row, start_col, end_row, end_col)
local has_parser, parser = pcall(ts.get_parser, buf, lang)
if has_parser then
return parser:parse()[1]:root():named_descendant_for_range(start_row, start_col, end_row, end_col)
end
end

---@param buf buffer
Expand All @@ -71,15 +75,15 @@ function M.get_indent(buf, row)
return line:match "^%s*"
end

---@param lines table
---@param lines string[]
---@param indent string
function M.add_indent(lines, indent)
for i = 2, #lines do
lines[i] = indent .. lines[i]
end
end

---@param lines table
---@param lines string[]
---@param indent string
function M.remove_indent(lines, indent)
indent = "^" .. indent
Expand All @@ -98,7 +102,15 @@ function M.to_ts_query_str(s)
end

-- Compute window size to show giving lines.
---@param lines string[]
---@param config Config
---@return number
---@return number
function M.get_win_size(lines, config)
---@param i number
---@param min number
---@param max number
---@return number
local function clamp(i, min, max)
return math.min(math.max(i, min), max)
end
Expand Down
9 changes: 8 additions & 1 deletion tests/ssr_spec.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
local u = require "ssr.utils"
local ParseContext = require("ssr.parse").ParseContext
local ts = vim.treesitter
local search = require("ssr.search").search
local replace = require("ssr.search").replace

Expand Down Expand Up @@ -222,6 +223,10 @@ local a = vim.api
]]

describe("", function()
-- Plenary runs nvim with `--noplugin` argument.
-- Make sure nvim-treesitter is loaded, which populates vim.treesitter's ft_to_lang table.
require "nvim-treesitter"

for _, s in ipairs(tests) do
local ft, desc, content, pattern, template, expected =
s:match "^ (%a-) (.-)\n(.-)%s?====%s?(.-)%s?==>>%s?(.-)%s?====%s?(.-)%s?$"
Expand Down Expand Up @@ -249,7 +254,9 @@ describe("", function()
local buf = vim.api.nvim_create_buf(false, true)
vim.bo[buf].filetype = ft
vim.api.nvim_buf_set_lines(buf, 0, -1, true, content)
local origin_node = u.node_for_range(buf, start_row, start_col, end_row, end_col)
local lang = ts.language.get_lang(vim.bo[buf].filetype)
assert(lang, "language not found")
local origin_node = u.node_for_range(buf, lang, start_row, start_col, end_row, end_col)

local parse_context = ParseContext.new(buf, origin_node)
assert(parse_context)
Expand Down

0 comments on commit 10d51dd

Please sign in to comment.