In [None]:
"""安装依赖
"""
!pip install chromadb
!pip install diffusers

In [None]:
"""下载模型
"""
!mkdir -p /content/modelsw
!git clone --depth=1 https://huggingface.co/openai/clip-vit-large-patch14  /content/models/clip-vit-large-patch14
!git clone --depth=1 https://huggingface.co/stabilityai/sdxl-turbo  /content/models/sdxl-turbo

In [None]:
"""使用sdxl-turbo生成搜索用的图片
"""
import os
import torch
from diffusers import DiffusionPipeline

prompts = [
    ("cat", "a photo of cat"),
    ("dog", "a photo of dog"),
    ("pig", "a photo of pig"),
    ("chair", "a photo of chair"),
    ("table", "a photo of dining table")
]

model_id = "/content/models/sdxl-turbo"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.to("cuda")

for category, prompt in prompts:
    os.makedirs(f"output/{category}", exist_ok=True)
    for index in range(2):
        image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
        image.save(f"output/{category}/{category}-{index}.png")

In [None]:
"""初始化Chroma
"""
import chromadb

client = chromadb.PersistentClient(path="/content/chroma/my")
collection = client.get_or_create_collection(
    name="my_collection",
    metadata={"hnsw:space": "ip"}
)

In [None]:
"""图片向量存入Chroma
"""
import os
from PIL import Image
import requests
import hashlib

from transformers import CLIPProcessor, CLIPModel

model_id = "/content/models/clip-vit-large-patch14"

def all_images():
    for root, ds, fs in os.walk("output"):
        for f in fs:
            if f.endswith('.png'):
                yield os.path.join(root, f)

model = CLIPModel.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)

for image_path in all_images():
    image = Image.open(image_path)
    inputs = processor(images=image.resize((224, 224)), return_tensors="pt")
    image_feature = model.get_image_features(**inputs)[0]

    id_ = hashlib.md5(image.tobytes()).hexdigest()
    collection.add(
        embeddings=[image_feature.tolist()],
        metadatas=[{"source": image_path, "category": image_path.split("/")[1]}],
        ids=[id_]
    )

    print(image_path, id_)


In [None]:
"""图片检索
"""
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id)


# 文本 -> 图片检索
queries = [
    "a photo of cat",
    "a photo of dog",
    "a photo of pig",
    "a photo of chair",
    "a photo of dining table"
]

print("="*20)
print(f"文本 -> 图片检索")
for query in queries:
    inputs = tokenizer([query], padding=True, return_tensors="pt")
    text_feature = model.get_text_features(**inputs)[0]
    result = collection.query(
        query_embeddings=[text_feature.tolist()],
        n_results=2
    )
    print(f"检索：{query}")
    for metadata in result["metadatas"][0]:
        print("image_path:", metadata["source"])

# 图片->图片检索
images = [
    "output/cat/cat-0.png",
    "output/dog/dog-0.png",
    "output/pig/pig-0.png",
    "output/chair/chair-0.png",
    "output/table/table-0.png"
]

print("="*20)
print(f"图片 -> 图片检索")
for image_path in images:
    image = Image.open(image_path)
    inputs = processor(images=image.resize((224, 224)), return_tensors="pt")
    image_feature = model.get_image_features(**inputs)[0]

    result = collection.query(
        query_embeddings=[image_feature.tolist()],
        n_results=2
    )

    print(f"检索：{image_path}")
    for metadata in result["metadatas"][0]:
        print("image_path:", metadata["source"])
