In [None]:
from typing import Sequence, Optional
from google.cloud import vision
import io
from PIL import Image
import requests
from IPython.display import Image, display
import base64


def googleVisionAIAnalysis(
    feature_types: Sequence,
    path: Optional[str] = None,
    base64Str: Optional[str] = None,
) -> vision.AnnotateImageResponse:
    if path is None and base64Str is None:
        raise Exception("Choose one between path and base64Str")
    if path != None and base64Str is not None:
        raise Exception("Choose one between path and base64Str")
    if path != None and (path.startswith("http://") or path.startswith("https://")):
        # URL链接的图片
        response = requests.get(path)
        if response.status_code == 200:
            content = response.content
            display(Image(url=path))
        else:
            raise Exception(f"Rquest Failed. status: {response.status_code}")
    elif path != None:
        # 本地的图片
        with io.open(path, "rb") as image_file:
            content = image_file.read()
        display(Image(filename=path))
    if base64Str != None:
        content = base64.b64decode(base64Str)
        display(Image(data=content))

    # 实例化一个 Vision 客户端
    client = vision.ImageAnnotatorClient()
    image = vision.Image(content=content)
    features = [vision.Feature(type_=feature_type) for feature_type in feature_types]
    request = vision.AnnotateImageRequest(image=image, features=features)
    # 调用 Vision API
    # response = client.document_text_detection(image=image)
    response = client.annotate_image(request=request)
    return response


def print_objects(response: vision.AnnotateImageResponse) -> str:
    re = "Objects Annotations"
    split_row = "-" * 80
    re = "\n".join([re, split_row])
    objs = []
    for obj in response.localized_object_annotations:
        nvertices = obj.bounding_poly.normalized_vertices
        s = "|".join(
            [
                f"{obj.score:4.0%}",
                f"{obj.name:15}",
                f"{obj.mid:10}",
                ",".join(f"({v.x:.1f},{v.y:.1f})" for v in nvertices),
            ]
        )
        objs.append(s)
    return re + "\n" + "\n".join(objs) + "\n"


def print_labels(response: vision.AnnotateImageResponse) -> str:
    re = "Label Annotations"
    split_row = "-" * 80
    re = "\n".join([re, split_row])
    labels = []
    for label in response.label_annotations:
        s = "|".join(
            [
                f"{label.score:4.0%}",
                f"{label.description:5}",
            ]
        )
        labels.append(s)
    return re + "\n" + "\n".join(labels) + "\n"


def print_text(response: vision.AnnotateImageResponse) -> str:
    re = "Text Annotations"
    split_row = "-" * 80
    re = "\n".join([re, split_row])
    texts = []
    for annotation in response.text_annotations:
        vertices = [f"({v.x},{v.y})" for v in annotation.bounding_poly.vertices]
        s = "Vertices:" + "| Vertices: ".join(
            [
                f"{repr(annotation.description):42}",
                ",".join(vertices),
            ]
        )
        texts.append(s)
    return re + "\n" + "\n".join(texts) + "\n"


def print_faces(response: vision.AnnotateImageResponse) -> str:
    re = "Face Annotations"
    split_row = "-" * 80
    re = "\n".join([re, split_row])
    faces = []
    for face_number, face in enumerate(response.face_annotations, 1):
        vertices = ",".join(f"({v.x},{v.y})" for v in face.bounding_poly.vertices)
        s = "\n".join(
            [
                f"# Face {face_number} @ {vertices}",
                f"Joy:     {face.joy_likelihood.name}",
                f"Exposed: {face.under_exposed_likelihood.name}",
                f"Blurred: {face.blurred_likelihood.name}",
                "\n",
            ]
        )
        faces.append(s)
    return re + "\n" + "\n".join(faces) + "\n"


def print_landmarks(
    response: vision.AnnotateImageResponse, min_score: float = 0.5
) -> str:
    re = "Landmark Annotations"
    split_row = "-" * 80
    re = "\n".join([re, split_row])
    landmarks = []
    for landmark in response.landmark_annotations:
        if landmark.score < min_score:
            continue
        vertices = [f"({v.x},{v.y})" for v in landmark.bounding_poly.vertices]
        lat_lng = landmark.locations[0].lat_lng
        s = "|".join(
            [
                f"{landmark.description:18}",
                ",".join(vertices),
                f"{lat_lng.latitude:.5f}",
                f"{lat_lng.longitude:.5f}",
            ]
        )
        landmarks.append(s)
    return re + "\n" + "\n".join(landmarks) + "\n"


def print_image_properties(
    response: vision.AnnotateImageResponse, min_score: float = 0.5
) -> str:
    re = "Image Properties Annotations"
    split_row = "-" * 80
    re = "\n".join([re, split_row])
    colors = []
    for color in response.image_properties_annotation.dominant_colors.colors:
        s = "".join(
            [
                "RGB(",
                f"{color.color.red:.0f},",
                f"{color.color.green:.0f},",
                f"{color.color.blue:.0f}",
                f"{')':20}",
                "|",
                f"Score:{color.score:4.0%}",
                "|",
                f"Pixel fraction:{color.pixel_fraction:4.0%}",
            ]
        )
        colors.append(s)
    return re + "\n" + "\n".join(colors) + "\n"

In [None]:
from typing import Any, Dict, List, Optional
from langchain.prompts import PromptTemplate
from langchain.chains.base import Chain
from langchain import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from pydantic import Extra
from langchain.callbacks.manager import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
)


