In [30]:
pip install transformers torchvision



In [31]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch

In [32]:
# 加载 CLIP 模型（ViT 版本）
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [33]:
def extract_image_embeddings(image_paths, model, processor):
    """
    提取图片嵌入及其对应的 ID。
    :param image_paths: 图片文件路径列表
    :param model: CLIP 模型
    :param processor: CLIP 的预处理工具
    :return: 包含图片 ID 和嵌入的字典
    """
    image_embeddings = {}
    for idx, img_path in enumerate(image_paths):
        image = Image.open(img_path).convert("RGB")
        inputs = processor(images=image, return_tensors="pt", padding=True)

        # 提取图像嵌入
        with torch.no_grad():
            embedding = model.get_image_features(**inputs).squeeze(0)  # 输出 1x512 变为 512
        image_embeddings[f"image_{idx+1}"] = embedding
    return image_embeddings


In [34]:
def extract_text_embedding(text, model, processor):
    """
    提取文本的嵌入。
    :param text: 输入的文本
    :param model: CLIP 模型
    :param processor: CLIP 的预处理工具
    :return: 文本的嵌入
    """
    inputs = processor(text=[text], return_tensors="pt", padding=True)

    # 提取文本嵌入
    with torch.no_grad():
        embedding = model.get_text_features(**inputs).squeeze(0)  # 输出 1x512 变为 512
    return embedding


In [35]:
image_paths = [
    "/content/02072099056.png",
    "/content/16206684278.png",
    "/content/32267606120.png",
    "/content/52868171709.png",
    "/content/55188636806.png"
]

# 提取图片嵌入
image_embeddings = extract_image_embeddings(image_paths, model, processor)
print("Image Embeddings:")
for img_id, embedding in image_embeddings.items():
    print(f"{img_id}: {embedding.shape}")  # 每个嵌入应为 [512]

# 提取文本嵌入
text = "bull market"
text_embedding = extract_text_embedding(text, model, processor)
print(f"Text Embedding Shape: {text_embedding.shape}")  # 应为 [512]

# 保存嵌入
torch.save(image_embeddings, "image_embeddings.pt")
torch.save({"bull market": text_embedding}, "text_embedding.pt")
print("Embeddings saved.")



Image Embeddings:
image_1: torch.Size([512])
image_2: torch.Size([512])
image_3: torch.Size([512])
image_4: torch.Size([512])
image_5: torch.Size([512])
Text Embedding Shape: torch.Size([512])
Embeddings saved.


In [37]:
# 加载图片和文本嵌入
image_embeddings = torch.load("image_embeddings.pt")  # 图片嵌入字典
text_embedding = torch.load("text_embedding.pt")["bull market"]  # 文本嵌入


  image_embeddings = torch.load("image_embeddings.pt")  # 图片嵌入字典
  text_embedding = torch.load("text_embedding.pt")["bull market"]  # 文本嵌入


In [38]:
import torch
import torch.nn as nn

class VectorComparisonNet(nn.Module):
    def __init__(self, embedding_dim, hidden_dim):
        super(VectorComparisonNet, self).__init__()
        self.fc1 = nn.Linear(embedding_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, image_emb, text_emb):
        if image_emb.dim() == 1:
            image_emb = image_emb.unsqueeze(dim=0)
        if text_emb.dim() == 1:
            text_emb = text_emb.unsqueeze(dim=0)

        combined = torch.cat([image_emb, text_emb], dim=-1)
        hidden = self.relu(self.fc1(combined))
        output = self.sigmoid(self.fc2(hidden))
        return output

In [39]:
# 定义嵌入维度和隐藏层维度
embedding_dim = 512  # CLIP 的默认嵌入维度
hidden_dim = 128

# 初始化神经网络
model = VectorComparisonNet(embedding_dim, hidden_dim)

# 如果有 GPU，加载到 GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


VectorComparisonNet(
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=1, bias=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
)

In [40]:



import os

# 将文本嵌入加载到设备
text_embedding = text_embedding.to(device)

# 遍历图片嵌入并计算相似性分数
scores = []
for img_path, (image_id, image_embedding) in zip(image_paths, image_embeddings.items()):
    image_embedding = image_embedding.to(device)

    # 获取文件名
    image_name = os.path.basename(img_path)

    # 通过神经网络计算相似性分数
    with torch.no_grad():
        similarity_score = model(image_embedding, text_embedding).item()

    scores.append((image_name, similarity_score))

# 根据相似性分数排序
scores.sort(key=lambda x: x[1], reverse=True)

# 输出图片文件名和相似性分数
print("Similarity Scores (by image file names):")
for image_name, score in scores:
    print(f"{image_name}: {score:.4f}")



Similarity Scores (by image file names):
02072099056.png: 0.5219
16206684278.png: 0.5068
55188636806.png: 0.4982
52868171709.png: 0.4954
32267606120.png: 0.4935
