From 5252b541130541f60cbc0a46a38a39d7145c4264 Mon Sep 17 00:00:00 2001 From: Chen Shuaimin Date: Thu, 10 Aug 2023 15:46:29 +0800 Subject: [PATCH] Same wildcard should match the same subtree fixes #15 --- lua/ssr/parse.lua | 6 ++-- lua/ssr/search.lua | 46 ++++++++++++++++++++++++--- tests/ssr_spec.lua | 77 +++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 115 insertions(+), 14 deletions(-) diff --git a/lua/ssr/parse.lua b/lua/ssr/parse.lua index 254dba1..1050d08 100644 --- a/lua/ssr/parse.lua +++ b/lua/ssr/parse.lua @@ -1,11 +1,9 @@ local ts = vim.treesitter local parsers = require "nvim-treesitter.parsers" -local u = require "ssr.utils" +local wildcard_prefix = require("ssr.search").wildcard_prefix local M = {} -M.wildcard_prefix = "__ssr_var_" - ---@class ParseContext ---@field lang string ---@field before string @@ -66,7 +64,7 @@ end ---@return TSNode, string function ParseContext:parse(pattern) -- Replace named wildcard $name to identifier __ssr_var_name to avoid syntax error. - pattern = pattern:gsub("%$([_%a%d]+)", M.wildcard_prefix .. "%1") + pattern = pattern:gsub("%$([_%a%d]+)", wildcard_prefix .. "%1") local context_text = self.before .. pattern .. self.after local root = ts.get_string_parser(context_text, self.lang):parse()[1]:root() local lines = vim.split(pattern, "\n") diff --git a/lua/ssr/search.lua b/lua/ssr/search.lua index afa7e10..766dd75 100644 --- a/lua/ssr/search.lua +++ b/lua/ssr/search.lua @@ -2,10 +2,11 @@ local api = vim.api local ts = vim.treesitter local parsers = require "nvim-treesitter.parsers" local u = require "ssr.utils" -local wildcard_prefix = require("ssr.parse").wildcard_prefix local M = {} +M.wildcard_prefix = "__ssr_var_" + ---@class Match ---@field range ExtmarkRange ---@field captures ExtmarkRange[] @@ -41,6 +42,33 @@ function ExtmarkRange:get() return extmark[1], extmark[2], extmark[3].end_row, extmark[3].end_col end +-- Compare if two captured trees can match. +-- The check is loose because users want to match different types of node. +-- e.g. converting `{ foo: foo }` to shorthand `{ foo }`. +ts.query.add_predicate("ssr-tree-match?", function(match, _pattern, buf, pred) + ---@param node1 TSNode + ---@param node2 TSNode + ---@return boolean + local function tree_match(node1, node2) + if node1:named() ~= node2:named() then + return false + end + if node1:child_count() == 0 or node2:child_count() == 0 then + return ts.get_node_text(node1, buf) == ts.get_node_text(node2, buf) + end + if node1:child_count() ~= node2:child_count() then + return false + end + for i = 0, node1:child_count() - 1 do + if not tree_match(node1:child(i), node2:child(i)) then + return false + end + end + return true + end + return tree_match(match[pred[2]], match[pred[3]]) +end, true) + -- Build a TS sexpr represting the node. ---@param node TSNode ---@param source string @@ -54,11 +82,19 @@ local function build_sexpr(node, source) local text = ts.get_node_text(node, source) -- Special identifier __ssr_var_name is a named wildcard. - local var = text:match("^" .. wildcard_prefix .. "([_%a%d]+)$") + -- Handle this early to make sure wildcard captures largest node. + local var = text:match("^" .. M.wildcard_prefix .. "([_%a%d]+)$") if var then - wildcards[var] = next_idx - next_idx = next_idx + 1 - return "(_) @" .. var + if not wildcards[var] then + wildcards[var] = next_idx + next_idx = next_idx + 1 + return "(_) @" .. var + else + -- Same wildcard should match the same subtree. + local sexpr = string.format("(_) @_%d (#ssr-tree-match? @_%d @%s)", next_idx, next_idx, var) + next_idx = next_idx + 1 + return sexpr + end end -- Leaf nodes (keyword, identifier, literal and symbol) should match text. diff --git a/tests/ssr_spec.lua b/tests/ssr_spec.lua index f2b3701..2cd3315 100644 --- a/tests/ssr_spec.lua +++ b/tests/ssr_spec.lua @@ -138,10 +138,9 @@ foo($a, $b) ==>> ($a).foo($b) String::from((y + 5).foo(z)) ]] -t [[ go parsed correctly +t [[ go parse Go := in function func main() { - - print(commit) + } ==== $a, _ := os.LookupEnv($b) @@ -149,11 +148,79 @@ $a, _ := os.LookupEnv($b) $a := os.Getenv($b) ==== func main() { - commit := os.Getenv("GITHUB_SHA") - print(commit) + commit := os.Getenv("GITHUB_SHA") } ]] +t [[ go match Go if err +fn main() { + +} +==== +if err != nil { panic(err) } ==>> x +==== +fn main() { + x +} +]] + +t [[ rust reused wildcard: compound assignments +; +bar = foo + idx; +*foo.bar() = * foo . bar () + 1; +(foo + bar) = (foo + bar) + 1; +(foo + bar) = (foo - bar) + 1; +==== +$a = $a + $b ==>> $a += $b +==== +idx += 1; +bar = foo + idx; +*foo.bar() += 1; +(foo + bar) += 1; +(foo + bar) = (foo - bar) + 1; +]] + +t [[ python reused wildcard: indent +def f(): + +==== +if $foo: + if $foo: + $body +==>> +if $foo: + $body +==== +def f(): + if await foo.bar(baz): + pass +]] + +-- two `foo`s have different type: `property_identifier` and `identifier` +t [[ javascript reused wildcard: match different node types 1 +<{ foo: foo }> +{ foo: bar } +==== +{ $a: $a } ==>> { $a } +==== +{ foo } +{ foo: bar } +]] + +t [[ lua reused wildcard: match different node types 2 + +local a = vim.api +==== +local $a = vim.$a ==>> x +==== +x +local a = vim.api +]] + describe("", function() for _, s in ipairs(tests) do local ft, desc, content, pattern, template, expected =