Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PromptNER based Chain-of-Thought prompting for span tasks #180

Merged
merged 108 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
a175d36
initial POC for Chain of Thought NER task
Jun 15, 2023
820337b
ruff fix
Jun 15, 2023
ea261aa
Merge branch 'main' of ssh://github.com/explosion/spacy-llm into kab/…
Jun 20, 2023
6fa0b6a
Merge branch 'develop' of ssh://github.com/explosion/spacy-llm into k…
Jul 4, 2023
4c562d9
update template
vinbo8 Jul 5, 2023
54a9eae
consilidate approach to work with main SpanTask
Jul 5, 2023
8737f4a
Merge branch 'develop' of ssh://github.com/explosion/spacy-llm into k…
Jul 5, 2023
7a1fbdb
fix tests around label consistency checks
Jul 5, 2023
2921ca9
Merge branch 'main' into kab/cot-ner
vinbo8 Jul 6, 2023
2a0b1c8
merge kab/cot-ner-integrate
vinbo8 Jul 6, 2023
4772cb2
fix edge cases
vinbo8 Jul 6, 2023
d95eaac
merge develop
vinbo8 Jul 6, 2023
57a71b3
update label consistency checks
vinbo8 Jul 6, 2023
6335b3b
move label consistency checks
vinbo8 Jul 6, 2023
c565256
handle labels in span.py
vinbo8 Jul 6, 2023
54f28d2
cleanup older NER
vinbo8 Jul 17, 2023
aa25bba
fixes
vinbo8 Jul 17, 2023
f52feb5
cleanup
vinbo8 Jul 17, 2023
19d55cc
update NER template with label_definitions + initial description, fix…
Jul 21, 2023
6edea90
properly parametrize response parsing for SpanReason test
Jul 21, 2023
9f08287
start to parametrize NER tests properly with new v3 template
Jul 21, 2023
e4f8195
fix docstring
Jul 21, 2023
cb64f6a
rm single_match since it's always true now
Jul 21, 2023
a8c88bf
rm single_match since it's always true now
Jul 21, 2023
3e34add
fix typing of description and default properly for SpanCatTask
Jul 21, 2023
c40b897
fix test
Jul 21, 2023
fc3455c
fix NER tests
Jul 24, 2023
36579f4
fir more ner tests + add initial test for SpanReason.from_str
Jul 24, 2023
e0f2ef0
fix ner to_disk test
Jul 25, 2023
5a4b681
enable adding ner prompt examples from initialize and fix ner_init test
Jul 25, 2023
18e5e10
test fixes
Jul 25, 2023
eddac8c
add yaml/jsonl version of ner examples. Fix inconsistent labels tests
Jul 25, 2023
1299787
use yaml/jsonl versions of ner examples
Jul 25, 2023
b0a076d
actually check scoring with real LLM call
Jul 25, 2023
b432a50
rename format_response to extract_span_reasons
Jul 25, 2023
ee64727
move Self to compat types
Jul 25, 2023
01e1de9
fix test for serde
Jul 26, 2023
8c31b05
Self only in 3.10+
Jul 27, 2023
6235091
Self only in 3.11+ actually
Jul 27, 2023
66c7377
ner test fixes
Jul 27, 2023
8e5952c
convert spancat to new span task format
Jul 27, 2023
3fa3fbe
add better doc for SpanReason.to_str
Jul 27, 2023
7beaf93
fixing tests for spancat
Jul 27, 2023
0e162d6
adjust span matching by adding an setting
Jul 31, 2023
6894a80
support conditional allow_overlap like standard spancat
Jul 31, 2023
bb261bd
remove dict | operator that only works in python3.9 +
Jul 31, 2023
dfa96c0
disable test for now so CI passes
Jul 31, 2023
1700849
revert spanreason start_char
Jul 31, 2023
ae3a545
fix spancat template rendering for allow_overlap
Jul 31, 2023
850e8ee
clean up tests for init with spacy examples
Jul 31, 2023
7242d6e
fix spancat test?
Jul 31, 2023
150d21e
make spancat scoring use external model, not weird dummy data
Jul 31, 2023
abd52b8
run case sensitive matching then fallback to case insensitive if the …
Jul 31, 2023
bde5c8d
rm prev_span reference in parsing
Jul 31, 2023
c86b0e9
separate span parsing for a single doc into its own function
Jul 31, 2023
91bffaf
fix typing on the regression test
Jul 31, 2023
d02b425
add description field to cfg_keys so it gets serialized
Jul 31, 2023
17a2a9a
add old spancat/ner versions to tasks.legacy module
Aug 1, 2023
16f5d04
add deprecation warnings + test deprecation warnings
Aug 1, 2023
7cab77f
update usage examples
Aug 1, 2023
b1a4adf
update examples and readme
Aug 1, 2023
8b40100
Merge branch 'main' of ssh://github.com/explosion/spacy-llm into kab/…
Aug 1, 2023
574c9fe
Merge branch 'kab/cot-ner' of ssh://github.com/explosion/spacy-llm in…
Aug 1, 2023
3ab56e5
fix usage_examples
Aug 1, 2023
46802bb
Merge branch 'kab/cot-ner' of ssh://github.com/explosion/spacy-llm in…
Aug 1, 2023
629dc3d
Merge pull request #239 from explosion/kab/cot-ner-legacy
Aug 2, 2023
12093bf
Merge branch 'develop' of ssh://github.com/explosion/spacy-llm into k…
Aug 2, 2023
f824a8e
Merge branch 'kab/cot-ner' of ssh://github.com/explosion/spacy-llm in…
Aug 2, 2023
09533e0
rename warning to LLMW001, fix usage_example + readme tests
Aug 2, 2023
a6888d4
remove separate case sensitive match step before doing case insensiti…
Aug 2, 2023
d43c840
fix regression test to have 3 ents
Aug 2, 2023
b42e849
fix incremental parsing
Aug 2, 2023
47c8a75
rm extra docstring stuff
Aug 2, 2023
edbc17d
rm extra test
Aug 3, 2023
990edf3
consolidate new spans template and ensure valid labels appear in the …
Aug 4, 2023
f77b0cd
make prompt_examples required since it's required in confection facto…
Aug 4, 2023
e87f1a4
fix template rendering tests with new optional description and defaul…
Aug 4, 2023
03ab495
Merge branch 'develop' of ssh://github.com/explosion/spacy-llm into k…
Aug 8, 2023
d1b1410
Remove extra deprecation warning
Aug 8, 2023
208dbd6
Update usage_examples/ner_v3_openai/README.md
Aug 8, 2023
20be25a
Sync with new task structure.
rmitsch Aug 21, 2023
e6f63f5
Fix 3.6 Protocol import.
rmitsch Aug 21, 2023
af04a26
Update ignored warnings.
rmitsch Aug 21, 2023
8a95115
Fix filterwarnings.
rmitsch Aug 21, 2023
80fc6c8
Update filterwarnings.
rmitsch Aug 21, 2023
74a8486
Renamed examples to prompt_examples.
rmitsch Aug 21, 2023
61a6ab7
Add default example if none are provided to COT NER/SpanCat tasks (#270)
Aug 24, 2023
fa4e879
Test Pydantic Mac OS Py 3.8 issue.
rmitsch Aug 24, 2023
5903a03
Incorporate feedback. Readd Pydantic REL example workaround.
rmitsch Aug 24, 2023
69f5111
Readd NER Dolly usage example, removed TextCat Dolly one.
rmitsch Aug 24, 2023
4706aa7
Update NER Dolly usage example to use NER.v3.
rmitsch Aug 24, 2023
134da76
Readd NER Dolly test. Revert to NER.v2. Refactor span extraction for …
rmitsch Aug 24, 2023
7d27d0e
Fix span reason extraction.
rmitsch Aug 24, 2023
2979ffc
Add working Paris-Paris-Paris example.
rmitsch Aug 24, 2023
398652a
Uncomment example for NER prediction test.
rmitsch Aug 24, 2023
b4ad4d8
Fix NER prediction test.
rmitsch Aug 24, 2023
66c4fae
remove errors class entirely
svlandeg Aug 24, 2023
88a2482
Merge branch 'kab/cot-ner' of github.com:explosion/spacy-llm into kab…
rmitsch Aug 25, 2023
2664a47
Update .github/workflows/test.yml
rmitsch Aug 25, 2023
ee7fcdd
Update spacy_llm/tests/tasks/test_ner.py
rmitsch Aug 25, 2023
b64bf90
Remove overlap part in NER template.
rmitsch Aug 25, 2023
f895343
Merge branch 'kab/cot-ner' of github.com:explosion/spacy-llm into kab…
rmitsch Aug 25, 2023
c763689
Remove overlap path in NER and SpanCat templates.
rmitsch Aug 25, 2023
d2f2a07
Update spacy_llm/tasks/spancat/registry.py
rmitsch Aug 25, 2023
48bbd3c
Update spacy_llm/tasks/ner/registry.py
rmitsch Aug 25, 2023
68c0a55
Merge branch 'develop' into kab/cot-ner
rmitsch Aug 25, 2023
1c269e0
Changed SpanCat prompt intro.
rmitsch Aug 25, 2023
3a31530
Add docstring info for description.
rmitsch Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 152 additions & 4 deletions spacy_llm/tasks/ner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from spacy.language import Language
from spacy.scorer import get_ner_prf
Expand All @@ -9,13 +10,14 @@

from ..compat import Literal
from ..registry import registry
from ..ty import ExamplesConfigType
from ..ty import COTExamplesConfigType, ExamplesConfigType
from ..util import split_labels
from .span import SpanExample, SpanTask
from .span import COTSpanExample, SpanExample, SpanReason, SpanTask
from .templates import read_template

_DEFAULT_NER_TEMPLATE_V1 = read_template("ner")
_DEFAULT_NER_TEMPLATE_V2 = read_template("ner.v2")
_DEFAULT_NER_TEMPLATE_V3 = read_template("ner.v3")


@registry.llm_tasks("spacy.NER.v1")
Expand Down Expand Up @@ -107,7 +109,55 @@ def make_ner_task_v2(
)


class NERTask(SpanTask):
@registry.llm_tasks("spacy.NER.v3")
def make_ner_task_v3(
examples: COTExamplesConfigType,
description: str,
labels: Union[List[str], str] = [],
template: str = _DEFAULT_NER_TEMPLATE_V3,
label_definitions: Optional[Dict[str, str]] = None,
normalizer: Optional[Callable[[str], str]] = None,
alignment_mode: Literal["strict", "contract", "expand"] = "contract",
case_sensitive_matching: bool = False,
single_match: bool = False,
):
"""NER.v3 task factory.

examples (Union[Callable[[], Iterable[COTS]]]): Optional callable that
reads a file containing task examples for few-shot learning. If None is
passed, then zero-shot learning will be used.
labels (Union[str, List[str]]): List of labels to pass to the template,
either an actual list or a comma-separated string.
Leave empty to populate it at initialization time (only if examples are provided).
template (str): Prompt template passed to the model.
label_definitions (Optional[Dict[str, str]]): Map of label -> description
of the label to help the language model output the entities wanted.
It is usually easier to provide these definitions rather than
full examples, although both can be provided.
normalizer (Optional[Callable[[str], str]]): optional normalizer function.
alignment_mode (str): "strict", "contract" or "expand".
case_sensitive: Whether to search without case sensitivity.
single_match (bool): If False, allow one substring to match multiple times in
the text. If True, returns the first hit.
"""
labels_list = split_labels(labels)
raw_examples = examples() if callable(examples) else examples
span_examples = [COTSpanExample(**eg) for eg in raw_examples]

return COTNERTask(
labels=labels_list,
template=template,
description=description,
label_definitions=label_definitions,
prompt_examples=span_examples,
normalizer=normalizer,
alignment_mode=alignment_mode,
case_sensitive_matching=case_sensitive_matching,
single_match=single_match,
)


class NERTask(SpanTask[SpanExample]):
def __init__(
self,
labels: List[str] = [],
Expand Down Expand Up @@ -203,10 +253,108 @@ def scorer(
) -> Dict[str, Any]:
return get_ner_prf(examples)

@property
def _Example(self) -> type[SpanExample]:
return SpanExample

def _create_prompt_example(self, example: Example) -> SpanExample:
"""Create an NER prompt example from a spaCy example."""
entities = defaultdict(list)
for ent in example.reference.ents:
entities[ent.label_].append(ent.text)

return SpanExample(text=example.reference.text, entities=entities)


class COTNERTask(SpanTask[COTSpanExample]):
def __init__(
self,
labels: List[str],
template: str,
description: Optional[str] = None,
prompt_examples: Optional[List[COTSpanExample]] = None,
label_definitions: Optional[Dict[str, str]] = None,
normalizer: Optional[Callable[[str], str]] = None,
alignment_mode: Literal["strict", "contract", "expand"] = "contract",
case_sensitive_matching: bool = False,
single_match: bool = False,
):
super().__init__(
labels=labels,
template=template,
description=description,
label_definitions=label_definitions,
prompt_examples=prompt_examples,
normalizer=normalizer,
alignment_mode=alignment_mode,
case_sensitive_matching=case_sensitive_matching,
single_match=single_match,
)

def initialize(
self,
get_examples: Callable[[], Iterable["Example"]],
nlp: Language,
labels: List[str] = [],
**kwargs: Any,
) -> None:
"""Initialize the NER task, by auto-discovering labels.

Labels can be set through, by order of precedence:

- the `[initialize]` section of the pipeline configuration
- the `labels` argument supplied to the task factory
- the labels found in the examples

get_examples (Callable[[], Iterable["Example"]]): Callable that provides examples
for initialization.
nlp (Language): Language instance.
labels (List[str]): Optional list of labels.
"""

examples = get_examples()

if not labels:
labels = list(self._label_dict.values())

if not labels:
label_set = set()

for eg in examples:
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
for ent in eg.reference.ents:
label_set.add(ent.label_)
labels = list(label_set)

self._label_dict = {self._normalizer(label): label for label in labels}

def _format_response(self, response: str) -> Iterable[Tuple[str, Iterable[str]]]:
"""Parse raw string response into a structured format"""
output: dict[str, list[str]] = defaultdict(list)
assert self._normalizer is not None
for line in response.strip().split("\n"):
entity = SpanReason.from_str(line)
if entity:
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
norm_label = self._normalizer(entity.label)
if norm_label not in self._label_dict:
continue
label = self._label_dict[norm_label]
output[label].append(entity.text)
return output.items()

def assign_spans(
self,
doc: Doc,
spans: List[Span],
) -> None:
"""Assign spans to the document."""
doc.set_ents(filter_spans(spans))

def scorer(
self,
examples: Iterable[Example],
) -> Dict[str, Any]:
return get_ner_prf(examples)

@property
def _Example(self) -> type[COTSpanExample]:
return COTSpanExample
106 changes: 82 additions & 24 deletions spacy_llm/tasks/span.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
import warnings
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type
from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Type

import jinja2
from pydantic import BaseModel
Expand All @@ -11,20 +12,56 @@
from .util.serialization import SerializableTask


class SpanReason(BaseModel):
text: str
is_entity: bool
label: str
reason: str

@classmethod
def from_str(cls, s: str, sep: str = "|"):
clean_str = s.strip()
if "." in clean_str:
clean_str = clean_str.split(".", maxsplit=1)[1]
components = [c.strip() for c in clean_str.split(sep)]
if len(components) == 4:
return cls(
text=components[0],
is_entity=components[1].lower() == "true",
label=components[2],
reason=components[3],
)

def __str__(self) -> str:
return self.to_str()

def to_str(self) -> str:
return f"{self.text} | {self.is_entity} | {self.label} | {self.reason}"


class SpanExample(BaseModel):
text: str
entities: Dict[str, List[str]]


class SpanTask(SerializableTask[SpanExample]):
class COTSpanExample(BaseModel):
text: str
entities: List[SpanReason]


_PromptExampleT = TypeVar("_PromptExampleT", SpanExample, COTSpanExample)


class SpanTask(SerializableTask[_PromptExampleT]):
"""Base class for Span-related tasks, eg NER and SpanCat."""

def __init__(
self,
labels: List[str],
template: str,
label_definitions: Optional[Dict[str, str]] = {},
prompt_examples: Optional[List[SpanExample]] = None,
description: Optional[str] = None,
label_definitions: Optional[Dict[str, str]] = None,
prompt_examples: Optional[List[_PromptExampleT]] = None,
normalizer: Optional[Callable[[str], str]] = None,
alignment_mode: Literal[
"strict", "contract", "expand" # noqa: F821
Expand All @@ -37,6 +74,7 @@ def __init__(
self._normalizer(label): label for label in sorted(set(labels))
}
self._template = template
self._description = description
self._label_definitions = label_definitions
self._prompt_examples = prompt_examples or []
self._validate_alignment(alignment_mode)
Expand All @@ -47,16 +85,33 @@ def __init__(
if self._prompt_examples:
self._prompt_examples = self._check_label_consistency()

def _check_label_consistency(self) -> List[SpanExample]:
@property
def labels(self) -> Tuple[str, ...]:
return tuple(self._label_dict.values())

@property
def prompt_template(self) -> str:
return self._template

def _check_label_consistency(self) -> List[_PromptExampleT]:
"""Checks consistency of labels between examples and defined labels. Emits warning on inconsistency.
RETURNS (List[SpanExample]): List of SpanExamples with valid labels.
"""
assert self._prompt_examples
example_labels = {
self._normalizer(key): key
for example in self._prompt_examples
for key in example.entities
}
if isinstance(self._prompt_examples[0], SpanExample):
example_labels = {
self._normalizer(key): key
for example in self._prompt_examples
for key in example.entities
}
else:
example_labels = {
self._normalizer(key.label): key.label
for example in self._prompt_examples
for key in example.entities
if key.is_entity
}
vinbo8 marked this conversation as resolved.
Show resolved Hide resolved

unspecified_labels = {
example_labels[key]
for key in (set(example_labels.keys()) - set(self._label_dict.keys()))
Expand All @@ -70,36 +125,39 @@ def _check_label_consistency(self) -> List[SpanExample]:
)

# Return examples without non-declared labels. If an example only has undeclared labels, it is discarded.
return [
example
for example in [
SpanExample(
examples = []
for example in self._prompt_examples:
if isinstance(self._prompt_examples[0], SpanExample):
span_example = SpanExample(
text=example.text,
entities={
label: entities
for label, entities in example.entities.items()
if self._normalizer(label) in self._label_dict
},
)
for example in self._prompt_examples
]
if len(example.entities)
]
else:
span_example = COTSpanExample(
text=example.text,
entities=[
entity
for entity in example.entities
if self._normalizer(entity.label) in self._label_dict
],
)

@property
def labels(self) -> Tuple[str, ...]:
return tuple(self._label_dict.values())
if len(span_example.entities):
examples.append(span_example)

@property
def prompt_template(self) -> str:
return self._template
return examples

def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]:
environment = jinja2.Environment()
_template = environment.from_string(self._template)
for doc in docs:
prompt = _template.render(
text=doc.text,
description=self._description,
labels=list(self._label_dict.values()),
label_definitions=self._label_definitions,
examples=self._prompt_examples,
Expand Down
4 changes: 4 additions & 0 deletions spacy_llm/tasks/spancat.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ def _cfg_keys(self) -> List[str]:
"_single_match",
]

@property
def _Example(self) -> type[SpanExample]:
return SpanExample

def _create_prompt_example(self, example: Example) -> SpanExample:
"""Create a spancat prompt example from a spaCy example."""
entities = defaultdict(list)
Expand Down
19 changes: 19 additions & 0 deletions spacy_llm/tasks/templates/ner.v3.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{{ description }}
kabirkhan marked this conversation as resolved.
Show resolved Hide resolved
{# whitespace #}
{# whitespace #}
Q: Given the paragraph below, identify a list of possible entities, and for each entry explain why it is or is not an entity:
{# whitespace #}
{# whitespace #}
{%- for example in examples -%}
Paragraph: {{ example.text }}
Answer:
{# whitespace #}
{%- for span in example.entities -%}
{{ loop.index }}. {{ span.to_str() }}
{# whitespace #}
vinbo8 marked this conversation as resolved.
Show resolved Hide resolved
{%- endfor -%}
{# whitespace #}
{# whitespace #}
{%- endfor -%}
Paragraph: {{ text }}
Answer:
Loading
Loading