In [None]:
%pip install chromadb

In [32]:
# 初始化智谱 embedding
from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, Field, model_validator
import os
from zhipuai import ZhipuAI

class ZhipuAIEmbeddings(BaseModel, Embeddings):
    """ZhipuAI embedding model integration.

    Setup:

        To use, you should have the ``zhipuai`` python package installed, and the
        environment variable ``ZHIPU_API_KEY`` set with your API KEY.

        More instructions about ZhipuAi Embeddings, you can get it
        from  https://open.bigmodel.cn/dev/api#vector

        .. code-block:: bash

            pip install -U zhipuai
            export ZHIPU_API_KEY="your-api-key"

    Key init args — completion params:
        model: Optional[str]
            Name of ZhipuAI model to use.
        api_key: str
            Automatically inferred from env var `ZHIPU_API_KEY` if not provided.

    See full list of supported init args and their descriptions in the params section.

    Instantiate:

        .. code-block:: python

            from langchain_community.embeddings import ZhipuAIEmbeddings

            embed = ZhipuAIEmbeddings(
                model="embedding-2",
                # api_key="...",
            )

    Embed single text:
        .. code-block:: python

            input_text = "The meaning of life is 42"
            embed.embed_query(input_text)

        .. code-block:: python

            [-0.003832892, 0.049372625, -0.035413884, -0.019301128, 0.0068899863, 0.01248398, -0.022153955, 0.006623926, 0.00778216, 0.009558191, ...]


    Embed multiple text:
        .. code-block:: python

            input_texts = ["This is a test query1.", "This is a test query2."]
            embed.embed_documents(input_texts)

        .. code-block:: python

            [
                [0.0083934665, 0.037985895, -0.06684559, -0.039616987, 0.015481004, -0.023952313, ...],
                [-0.02713102, -0.005470169, 0.032321047, 0.042484466, 0.023290444, 0.02170547, ...]
            ]
    """  # noqa: E501

    client: Any = Field(default=None, exclude=True)  #: :meta private:
    model: str = Field(default="embedding-2")
    """Model name"""
    api_key: str
    """Automatically inferred from env var `ZHIPU_API_KEY` if not provided."""
    dimensions: Optional[int] = None
    """The number of dimensions the resulting output embeddings should have.

    Only supported in `embedding-3` and later models.
    """

    @model_validator(mode="before")
    @classmethod
    def validate_environment(cls, values: Dict) -> Any:
        """Validate that auth token exists in environment."""
        values["api_key"] = get_from_dict_or_env(values, "api_key", "ZHIPUAI_API_KEY")
        try:
            from zhipuai import ZhipuAI

            values["client"] = ZhipuAI(api_key=values["api_key"])
        except ImportError:
            raise ImportError(
                "Could not import zhipuai python package."
                "Please install it with `pip install zhipuai`."
            )
        return values



    def embed_query(self, text: str) -> List[float]:
        """
        Embeds a text using the AutoVOT algorithm.

        Args:
            text: A text to embed.

        Returns:
            Input document's embedded list.
        """
        resp = self.embed_documents([text])
        return resp[0]




    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """
        Embeds a list of text documents using the AutoVOT algorithm.

        Args:
            texts: A list of text documents to embed.

        Returns:
            A list of embeddings for each document in the input list.
            Each embedding is represented as a list of float values.
        """
        if self.dimensions is not None:
            resp = self.client.embeddings.create(
                model=self.model,
                input=texts,
                dimensions=self.dimensions,
            )
        else:
            resp = self.client.embeddings.create(model=self.model, input=texts)
        embeddings = [r.embedding for r in resp.data]
        return embeddings


In [None]:

import chromadb
from chromadb.utils import embedding_functions
from pprint import pprint
import os

# 初始化Chroma客户端（持久化模式）
client = chromadb.PersistentClient(path="ch16_db")

# 创建一个集合（类似于表）
collection = client.get_or_create_collection("products")

# 定义嵌入函数（使用预训练的嵌入模型）
embedding = ZhipuAIEmbeddings(
    model="embedding-2",
    api_key=os.getenv("ZHIPU_API_KEY"),
    dimensions=1024
)

