Skip to content

Commit

Permalink
Fix REL label issue in prompt generation (#367)
Browse files Browse the repository at this point in the history
* Fix REL label issue.

* Remove .issue mark.
  • Loading branch information
rmitsch committed Nov 13, 2023
1 parent 7687d44 commit 549681c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
4 changes: 3 additions & 1 deletion spacy_llm/tasks/rel/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def _preannotate(doc: Union[Doc, RELExample]) -> str:
for i, ent in enumerate(doc.ents):
end = ent.end_char
before, after = text[: end + offset], text[end + offset :]
annotation = f"[ENT{i}:{ent.label}]"
annotation = (
f"[ENT{i}:{ent.label if isinstance(doc, RELExample) else ent.label_}]"
)
offset += len(annotation)
text = f"{before}{annotation}{after}"

Expand Down
20 changes: 20 additions & 0 deletions spacy_llm/tests/tasks/test_rel.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,23 @@ def test_incorrect_indexing():
)
== 0
)


@pytest.mark.external
@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available")
def test_labels_in_prompt(request: FixtureRequest):
"""See https://github.com/explosion/spacy-llm/issues/366."""
config = Config().from_str(request.getfixturevalue("zeroshot_cfg_string"))
config["components"].pop("ner")
config.pop("initialize")
config["nlp"]["pipeline"] = ["llm"]
config["components"]["llm"]["task"]["labels"] = ["A", "B", "C"]
nlp = assemble_from_config(config)

doc = Doc(get_lang_class("en")().vocab, words=["Well", "hello", "there"])
doc.ents = [Span(doc, 0, 1, "A"), Span(doc, 1, 2, "B"), Span(doc, 2, 3, "C")]

assert (
"Well[ENT0:A] hello[ENT1:B] there[ENT2:C]"
in list(nlp.get_pipe("llm")._task.generate_prompts([doc]))[0]
)

0 comments on commit 549681c

Please sign in to comment.