class GoogleVisonAIChain(Chain):
    """
    An example of a custom chain.
    """

    prompt: BasePromptTemplate
    """Prompt object to use."""
    llm: BaseLanguageModel
    output_key: str = "text"  #: :meta private:

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    @property
    def input_keys(self) -> List[str]:
        """Will be whatever keys the prompt expects.

        :meta private:
        """
        return self.prompt.input_variables

    @property
    def output_keys(self) -> List[str]:
        """Will always return text key.

        :meta private:
        """
        return [self.output_key]

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        prompt_value = self.prompt.format_prompt(**inputs)
        if run_manager:
            run_manager.on_text(
                prompt_value.to_string(), color="green", end="\n", verbose=self.verbose
            )
        result = self.llm.generate_prompt(
            prompts=[prompt_value],
            callbacks=run_manager.get_child() if run_manager else None,
        )
        if run_manager:
            run_manager.on_text(
                result.generations[0][0].text,
                color="yellow",
                end="\n",
                verbose=self.verbose,
            )
        return {self.output_key: result.generations[0][0].text}

    async def _acall(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
    ) -> Dict[str, str]:
        prompt_value = self.prompt.format_prompt(**inputs)
        if run_manager:
            await run_manager.on_text(
                prompt_value.to_string(), color="green", end="\n", verbose=self.verbose
            )
        result = await self.llm.agenerate_prompt(
            prompts=[prompt_value],
            callbacks=run_manager.get_child() if run_manager else None,
        )
        if run_manager:
            await run_manager.on_text(
                result.generations[0][0].text,
                color="yellow",
                end="\n",
                verbose=self.verbose,
            )
        return {self.output_key: result.generations[0][0].text}

    @property
    def _chain_type(self) -> str:
        return "calendar_chain"

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        **kwargs: Any,
    ) -> Chain:
        template = """{vision_ai_answer}

说明
---------------------
- 上面是一张图片经过分析后的数据。
- 图片主题分类：风景主题、人物主题、户型图主题和其他
- 在下面"图片内容"处，根据不同的主题，按照下面相应"要求"，用一段话描写出你看到的内容。

风景主题要求
---------------------
- 不要使用"标注"和"注释"
- 要求说明图片中都包含了哪些内容
- 要求写出对图片的颜色和光线的感受
- 要求对整体图片的背景进行分析
- 如果图片包含有地标，要求对图片中包含的地标进行简单的介绍

人物主题要求
---------------------
- 不要使用"标注"和"注释"
- 要求说明图片中都包含了哪些内容
- 要求写出对图片的颜色和光线的感受
- 要求对整体图片的背景进行分析
- 要求详细描写人物的面容表情
- 要求详细描写人物的衣着

户型图主题要求
---------------------
- 不要使用"标注"和"注释"
- 默认向上方为北方。
- 要求说明图片中都包含了哪些内容
- 要求描述清楚各个房间的位置
- 要求预测房间门的朝向。

其他主题要求
---------------------
- 要求说明图片中都包含了哪些内容

图片内容
---------------------
"""
        prompt = PromptTemplate.from_template(template=template)
        return cls(llm=llm, prompt=prompt, **kwargs)

In [None]:
export DISPLAY=:0

In [None]:
async def analysisImage(img_str: str) -> str:
    from dotenv import load_dotenv

    load_dotenv(dotenv_path="env")
    features = [
        vision.Feature.Type.OBJECT_LOCALIZATION,
        vision.Feature.Type.FACE_DETECTION,
        vision.Feature.Type.LANDMARK_DETECTION,
        vision.Feature.Type.LABEL_DETECTION,
        vision.Feature.Type.TEXT_DETECTION,
        vision.Feature.Type.IMAGE_PROPERTIES,
    ]
    import os

    if img_str.startswith(("http://", "https://")):
        response = googleVisionAIAnalysis(path=img_str, feature_types=features)
    # 检查是否是本地文件路径
    elif os.path.isfile(img_str):
        response = googleVisionAIAnalysis(path=img_str, feature_types=features)
    # 检查是否是base64
    else:
        try:
            # Remove base64 image prefix if any
            if "base64," in img_str:
                n_img_str = img_str.split("base64,")[1]
            # Check if result is url safe
            url_safe_base64 = "-" in n_img_str or "_" in n_img_str
            import urllib

            base64_str = (
                urllib.parse.unquote(img_str)
                + (3 - len(urllib.parse.unquote(img_str)) % 3) * "="
            )
            base64.b64decode(base64_str, altchars="+-" if url_safe_base64 else None)
        except Exception:
            raise Exception(f"img_str format error:{img_str}")
        response = googleVisionAIAnalysis(base64Str=n_img_str, feature_types=features)

    vision_ai_answer = "\n".join(
        [
            print_objects(response),
            print_labels(response),
            print_text(response),
            print_faces(response),
            print_landmarks(response),
            print_image_properties(response),
        ]
    )

    from langchain.chat_models import ChatOpenAI

    # chat_gpt = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k", verbose=True)
    chat_gpt = ChatOpenAI(temperature=0, model="gpt-4", verbose=True)
    chain = GoogleVisonAIChain.from_llm(llm=chat_gpt,verbose=True)
    return await chain.arun(vision_ai_answer=vision_ai_answer)

In [None]:
await analysisImage(img_str=input("Input image path or url or base64 string: "))