# 添加嵌入数据
collection.add(
    documents=["Galaxy S21", "iPhone 13", "MacBook Pro"],
    embeddings=embedding.embed_documents(["Galaxy S21", "iPhone 13", "MacBook Pro"]),
    metadatas=[
        {"category": "手机", "price": 799.99},
        {"category": "手机", "price": 999.99},
        {"category": "笔记本电脑", "price": 1299.99}
    ],
    ids=["prod1", "prod2", "prod3"]
)

print("数据添加完成！")

# 获取集合中的所有数据
all_data = collection.get()
print("集合中的所有数据：")
pprint(all_data)

# 根据 ID 获取特定的文档
specific_data = collection.get(ids=["prod1"])
print("\nID 为 'prod1' 的文档：")
pprint(specific_data)

# 根据元数据条件获取文档
filtered_data = collection.get(where={"category": "手机"})
print("\n类别为 '手机' 的文档：")
pprint(filtered_data)

# 更新已有文档的元数据
collection.update(
    ids=["prod1"],
    metadatas=[{"category": "手机", "price": 749.99}]
)
print("\n已更新 ID 为 'prod1' 的文档价格。")

# 删除特定 ID 的文档
collection.delete(ids=["prod2"])
print("\n已删除 ID 为 'prod2' 的文档。")

# 查看集合中剩余的文档
remaining_data = collection.get()
print("\n剩余的文档：")
pprint(remaining_data)



In [None]:
import chromadb
from chromadb.utils import embedding_functions
from pprint import pprint
import os

# 初始化Chroma客户端（持久化模式）
client = chromadb.PersistentClient(path="ch16_db")

# 创建一个集合（类似于表）
collection = client.get_or_create_collection("products")

# 定义嵌入函数（使用预训练的嵌入模型）
embedding = ZhipuAIEmbeddings(
    model="embedding-2",
    api_key=os.getenv("ZHIPU_API_KEY"),
    dimensions=1024
)

# 添加嵌入数据
collection.add(
    documents=["Galaxy S21", "iPhone 13", "MacBook Pro"],
    embeddings=embedding.embed_documents(["Galaxy S21", "iPhone 13", "MacBook Pro"]),
    metadatas=[
        {"category": "手机", "price": 799.99},
        {"category": "手机", "price": 999.99},
        {"category": "笔记本电脑", "price": 1299.99}
    ],
    ids=["prod1", "prod2", "prod3"]
)

print("数据添加完成！")

In [None]:
# 获取集合中的所有数据
all_data = collection.get()
print("集合中的所有数据：")
pprint(all_data)

# 根据 ID 获取特定的文档
specific_data = collection.get(ids=["prod1"])
print("\nID 为 'prod1' 的文档：")
pprint(specific_data)

# 根据元数据条件获取文档
filtered_data = collection.get(where={"category": "手机"})
print("\n类别为 '手机' 的文档：")
pprint(filtered_data)

# 更新已有文档的元数据
collection.update(
    ids=["prod1"],
    metadatas=[{"category": "手机", "price": 749.99}]
)
print("\n已更新 ID 为 'prod1' 的文档价格。")

# 删除特定 ID 的文档
collection.delete(ids=["prod2"])
print("\n已删除 ID 为 'prod2' 的文档。")

# 查看集合中剩余的文档
remaining_data = collection.get()
print("\n剩余的文档：")
pprint(remaining_data)

In [None]:
# 导入所需的模块
from langchain.vectorstores import Chroma
import os

# 从磁盘加载持久化数据库
persist_directory = "ch16_db"
embedding = ZhipuAIEmbeddings(
    model="embedding-2",
    api_key=os.getenv("ZHIPU_API_KEY"),
    dimensions=1024
)

vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)

# 执行相似性搜索并返回结果及其得分
query = "智能手机"
results = vectordb.similarity_search_with_score(query)

# 输出搜索结果
print("Search results:")
for doc, score in results:
    print(f"Document: {doc.page_content}, Score: {score}")

In [None]:

