Skip to content

Commit

Permalink
Alphabetically sort params
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed May 8, 2023
1 parent 2c842eb commit 2c2d138
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions haystack/nodes/prompt/invocation_layer/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ def __init__(self, api_key: str, model_name_or_path: str, max_length: Optional[i
self.model_input_kwargs = {
key: kwargs[key]
for key in [
"model",
"p",
"end_sequences",
"frequency_penalty",
"k",
"truncate",
"logit_bias",
"max_tokens",
"model",
"num_generations",
"temperature",
"end_sequences",
"p",
"presence_penalty",
"logit_bias",
"frequency_penalty",
"return_likelihoods",
"stream",
"stream_handler",
"temperature",
"truncate",
]
if key in kwargs
}
Expand Down Expand Up @@ -103,19 +103,19 @@ def invoke(self, *args, **kwargs):
kwargs_with_defaults.update(kwargs)
# see https://docs.cohere.com/reference/generate
params = {
"model": kwargs_with_defaults.get("model", self.model_name_or_path),
"prompt": prompt,
"p": kwargs_with_defaults.get("p", None),
"end_sequences": kwargs_with_defaults.get("end_sequences", stop_words),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", None),
"k": kwargs_with_defaults.get("k", None),
"truncate": kwargs_with_defaults.get("truncate", None),
"max_tokens": kwargs_with_defaults.get("max_tokens", self.max_length),
"model": kwargs_with_defaults.get("model", self.model_name_or_path),
"num_generations": kwargs_with_defaults.get("num_generations", None),
"end_sequences": kwargs_with_defaults.get("end_sequences", stop_words),
"temperature": kwargs_with_defaults.get("temperature", None),
"frequency_penalty": kwargs_with_defaults.get("frequency_penalty", None),
"p": kwargs_with_defaults.get("p", None),
"presence_penalty": kwargs_with_defaults.get("presence_penalty", None),
"prompt": prompt,
"return_likelihoods": kwargs_with_defaults.get("return_likelihoods", None),
"stream": stream,
"temperature": kwargs_with_defaults.get("temperature", None),
"truncate": kwargs_with_defaults.get("truncate", None),
}
response = self._post(params, stream=stream)
if not stream:
Expand Down

0 comments on commit 2c2d138

Please sign in to comment.