Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDaoust committed Oct 9, 2023
1 parent ff232cf commit eb42b85
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
21 changes: 12 additions & 9 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import os
import dataclasses
import types
from typing import Any, cast, Sequence
from typing import Any, cast
from collections.abc import Sequence

import google.ai.generativelanguage as glm

Expand Down Expand Up @@ -56,7 +57,7 @@ def configure(
client_options: client_options_lib.ClientOptions | dict | None = None,
client_info: gapic_v1.client_info.ClientInfo | None = None,
default_metadata: Sequence[tuple[str, str]] = (),
):
) -> None:
"""Captures default client configuration.
If no API key has been provided (either directly, or on `client_options`) and the
Expand All @@ -70,6 +71,7 @@ def configure(
If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
used.
default_metadata: Default (key, value) metadata pairs to send with every request.
when using `transport="rest"` these are sent as HTTP headers.
"""
if isinstance(client_options, dict):
client_options = client_options_lib.from_dict(client_options)
Expand Down Expand Up @@ -128,14 +130,14 @@ def make_client(self, cls):
def keep(name, f):
if name.startswith("_"):
return False
if not isinstance(f, types.FunctionType):
elif not isinstance(f, types.FunctionType):
return False
if isinstance(f, classmethod):
elif isinstance(f, classmethod):
return False
if isinstance(f, staticmethod):
False

return True
elif isinstance(f, staticmethod):
return False
else:
return True

def add_default_metadata_wrapper(f):
def call(*args, metadata=(), **kwargs):
Expand Down Expand Up @@ -211,7 +213,8 @@ def configure(
a separate client). This is a shortcut for `client_options={"api_key": api_key}`.
If omitted, and the `GOOGLE_API_KEY` environment variable is set, it will be
used.
default_metadata: Default `(key, value)` metadata pairs to send with every request.
default_metadata: Default (key, value) metadata pairs to send with every request.
when using `transport="rest"` these are sent as HTTP headers.
"""
return _client_manager.configure(
api_key=api_key,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,20 @@ def __init__(self, *args, **kwargs):
def generate_text(self, metadata=None):
self.metadata = metadata

not_a_function = 7
def _hidden(self):
self.called_hidden = True
@staticmethod
def static():
pass

@classmethod
def classm(cls):
cls.called_classm=True

@mock.patch.object(glm, "TextServiceClient", DummyClient)
def test_default_metadata(self):
# The metadata wrapper injects this argument.
metadata = [("hello", "world")]
client.configure(default_metadata=metadata)

Expand All @@ -102,6 +114,18 @@ def test_default_metadata(self):

self.assertEqual(metadata, text_client.metadata)

self.assertEqual(text_client.not_a_function, ClientTests.DummyClient.not_a_function)

# Since these don't have a metadata arg, they'll fail if the wrapper is applied.
text_client._hidden()
self.assertTrue(text_client.called_hidden)

text_client.static()

text_client.classm()
self.assertTrue(ClientTests.DummyClient.called_classm)



if __name__ == "__main__":
absltest.main()

0 comments on commit eb42b85

Please sign in to comment.