# 创建基于Chroma向量存储的检索器
retriever = vectordb.as_retriever(search_type="mmr")

# 检索与查询相关的文档
retrieved_docs = retriever.get_relevant_documents(query)

# 输出第一篇检索到的文档内容
print("Retrieved document:")
print(retrieved_docs[0].page_content)

In [None]:
# 导入所需的模块
from langchain.vectorstores import Chroma
import os

# 从磁盘加载持久化数据库
persist_directory = "ch16_db"
embedding = ZhipuAIEmbeddings(
    model="embedding-2",
    api_key=os.getenv("ZHIPU_API_KEY"),  # 确保环境变量已设置或直接替换为您的API密钥
    dimensions=1024
)

vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)

# 执行相似性搜索并返回结果及其得分
query = "智能手机"
result = vectordb.similarity_search_with_score(query)

# 输出第一条搜索结果
print(result)


In [None]:
# 创建基于Chroma向量存储的检索器
retriever = vectordb.as_retriever(search_type="mmr")

# 检索与查询相关的文档
retrieved_docs = retriever.get_relevant_documents(query)

# 输出第一篇检索到的文档内容
print(retrieved_docs[0].page_content)

In [None]:

collection = client.get_or_create_collection("products")

# 定义嵌入函数（假设使用预训练的嵌入模型）

# 添加嵌入数据
collection.add(
    documents=["Galaxy S21", "iPhone 13", "MacBook Pro"],
    embeddings=embedding.embed_documents(["Galaxy S21", "iPhone 13", "MacBook Pro"]),
    metadatas=[
        {"category": "手机", "price": 799.99},
        {"category": "手机", "price": 999.99},
        {"category": "笔记本电脑", "price": 1299.99}
    ],
    ids=["prod1", "prod2", "prod3"]
)

# 查询相似向量



results = collection.query(
    query_embeddings=embedding.embed_documents(["智能手机"]),
    n_results=2
)

pprint(results)

In [None]:
print(collection.count())  # 确保集合中有文档


In [None]:
query_embeddings = embedding.embed_documents(["智能手机"])
print(query_embeddings)  # 检查查询嵌入是否正确生成

In [None]:
embeddings = embedding.embed_documents(["Galaxy S21", "iPhone 13", "MacBook Pro"])
print(embeddings)  # 确保嵌入生成成功


In [None]:
# 导入所需的模块
from langchain.vectorstores import Chroma

# 从磁盘加载持久化数据库
persist_directory = "ch16_db"
embedding = ZhipuAIEmbeddings(
    model="embedding-2",
    api_key=os.getenv("ZHIPU_API_KEY"),
    dimensions=1024
)
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)
# 执行相似性搜索并返回结果及其得分
query = "Galaxy S21"
print(query)
result = vectordb.similarity_search_with_score(query)
# 输出第一条搜索结果
print(result)

In [None]:
# 导入所需模块
from langchain.vectorstores import Chroma
from langchain_core.embeddings import Embeddings
from zhipuai import ZhipuAI
import os


# 设置持久化目录
persist_directory = "ch16_db"



# 初始化 Chroma 数据库
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)

# 执行相似性搜索并返回结果及其得分
query = "iphone"
print("Query:", query)

# 生成查询嵌入
query_embeddings = embedding.embed_documents([query])
print("Query embeddings:", query_embeddings)  # 打印查询嵌入

# 获取集合并检查数据库中的文档数量

documents = collection.get()  # 获取所有文档
document_count = len(documents)  # 计算文档数量
print("Number of documents in database:", document_count)

# 如果数据库中有文档，则进行相似性搜索
if document_count > 0:
    # 进行相似性搜索
    result = vectordb.similarity_search_with_score(query, k=5)
    # 输出结果
    print("Search results:", result)
else:
    print("数据库中没有文档，无法进行搜索。")


In [None]:
import chromadb
from langchain.vectorstores import Chroma

import os

# 初始化Chroma客户端
persist_directory = "ch16_db"
embedding = ZhipuAIEmbeddings(model="embedding-2", api_key=os.getenv("ZHIPU_API_KEY"), dimensions=1024)
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)

