Skip to content

Commit

Permalink
Add and fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDaoust committed Oct 5, 2023
1 parent 3077a14 commit f2e3a22
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 14 deletions.
8 changes: 4 additions & 4 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@dataclasses.dataclass
class _ClientManager:
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
metadata: Sequence[tuple[str, str]] = ()
default_metadata: Sequence[tuple[str, str]] = ()
discuss_client: glm.DiscussServiceClient | None = None
discuss_async_client: glm.DiscussServiceAsyncClient | None = None
model_client: glm.ModelServiceClient | None = None
Expand Down Expand Up @@ -117,7 +117,7 @@ def configure(

def make_client(self, cls):
# Attempt to configure using defaults.
if self.client_config is None:
if not self.client_config:
configure()

client = cls(**self.client_config)
Expand Down Expand Up @@ -176,7 +176,7 @@ def get_default_model_client(self) -> glm.ModelServiceClient:
def get_default_operations_client(self) -> operations_v1.OperationsClient:
if self.operations_client is None:
self.model_client = get_default_model_client()
self.operations_client = model_client._transport.operations_client
self.operations_client = self.model_client._transport.operations_client

return self.operations_client

Expand Down Expand Up @@ -228,7 +228,7 @@ def get_default_discuss_client() -> glm.DiscussServiceClient:


def get_default_text_client() -> glm.TextServiceClient:
return _client_manager.get_default_discuss_client()
return _client_manager.get_default_text_client()


def get_default_operations_client() -> operations_v1.OperationsClient:
Expand Down
32 changes: 25 additions & 7 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,38 @@
from absl.testing import parameterized

from google.api_core import client_options
import google.ai.generativelanguage as glm
from google.generativeai import client


class ClientTests(parameterized.TestCase):
def setUp(self):
super().setUp()
client.default_client_config = {}
client._client_manager = client._ClientManager()

def test_api_key_passed_directly(self):
client.configure(api_key="AIzA_direct")

client_opts = client.default_client_config["client_options"]
client_opts = client._client_manager.client_config["client_options"]
self.assertEqual(client_opts.api_key, "AIzA_direct")

def test_api_key_passed_via_client_options(self):
client_opts = client_options.ClientOptions(api_key="AIzA_client_opts")
client.configure(client_options=client_opts)

client_opts = client.default_client_config["client_options"]
client_opts = client._client_manager.client_config["client_options"]
self.assertEqual(client_opts.api_key, "AIzA_client_opts")

@mock.patch.dict(os.environ, {"GOOGLE_API_KEY": "AIzA_env"})
def test_api_key_from_environment(self):
# Default to API key loaded from environment.
client.configure()
client_opts = client.default_client_config["client_options"]
client_opts = client._client_manager.client_config["client_options"]
self.assertEqual(client_opts.api_key, "AIzA_env")

# But not when a key is provided explicitly.
client.configure(api_key="AIzA_client")
client_opts = client.default_client_config["client_options"]
client_opts = client._client_manager.client_config["client_options"]
self.assertEqual(client_opts.api_key, "AIzA_client")

def test_api_key_cannot_be_set_twice(self):
Expand All @@ -65,7 +66,7 @@ def test_api_key_and_client_options(self):
client_opts = client_options.ClientOptions(api_endpoint="web.site")
client.configure(api_key="AIzA_client", client_options=client_opts)

actual_client_opts = client.default_client_config["client_options"]
actual_client_opts = client._client_manager.client_config["client_options"]
self.assertEqual(actual_client_opts.api_key, "AIzA_client")
self.assertEqual(actual_client_opts.api_endpoint, "web.site")

Expand All @@ -74,15 +75,32 @@ def test_api_key_and_client_options(self):
client.get_default_text_client,
client.get_default_discuss_async_client,
client.get_default_model_client,
client.get_default_operations_client,
)
@mock.patch.dict(os.environ, {"GOOGLE_API_KEY": "AIzA_env"})
def test_configureless_client_with_key(self, factory_fn):
_ = factory_fn()

# And ensure that it has set the default options.
actual_client_opts = client.default_client_config["client_options"]
actual_client_opts = client._client_manager.client_config["client_options"]
self.assertEqual(actual_client_opts.api_key, "AIzA_env")

class DummyClient:
def __init__(self, *args, **kwargs):
pass
def generate_text(self, metadata=None):
self.metadata = metadata
@mock.patch.object(glm, 'TextServiceClient', DummyClient)
def test_default_metadata(self):
breakpoint()

metadata = [('hello', 'world')]
client.configure(default_metadata = metadata)

text_client = client.get_default_text_client()
text_client.generate_text()

self.assertEqual(metadata, text_client.metadata)

if __name__ == "__main__":
absltest.main()
2 changes: 1 addition & 1 deletion tests/test_discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class UnitTests(parameterized.TestCase):
def setUp(self):
self.client = unittest.mock.MagicMock()

client.default_discuss_client = self.client
client._client_manager.discuss_client = self.client

self.observed_request = None

Expand Down
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class UnitTests(parameterized.TestCase):
def setUp(self):
self.client = unittest.mock.MagicMock()

client.default_model_client = self.client
client._client_manager.model_client = self.client

def add_client_method(f):
name = f.__name__
Expand Down
2 changes: 1 addition & 1 deletion tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class UnitTests(parameterized.TestCase):
def setUp(self):
self.client = unittest.mock.MagicMock()

client.default_text_client = self.client
client._client_manager.text_client = self.client

self.observed_request = None

Expand Down

0 comments on commit f2e3a22

Please sign in to comment.