In [1]:
import os
import torch
import torch.nn as nn
import torch. nn.functional as F
import numpy as np
import pandas as pd
import faiss
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CHECKPOINT_PATH = "../[2_training]/checkpoints/best_model.pt"
INDEX_PATH = "./indexes/text_flat.index"
METADATA_PATH = "./embeddings/metadata.csv"
TEXT_MODEL = "hfl/chinese-roberta-wwm-ext-large"
PROJ_DIM = 512
MAX_TEXT_LEN = 128
DEVICE = "cuda"
TOP_K = 10

In [3]:
# ===================== 文本编码器（复用训练时的结构）=====================
class TextEncoder(nn. Module):
    def __init__(self, text_model_name, proj_dim=512):
        super().__init__()
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        txt_dim = self.text_encoder.config.hidden_size
        self. text_proj = nn.Sequential(
            nn.Linear(txt_dim, proj_dim),
            nn. GELU(),
            nn.Linear(proj_dim, proj_dim)
        )

    def forward(self, input_ids, attention_mask):
        txt_out = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        txt_feat = txt_out. pooler_output
        txt_emb = F.normalize(self.text_proj(txt_feat), dim=-1)
        return txt_emb


In [4]:
# ===================== 加载资源 =====================
print("Loading model...")
model = TextEncoder(TEXT_MODEL, PROJ_DIM).to(DEVICE)

# 加载训练好的权重（只需文本部分）
full_state = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
text_state = {}
for k, v in full_state.items():
    if "text_encoder" in k: 
        new_k = k.replace("text_encoder.", "text_encoder.")
        text_state[new_k] = v
    elif "text_proj" in k:
        new_k = k.replace("text_proj.", "text_proj.")
        text_state[new_k] = v

model.load_state_dict(text_state, strict=False)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL)

print("Loading index and metadata...")
index = faiss.read_index(INDEX_PATH)
metadata = pd.read_csv(METADATA_PATH)

print(f"✓ Ready!  Index size: {index.ntotal}, Metadata:  {len(metadata)} rows\n")

Loading model...
Loading index and metadata...
✓ Ready!  Index size: 1000, Metadata:  1000 rows



In [7]:
# ===================== 查询函数 =====================
def search(query_text, top_k=TOP_K):
    """关键词查询 → 返回 Top-K 结果"""
    # 编码查询
    inputs = tokenizer(
        query_text, 
        max_length=MAX_TEXT_LEN, 
        padding="max_length", 
        truncation=True, 
        return_tensors="pt"
    ).to(DEVICE)
    
    with torch.no_grad():
        query_emb = model(inputs["input_ids"], inputs["attention_mask"]).cpu().numpy()
    
    # 检索
    scores, indices = index.search(query_emb, top_k)
    
    # 返回结果
    results = []
    for i, (idx, score) in enumerate(zip(indices[0], scores[0])):
        row = metadata.iloc[idx]
        results.append({
            "rank": i + 1,
            "score": float(score),
            "id": str(row["id"]),
            "text": str(row["text"]),
            "image_path": str(row["image_path"]),
            "audio_path": str(row["audio_path"])
        })
    return results

In [None]:
print("=== 多模态检索系统 ===")
print("输入关键词查询，输入 'q' 退出\n")

while True:
    query = input("查询:  ").strip()
    if query. lower() in ["q", "quit", "exit"]:
        break
    if not query:
        continue
    
    results = search(query, top_k=5)
    
    print(f"\n找到 {len(results)} 条结果:")
    for r in results:
        print(f"\n[{r['rank']}] 相似度: {r['score']:.3f}")
        print(f"  ID: {r['id']}")
        print(f"  文本: {r['text'][: 100]}...")
        print(f"  图片: {r['image_path']}")
        print(f"  音频: {r['audio_path']}")
    print("-" * 80)

=== 多模态检索系统 ===
输入关键词查询，输入 'q' 退出


找到 5 条结果:

[1] 相似度: 0.678
  ID: 75
  文本: title:斜背弓箭包骑行包斜挎包户外登山包防水背包男士双肩包轻便单肩包 | 商品名称:斜背弓箭包骑行包斜挎包户外登山包防水背包男士双肩包轻便单肩包 | label:背包 | Llabel:出行用品 ...
  图片: ../../datas/origin_datas/images_part/75.jpg
  音频: ../../datas/origin_datas/audio_part/75.mp3

[2] 相似度: 0.599
  ID: 953
  文本: title:手包折叠凳户外露营野餐便携小马扎高铁地铁无座神器折叠迷你凳子 | 商品名称:手包折叠凳户外露营野餐便携小马扎高铁地铁无座神器折叠迷你凳子 | label:板凳 | Llabel:家具商品 ...
  图片: ../../datas/origin_datas/images_part/953.jpg
  音频: ../../datas/origin_datas/audio_part/953.mp3

[3] 相似度: 0.541
  ID: 24
  文本: title:Mrace大容量旅行背包书包女轻便电脑包徒步登山爬山双肩包男 | 商品名称:Mrace大容量旅行背包书包女轻便电脑包徒步登山爬山双肩包男 | label:背包 | Llabel:出行用品 ...
  图片: ../../datas/origin_datas/images_part/24.jpg
  音频: ../../datas/origin_datas/audio_part/24.mp3

[4] 相似度: 0.539
  ID: 936
  文本: title:户外手包折叠凳便携笔袋小马扎露营地铁火车无座神器户外迷你凳子 | 商品名称:户外手包折叠凳便携笔袋小马扎露营地铁火车无座神器户外迷你凳子 | label:板凳 | Llabel:家具商品 ...
  图片: ../../datas/origin_datas/images_part/936.jpg
  音频: ../../datas/origin_datas/audio_part/936.mp3

[5] 相似度: 0.