Skip to content

Commit

Permalink
Fix pytype.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkDaoust committed Dec 12, 2023
1 parent 517a248 commit bcc6a11
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 41 deletions.
12 changes: 6 additions & 6 deletions google/generativeai/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]:
def embed_content(
model: model_types.BaseModelNameOptions,
content: content_types.ContentType,
task_type: EmbeddingTaskTypeOptions,
title: str | None,
client: glm.GenerativeServiceClient = None,
task_type: EmbeddingTaskTypeOptions | None = None,
title: str | None = None,
client: glm.GenerativeServiceClient | None = None,
) -> text_types.EmbeddingDict:
...

Expand All @@ -101,9 +101,9 @@ def embed_content(
def embed_content(
model: model_types.BaseModelNameOptions,
content: Iterable[content_types.ContentType],
task_type: EmbeddingTaskTypeOptions,
title: str | None,
client: glm.GenerativeServiceClient = None,
task_type: EmbeddingTaskTypeOptions | None = None,
title: str | None = None,
client: glm.GenerativeServiceClient | None = None,
) -> text_types.BatchEmbeddingDict:
...

Expand Down
2 changes: 1 addition & 1 deletion google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def rewind(self) -> tuple[glm.Content, glm.Content]:
return result

@property
def last(self) -> generation_types.GenerateContentResponse | None:
def last(self) -> generation_types.BaseGenerateContentResponse | None:
"""returns the last received `genai.GenerateContentResponse`"""
return self._last_received

Expand Down
49 changes: 31 additions & 18 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,33 @@
import io
import mimetypes
import pathlib
from typing import TypedDict, Union
import typing
from typing import Any, TypedDict, Union

from google.ai import generativelanguage as glm

try:
if typing.TYPE_CHECKING:
import PIL.Image
except ImportError:
PIL = None

try:
import IPython.display
except ImportError:
IPython = None

IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
else:
try:
import PIL.Image
except ImportError:
PIL = None

try:
import IPython.display
except ImportError:
IPython = None

IMAGE_TYPES = ()
if PIL is not None:
IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)

if IPython is not None:
IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)


__all__ = [
Expand All @@ -30,13 +44,7 @@
"ContentsType",
]

# TODO(markdaoust): merge into blob types to avoind the empty union.
IMAGE_TYPES = ()
if PIL is not None:
IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)

if IPython is not None:
IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)


def pil_to_png_bytes(img):
Expand All @@ -46,7 +54,7 @@ def pil_to_png_bytes(img):
return bytesio.read()


def image_to_blob(image: ImageType) -> glm.Blob:
def image_to_blob(image) -> glm.Blob:
if PIL is not None:
if isinstance(image, PIL.Image.Image):
return glm.Blob(mime_type="image/png", data=pil_to_png_bytes(image))
Expand Down Expand Up @@ -102,13 +110,18 @@ def _convert_dict(d: Mapping) -> glm.Content | glm.Part | glm.Blob:
)


BlobType = Union[(glm.Blob, BlobDict) + IMAGE_TYPES]


def is_blob_dict(d):
return "mime_type" in d and "data" in d


if typing.TYPE_CHECKING:
BlobType = Union[
glm.Blob, BlobDict, PIL.Image.Image, IPython.display.Image
] # Any for the images
else:
BlobType = Union[glm.Blob, BlobDict, Any]


def to_blob(blob: BlobType) -> glm.Blob:
if isinstance(blob, Mapping):
blob = _convert_dict(blob)
Expand Down
18 changes: 11 additions & 7 deletions google/generativeai/types/generation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import contextlib
import sys
from collections.abc import Iterable, AsyncIterable
import dataclasses
import itertools
Expand Down Expand Up @@ -66,7 +67,7 @@ def to_generation_config_dict(generation_config: GenerationConfigType):
if generation_config is None:
return {}
elif isinstance(generation_config, glm.GenerationConfig):
return type(generation_config).to_dict(generation_config)
return type(generation_config).to_dict(generation_config) # pytype: disable=attribute-error
elif isinstance(generation_config, GenerationConfig):
generation_config = dataclasses.asdict(generation_config)
return {key: value for key, value in generation_config.items() if value is not None}
Expand Down Expand Up @@ -395,9 +396,12 @@ class AsyncGenerateContentResponse(BaseGenerateContentResponse):

@classmethod
async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]):
iterator = aiter(iterator)
if (sys.version_info.major, sys.version_info.minor) < (3, 10):
raise RuntimeError("__aiter__ requires python 3.10+")

iterator = aiter(iterator) # type: ignore
with rewrite_stream_error():
response = await anext(iterator)
response = await anext(iterator) # type: ignore

return cls(
done=False,
Expand All @@ -424,7 +428,7 @@ async def __aiter__(self):

# Always have the next chunk available.
if len(self._chunks) == 0:
self._chunks.append(await anext(self._iterator))
self._chunks.append(await anext(self._iterator)) # type: ignore

for n in itertools.count():
if self._error:
Expand All @@ -437,7 +441,7 @@ async def __aiter__(self):
return

try:
item = await anext(self._iterator)
item = await anext(self._iterator) # type: ignore
except StopAsyncIteration:
self._done = True
except Exception as e:
Expand All @@ -452,9 +456,9 @@ async def __aiter__(self):
item = GenerateContentResponse.from_response(item)
yield item

def resolve(self):
async def resolve(self):
if self._done:
return

for _ in self:
async for _ in self:
pass
3 changes: 2 additions & 1 deletion google/generativeai/types/safety_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,12 @@ class LooseSafetySettingDict(TypedDict):


EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions]
EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions]

SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None]


def to_easy_safety_dict(settings: SafetySettingOptions, harm_category_set) -> EasySafetySetting:
def to_easy_safety_dict(settings: SafetySettingOptions, harm_category_set) -> EasySafetySettingDict:
if settings is None:
return {}
elif isinstance(settings, Mapping):
Expand Down
4 changes: 0 additions & 4 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def add_client_method(f):
def generate_content(
request: glm.GenerateContentRequest,
) -> glm.GenerateContentResponse:
if request is None:
request = glm.GetModelRequest(name=name)
self.assertIsInstance(request, glm.GenerateContentRequest)
self.observed_requests.append(request)
response = self.responses["generate_content"].pop(0)
Expand All @@ -54,8 +52,6 @@ def generate_content(
def stream_generate_content(
request: glm.GetModelRequest,
) -> Iterable[glm.GenerateContentResponse]:
if request is None:
request = glm.GetModelRequest(name=name)
self.observed_requests.append(request)
response = self.responses["stream_generate_content"].pop(0)
return response
Expand Down
4 changes: 0 additions & 4 deletions tests/test_generative_models_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def add_client_method(f):
async def generate_content(
request: glm.GenerateContentRequest,
) -> glm.GenerateContentResponse:
if request is None:
request = glm.GetModelRequest(name=name)
self.assertIsInstance(request, glm.GenerateContentRequest)
self.observed_requests.append(request)
response = self.responses["generate_content"].pop(0)
Expand All @@ -60,8 +58,6 @@ async def generate_content(
async def stream_generate_content(
request: glm.GetModelRequest,
) -> Iterable[glm.GenerateContentResponse]:
if request is None:
request = glm.GetModelRequest(name=name)
self.observed_requests.append(request)
response = self.responses["stream_generate_content"].pop(0)
return response
Expand Down

0 comments on commit bcc6a11

Please sign in to comment.