# 查询内容
query = "Galaxy S21"
query_embeddings = embedding.embed_documents([query])

# 检查查询嵌入
print("Query embeddings:", query_embeddings)

# 获取集合中的文档
documents = vectordb.get()  # 确保从vectordb获取文档
document_count = len(documents)
print("Number of documents in database:", document_count)

# 如果有文档，则进行相似性搜索
if document_count > 0:
    result = vectordb.similarity_search_with_score(query, k=5)
    print("Search results:", result)
else:
    print("数据库中没有文档，无法进行搜索。")


In [None]:
import chromadb
from chromadb.utils import embedding_functions
from pprint import pprint
import os
import numpy as np

# 假设您已经定义了 ZhipuAIEmbeddings，如果没有，需要替换为实际的嵌入函数
# 请确保您已经正确导入或定义了 ZhipuAIEmbeddings 类

# 初始化 Chroma 客户端（持久化模式）
client = chromadb.PersistentClient(path="ch16_db")

# 删除并重新创建集合
try:
    client.delete_collection(name="products")
    print("集合已删除。")
except Exception as e:
    print("集合不存在，创建新的集合。")

# 创建新的集合
collection = client.create_collection(name="products")
print("新的集合已创建。")

# 定义嵌入函数（使用预训练的嵌入模型）
embedding = ZhipuAIEmbeddings(
    model="embedding-2",
    api_key=os.getenv("ZHIPU_API_KEY"),
    dimensions=1024
)

# 添加嵌入数据
collection.add(
    documents=["Galaxy S21", "iPhone 13", "MacBook Pro"],
    embeddings=embedding.embed_documents(["Galaxy S21", "iPhone 13", "MacBook Pro"]),
    metadatas=[
        {"category": "手机", "price": 799.99},
        {"category": "手机", "price": 999.99},
        {"category": "笔记本电脑", "price": 1299.99}
    ],
    ids=["prod1", "prod2", "prod3"]
)
print("数据添加完成！")

# 打印集合中的所有数据，验证添加是否成功
# all_data = collection.get(include=["documents", "metadatas", "ids"])  # 错误的方式
all_data = collection.get(include=["documents", "metadatas"])  # 移除 "ids"注意，ids 会自动返回。不需要写到这里面

print("当前集合中的数据：")
for metadata, document, doc_id in zip(all_data.get('metadatas', []), all_data.get('documents', []), all_data.get('ids', [])):
    print(f"ID: {doc_id}, 产品: {document}, 分类: {metadata['category']}, 价格: {metadata['price']}")

# 自定义元数据过滤：获取价格低于1000的手机，并包括嵌入向量
filtered_data = collection.get(
    where={
        "$and": [
            {"category": "手机"},
            {"price": {"$lte": 1000}}
        ]
    },
    include=["embeddings", "documents", "metadatas"]  # 移除 "ids"
)

# 检查是否有返回结果，且结果不为 None
if not filtered_data or not filtered_data.get('documents'):
    print("没有找到符合条件的产品。")
else:
    # 组合元数据、文档、ID 和嵌入
    combined_data = list(zip(
        filtered_data.get('metadatas', []),
        filtered_data.get('documents', []),
        filtered_data.get('ids', []),  # ids 默认返回
        filtered_data.get('embeddings', [])
    ))

    # 定义查询嵌入（例如，用户想查询“高性能手机”）
    query_text = "高性能手机"
    query_embedding = embedding.embed_query(query_text)

    # 自定义相似度函数（余弦相似度）
    def cosine_similarity(a, b):
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

    # 对结果根据自定义相似度进行排序
    sorted_data = sorted(
        combined_data,
        key=lambda x: cosine_similarity(query_embedding, x[3]),
        reverse=True  # 相似度从高到低排序
    )

    print(f"\n查询 '{query_text}' 的结果（按相似度排序）：")
    for metadata, document, doc_id, _ in sorted_data:
        print(f"ID: {doc_id}, 产品: {document}, 价格: {metadata['price']}")


