Skip to content

Commit

Permalink
Migrate to latest gemini api syntax.
Browse files Browse the repository at this point in the history
  • Loading branch information
eli64s committed Mar 1, 2024
1 parent 9ed33f9 commit 50a828f
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 25 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "readmeai"
version = "0.5.071"
version = "0.5.072"
description = "👾 Automated README file generator, powered by LLM APIs."
authors = ["Eli <egsalamie@gmail.com>"]
license = "MIT"
Expand Down Expand Up @@ -44,7 +44,7 @@ aiohttp = "^3.9.3"
click = "^8.1.7"
colorlog = "^6.7.0"
gitpython = "^3.1.31"
google-cloud-aiplatform = "1.39.0"
google-cloud-aiplatform = "^1.39.0"
openai = "*"
pydantic = ">=1.10.9,<2.0.0"
tenacity = "^8.2.2"
Expand Down
2 changes: 1 addition & 1 deletion readmeai/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def prompt_for_image(
- OFFLINE # Offline mode - no LLM service used \n
- OLLAMA # Ollama - llama2 \n
- OPENAI # OpenAI - gpt-3.5-turbo \n
- VERTEX # Google Cloud Vertex AI - gemini-pro) \n
- VERTEX # Google Cloud Vertex AI - gemini-1.0-pro) \n
""",
)

Expand Down
4 changes: 2 additions & 2 deletions readmeai/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_environment(llm_api: str, llm_model: str) -> tuple:
_logger.info(_log_message.format("Goolge Cloud Vertex AI"))
return (
llms.VERTEX.name,
llm_model if llm_model is not None else "gemini-pro",
llm_model if llm_model is not None else "gemini-1.0-pro",
)

else:
Expand Down Expand Up @@ -97,7 +97,7 @@ def get_environment(llm_api: str, llm_model: str) -> tuple:
_logger.info(_log_message.format("Vertex AI"))
return (
llms.VERTEX.name,
llm_model if llm_model is not None else "gemini-pro",
llm_model if llm_model is not None else "gemini-1.0-pro",
)

else:
Expand Down
6 changes: 3 additions & 3 deletions readmeai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ async def _make_request(
) -> Tuple[str, str]:
"""Processes OpenAI API LLM responses and returns generated text."""
try:
prompt = await token_handler(self.config, index, prompt, tokens)
parameters = await self._build_payload(prompt, tokens)

data = await self._build_payload(prompt, tokens)
prompt = await token_handler(self.config, index, prompt, tokens)

async with self._session.post(
self.endpoint,
headers=self.headers,
json=data,
json=parameters,
) as response:
response.raise_for_status()
response = await response.json()
Expand Down
17 changes: 9 additions & 8 deletions readmeai/models/vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def _model_settings(self):
"""Initializes the Vertex AI LLM settings."""
self.location = os.environ.get("VERTEXAI_LOCATION")
self.project_id = os.environ.get("VERTEXAI_PROJECT")
self.model_id = self.config.llm.model
self.temperature = self.config.llm.temperature
self.tokens = self.config.llm.tokens
self.top_p = self.config.llm.top_p
vertexai.init(location=self.location, project=self.project_id)
self.model = GenerativeModel(self.config.llm.model)
vertexai.init(project=self.project_id, location=self.location)
self.model = GenerativeModel(self.model_id)

async def _build_payload(self, prompt: str, tokens: int) -> dict:
"""Build payload for POST request to Vertex AI API."""
Expand Down Expand Up @@ -67,18 +68,18 @@ async def _make_request(
) -> Tuple[str, str]:
"""Processes Vertex AI LLM API responses and returns generated text."""
try:
prompt = await token_handler(self.config, index, prompt, tokens)
parameters = await self._build_payload(prompt, tokens)

data = await self._build_payload(prompt, tokens)
prompt = await token_handler(self.config, index, prompt, tokens)

async with self.rate_limit_semaphore:
response = await self.model.generate_content_async(
prompt,
generation_config=data,
generation_config=parameters,
)
content = response.candidates[0].content.parts[0].text
self._logger.info(f"Response for '{index}':\n{content}")
return index, clean_response(index, content)
response_text = response.text
self._logger.info(f"Response for '{index}':\n{response_text}")
return index, clean_response(index, response_text)

except (
aiohttp.ClientError,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def vertex_handler(mock_configs):
},
):
mock_configs.config.llm.api = "VERTEX"
mock_configs.config.llm.model = "gemini-pro"
mock_configs.config.llm.model = "gemini-1.0-pro"
yield VertexAIHandler(mock_configs)


Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def test_get_environment_openai(mock_configs):
def test_get_environment_vertex(mock_configs):
"""Test that the environment is setup correctly for Vertex AI."""
mock_configs.config.llm.api = ModelOptions.VERTEX.name
mock_configs.config.llm.model = "gemini-pro"
mock_configs.config.llm.model = "gemini-1.0-pro"
test_api, test_model = get_environment(
mock_configs.config.llm.api, mock_configs.config.llm.model
)
assert test_api == ModelOptions.VERTEX.name
assert test_model == "gemini-pro"
assert test_model == "gemini-1.0-pro"


@patch.dict("os.environ", {}, clear=True)
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_no_api_specified_but_vertex_settings_exist_in_env(mock_configs):
mock_configs.config.llm.api, mock_configs.config.llm.model
)
assert test_api == ModelOptions.VERTEX.name
assert test_model == "gemini-pro"
assert test_model == "gemini-1.0-pro"


@patch.dict("os.environ", {}, clear=True)
Expand Down
10 changes: 5 additions & 5 deletions tests/models/test_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
@pytest.mark.asyncio
async def test_vertex_handler_sets_attributes(vertex_handler):
"""Test that the Vertex AI handler sets the correct attributes."""
# Arrange
assert hasattr(vertex_handler, "temperature")
assert hasattr(vertex_handler, "location")
assert hasattr(vertex_handler, "project_id")
assert hasattr(vertex_handler, "model")
assert hasattr(vertex_handler, "temperature")
assert hasattr(vertex_handler, "top_p")


@pytest.mark.asyncio
async def test_vertex_make_request_with_context(
vertex_handler, mock_config, mock_configs
):
async def test_vertex_make_request_with_context(vertex_handler):
"""Test that the Vertex AI handler handles a response with context."""
# Arrange
handler = vertex_handler
Expand Down

0 comments on commit 50a828f

Please sign in to comment.