diff --git a/lua/otter/keeper.lua b/lua/otter/keeper.lua index 2bacc6b..77f574f 100644 --- a/lua/otter/keeper.lua +++ b/lua/otter/keeper.lua @@ -2,6 +2,7 @@ local M = {} local fn = require("otter.tools.functions") local extensions = require("otter.tools.extensions") +local treesitter_iterator = require("otter.tools.treesitter_iterator") local api = vim.api local ts = vim.treesitter @@ -50,11 +51,10 @@ M.extract_code_chunks = function(main_nr, lang, exclude_eval_false, row_from, ro local code_chunks = {} local lang_capture = nil - for id, node, metadata in query:iter_captures(root, main_nr) do + for id, node, metadata in treesitter_iterator.iter_captures(root, main_nr, query) do local name = query.captures[id] local text local was_stripped - lang_capture = determine_language(main_nr, name, node, metadata, lang_capture) if lang_capture @@ -135,7 +135,7 @@ M.get_current_language_context = function(main_nr) local tree = parser:parse() local root = tree[1]:root() local lang_capture = nil - for id, node, metadata in query:iter_captures(root, main_nr) do + for id, node, metadata in treesitter_iterator.iter_captures(root, main_nr, query) do local name = query.captures[id] lang_capture = determine_language(main_nr, name, node, metadata, lang_capture) @@ -377,7 +377,7 @@ M.get_language_lines_around_cursor = function() local tree = parser:parse() local root = tree[1]:root() - for id, node, metadata in query:iter_captures(root, main_nr) do + for id, node, metadata in treesitter_iterator.iter_captures(root, main_nr, query) do local name = query.captures[id] if name == "content" then if ts.is_in_node_range(node, row, col) then diff --git a/lua/otter/tools/treesitter_iterator.lua b/lua/otter/tools/treesitter_iterator.lua new file mode 100644 index 0000000..ed82e36 --- /dev/null +++ b/lua/otter/tools/treesitter_iterator.lua @@ -0,0 +1,31 @@ +M = {} + +M.iter_captures = function (node, source, query) + if type(source) == "number" and source == 0 then + source = vim.api.nvim_get_current_buf() + end + + local raw_iter = node:_rawquery(query.query, true, 0, -1) + local metadata = {} + local function iter(end_line) + local capture, captured_node, match = raw_iter() + + if match ~= nil then + local active = query:match_preds(match, match.pattern, source) + match.active = active + if not active then + return iter(end_line) -- tail call: try next match + else + -- if it has an active match, reset the metadata. + -- then hopefully apply_directives can fill the metadata + metadata = {} + end + query:apply_directives(match, match.pattern, source, metadata) + end + return capture, captured_node, metadata + end + return iter +end + + +return M