Skip to content

Commit

Permalink
Add Extract V0 API (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigue-h committed May 18, 2022
1 parent defa940 commit 3072cb7
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 3 deletions.
1 change: 1 addition & 0 deletions cohere/__init__.py
Expand Up @@ -7,6 +7,7 @@
GENERATE_URL = 'generate'
EMBED_URL = 'embed'
CLASSIFY_URL = 'classify'
EXTRACT_URL = 'extract'

CHECK_API_KEY_URL = 'check-api-key'
TOKENIZE_URL = 'tokenize'
34 changes: 32 additions & 2 deletions cohere/client.py
Expand Up @@ -14,7 +14,8 @@
from cohere.error import CohereError
from cohere.generation import Generations, Generation, TokenLikelihood
from cohere.tokenize import Tokens
from cohere.classify import Classifications, Classification, Example, Confidence
from cohere.classify import Classifications, Classification, Example as ClassifyExample, Confidence
from cohere.extract import Entity, Example as ExtractExample, Extraction, Extractions

use_xhr_client = False
try:
Expand Down Expand Up @@ -158,7 +159,7 @@ def classify(
self,
model: str,
inputs: List[str],
examples: List[Example] = [],
examples: List[ClassifyExample] = [],
taskDescription: str = '',
outputIndicator: str = ''
) -> Classifications:
Expand Down Expand Up @@ -188,6 +189,35 @@ def classify(

return Classifications(classifications)

def unstable_extract(
self,
model: str,
examples: List[ExtractExample],
texts: List[str]
) -> Extractions:
'''
Makes a request to the Cohere API to extract entities from a list of texts.
Takes in a list of cohere.extract.Example objects to specify the entities to extract.
Returns an cohere.extract.Extractions object containing extractions per text.
'''

json_body = json.dumps({
'texts': texts,
'examples': [ex.toDict() for ex in examples],
})
response = self.__request(json_body, cohere.EXTRACT_URL, model)
extractions = []

for res in response['results']:
extraction = Extraction(**res)
extraction.entities = []
for entity in res['entities']:
extraction.entities.append(Entity(**entity))

extractions.append(extraction)

return Extractions(extractions)

def tokenize(self, model: str, text: str) -> Tokens:
if (use_go_tokenizer):
encoder = tokenizer.NewFromPrebuilt('coheretext-50k')
Expand Down
93 changes: 93 additions & 0 deletions cohere/extract.py
@@ -0,0 +1,93 @@
from cohere.response import CohereObject
from typing import List


class Entity:
'''
Entity represents a single extracted entity from a text. An entity has a
type and a value. For the text "I am a plumber", an extracted entity could be
of type "profession" with the value "plumber".
'''

def __init__(self, type: str, value: str) -> None:
self.type = type
self.value = value

def toDict(self) -> dict:
return {"type": self.type, "value": self.value}

def __str__(self) -> str:
return f"{self.type}: {self.value}"

def __repr__(self) -> str:
return str(self.toDict())

def __eq__(self, other) -> bool:
return self.type == other.type and self.value == other.value


class Example:
'''
Example represents a sample extraction from a text, to be provided to the model. An Example
contains the input text and a list of entities extracted from the text.
>>> example = Example("I am a plumber", [Entity("profession", "plumber")])
>>> example = Example("Joe is a teacher", [
Entity("name", "Joe"), Entity("profession", "teacher")
])
'''

def __init__(self, text: str, entities: List[Entity]) -> None:
self.text = text
self.entities = entities

def toDict(self):
return {"text": self.text, "entities": [entity.toDict() for entity in self.entities]}

def __str__(self) -> str:
return f"{self.text} -> {self.entities}"

def __repr__(self) -> str:
return str(self.toDict())


class Extraction:
'''
Represents the result of extracting entities from a single text input. An extraction
contains the text input, the list of entities extracted from the text, and the id of the
extraction.
'''

def __init__(self, id: str, text: str, entities: List[Entity]) -> None:
self.id = id
self.text = text
self.entities = entities

def __repr__(self) -> str:
return str(self.toDict())

def toDict(self) -> dict:
return {"id": self.id, "text": self.text, "entities": [entity.toDict() for entity in self.entities]}


class Extractions(CohereObject):
'''
Represents the main response of calling the Extract API. An Extractions is iterable and
contains a list of of Extraction objects, one per text input.
'''

def __init__(self, extractions: List[Extraction]) -> None:
self.extractions = extractions
self.iterator = iter(extractions)

def __iter__(self) -> iter:
return self.iterator

def __next__(self) -> next:
return next(self.iterator)

def __len__(self) -> int:
return len(self.extractions)

def __getitem__(self, index: int) -> Extraction:
return self.extractions[index]
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -23,7 +23,7 @@ def has_ext_modules(foo) -> bool:

setuptools.setup(
name='cohere',
version='1.3.8',
version='1.3.9',
author='kipply',
author_email='carol@cohere.ai',
description='A Python library for the Cohere API',
Expand Down
125 changes: 125 additions & 0 deletions tests/test_extract.py
@@ -0,0 +1,125 @@
import unittest
import cohere
from cohere.extract import Example, Entity, Extractions
from utils import get_api_key

co = cohere.Client(get_api_key())


class TestExtract(unittest.TestCase):
def test_success(self):
examples = [Example(
text="hello my name is John, and I like to play ping pong",
entities=[Entity(type="Name", value="John")])]
texts = ["hello Roberta, how are you doing today?"]

extractions = co.unstable_extract('small', examples, texts)

self.assertIsInstance(extractions, Extractions)
self.assertIsInstance(extractions[0].text, str)
self.assertIsInstance(extractions[0].entities, list)
self.assertEqual(extractions[0].entities[0].type, "Name")
self.assertEqual(extractions[0].entities[0].value, "Roberta")

def test_empty_text(self):
with self.assertRaises(cohere.CohereError):
co.unstable_extract(
'small', examples=[Example(
text="hello my name is John, and I like to play ping pong",
entities=[Entity(type="Name", value="John")])],
texts=[""])

def test_empty_entities(self):
with self.assertRaises(cohere.CohereError):
co.unstable_extract(
'large', examples=[Example(
text="hello my name is John, and I like to play ping pong",
entities=[])],
texts=["hello Roberta, how are you doing today?"])

def test_varying_amount_of_entities(self):
examples = [
Example(
text="the bananas are red",
entities=[Entity(type="fruit", value="bananas"), Entity(type="color", value="red")]),
Example(
text="i love the color blue",
entities=[Entity(type="color", value="blue")]),
Example(
text="i love apples",
entities=[Entity(type="fruit", value="apple")]),
Example(
text="purple is my favorite color",
entities=[Entity(type="color", value="purple")]),
Example(
text="wow, that apple is green?",
entities=[Entity(type="fruit", value="apple"), Entity(type="color", value="green")])]
texts = ["Jimmy ate my banana", "my favorite color is yellow", "green apple is my favorite fruit"]

extractions = co.unstable_extract('medium', examples, texts)

self.assertIsInstance(extractions, Extractions)
self.assertIsInstance(extractions[0].text, str)
self.assertIsInstance(extractions[1].text, str)
self.assertIsInstance(extractions[2].text, str)
self.assertIsInstance(extractions[0].entities, list)
self.assertIsInstance(extractions[1].entities, list)
self.assertIsInstance(extractions[2].entities, list)

self.assertEqual(len(extractions[0].entities), 1)
self.assertIn(Entity(type="fruit", value="banana"), extractions[0].entities)

self.assertEqual(len(extractions[1].entities), 1)
self.assertIn(Entity(type="color", value="yellow"), extractions[1].entities)

self.assertEqual(len(extractions[2].entities), 2)
self.assertIn(Entity(type="color", value="green"), extractions[2].entities)
self.assertIn(Entity(type="fruit", value="apple"), extractions[2].entities)

def test_many_examples_and_multiple_texts(self):
examples = [
Example(
text="hello my name is John, and I like to play ping pong",
entities=[Entity(type="Name", value="John"), Entity(type="Game", value="ping pong")]),
Example(
text="greetings, I'm Roberta and I like to play golf",
entities=[Entity(type="Name", value="Roberta"), Entity(type="Game", value="golf")]),
Example(
text="let me introduce myself, my name is Tina and I like to play baseball",
entities=[Entity(type="Name", value="Tina"), Entity(type="Game", value="baseball")])]
texts = ["hi, my name is Charlie and I like to play basketball", "hello, I'm Olivia and I like to play soccer"]

extractions = co.unstable_extract('medium', examples, texts)

self.assertEqual(len(extractions), 2)
self.assertIsInstance(extractions, Extractions)
self.assertIsInstance(extractions[0].text, str)
self.assertIsInstance(extractions[1].text, str)
self.assertIsInstance(extractions[0].entities, list)
self.assertIsInstance(extractions[1].entities, list)
self.assertEqual(len(extractions[0].entities), 2)
self.assertEqual(len(extractions[1].entities), 2)

def test_no_entities(self):
examples = [
Example(
text="hello my name is John, and I like to play ping pong",
entities=[Entity(type="Name", value="John"), Entity(type="Game", value="ping pong")]),
Example(
text="greetings, I'm Roberta and I like to play golf",
entities=[Entity(type="Name", value="Roberta"), Entity(type="Game", value="golf")]),
Example(
text="let me introduce myself, my name is Tina and I like to play baseball",
entities=[Entity(type="Name", value="Tina"), Entity(type="Game", value="baseball")])]
texts = ["hi, my name is Charlie and I like to play basketball", "hello!"]

extractions = co.unstable_extract('medium', examples, texts)

self.assertEqual(len(extractions), 2)
self.assertIsInstance(extractions, Extractions)

self.assertEqual(len(extractions[0].entities), 2)
self.assertIn(Entity(type="Name", value="Charlie"), extractions[0].entities)
self.assertIn(Entity(type="Game", value="basketball"), extractions[0].entities)

self.assertEqual(len(extractions[1].entities), 0)

0 comments on commit 3072cb7

Please sign in to comment.