In [63]:
import chromadb
from chromadb.utils import embedding_functions
from pprint import pprint
import os
import numpy as np
# 假设您已经定义了 ZhipuAIEmbeddings，如果没有，需要替换为实际的嵌入函数
# 请确保您已经正确导入或定义了 ZhipuAIEmbeaf77489efe3a6ca0afc5dccdc4d4e657ddings 类

# 初始化 Chroma 客户端（持久化模式）
client = chromadb.PersistentClient(path="ch16_db")

# 删除并重新创建集合
try:
    client.delete_collection(name="products")
    print("集合已删除。")
except Exception as e:
    print("集合不存在，创建新的集合。")

# 创建新的集合
collection = client.create_collection(name="products")
print("新的集合已创建。")

# 定义嵌入函数（使用预训练的嵌入模型）
embedding = ZhipuAIEmbeddings(
    model="embedding-2",
    api_key=os.getenv("ZHIPU_API_KEY"),
    dimensions=1024
)

# 添加嵌入数据
documents = ["Galaxy S21", "iPhone 13", "MacBook Pro"]
metadatas = [
    {"category": "手机", "price": 799.99},
    {"category": "手机", "price": 999.99},
    {"category": "笔记本电脑", "price": 1299.99}
]
ids = ["prod1", "prod2", "prod3"]
embeddings = embedding.embed_documents(documents)

collection.add(
    documents=documents,
    embeddings=embeddings,
    metadatas=metadatas,
    ids=ids
)
print("数据添加完成！")

# 打印集合中的所有数据，验证添加是否成功
all_data = collection.get(include=["documents", "metadatas"])  # ids 会自动返回

print("当前集合中的数据：")
for metadata, document, doc_id in zip(all_data.get('metadatas', []), all_data.get('documents', []), all_data.get('ids', [])):
    print(f"ID: {doc_id}, 产品: {document}, 分类: {metadata['category']}, 价格: {metadata['price']}")

# 定义自定义相似度检索器
class CustomSimilarityRetriever:
    def __init__(self, collection, embedding_function, k=3):
        self.collection = collection
        self.embedding_function = embedding_function
        self.k = k

    def get_relevant_documents(self, query):
        # 计算查询的嵌入
        query_embedding = self.embedding_function.embed_query(query)
        # 获取集合中的所有嵌入数据
        all_data = self.collection.get(
            include=["embeddings", "documents", "metadatas"]  # 移除了 "ids"
        )
        if not all_data or not all_data.get('documents'):
            return []
        embeddings = all_data.get('embeddings')
        documents = all_data.get('documents')
        metadatas = all_data.get('metadatas')
        ids = all_data.get('ids')  # ids 会自动返回

        # 转换嵌入为 numpy 数组
        query_embedding = np.array(query_embedding)
        doc_embeddings = np.array(embeddings)

        # 定义余弦相似度函数
        def cosine_similarity(a, b):
            return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

        # 计算查询与每个文档的相似度
        similarities = [cosine_similarity(query_embedding, doc_embedding) for doc_embedding in doc_embeddings]

        # 获取相似度最高的前 k 个文档的索引
        top_k_indices = np.argsort(similarities)[-self.k:][::-1]

        # 准备返回的文档列表
        results = []
        for idx in top_k_indices:
            doc = {
                'document': documents[idx],
                'metadata': metadatas[idx],
                'id': ids[idx],
                # 'embedding': embeddings[idx]  # 如果需要，可以包括嵌入向量
            }
            results.append(doc)
        return results

# 创建自定义相似度检索器
custom_retriever = CustomSimilarityRetriever(
    collection=collection,
    embedding_function=embedding,
    k=3  # 返回3个文档
)

# 使用自定义检索器进行检索
query_text = "高性能手机"
retrieved_docs = custom_retriever.get_relevant_documents(query_text)

print(f"\n查询 '{query_text}' 的结果：")
for doc in retrieved_docs:
    print(f"ID: {doc['id']}, 产品: {doc['document']}, 价格: {doc['metadata']['price']}")


