Skip to content

Commit

Permalink
Same wildcard should match the same subtree
Browse files Browse the repository at this point in the history
fixes #15
  • Loading branch information
cshuaimin committed Aug 14, 2023
1 parent 4dbd5d7 commit 5252b54
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 14 deletions.
6 changes: 2 additions & 4 deletions lua/ssr/parse.lua
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
local ts = vim.treesitter
local parsers = require "nvim-treesitter.parsers"
local u = require "ssr.utils"
local wildcard_prefix = require("ssr.search").wildcard_prefix

local M = {}

M.wildcard_prefix = "__ssr_var_"

---@class ParseContext
---@field lang string
---@field before string
Expand Down Expand Up @@ -66,7 +64,7 @@ end
---@return TSNode, string
function ParseContext:parse(pattern)
-- Replace named wildcard $name to identifier __ssr_var_name to avoid syntax error.
pattern = pattern:gsub("%$([_%a%d]+)", M.wildcard_prefix .. "%1")
pattern = pattern:gsub("%$([_%a%d]+)", wildcard_prefix .. "%1")
local context_text = self.before .. pattern .. self.after
local root = ts.get_string_parser(context_text, self.lang):parse()[1]:root()
local lines = vim.split(pattern, "\n")
Expand Down
46 changes: 41 additions & 5 deletions lua/ssr/search.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ local api = vim.api
local ts = vim.treesitter
local parsers = require "nvim-treesitter.parsers"
local u = require "ssr.utils"
local wildcard_prefix = require("ssr.parse").wildcard_prefix

local M = {}

M.wildcard_prefix = "__ssr_var_"

---@class Match
---@field range ExtmarkRange
---@field captures ExtmarkRange[]
Expand Down Expand Up @@ -41,6 +42,33 @@ function ExtmarkRange:get()
return extmark[1], extmark[2], extmark[3].end_row, extmark[3].end_col
end

-- Compare if two captured trees can match.
-- The check is loose because users want to match different types of node.
-- e.g. converting `{ foo: foo }` to shorthand `{ foo }`.
ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred)
---@param node1 TSNode
---@param node2 TSNode
---@return boolean
local function tree_match(node1, node2)
if node1:named() ~= node2:named() then
return false
end
if node1:child_count() == 0 or node2:child_count() == 0 then
return ts.get_node_text(node1, buf) == ts.get_node_text(node2, buf)
end
if node1:child_count() ~= node2:child_count() then
return false
end
for i = 0, node1:child_count() - 1 do
if not tree_match(node1:child(i), node2:child(i)) then
return false
end
end
return true
end
return tree_match(match[pred[2]], match[pred[3]])
end, true)

-- Build a TS sexpr represting the node.
---@param node TSNode
---@param source string
Expand All @@ -54,11 +82,19 @@ local function build_sexpr(node, source)
local text = ts.get_node_text(node, source)

-- Special identifier __ssr_var_name is a named wildcard.
local var = text:match("^" .. wildcard_prefix .. "([_%a%d]+)$")
-- Handle this early to make sure wildcard captures largest node.
local var = text:match("^" .. M.wildcard_prefix .. "([_%a%d]+)$")
if var then
wildcards[var] = next_idx
next_idx = next_idx + 1
return "(_) @" .. var
if not wildcards[var] then
wildcards[var] = next_idx
next_idx = next_idx + 1
return "(_) @" .. var
else
-- Same wildcard should match the same subtree.
local sexpr = string.format("(_) @_%d (#ssr-tree-match? @_%d @%s)", next_idx, next_idx, var)
next_idx = next_idx + 1
return sexpr
end
end

-- Leaf nodes (keyword, identifier, literal and symbol) should match text.
Expand Down
77 changes: 72 additions & 5 deletions tests/ssr_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,89 @@ foo($a, $b) ==>> ($a).foo($b)
String::from((y + 5).foo(z))
]]

t [[ go parsed correctly
t [[ go parse Go := in function
func main() {
<commit, _ := os.LookupEnv("GITHUB_SHA")>
print(commit)
<commit, _ := os.LookupEnv("GITHUB_SHA")>
}
====
$a, _ := os.LookupEnv($b)
==>>
$a := os.Getenv($b)
====
func main() {
commit := os.Getenv("GITHUB_SHA")
print(commit)
commit := os.Getenv("GITHUB_SHA")
}
]]

t [[ go match Go if err
fn main() {
<if err != nil {
panic(err)
}>
}
====
if err != nil { panic(err) } ==>> x
====
fn main() {
x
}
]]

t [[ rust reused wildcard: compound assignments
<idx = idx + 1>;
bar = foo + idx;
*foo.bar() = * foo . bar () + 1;
(foo + bar) = (foo + bar) + 1;
(foo + bar) = (foo - bar) + 1;
====
$a = $a + $b ==>> $a += $b
====
idx += 1;
bar = foo + idx;
*foo.bar() += 1;
(foo + bar) += 1;
(foo + bar) = (foo - bar) + 1;
]]

t [[ python reused wildcard: indent
def f():
<if await foo.bar(baz):
if await foo.bar(baz):
pass>
====
if $foo:
if $foo:
$body
==>>
if $foo:
$body
====
def f():
if await foo.bar(baz):
pass
]]

-- two `foo`s have different type: `property_identifier` and `identifier`
t [[ javascript reused wildcard: match different node types 1
<{ foo: foo }>
{ foo: bar }
====
{ $a: $a } ==>> { $a }
====
{ foo }
{ foo: bar }
]]

t [[ lua reused wildcard: match different node types 2
<local api = vim.api>
local a = vim.api
====
local $a = vim.$a ==>> x
====
x
local a = vim.api
]]

describe("", function()
for _, s in ipairs(tests) do
local ft, desc, content, pattern, template, expected =
Expand Down

0 comments on commit 5252b54

Please sign in to comment.