-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
defa940
commit 3072cb7
Showing
5 changed files
with
252 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from cohere.response import CohereObject | ||
from typing import List | ||
|
||
|
||
class Entity: | ||
''' | ||
Entity represents a single extracted entity from a text. An entity has a | ||
type and a value. For the text "I am a plumber", an extracted entity could be | ||
of type "profession" with the value "plumber". | ||
''' | ||
|
||
def __init__(self, type: str, value: str) -> None: | ||
self.type = type | ||
self.value = value | ||
|
||
def toDict(self) -> dict: | ||
return {"type": self.type, "value": self.value} | ||
|
||
def __str__(self) -> str: | ||
return f"{self.type}: {self.value}" | ||
|
||
def __repr__(self) -> str: | ||
return str(self.toDict()) | ||
|
||
def __eq__(self, other) -> bool: | ||
return self.type == other.type and self.value == other.value | ||
|
||
|
||
class Example: | ||
''' | ||
Example represents a sample extraction from a text, to be provided to the model. An Example | ||
contains the input text and a list of entities extracted from the text. | ||
>>> example = Example("I am a plumber", [Entity("profession", "plumber")]) | ||
>>> example = Example("Joe is a teacher", [ | ||
Entity("name", "Joe"), Entity("profession", "teacher") | ||
]) | ||
''' | ||
|
||
def __init__(self, text: str, entities: List[Entity]) -> None: | ||
self.text = text | ||
self.entities = entities | ||
|
||
def toDict(self): | ||
return {"text": self.text, "entities": [entity.toDict() for entity in self.entities]} | ||
|
||
def __str__(self) -> str: | ||
return f"{self.text} -> {self.entities}" | ||
|
||
def __repr__(self) -> str: | ||
return str(self.toDict()) | ||
|
||
|
||
class Extraction: | ||
''' | ||
Represents the result of extracting entities from a single text input. An extraction | ||
contains the text input, the list of entities extracted from the text, and the id of the | ||
extraction. | ||
''' | ||
|
||
def __init__(self, id: str, text: str, entities: List[Entity]) -> None: | ||
self.id = id | ||
self.text = text | ||
self.entities = entities | ||
|
||
def __repr__(self) -> str: | ||
return str(self.toDict()) | ||
|
||
def toDict(self) -> dict: | ||
return {"id": self.id, "text": self.text, "entities": [entity.toDict() for entity in self.entities]} | ||
|
||
|
||
class Extractions(CohereObject): | ||
''' | ||
Represents the main response of calling the Extract API. An Extractions is iterable and | ||
contains a list of of Extraction objects, one per text input. | ||
''' | ||
|
||
def __init__(self, extractions: List[Extraction]) -> None: | ||
self.extractions = extractions | ||
self.iterator = iter(extractions) | ||
|
||
def __iter__(self) -> iter: | ||
return self.iterator | ||
|
||
def __next__(self) -> next: | ||
return next(self.iterator) | ||
|
||
def __len__(self) -> int: | ||
return len(self.extractions) | ||
|
||
def __getitem__(self, index: int) -> Extraction: | ||
return self.extractions[index] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import unittest | ||
import cohere | ||
from cohere.extract import Example, Entity, Extractions | ||
from utils import get_api_key | ||
|
||
co = cohere.Client(get_api_key()) | ||
|
||
|
||
class TestExtract(unittest.TestCase): | ||
def test_success(self): | ||
examples = [Example( | ||
text="hello my name is John, and I like to play ping pong", | ||
entities=[Entity(type="Name", value="John")])] | ||
texts = ["hello Roberta, how are you doing today?"] | ||
|
||
extractions = co.unstable_extract('small', examples, texts) | ||
|
||
self.assertIsInstance(extractions, Extractions) | ||
self.assertIsInstance(extractions[0].text, str) | ||
self.assertIsInstance(extractions[0].entities, list) | ||
self.assertEqual(extractions[0].entities[0].type, "Name") | ||
self.assertEqual(extractions[0].entities[0].value, "Roberta") | ||
|
||
def test_empty_text(self): | ||
with self.assertRaises(cohere.CohereError): | ||
co.unstable_extract( | ||
'small', examples=[Example( | ||
text="hello my name is John, and I like to play ping pong", | ||
entities=[Entity(type="Name", value="John")])], | ||
texts=[""]) | ||
|
||
def test_empty_entities(self): | ||
with self.assertRaises(cohere.CohereError): | ||
co.unstable_extract( | ||
'large', examples=[Example( | ||
text="hello my name is John, and I like to play ping pong", | ||
entities=[])], | ||
texts=["hello Roberta, how are you doing today?"]) | ||
|
||
def test_varying_amount_of_entities(self): | ||
examples = [ | ||
Example( | ||
text="the bananas are red", | ||
entities=[Entity(type="fruit", value="bananas"), Entity(type="color", value="red")]), | ||
Example( | ||
text="i love the color blue", | ||
entities=[Entity(type="color", value="blue")]), | ||
Example( | ||
text="i love apples", | ||
entities=[Entity(type="fruit", value="apple")]), | ||
Example( | ||
text="purple is my favorite color", | ||
entities=[Entity(type="color", value="purple")]), | ||
Example( | ||
text="wow, that apple is green?", | ||
entities=[Entity(type="fruit", value="apple"), Entity(type="color", value="green")])] | ||
texts = ["Jimmy ate my banana", "my favorite color is yellow", "green apple is my favorite fruit"] | ||
|
||
extractions = co.unstable_extract('medium', examples, texts) | ||
|
||
self.assertIsInstance(extractions, Extractions) | ||
self.assertIsInstance(extractions[0].text, str) | ||
self.assertIsInstance(extractions[1].text, str) | ||
self.assertIsInstance(extractions[2].text, str) | ||
self.assertIsInstance(extractions[0].entities, list) | ||
self.assertIsInstance(extractions[1].entities, list) | ||
self.assertIsInstance(extractions[2].entities, list) | ||
|
||
self.assertEqual(len(extractions[0].entities), 1) | ||
self.assertIn(Entity(type="fruit", value="banana"), extractions[0].entities) | ||
|
||
self.assertEqual(len(extractions[1].entities), 1) | ||
self.assertIn(Entity(type="color", value="yellow"), extractions[1].entities) | ||
|
||
self.assertEqual(len(extractions[2].entities), 2) | ||
self.assertIn(Entity(type="color", value="green"), extractions[2].entities) | ||
self.assertIn(Entity(type="fruit", value="apple"), extractions[2].entities) | ||
|
||
def test_many_examples_and_multiple_texts(self): | ||
examples = [ | ||
Example( | ||
text="hello my name is John, and I like to play ping pong", | ||
entities=[Entity(type="Name", value="John"), Entity(type="Game", value="ping pong")]), | ||
Example( | ||
text="greetings, I'm Roberta and I like to play golf", | ||
entities=[Entity(type="Name", value="Roberta"), Entity(type="Game", value="golf")]), | ||
Example( | ||
text="let me introduce myself, my name is Tina and I like to play baseball", | ||
entities=[Entity(type="Name", value="Tina"), Entity(type="Game", value="baseball")])] | ||
texts = ["hi, my name is Charlie and I like to play basketball", "hello, I'm Olivia and I like to play soccer"] | ||
|
||
extractions = co.unstable_extract('medium', examples, texts) | ||
|
||
self.assertEqual(len(extractions), 2) | ||
self.assertIsInstance(extractions, Extractions) | ||
self.assertIsInstance(extractions[0].text, str) | ||
self.assertIsInstance(extractions[1].text, str) | ||
self.assertIsInstance(extractions[0].entities, list) | ||
self.assertIsInstance(extractions[1].entities, list) | ||
self.assertEqual(len(extractions[0].entities), 2) | ||
self.assertEqual(len(extractions[1].entities), 2) | ||
|
||
def test_no_entities(self): | ||
examples = [ | ||
Example( | ||
text="hello my name is John, and I like to play ping pong", | ||
entities=[Entity(type="Name", value="John"), Entity(type="Game", value="ping pong")]), | ||
Example( | ||
text="greetings, I'm Roberta and I like to play golf", | ||
entities=[Entity(type="Name", value="Roberta"), Entity(type="Game", value="golf")]), | ||
Example( | ||
text="let me introduce myself, my name is Tina and I like to play baseball", | ||
entities=[Entity(type="Name", value="Tina"), Entity(type="Game", value="baseball")])] | ||
texts = ["hi, my name is Charlie and I like to play basketball", "hello!"] | ||
|
||
extractions = co.unstable_extract('medium', examples, texts) | ||
|
||
self.assertEqual(len(extractions), 2) | ||
self.assertIsInstance(extractions, Extractions) | ||
|
||
self.assertEqual(len(extractions[0].entities), 2) | ||
self.assertIn(Entity(type="Name", value="Charlie"), extractions[0].entities) | ||
self.assertIn(Entity(type="Game", value="basketball"), extractions[0].entities) | ||
|
||
self.assertEqual(len(extractions[1].entities), 0) |