Skip to content

Commit

Permalink
Add SDK-level validation for Classify params (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
abdullahkady committed May 17, 2023
1 parent d92bd90 commit 7daa907
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 4.4.2
- [#230](https://github.com/cohere-ai/cohere-python/pull/230)
- Add SDK level validation for classify params

## 4.4.1
- [#224](https://github.com/cohere-ai/cohere-python/pull/224)
- Update co.chat parameter `chat_history`
Expand Down
11 changes: 7 additions & 4 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,13 @@ def classify(
examples (List[ClassifyExample]): A list of ClassifyExample objects containing a text and its associated label.
truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length.
"""
examples_dicts: list[dict[str, str]] = []
for example in examples:
example_dict = {"text": example.text, "label": example.label}
examples_dicts.append(example_dict)
if not preset:

This comment has been minimized.

Copy link
@peterc

peterc May 28, 2023

It also needs to recognize when a custom model is being used.

if not examples:
raise CohereError(message="examples must be a non-empty list of ClassifyExample objects.")
if not inputs:
raise CohereError(message="inputs must be a non-empty list of strings.")

examples_dicts = [{"text": example.text, "label": example.label} for example in examples]

json_body = {
"model": model,
Expand Down
6 changes: 6 additions & 0 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ async def classify(
examples: List[ClassifyExample] = [],
truncate: Optional[str] = None,
) -> Classifications:
if not preset:
if not examples:
raise CohereError(message="examples must be a non-empty list of ClassifyExample objects.")
if not inputs:
raise CohereError(message="inputs must be a non-empty list of strings.")

examples_dicts = [{"text": example.text, "label": example.label} for example in examples]

json_body = {
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cohere"
version = "4.4.1"
version = "4.4.2"
description = ""
authors = ["Cohere"]
readme = "README.md"
Expand Down
27 changes: 27 additions & 0 deletions tests/async/test_async_classify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from cohere.error import CohereError
from cohere.responses.classify import Example


Expand All @@ -25,3 +26,29 @@ async def test_async_classify(async_client):
assert prediction.meta
assert prediction.meta["api_version"]
assert prediction.meta["api_version"]["version"]


@pytest.mark.asyncio
async def test_async_classify_input_validation(async_client):
for value in [None, []]:
with pytest.raises(CohereError) as exc:
await async_client.classify(
model="small",
inputs=value,
examples=[
Example("apple", "fruit"),
Example("kiwi", "fruit"),
Example("yellow", "color"),
Example("magenta", "color"),
],
)
assert "inputs must be a non-empty list of strings." in str(exc.value)

for value in [None, []]:
with pytest.raises(CohereError) as exc:
await async_client.classify(
model="small",
inputs=["apple", "yellow"],
examples=value,
)
assert "examples must be a non-empty list of ClassifyExample objects." in str(exc.value)
25 changes: 25 additions & 0 deletions tests/sync/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from utils import get_api_key

import cohere
from cohere.error import CohereError
from cohere.responses.classify import Example

co = cohere.Client(get_api_key())
Expand Down Expand Up @@ -40,6 +41,30 @@ def test_success(self):
self.assertEqual(len(prediction.classifications), 1)
self.assertEqual(prediction.classifications[0].prediction, "color")

def test_input_validation(self):
for value in [None, []]:
with pytest.raises(CohereError) as exc:
co.classify(
model="small",
inputs=value,
examples=[
Example("apple", "fruit"),
Example("kiwi", "fruit"),
Example("yellow", "color"),
Example("magenta", "color"),
],
)
assert "inputs must be a non-empty list of strings." in str(exc.value)

for value in [None, []]:
with pytest.raises(CohereError) as exc:
co.classify(
model="small",
inputs=["apple", "yellow"],
examples=value,
)
assert "examples must be a non-empty list of ClassifyExample objects." in str(exc.value)

def test_empty_inputs(self):
with self.assertRaises(cohere.CohereError):
co.classify(
Expand Down

0 comments on commit 7daa907

Please sign in to comment.