Skip to content

Commit

Permalink
feat: Support context recognition for injected languages
Browse files Browse the repository at this point in the history
  • Loading branch information
kwaszczuk authored Feb 16, 2024
1 parent 32bbb21 commit 0a95d47
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 56 deletions.
127 changes: 72 additions & 55 deletions lua/treesitter-context/context.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,13 @@ local get_lang = vim.treesitter.language.get_lang or require('nvim-treesitter.pa
--- @diagnostic disable-next-line:deprecated
local get_query = vim.treesitter.query.get or vim.treesitter.query.get_query

---@param bufnr integer
---@param row integer
---@param col integer
---@return TSNode?
local function get_node(bufnr, row, col)
local root_tree = vim.treesitter.get_parser(bufnr)
if not root_tree then
return
end

return root_tree:named_node_for_range({ row, col, row, col + 1 })
end
--- @param langtree LanguageTree
--- @param range Range4
--- @return TSNode[]?
local function get_parent_nodes(langtree, range)
local tree = langtree:tree_for_range(range, { ignore_injections = true })
local n = tree:root():named_descendant_for_range(unpack(range))

--- @param node TSNode
--- @return TSNode[]
local function get_parent_nodes(node)
local n = node --- @type TSNode?
local ret = {} --- @type TSNode[]
while n do
ret[#ret + 1] = n
Expand Down Expand Up @@ -108,12 +98,9 @@ local context_range = cache.memoize(function(node, query)
end
end, hash_node)

---@param bufnr integer
---@param lang string
---@return Query?
local function get_context_query(bufnr)
--- @type string
local lang = assert(get_lang(vim.bo[bufnr].filetype))

local function get_context_query(lang)
local ok, query = pcall(get_query, lang, 'context')

if not ok then
Expand Down Expand Up @@ -182,6 +169,37 @@ end

local M = {}

---@param bufnr integer
---@param row integer
---@param col integer
---@return LanguageTree[]
local function get_parent_langtrees(bufnr, range)
local root_tree = vim.treesitter.get_parser(bufnr)
if not root_tree then
return {}
end

local parent_langtrees = {root_tree}

while true do
local child_langtree = nil

for _, langtree in pairs(parent_langtrees[#parent_langtrees]:children()) do
if langtree:contains(range) then
child_langtree = langtree
break
end
end

if child_langtree == nil then
break
end
parent_langtrees[#parent_langtrees + 1] = child_langtree
end

return parent_langtrees
end

--- @param bufnr integer
--- @param winid integer
--- @return Range4[]?, string[]?
Expand All @@ -196,12 +214,6 @@ function M.get(bufnr, winid)
return
end

local query = get_context_query(bufnr)

if not query then
return
end

local top_row = fn.line('w0', winid) - 1

--- @type integer, integer
Expand All @@ -220,40 +232,45 @@ function M.get(bufnr, winid)

for offset = 0, max_lines do
local node_row = row + offset

local node = get_node(bufnr, node_row, offset == 0 and col or 0)
if not node then
return
end

local parents = get_parent_nodes(node)
local col0 = offset == 0 and col or 0
local line_range = { node_row, col0, node_row, col0 + 1 }

context_ranges = {}
context_lines = {}
contexts_height = 0

for i = #parents, 1, -1 do
local parent = parents[i]
local parent_start_row = parent:range()

local contexts_end_row = top_row + math.min(max_lines, contexts_height)
-- Only process the parent if it is not in view.
if parent_start_row < contexts_end_row then
local range0 = context_range(parent, query)
if range0 then
local range, lines = get_text_for_range(range0)

local last_context = context_ranges[#context_ranges]
if last_context and parent_start_row == last_context[1] then
-- If there are multiple contexts on the same row, then prefer the inner
contexts_height = contexts_height - util.get_range_height(last_context)
context_ranges[#context_ranges] = nil
context_lines[#context_lines] = nil
end
local parent_trees = get_parent_langtrees(bufnr, line_range)
for i = 1, #parent_trees, 1 do
local langtree = parent_trees[i]
local query = get_context_query(langtree:lang())
if not query then
return
end

contexts_height = contexts_height + util.get_range_height(range)
context_ranges[#context_ranges + 1] = range
context_lines[#context_lines + 1] = lines
local parents = get_parent_nodes(langtree, line_range)
for j = #parents, 1, -1 do
local parent = parents[j]
local parent_start_row = parent:range()

local contexts_end_row = top_row + math.min(max_lines, contexts_height)
-- Only process the parent if it is not in view.
if parent_start_row < contexts_end_row then
local range0 = context_range(parent, query)
if range0 then
local range, lines = get_text_for_range(range0)

local last_context = context_ranges[#context_ranges]
if last_context and parent_start_row == last_context[1] then
-- If there are multiple contexts on the same row, then prefer the inner
contexts_height = contexts_height - util.get_range_height(last_context)
context_ranges[#context_ranges] = nil
context_lines[#context_lines] = nil
end

contexts_height = contexts_height + util.get_range_height(range)
context_ranges[#context_ranges + 1] = range
context_lines[#context_lines + 1] = lines
end
end
end
end
Expand Down
45 changes: 45 additions & 0 deletions test/test.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
```html
<html>
<body>






<script>
function test() {
if test != "" {
}
}
</script>
</body>
</html>
```
45 changes: 44 additions & 1 deletion test/ts_context_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@ describe('ts_context', function()
"lua",
"rust",
"cpp",
"typescript"
"typescript",
"markdown",
"html",
"javascript",
},
sync_install = true,
}
]]
-- Required for the proper Markdown support
exec_lua [[require'nvim-treesitter'.setup()]]

cmd [[let $XDG_CACHE_HOME='scratch/cache']]
cmd [[set packpath=]]
cmd('syntax enable')
Expand Down Expand Up @@ -329,6 +335,43 @@ describe('ts_context', function()
]]}
end)

it('markdown', function()
cmd('edit test/test.md')
exec_lua [[vim.treesitter.start()]]

feed'3<C-e>'
screen:expect{grid=[[
{14:<html>}{2: }|
{2: }{14:<body>}{2: }|
|*3
^ |
{15:<script>} |
|*9
]]}

feed'5<C-e>'
screen:expect{grid=[[
{14:<html>}{2: }|
{2: }{14:<body>}{2: }|
{2: }{14:<script>}{2: }|
|*2
^ |
|*8
{4:function} {5:test}{15:()} {15:{} |
|
]]}

feed'12<C-e>'
screen:expect{grid=[[
{14:<html>}{2: }|
{2: }{14:<body>}{2: }|
{2: }{14:<script>}{2: }|
{2: }{1:function}{2: }{3:test}{14:()}{2: }{14:{}{2: }|
{2: }{1:if}{2: }{3:test}{2: }{1:!=}{2: }{10:""}{2: }{14:{}{2: }|
^ |
|*10
]]}
end)
end)

end)

0 comments on commit 0a95d47

Please sign in to comment.