Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix displacy span stacking #13068

Merged
merged 5 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions spacy/displacy/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,25 @@ def render_spans(
spans (list): Individual entity spans and their start, end, label, kb_id and kb_url.
title (str / None): Document title set in Doc.user_data['title'].
"""
per_token_info = []
per_token_info = self._assemble_per_token_info(tokens, spans)
markup = self._render_markup(per_token_info)
markup = TPL_SPANS.format(content=markup, dir=self.direction)
if title:
markup = TPL_TITLE.format(title=title) + markup
return markup

@staticmethod
def _assemble_per_token_info(
tokens: List[str], spans: List[Dict[str, Any]]
) -> List[Dict[str, List[Dict[str, Any]]]]:
"""Assembles token info used to generate markup in render_spans().
tokens (List[str]): Tokens in text.
spans (List[Dict[str, Any]]): Spans in text.
RETURNS (List[Dict[str, List[Dict, str, Any]]]): Per token info needed to render HTML markup for given tokens
and spans.
"""
per_token_info: List[Dict[str, List[Dict[str, Any]]]] = []

# we must sort so that we can correctly describe when spans need to "stack"
# which is determined by their start token, then span length (longer spans on top),
# then break any remaining ties with the span label
Expand All @@ -154,29 +172,35 @@ def render_spans(
s["label"],
),
)

for s in spans:
# this is the vertical 'slot' that the span will be rendered in
# vertical_position = span_label_offset + (offset_step * (slot - 1))
s["render_slot"] = 0

for idx, token in enumerate(tokens):
# Identify if a token belongs to a Span (and which) and if it's a
# start token of said Span. We'll use this for the final HTML render
token_markup: Dict[str, Any] = {}
token_markup["text"] = token
concurrent_spans = 0
intersecting_spans: List[Dict[str, Any]] = []
entities = []
for span in spans:
ent = {}
if span["start_token"] <= idx < span["end_token"]:
concurrent_spans += 1
span_start = idx == span["start_token"]
ent["label"] = span["label"]
ent["is_start"] = span_start
if span_start:
# When the span starts, we need to know how many other
# spans are on the 'span stack' and will be rendered.
# This value becomes the vertical render slot for this entire span
span["render_slot"] = concurrent_spans
span["render_slot"] = (
intersecting_spans[-1]["render_slot"]
if len(intersecting_spans)
else 0
) + 1
intersecting_spans.append(span)
ent["render_slot"] = span["render_slot"]
kb_id = span.get("kb_id", "")
kb_url = span.get("kb_url", "#")
Expand All @@ -193,11 +217,8 @@ def render_spans(
span["render_slot"] = 0
token_markup["entities"] = entities
per_token_info.append(token_markup)
markup = self._render_markup(per_token_info)
markup = TPL_SPANS.format(content=markup, dir=self.direction)
if title:
markup = TPL_TITLE.format(title=title) + markup
return markup

return per_token_info

def _render_markup(self, per_token_info: List[Dict[str, Any]]) -> str:
"""Render the markup from per-token information"""
Expand Down
22 changes: 21 additions & 1 deletion spacy/tests/test_displacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from spacy import displacy
from spacy.displacy.render import DependencyRenderer, EntityRenderer
from spacy.displacy.render import DependencyRenderer, EntityRenderer, SpanRenderer
from spacy.lang.en import English
from spacy.lang.fa import Persian
from spacy.tokens import Doc, Span
Expand Down Expand Up @@ -468,3 +468,23 @@ def test_issue12816(en_vocab) -> None:
# Verify that the HTML tag is still escaped
html = displacy.render(doc, style="span")
assert "&lt;TEST&gt;" in html


@pytest.mark.issue(13056)
def test_displacy_span_stacking():
"""Test whether span stacking works properly for multiple overlapping spans."""
spans = [
{"start_token": 2, "end_token": 5, "label": "SkillNC"},
{"start_token": 0, "end_token": 2, "label": "Skill"},
{"start_token": 1, "end_token": 3, "label": "Skill"},
]
tokens = ["Welcome", "to", "the", "Bank", "of", "China", "."]
per_token_info = SpanRenderer._assemble_per_token_info(spans=spans, tokens=tokens)

assert len(per_token_info) == len(tokens)
assert all([len(per_token_info[i]["entities"]) == 1 for i in (0, 3, 4)])
assert all([len(per_token_info[i]["entities"]) == 2 for i in (1, 2)])
assert per_token_info[1]["entities"][0]["render_slot"] == 1
assert per_token_info[1]["entities"][1]["render_slot"] == 2
assert per_token_info[2]["entities"][0]["render_slot"] == 2
assert per_token_info[2]["entities"][1]["render_slot"] == 3