Skip to content

Commit

Permalink
feat: Added the vision_models.Image._mime_type property to make `vi…
Browse files Browse the repository at this point in the history
…sion_models.Image` compatible with `generative_models.Image`

- This will allow `generative_models.Part.from_image` to accept `vision_models.Image` objects.

- Added `vision_models.Video._mime_type`
- Fixed linter errors.

PiperOrigin-RevId: 632153540
  • Loading branch information
holtskinner authored and copybara-github committed May 9, 2024
1 parent e0c6227 commit 6557d88
Showing 1 changed file with 55 additions and 20 deletions.
75 changes: 55 additions & 20 deletions vertexai/vision_models/_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: disable=bad-continuation, line-too-long, protected-access
"""Classes for working with vision models."""

import base64
Expand Down Expand Up @@ -99,15 +100,22 @@ def load_from_file(location: str) -> "Image":
image = Image(image_bytes=image_bytes)
return image

@property
def _blob(self) -> storage.Blob:
if self._gcs_uri is None:
raise AttributeError("_blob is only supported when gcs_uri is set.")
storage_client = storage.Client(
credentials=aiplatform_initializer.global_config.credentials
)
blob = storage.Blob.from_string(uri=self._gcs_uri, client=storage_client)
# Needed to populate `blob.content_type`
blob.reload()
return blob

@property
def _image_bytes(self) -> bytes:
if self._loaded_bytes is None:
storage_client = storage.Client(
credentials=aiplatform_initializer.global_config.credentials
)
self._loaded_bytes = storage.Blob.from_string(
uri=self._gcs_uri, client=storage_client
).download_as_bytes()
self._loaded_bytes = self._blob.download_as_bytes()
return self._loaded_bytes

@_image_bytes.setter
Expand All @@ -117,13 +125,27 @@ def _image_bytes(self, value: bytes):
@property
def _pil_image(self) -> "PIL_Image.Image":
if self._loaded_image is None:
if not PIL_Image:
raise RuntimeError(
"The PIL module is not available. Please install the Pillow package."
)
self._loaded_image = PIL_Image.open(io.BytesIO(self._image_bytes))
return self._loaded_image

@property
def _size(self):
return self._pil_image.size

@property
def _mime_type(self) -> str:
"""Returns the MIME type of the image."""
if self._gcs_uri:
return self._blob.content_type
if PIL_Image:
return PIL_Image.MIME.get(self._pil_image.format, "image/jpeg")
# Fall back to jpeg
return "image/jpeg"

def show(self):
"""Shows the image.
Expand All @@ -146,7 +168,7 @@ def _as_base64_string(self) -> str:
Returns:
Base64 encoding of the image as a string.
"""
# ! b64encode returns `bytes` object, not ``str.
# ! b64encode returns `bytes` object, not `str`.
# We need to convert `bytes` to `str`, otherwise we get service error:
# "received initial metadata size exceeds limit"
return base64.b64encode(self._image_bytes).decode("ascii")
Expand Down Expand Up @@ -196,21 +218,36 @@ def load_from_file(location: str) -> "Video":
video = Video(video_bytes=video_bytes)
return video

@property
def _blob(self) -> storage.Blob:
if self._gcs_uri is None:
raise AttributeError("_blob is only supported when gcs_uri is set.")
storage_client = storage.Client(
credentials=aiplatform_initializer.global_config.credentials
)
blob = storage.Blob.from_string(uri=self._gcs_uri, client=storage_client)
# Needed to populate `blob.content_type`
blob.reload()
return blob

@property
def _video_bytes(self) -> bytes:
if self._loaded_bytes is None:
storage_client = storage.Client(
credentials=aiplatform_initializer.global_config.credentials
)
self._loaded_bytes = storage.Blob.from_string(
uri=self._gcs_uri, client=storage_client
).download_as_bytes()
self._loaded_bytes = self._blob.download_as_bytes()
return self._loaded_bytes

@_video_bytes.setter
def _video_bytes(self, value: bytes):
self._loaded_bytes = value

@property
def _mime_type(self) -> str:
"""Returns the MIME type of the video."""
if self._gcs_uri:
return self._blob.content_type
# Fall back to mp4
return "video/mp4"

def save(self, location: str):
"""Saves video to a file.
Expand All @@ -225,7 +262,7 @@ def _as_base64_string(self) -> str:
Returns:
Base64 encoding of the video as a string.
"""
# ! b64encode returns `bytes` object, not ``str.
# ! b64encode returns `bytes` object, not `str`.
# We need to convert `bytes` to `str`, otherwise we get service error:
# "received initial metadata size exceeds limit"
return base64.b64encode(self._video_bytes).decode("ascii")
Expand Down Expand Up @@ -582,8 +619,7 @@ def generate_images(
* "16:9" : 16:9 aspect ratio
* "4:3" : 4:3 aspect ratio
* "3:4" : 3:4 aspect_ratio
guidance_scale: Controls the strength of the prompt. Suggested values
are:
guidance_scale: Controls the strength of the prompt. Suggested values are:
* 0-9 (low strength)
* 10-20 (medium strength)
* 21+ (high strength)
Expand Down Expand Up @@ -667,8 +703,7 @@ def edit_image(
* 0-9 (low strength)
* 10-20 (medium strength)
* 21+ (high strength)
edit_mode: Describes the editing mode for the request. Supported values
are:
edit_mode: Describes the editing mode for the request. Supported values are:
* inpainting-insert: fills the mask area based on the text prompt
(requires mask and text)
* inpainting-remove: removes the object(s) in the mask area.
Expand All @@ -677,7 +712,6 @@ def edit_image(
(Requires mask)
* product-image: Changes the background for the predominant product
or subject in the image
segmentation_classes: List of class IDs for segmentation. Max of 5 IDs
mask_mode: Solicits generation of the mask (v/s providing mask as an
input). Supported values are:
* background: Automatically generates a mask for all regions except
Expand All @@ -686,6 +720,7 @@ def edit_image(
subjects(s) of the image.
* semantic: Segment one or more of the segmentation classes using
class ID
segmentation_classes: List of class IDs for segmentation. Max of 5 IDs
mask_dilation: Defines the dilation percentage of the mask provided.
Float between 0 and 1. Defaults to 0.03
product_position: Defines whether the product should stay fixed or be
Expand Down Expand Up @@ -1241,7 +1276,7 @@ class WatermarkVerificationResponse:


class WatermarkVerificationModel(_model_garden_models._ModelGardenModel):
"""Verifies if an image has a watermark"""
"""Verifies if an image has a watermark."""

__module__ = "vertexai.preview.vision_models"

Expand Down

0 comments on commit 6557d88

Please sign in to comment.