产品: iPhone 14 Pro Max
提取的特征: 拍照 奢华 大屏 性能 专业 便携 高端 智能
生成的广告文案: 非凡性能，精致设计，iPhone 14 Pro Max为您带来智能新体验


In [58]:
# 导入必要的库
from langchain.prompts.example_selector.base import BaseExampleSelector
from typing import Dict, List
import numpy as np
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
from langchain_community.chat_models import ChatTongyi
import os
from dotenv import load_dotenv
import chromadb
from chromadb.utils import embedding_functions

# 定义类别及其关键词
CATEGORIES = {
    "手机": ["手机", "智能手机", "iPhone", "Android", "移动"],
    "电脑": ["电脑", "笔记本", "台式机", "Mac", "Windows"],
    "服装": ["服装", "衬衫", "裤子", "连衣裙", "服饰"],
}

# 定义自定义示例选择器
class CustomExampleSelector(BaseExampleSelector):
    def __init__(self, examples: List[Dict[str, str]], categories: Dict[str, List[str]]):
        """
        初始化示例选择器。

        :param examples: 示例列表，每个示例应包含 'input', 'output' 和 'category' 字段。
        :param categories: 类别定义，每个类别对应相关的关键词列表。
        """
        self.examples = examples
        self.categories = categories

    def add_example(self, example: Dict[str, str]) -> None:
        """
        添加新的示例到examples列表中。

        :param example: 包含 'input', 'output' 和 'category' 的字典。
        """
        self.examples.append(example)

    def categorize_input(self, input_text: str) -> str:
        """
        根据输入文本识别其类别。

        :param input_text: 用户输入的文本。
        :return: 识别的类别名称，如果无法识别则返回 'general'。
        """
        for category, keywords in self.categories.items():
            for keyword in keywords:
                if keyword.lower() in input_text.lower():
                    return category
        return "general"

    def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
        """
        根据输入变量选择相关的示例。

        :param input_variables: 包含 'input' 的字典。
        :return: 选定的示例列表。
        """
        input_text = input_variables.get("input", "")
        category = self.categorize_input(input_text)
        print(f"识别的类别: {category}")

        if category != "general":
            # 从识别的类别中筛选示例
            category_examples = [ex for ex in self.examples if ex.get("category") == category]
            selected_examples = np.random.choice(category_examples, size=min(2, len(category_examples)), replace=False)
        else:
            selected_examples = np.random.choice(self.examples, size=2, replace=False)

        return selected_examples.tolist()

# 示例列表，包含类别标签
examples = [
    {"input": "iPhone 最新款", "output": "功能更强大的 Iphone 手机", "category": "手机"},
    {"input": "MacBook Pro", "output": "性价比更高的 苹果 笔记本", "category": "电脑"},
    {"input": "夏季连衣裙", "output": "舒适的夏季连衣裙，清凉时尚", "category": "服装"},
]

# 实例化 CustomExampleSelector
example_selector = CustomExampleSelector(examples, CATEGORIES)

# 添加新的示例
example_selector.add_example({"input": "小米手机", "output": "性价比无敌 ", "category": "手机"})
print("更新后的示例列表:")
for ex in example_selector.examples:
    print(ex)

# 定义示例格式模板
example_formatter_template = """输入: {input} 输出: {output}\n"""
example_prompt = PromptTemplate(
    input_variables=["input", "output"],
    template=example_formatter_template,
)

# 创建 FewShotPromptTemplate
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix="根据以下示例，为输入生成一个吸引人的宣传词。",
    suffix="输入: {input}\n输出:",
    input_variables=["input"],
)

# 测试格式化 FewShotPromptTemplate
formatted_prompt = few_shot_prompt.format(input="iPhone 14 Pro")
print("\n格式化后的提示:\n", formatted_prompt)


更新后的示例列表:
{'input': 'iPhone 最新款', 'output': '功能更强大的 Iphone 手机', 'category': '手机'}
{'input': 'MacBook Pro', 'output': '性价比更高的 苹果 笔记本', 'category': '电脑'}
{'input': '夏季连衣裙', 'output': '舒适的夏季连衣裙，清凉时尚', 'category': '服装'}
{'input': '小米手机', 'output': '性价比无敌 ', 'category': '手机'}
识别的类别: 手机

