diff --git a/camel/embeddings/__init__.py b/camel/embeddings/__init__.py index 0e4d5b136..f4b947dd3 100644 --- a/camel/embeddings/__init__.py +++ b/camel/embeddings/__init__.py @@ -14,9 +14,11 @@ from .base import BaseEmbedding from .openai_embedding import OpenAIEmbedding from .sentence_transformers_embeddings import SentenceTransformerEncoder +from .vlm_embedding import VisionLanguageEmbedding __all__ = [ "BaseEmbedding", "OpenAIEmbedding", "SentenceTransformerEncoder", + "VisionLanguageEmbedding", ] diff --git a/camel/embeddings/vlm_embedding.py b/camel/embeddings/vlm_embedding.py new file mode 100644 index 000000000..4bc391ffa --- /dev/null +++ b/camel/embeddings/vlm_embedding.py @@ -0,0 +1,143 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +from typing import Any, List, Optional, Union + +from PIL import Image + +from camel.embeddings import BaseEmbedding + + +class VisionLanguageEmbedding(BaseEmbedding[Union[str, Image.Image]]): + r"""Provides image embedding functionalities using multimodal model. + + Args: + model_name : The model type to be used for generating embeddings. + And the default value is: obj:`openai/clip-vit-base-patch32`. + + Raises: + RuntimeError: If an unsupported model type is specified. + """ + + def __init__( + self, model_name: str = "openai/clip-vit-base-patch32" + ) -> None: + r"""Initializes the: obj: `VisionLanguageEmbedding` class with a + specified model and return the dimension of embeddings. + + Args: + model_name (str, optional): The version name of the model to use. + (default: :obj:`openai/clip-vit-base-patch32`) + """ + from transformers import AutoModel, AutoProcessor + + try: + self.model = AutoModel.from_pretrained(model_name) + self.processor = AutoProcessor.from_pretrained(model_name) + except Exception as e: + raise RuntimeError(f"Failed to load model '{model_name}': {e}") + + self.valid_processor_kwargs = [] + self.valid_model_kwargs = [] + + try: + self.valid_processor_kwargs = ( + self.processor.image_processor._valid_processor_keys + ) + self.valid_model_kwargs = [ + "pixel_values", + "return_dict", + "interpolate_pos_encoding", + ] + except Exception: + print("Warning: not typically processor and model structure") + pass + self.dim: Optional[int] = None + + def embed_list( + self, objs: List[Union[Image.Image, str]], **kwargs: Any + ) -> List[List[float]]: + """Generates embeddings for the given images or texts. + + Args: + objs (List[Image.Image|str]): The list of images or texts for + which to generate the embeddings. + image_processor_kwargs: Extra kwargs passed to the image processor. + tokenizer_kwargs: Extra kwargs passed to the text tokenizer(processor). + model_kwargs: Extra kwargs passed to the main model. + + Returns: + List[List[float]]: A list that represents the generated embedding + as a list of floating-point numbers. + + Raises: + ValueError: If the input type is not `Image.Image` or `str`. + """ + if not objs: + raise ValueError("Input objs list is empty.") + + image_processor_kwargs: Optional[dict] = kwargs.get( + 'image_processor_kwargs', {} + ) + tokenizer_kwargs: Optional[dict] = kwargs.get('tokenizer_kwargs', {}) + model_kwargs: Optional[dict] = kwargs.get('model_kwargs', {}) + + result_list = [] + for obj in objs: + if isinstance(obj, Image.Image): + image_input = self.processor( + images=obj, + return_tensors="pt", + padding=True, + **image_processor_kwargs, + ) + image_feature = ( + self.model.get_image_features(**image_input, **model_kwargs) + .squeeze(dim=0) + .tolist() + ) + result_list.append(image_feature) + elif isinstance(obj, str): + text_input = self.processor( + text=obj, + return_tensors="pt", + padding=True, + **tokenizer_kwargs, + ) + text_feature = ( + self.model.get_text_features(**text_input, **model_kwargs) + .squeeze(dim=0) + .tolist() + ) + result_list.append(text_feature) + else: + raise ValueError("Input type is not image nor text.") + + self.dim = len(result_list[0]) + + if any(len(result) != self.dim for result in result_list): + raise ValueError("Dimensionality is not consistent.") + + return result_list + + def get_output_dim(self) -> int: + r"""Returns the output dimension of the embeddings. + + Returns: + int: The dimensionality of the embedding for the current model. + """ + if self.dim is None: + text = 'dimension' + inputs = self.processor(text=[text], return_tensors="pt") + self.dim = self.model.get_text_features(**inputs).shape[1] + return self.dim diff --git a/licenses/update_license.py b/licenses/update_license.py index d937bc699..0afbf3b9a 100644 --- a/licenses/update_license.py +++ b/licenses/update_license.py @@ -39,10 +39,12 @@ def update_license_in_file( start_line_start_with: str, end_line_start_with: str, ) -> bool: - with open(file_path, 'r') as f: + with open( + file_path, 'r', encoding='utf-8' + ) as f: # for windows compatibility content = f.read() - with open(license_template_path, 'r') as f: + with open(license_template_path, 'r', encoding='utf-8') as f: new_license = f.read().strip() maybe_existing_licenses = re.findall( diff --git a/poetry.lock b/poetry.lock index e0b3ebd7f..0c4304f0e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "accelerate" @@ -2354,9 +2354,13 @@ files = [ {file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"}, {file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"}, + {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"}, + {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"}, {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"}, + {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"}, + {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"}, {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"}, {file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"}, {file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"}, @@ -7076,16 +7080,16 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["PyMuPDF", "accelerate", "beautifulsoup4", "cohere", "datasets", "diffusers", "docx2txt", "googlemaps", "neo4j", "openapi-spec-validator", "opencv-python", "prance", "pydub", "pygithub", "pymilvus", "pyowm", "qdrant-client", "rank-bm25", "requests_oauthlib", "sentence-transformers", "sentencepiece", "slack-sdk", "soundfile", "torch", "transformers", "unstructured", "wikipedia", "wolframalpha"] +all = ["PyMuPDF", "accelerate", "beautifulsoup4", "cohere", "datasets", "diffusers", "docx2txt", "googlemaps", "neo4j", "openapi-spec-validator", "opencv-python", "pillow", "prance", "pydub", "pygithub", "pymilvus", "pyowm", "qdrant-client", "rank-bm25", "requests_oauthlib", "sentence-transformers", "sentencepiece", "slack-sdk", "soundfile", "torch", "transformers", "unstructured", "wikipedia", "wolframalpha"] encoders = ["sentence-transformers"] graph-storages = ["neo4j"] huggingface-agent = ["accelerate", "datasets", "diffusers", "opencv-python", "sentencepiece", "soundfile", "torch", "transformers"] retrievers = ["cohere", "rank-bm25"] test = ["mock", "pytest"] -tools = ["PyMuPDF", "beautifulsoup4", "docx2txt", "googlemaps", "openapi-spec-validator", "prance", "pydub", "pygithub", "pyowm", "requests_oauthlib", "slack-sdk", "unstructured", "wikipedia", "wolframalpha"] +tools = ["PyMuPDF", "beautifulsoup4", "docx2txt", "googlemaps", "openapi-spec-validator", "pillow", "prance", "pydub", "pygithub", "pyowm", "requests_oauthlib", "slack-sdk", "unstructured", "wikipedia", "wolframalpha"] vector-databases = ["pymilvus", "qdrant-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.12" -content-hash = "4e22773023f79ee56f30597046ad95cc5b87d3e1a27b3e7b99ba79e7622f0723" +content-hash = "2c2023104b8ef3b2eade2c43c979b047d3dc07a9bbb69795808ee4a96032dd33" diff --git a/pyproject.toml b/pyproject.toml index 251ee8d87..2fa482281 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ requests_oauthlib = { version = "^1.3.1", optional = true } prance = { version = "^23.6.21.0", optional = true } openapi-spec-validator = { version = "^0.7.1", optional = true } unstructured = { extras = ["all-docs"], version = "^0.10.30", optional = true } +pillow = { version = "^10.2.0", optional = true } slack-sdk = { version = "^3.27.2", optional = true } pydub = { version = "^0.25.1", optional = true } pygithub = { version = "^2.3.0", optional = true } @@ -114,6 +115,7 @@ tools = [ "prance", "openapi-spec-validator", "unstructured", + "pillow", "slack-sdk", "pydub", "pygithub", @@ -169,6 +171,7 @@ all = [ "neo4j", # retrievers "rank-bm25", + "pillow", ] [tool.poetry.group.dev] @@ -276,8 +279,10 @@ module = [ "cohere", "sentence_transformers.*", "pymilvus", + "pillow", "slack-sdk", "pydub", "pygithub" + ] ignore_missing_imports = true diff --git a/test/embeddings/test_vlm_embeddings.py b/test/embeddings/test_vlm_embeddings.py new file mode 100644 index 000000000..5e7179fa0 --- /dev/null +++ b/test/embeddings/test_vlm_embeddings.py @@ -0,0 +1,78 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import pytest +import requests +from PIL import Image +from transformers import CLIPModel, CLIPProcessor + +from camel.embeddings import VisionLanguageEmbedding + + +@pytest.fixture +def VLM_instance() -> VisionLanguageEmbedding: + return VisionLanguageEmbedding() + + +def test_CLIPEmbedding_initialization(VLM_instance): + assert VLM_instance is not None + assert isinstance(VLM_instance.model, CLIPModel) + assert isinstance(VLM_instance.processor, CLIPProcessor) + + +def test_image_embed_list_with_valid_input(VLM_instance): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + test_images = [image, image] + embeddings = VLM_instance.embed_list(test_images) + assert isinstance(embeddings, list) + assert len(embeddings) == 2 + for e in embeddings: + assert len(e) == VLM_instance.get_output_dim() + + +def test_image_embed_list_with_empty_input(VLM_instance): + with pytest.raises(ValueError): + VLM_instance.embed_list([]) + + +def test_text_embed_list_with_valid_input(VLM_instance): + test_texts = ['Hello world', 'Testing sentence embeddings'] + embeddings = VLM_instance.embed_list(test_texts) + assert isinstance(embeddings, list) + assert len(embeddings) == 2 + for e in embeddings: + assert len(e) == VLM_instance.get_output_dim() + + +def test_text_embed_list_with_empty_input(VLM_instance): + with pytest.raises(ValueError): + VLM_instance.embed_list([]) + + +def test_mixed_embed_list_with_valid_input(VLM_instance): + test_list = ['Hello world', 'Testing sentence embeddings'] + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + test_list.append(image) + embeddings = VLM_instance.embed_list(test_list) + assert isinstance(embeddings, list) + assert len(embeddings) == 3 + for e in embeddings: + assert len(e) == VLM_instance.get_output_dim() + + +def test_get_output_dim(VLM_instance): + output_dim = VLM_instance.get_output_dim() + assert isinstance(output_dim, int) + assert output_dim > 0