Skip to content

Commit

Permalink
Fix Together model validation error (princeton-nlp#236)
Browse files Browse the repository at this point in the history
* test: add unit test for Together model

* fix: deal with the new Together API

* chore: specify together version

* refactor: clean code

* change together model versioning from ">=~" to ">=" and write comment

* raise exception when together SDK version is below 1.1.0

* refactor: update unit test format

* speficy max_tokens
  • Loading branch information
mikanfactory committed Apr 25, 2024
1 parent 762a963 commit 6d8eee9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ dependencies:
- tenacity
- unidiff
- simple-parsing
- together
- together>=1.1.0 # Versions of together below 1.1.0 are not compatible. see https://github.com/princeton-nlp/SWE-agent/issues/135
- ollama
- rich-argparse
13 changes: 8 additions & 5 deletions sweagent/agent/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ class TogetherModel(BaseModel):

def __init__(self, args: ModelArguments, commands: list[Command]):
super().__init__(args, commands)
assert together.version >= '1.1.0', "Please upgrade to Together SDK v1.1.0 or later."

# Set Together key
cfg = config.Config(os.path.join(os.getcwd(), "keys.cfg"))
Expand Down Expand Up @@ -676,18 +677,20 @@ def query(self, history: list[dict[str, str]]) -> str:
"""
# Perform Together API call
prompt = self.history_to_messages(history)
# Anthropic's count_tokens is convenient because it caches and utilizes huggingface/tokenizers, so we will use.
max_tokens_to_sample = self.model_metadata["max_context"] - Anthropic().count_tokens(prompt)
completion = together.Complete.create(
model=self.api_model,
prompt=prompt,
max_tokens=self.model_metadata["max_context"],
stop="<human>",
max_tokens=max_tokens_to_sample,
stop=["<human>"],
temperature=self.args.temperature,
top_p=self.args.top_p,
)
# Calculate + update costs, return response
response = completion["output"]["choices"][0]["text"].split("<human>")[0]
input_tokens = completion["output"]["usage"]["prompt_tokens"]
output_tokens = completion["output"]["usage"]["completion_tokens"]
response = completion["choices"][0]["text"].split("<human>")[0]
input_tokens = completion["usage"]["prompt_tokens"]
output_tokens = completion["usage"]["completion_tokens"]
self.update_stats(input_tokens, output_tokens)
return response

Expand Down
21 changes: 20 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest.mock import MagicMock, Mock, patch
from sweagent.agent.models import OpenAIModel, ModelArguments
from sweagent.agent.models import OpenAIModel, ModelArguments, TogetherModel
import pytest


Expand All @@ -16,6 +16,13 @@ def openai_mock_client():

return model

@pytest.fixture
def mock_together_response():
return {
"choices": [{"text": "<human>Hello</human>"}],
"usage": {"prompt_tokens": 10, "completion_tokens": 10},
}


TEST_HISTORY = [
{
Expand All @@ -32,3 +39,15 @@ def test_openai_model(openai_mock_client):
model = OpenAIModel(TEST_MODEL_ARGUMENTS, [])
model.client = openai_mock_client
model.query(TEST_HISTORY)


@pytest.mark.parametrize("model_name", list(TogetherModel.MODELS) + list(TogetherModel.SHORTCUTS))
def test_together_model(mock_together_response, model_name):
with patch("sweagent.agent.models.config.Config"), \
patch("sweagent.agent.models.together") as mock_together:
mock_together.version = '1.1.0'
mock_together.Complete.create.return_value = mock_together_response

model_args = ModelArguments(model_name)
model = TogetherModel(model_args, [])
model.query(TEST_HISTORY)

0 comments on commit 6d8eee9

Please sign in to comment.