格式化后的提示:
 根据以下示例，为输入生成一个吸引人的宣传词。

输入: 小米手机 输出: 性价比无敌 


输入: iPhone 最新款 输出: 功能更强大的 Iphone 手机


输入: iPhone 14 Pro
输出:


In [68]:
import numpy as np
from typing import Dict, List
import os

class SmartPromptGenerator:
    def __init__(self, embedding_function):
        self.embedding_function = embedding_function
        self.categories = {
            "手机": ["手机", "智能手机", "iphone", "android", "galaxy"],
            "电脑": ["电脑", "笔记本", "macbook", "laptop", "notebook"],
            "平板": ["平板", "ipad", "tablet"],
        }
        
        # 改进的广告文案模板及其特征描述
        self.templates = {
            "手机": [
                {
                    "template": "突破创新边界，{product}让科技融入生活",
                    "features": "创新 科技 生活 智能 便携",
                    "price_range": "all",  # 适用所有价格段
                    "keywords": ["创新", "智能", "科技"]
                },
                {
                    "template": "非凡性能，精致设计，{product}为您带来智能新体验",
                    "features": "性能 设计 体验 高端 精致",
                    "price_range": "high",  # 高端价格段
                    "keywords": ["性能", "设计", "高端"]
                },
                {
                    "template": "震撼视觉，强劲性能，{product}定义科技新高度",
                    "features": "视觉 性能 科技 强劲 高清",
                    "price_range": "high",
                    "keywords": ["视觉", "性能", "科技"]
                }
            ],
            "电脑": [
                {
                    "template": "高效办公的得力助手，{product}让工作更轻松",
                    "features": "办公 效率 工作 生产力 便携",
                    "price_range": "medium",
                    "keywords": ["办公", "效率", "便携"]
                },
                {
                    "template": "性能之选，{product}为创造力而生",
                    "features": "性能 创造 专业 设计 创意",
                    "price_range": "high",
                    "keywords": ["性能", "专业", "创意"]
                },
                {
                    "template": "专业之选，{product}助您突破工作极限",
                    "features": "专业 工作站 性能 效率 极限",
                    "price_range": "high",
                    "keywords": ["专业", "性能", "极限"]
                }
            ],
            "平板": [
                {
                    "template": "创意无界，{product}让灵感自由绽放",
                    "features": "创意 灵感 自由 便携 轻薄",
                    "price_range": "all",
                    "keywords": ["创意", "灵感", "便携"]
                },
                {
                    "template": "轻盈优雅，{product}让创造更随心",
                    "features": "轻薄 优雅 设计 便携 时尚",
                    "price_range": "high",
                    "keywords": ["轻薄", "优雅", "时尚"]
                },
                {
                    "template": "随时随地，{product}让工作娱乐更随心",
                    "features": "便携 娱乐 工作 灵活 多功能",
                    "price_range": "medium",
                    "keywords": ["便携", "娱乐", "多功能"]
                }
            ]
        }
        
        # 预计算所有模板的特征向量
        self.template_embeddings = {}
        for category, templates in self.templates.items():
            self.template_embeddings[category] = []
            for template in templates:
                embedding = self.embedding_function.embed_query(template["features"])
                self.template_embeddings[category].append({
                    "template": template["template"],
                    "embedding": embedding,
                    "price_range": template["price_range"],
                    "keywords": template["keywords"]
                })

    def categorize_product(self, product_name: str) -> str:
        """根据产品名称判断类别"""
        for category, keywords in self.categories.items():
            for keyword in keywords:
                if keyword.lower() in product_name.lower():
                    return category
        return "general"
    
    def get_price_range(self, price: float) -> str:
        """确定价格区间"""
        if price > 1000:
            return "high"
        elif price > 500:
            return "medium"
        return "low"

    def extract_product_features(self, product_info: dict) -> str:
        """提取产品特征关键词"""
        category = self.categorize_product(product_info['document'])
        features = []
        
        # 从产品名称中提取特征
        product_name = product_info['document'].lower()
        if "pro" in product_name:
            features.extend(["专业", "高端"])
        if "max" in product_name:
            features.extend(["大屏", "性能"])
        if "air" in product_name:
            features.extend(["轻薄", "便携"])
            
        # 从价格区间提取特征
        price = product_info['metadata'].get('price', 0)
        price_range = self.get_price_range(price)
        if price_range == "high":
            features.extend(["高端", "奢华"])
        elif price_range == "medium":
            features.extend(["中端", "性价比"])
        else:
            features.extend(["实惠", "经济"])
            
        # 从类别添加特征
        if category == "手机":
            features.extend(["智能", "便携", "拍照"])
        elif category == "电脑":
            features.extend(["效率", "性能", "办公"])
        elif category == "平板":
            features.extend(["轻薄", "便携", "创意"])
                    
        return " ".join(list(set(features)))  # 去重

    def find_best_template(self, product_info: dict) -> str:
        """使用词向量相似度找到最匹配的广告文案模板"""
        category = self.categorize_product(product_info['document'])
        if category == "general":
            return "精品之选，{product}品质保证"
        
        price = product_info['metadata'].get('price', 0)
        price_range = self.get_price_range(price)
            
        # 提取产品特征并计算嵌入向量
        features = self.extract_product_features(product_info)
        features_embedding = self.embedding_function.embed_query(features)
        
        # 筛选符合价格区间的模板
        suitable_templates = [
            template for template in self.template_embeddings[category]
            if template["price_range"] in ["all", price_range]
        ]
        
        if not suitable_templates:
            return self.templates[category][0]["template"]
        
        # 计算与筛选后模板的相似度
        best_similarity = -1
        best_template = None
        
        for template_info in suitable_templates:
            # 计算词向量相似度
            similarity = self.cosine_similarity(
                np.array(features_embedding),
                np.array(template_info["embedding"])
            )
            
            # 根据关键词匹配度增加权重
            keyword_match = sum(1 for keyword in template_info["keywords"] 
                              if keyword.lower() in features.lower())
            keyword_bonus = keyword_match * 0.1  # 每个匹配关键词增加0.1的权重
            
            final_similarity = similarity + keyword_bonus
            
            if final_similarity > best_similarity:
                best_similarity = final_similarity
                best_template = template_info["template"]
            print(f"最佳相似度: {best_similarity}, 最佳模板: {best_template}")
                
        return best_template or self.templates[category][0]["template"]
        
    def cosine_similarity(self, a, b):
        """计算余弦相似度"""
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

    def generate_ad(self, product_info: dict) -> str:
        """生成广告文案"""
        template = self.find_best_template(product_info)
        return template.format(product=product_info['document'])

