Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Extract V0 API #60

Merged
merged 19 commits into from May 18, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions cohere/__init__.py
Expand Up @@ -7,6 +7,7 @@
GENERATE_URL = 'generate'
EMBED_URL = 'embed'
CLASSIFY_URL = 'classify'
EXTRACT_URL = 'extract'

CHECK_API_KEY_URL = 'check-api-key'
TOKENIZE_URL = 'tokenize'
24 changes: 24 additions & 0 deletions cohere/client.py
Expand Up @@ -15,6 +15,7 @@
from cohere.generation import Generations, Generation, TokenLikelihood
from cohere.tokenize import Tokens
from cohere.classify import Classifications, Classification, Example, Confidence
from cohere.extract import ExtractEntity, ExtractExample, Extraction

use_xhr_client = False
try:
Expand Down Expand Up @@ -188,6 +189,29 @@ def classify(

return Classifications(classifications)

def extract(
self,
model: str,
examples: List[ExtractExample],
texts: List[str]
) -> List[Extraction]:
json_body = json.dumps({
'texts': texts,
'examples': [ex.toDict() for ex in examples],
})
response = self.__request(json_body, cohere.EXTRACT_URL, model)
extractions = []

for res in response['results']:
extraction = Extraction(**res)
extraction.entities = []
for entity in res['entities']:
extraction.entities.append(ExtractEntity(**entity))

extractions.append(extraction)

return extractions

def tokenize(self, model: str, text: str) -> Tokens:
if (use_go_tokenizer):
encoder = tokenizer.NewFromPrebuilt('coheretext-50k')
Expand Down
42 changes: 42 additions & 0 deletions cohere/extract.py
@@ -0,0 +1,42 @@
from cohere.response import CohereObject
from typing import List


class ExtractEntity:
def __init__(self, type: str, value: str) -> None:
self.type = type
self.value = value

def toDict(self):
return {"type": self.type, "value": self.value}


class ExtractExample:
def __init__(self, text: str, entities: List[ExtractEntity]) -> None:
self.text = text
self.entities = entities

def toDict(self):
return {"text": self.text, "entities": [entity.toDict() for entity in self.entities]}


class Extraction:
def __init__(self, id: str, text: str, entities: List[ExtractEntity]) -> None:
self.id = id
self.text = text
self.entities = entities


class Extractions(CohereObject):
def __init__(self, extractions: List[Extraction]) -> None:
self.extractions = extractions
self.iterator = iter(extractions)

def __iter__(self) -> iter:
return self.iterator

def __next__(self) -> next:
return next(self.iterator)

def __len__(self) -> int:
return len(self.extractions)
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -23,7 +23,7 @@ def has_ext_modules(foo) -> bool:

setuptools.setup(
name='cohere',
version='1.3.7',
version='1.3.8',
author='kipply',
author_email='carol@cohere.ai',
description='A Python library for the Cohere API',
Expand Down
114 changes: 113 additions & 1 deletion tests/tests.py
Expand Up @@ -4,6 +4,7 @@
import string
import random
from cohere.classify import Example
from cohere.extract import ExtractEntity, ExtractExample

API_KEY = os.getenv('CO_API_KEY')
assert type(API_KEY)
Expand Down Expand Up @@ -192,7 +193,6 @@ def test_empty_inputs(self):
with self.assertRaises(cohere.CohereError):
co.classify(
'medium', [], [

Example('apple', 'fruit'),
Example('banana', 'fruit'),
Example('cherry', 'fruit'),
Expand Down Expand Up @@ -240,6 +240,118 @@ def test_success_all_fields(self):
self.assertEqual(prediction.classifications[1].prediction, 'color')


class TestExtract(unittest.TestCase):
def test_success(self):
examples = [ExtractExample(
text="hello my name is John, and I like to play ping pong",
entities=[ExtractEntity(type="Name", value="John")])]
texts = ["hello Roberta, how are you doing today?"]

extractions = co.extract('small', examples, texts)

self.assertIsInstance(extractions, list)
self.assertIsInstance(extractions[0].text, str)
self.assertIsInstance(extractions[0].entities, list)
self.assertEqual(extractions[0].entities[0].type, "Name")
self.assertEqual(extractions[0].entities[0].value, "Roberta")

def test_empty_text(self):
with self.assertRaises(cohere.CohereError):
co.extract(
'small', examples=[ExtractExample(
text="hello my name is John, and I like to play ping pong",
entities=[ExtractEntity(type="Name", value="John")])],
texts=[""])

def test_empty_entities(self):
with self.assertRaises(cohere.CohereError):
co.extract(
'large', examples=[ExtractExample(
text="hello my name is John, and I like to play ping pong",
entities=[])],
texts=["hello Roberta, how are you doing today?"])

def test_varying_amount_of_entities(self):
examples = [
ExtractExample(
text="the bananas are red",
entities=[ExtractEntity(type="fruit", value="bananas"), ExtractEntity(type="color", value="red")]),
ExtractExample(
text="i love the color blue",
entities=[ExtractEntity(type="color", value="blue")]),
ExtractExample(
text="i love apples",
entities=[ExtractEntity(type="fruit", value="apple")]),
ExtractExample(
text="purple is my favorite color",
entities=[ExtractEntity(type="color", value="purple")]),
ExtractExample(
text="wow, that apple is green?",
entities=[ExtractEntity(type="fruit", value="apple"), ExtractEntity(type="color", value="green")])]
texts = ["i love bananas", "my favorite color is yellow", "i love green apples"]

extractions = co.extract('medium', examples, texts)

self.assertIsInstance(extractions, list)
self.assertIsInstance(extractions[0].text, str)
self.assertIsInstance(extractions[1].text, str)
self.assertIsInstance(extractions[2].text, str)
self.assertIsInstance(extractions[0].entities, list)
self.assertIsInstance(extractions[1].entities, list)
self.assertIsInstance(extractions[2].entities, list)
self.assertEqual(len(extractions[0].entities), 1)
self.assertEqual(len(extractions[1].entities), 1)
self.assertEqual(len(extractions[2].entities), 2)

def test_many_examples_and_multiple_texts(self):
examples = [
ExtractExample(
text="hello my name is John, and I like to play ping pong",
entities=[ExtractEntity(type="Name", value="John"), ExtractEntity(type="Game", value="ping pong")]),
ExtractExample(
text="greetings, I'm Roberta and I like to play golf",
entities=[ExtractEntity(type="Name", value="Roberta"), ExtractEntity(type="Game", value="golf")]),
ExtractExample(
text="let me introduce myself, my name is Tina and I like to play baseball",
entities=[ExtractEntity(type="Name", value="Tina"), ExtractEntity(type="Game", value="baseball")])]
texts = ["hi, my name is Charlie and I like to play basketball", "hello, I'm Olivia and I like to play soccer"]

extractions = co.extract('medium', examples, texts)

self.assertEqual(len(extractions), 2)
self.assertIsInstance(extractions, list)
self.assertIsInstance(extractions[0].text, str)
self.assertIsInstance(extractions[1].text, str)
self.assertIsInstance(extractions[0].entities, list)
self.assertIsInstance(extractions[1].entities, list)
self.assertEqual(len(extractions[0].entities), 2)
self.assertEqual(len(extractions[1].entities), 2)

def test_no_entities(self):
examples = [
ExtractExample(
text="hello my name is John, and I like to play ping pong",
entities=[ExtractEntity(type="Name", value="John"), ExtractEntity(type="Game", value="ping pong")]),
ExtractExample(
text="greetings, I'm Roberta and I like to play golf",
entities=[ExtractEntity(type="Name", value="Roberta"), ExtractEntity(type="Game", value="golf")]),
ExtractExample(
text="let me introduce myself, my name is Tina and I like to play baseball",
entities=[ExtractEntity(type="Name", value="Tina"), ExtractEntity(type="Game", value="baseball")])]
texts = ["hi, my name is Charlie and I like to play basketball", "hello!"]

extractions = co.extract('medium', examples, texts)

self.assertEqual(len(extractions), 2)
self.assertIsInstance(extractions, list)
self.assertIsInstance(extractions[0].text, str)
self.assertIsInstance(extractions[1].text, str)
self.assertIsInstance(extractions[0].entities, list)
self.assertIsInstance(extractions[1].entities, list)
self.assertEqual(len(extractions[0].entities), 2)
self.assertEqual(len(extractions[1].entities), 0)


class TestTokenize(unittest.TestCase):
def test_success(self):
tokens = co.tokenize('medium', 'tokenize me!')
Expand Down