-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add Cohere PromptNode invocation layer #4827
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking already pretty good, I just found some minor things and have a few questions.
haystack/errors.py
Outdated
super().__init__(message=message, status_code=429, send_message_in_event=send_message_in_event) | ||
|
||
|
||
class CohereInferenceUnauthorizedError(HuggingFaceInferenceError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, CohereInferenceError
should probably be the base class.
def test_supports(): | ||
""" | ||
Test that supports returns True correctly for CohereInvocationLayer | ||
""" | ||
# doesn't support fake model | ||
assert not CohereInvocationLayer.supports("fake_model", api_key="fake_key") | ||
|
||
# supports cohere command with api_key | ||
assert CohereInvocationLayer.supports("command", api_key="fake_key") | ||
|
||
# supports cohere command-light with api_key | ||
assert CohereInvocationLayer.supports("command-light", api_key="fake_key") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
api_key
is not used in supports
method, why do we specify it here?
@bogdankostic I added 86d432a that we can then reuse for other layers, most notably for HF based layers, but some others potentially as well. In my experiments, I found gpt2 to be almost the same regarding token breakdown count. I used https://docs.cohere.com/reference/tokenize to count tokens in about five sample texts. The difference was within 1% to gpt2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost good to go, just added two questions that are not clear to me + not sure about the added tests for PromptHandler
.
Ensures CohereInvocationLayer is selected only when Cohere models are specified in | ||
the model name. | ||
""" | ||
is_inference_api = "api_key" in kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we check if the user has provided their api key? We don't do this in OpenAIInvocationLayer
's `support method either, for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is an extra method of increasing the positive signal that we are using Cohere indeed. We might have some cases where the exact same model is in HuggingFace but HuggingFace doesn't have api_key passed. We can go without it here as well but it is just another check to make sure that we are indeed using cohere method. That's all
haystack/errors.py
Outdated
@@ -203,3 +203,20 @@ class HuggingFaceInferenceUnauthorizedError(HuggingFaceInferenceError): | |||
|
|||
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False): | |||
super().__init__(message=message, status_code=401, send_message_in_event=send_message_in_event) | |||
|
|||
|
|||
class CohereInferenceError(NodeError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between CohereInferenceError
and CohereError
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None, good catch, let me update the PR.
test/prompt/test_handlers.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the tests added to this file are unit tests, as they use transformers
as external resource.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🚢
Related Issues
Proposed Changes:
Add Cohere command and command-light model support. Both models support token streaming
How did you test it?
Manually, unit tests
Notes for the reviewer
Checklist
fix:
,feat:
,build:
,chore:
,ci:
,docs:
,style:
,refactor:
,perf:
,test:
.