def main():
    # 初始化嵌入模型
    embedding = ZhipuAIEmbeddings(
        model="embedding-2",
        api_key=os.getenv("ZHIPU_API_KEY"),
        dimensions=1024
    )
    
    # 初始化生成器
    generator = SmartPromptGenerator(embedding)
    
    # 测试数据
    test_product = {
        'document': 'Iphone 18 Pro max',
        'metadata': {'price': 1099.99, 'category': '手机'},
        'id': 'prod1'
    }
    
    # 生成广告文案
    ad = generator.generate_ad(test_product)
    
    # 打印结果
    print(f"产品: {test_product['document']}")
    print(f"提取的特征: {generator.extract_product_features(test_product)}")
    print(f"生成的广告文案: {ad}")

if __name__ == "__main__":
    main()

最佳相似度: 0.7133109534885738, 最佳模板: 突破创新边界，{product}让科技融入生活
最佳相似度: 0.8960162956768927, 最佳模板: 非凡性能，精致设计，{product}为您带来智能新体验
最佳相似度: 0.8960162956768927, 最佳模板: 非凡性能，精致设计，{product}为您带来智能新体验
产品: Iphone 18 Pro max
提取的特征: 拍照 奢华 大屏 性能 专业 便携 高端 智能
生成的广告文案: 非凡性能，精致设计，Iphone 18 Pro max为您带来智能新体验
