In [None]:
import re
import faiss
import json
import numpy as np
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModel
from typing import List, Tuple
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torch
import os
from tag_filter_retriever import TagFilteredFAISSRetriever

In [None]:
MODEL_PATH = 'openbmb/MiniCPM-V-2_6'

In [None]:
#Minicpm初始化
def InitLLM():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    model = AutoModel.from_pretrained(
        MODEL_PATH,
        trust_remote_code=True,
        attn_implementation='sdpa',
        torch_dtype=torch.bfloat16,
        device_map="auto"
    ).eval()
    return model, tokenizer
model, tokenizer = InitLLM()

In [None]:
def extract_question_and_path(text: str):
    # 提取 Question: 和 Helpful Answer: 之间的内容
    match = re.search(r"Question:\s*(.*?)\s*Helpful Answer:", text, re.DOTALL)
    if not match:
        raise ValueError("无法在文本中找到 Question: 和 Helpful Answer:")

    question_line = match.group(1).strip()

    # 在提取到的内容中，用第一个问号分割
    if '?' not in question_line:
        raise ValueError("Question 行中不包含 '?'，无法区分问题和路径")

    idx = question_line.index('?')
    question = question_line[:idx+1].strip()
    image_path = question_line[idx+1:].strip()

    return question, image_path



In [None]:
#langchain适配minicpm
from langchain.llms.base import LLM
from typing import Optional

class MiniCPM_LLM(LLM):
    model: any
    tokenizer: any
    history: List = []

    def _call(self, query: str, stop: Optional[List[str]] = None) -> str:
        print(query)
        prompt, img_src = extract_question_and_path(query)
        image = Image.open(img_src).convert('RGB')
        messages = [{'role': 'user', 'content': [image,prompt]}]
        with torch.no_grad():
            response = self.model.chat(
                image = None,
                tokenizer=self.tokenizer,
                msgs=messages
            )
        return response

    @property
    def _llm_type(self) -> str:
        return "minicpm"


In [None]:
import json

# 读取 JSON 文件
def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

In [None]:
# 解析 JSON 数据并转换为目标格式
def convert_img_rag_data(json_data):
    converted_data = []
    
    for item in json_data:
        img_name = item["img_path"].split("/")[-1].split("_")[0]  # 提取图片名前缀
        tag = "safe" if "safe" == img_name else "unsafe"
        converted_data.append({
            "prompt": f"The content is about {img_name}",
            "label": "image",
            "explanation": item["text"],
            "embedding": None,
            "src": item["img_path"],
            "tag": f"{tag}"
        })
    
    return converted_data

# 读取和转换数据
file_path = "PATH_TO_YOUR_IMAGE_JSON"  # json文件格式为:img_path, text(图片路径, 文本内容)
json_data = load_json(file_path)
img_rag_data = convert_img_rag_data(json_data)


In [None]:
def convert_text_rag_data(jsonl_file):
    data = []
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line in f:
            item = json.loads(line.strip())  # 解析 JSONL 每一行
            label_text = ', '.join(item["label"])  # 处理 label 列表为字符串
            tag = "safe" if "safe" in label_text else "unsafe"
            formatted_item = {
                "prompt": f"The content is about {label_text}",
                "label": "text",
                "explanation": item["text"],
                "embedding": None,
                "src": "",
                "tag": f"{tag}"
            }
            data.append(formatted_item)
    return data

# 示例调用
jsonl_file = "PATH_TO_YOUR_TEXT_JSONL"  # 你的 JSONL 文件:text, label 
text_rag_data = convert_text_rag_data(jsonl_file)

In [None]:
#1、加载clip模型
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

In [None]:
# 2. 构建RAG数据
data = img_rag_data + text_rag_data

In [None]:
def truncate_text(text, max_length=77):
    """截断文本，确保不超过 max_length 个 token"""
    tokens = clip_processor.tokenizer.encode(text, add_special_tokens=True)
    if len(tokens) > max_length:
        tokens = tokens[:max_length-1] + [clip_processor.tokenizer.eos_token_id]  # 截断并加上结束符
    return clip_processor.tokenizer.decode(tokens)

In [None]:
# 3. 计算文本和图像的 CLIP 嵌入
def get_clip_embedding(item):
    
    #引入文本截断，防止文本过长
    if item["label"] == "text":
        # 文本嵌入
        truncated_text = truncate_text(item["explanation"])  # 先截断文本
        inputs = clip_processor(text=truncated_text, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            text_embedding = clip_model.get_text_features(**inputs)
        return text_embedding.squeeze().cpu().tolist()

    elif item["label"] == "image":
        # 图像嵌入
        image_path = item["src"]
        if os.path.exists(image_path):
            image = Image.open(image_path)
            inputs = clip_processor(images=image, return_tensors="pt")
            with torch.no_grad():
                image_embedding = clip_model.get_image_features(**inputs)
            return image_embedding.squeeze().cpu().tolist()
        else:
            print(f"⚠️: 图像 {image_path} 不存在，使用零向量代替！")
            return [0.0] * 512  # 图像不存在时用 0 填充

    else:
        raise ValueError(f"未知的数据类型: {item['label']}")


In [None]:
# 3. 计算嵌入并存储
from tqdm import tqdm

for item in tqdm(data, desc="Processing items"):
    item["embedding"] = get_clip_embedding(item)

In [None]:
# 4. 保存数据到 JSON 文件
with open("rag_data.json", "w") as f:
    json.dump(data, f, indent=4)


In [None]:
from langchain.embeddings.base import Embeddings

# 自定义 CLIPEmbedding 类
class CLIPEmbedding(Embeddings):
    def __init__(self, clip_model, clip_processor):
        self.clip_model = clip_model
        self.clip_processor = clip_processor

    def embed_documents(self, texts):
        inputs = self.clip_processor(text=texts, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            embeddings = self.clip_model.get_text_features(**inputs).cpu().numpy()
        return embeddings.tolist()

    def embed_query(self, query):
        inputs = self.clip_processor(text=query, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            embedding = self.clip_model.get_text_features(**inputs).squeeze().cpu().numpy()
        return embedding.tolist()
    
    def embed_image_query(self, img_src):
        image = Image.open(img_src)
        inputs = self.clip_processor(images=image, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            embedding = self.clip_model.get_image_features(**inputs).squeeze().cpu().numpy()
        return embedding.tolist()

# 使用 CLIPEmbedding
clip_embeddings = CLIPEmbedding(clip_model, clip_processor)

In [None]:
#5 构建包含元数据的faiss对象
from langchain.schema import Document
from langchain.docstore.in_memory import InMemoryDocstore
docs = []
vectors = []

for item in data:
    content = item["prompt"] + "\n" + item["explanation"] + "\n" + item["label"] + "\n" + item["src"] + "\n" + item["tag"]
    metadata = {
        "tag": item["tag"],
        "src": item["src"]
    }
    docs.append(Document(page_content=content, metadata=metadata))
    vectors.append(item["embedding"])

# 转为 numpy array
embeddings = np.array(vectors).astype("float32")

# 构造 faiss index
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings)


docstore = InMemoryDocstore(dict(enumerate(docs)))  
index_to_docstore_id = {i: i for i in range(len(docs))}

# 构造 FAISS 存储对象
faiss_store_filter = FAISS(
    embedding_function=clip_embeddings,  
    index=index,
    docstore=docstore,
    index_to_docstore_id={i: i for i in range(len(docs))}
)

In [None]:
# 保存faiss索引
faiss_store_filter.save_local("guard")

In [None]:
# 6. 构建 RAG 处理链
retriever_filter = faiss_store_filter.as_retriever(search_kwargs={"k": 2})
retriever_tag = TagFilteredFAISSRetriever(
    faiss_store=faiss_store_filter,
    embedding_model=clip_embeddings
)

In [None]:
from tqdm import tqdm
# 7. 加载 MiniCPM模型

llm = MiniCPM_LLM(model = model, tokenizer = tokenizer)
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever_filter, chain_type="stuff")

In [None]:
# 8. 使用 Minicpm 进行单图安全评估
def evaluate_safety(prompt: str = "describe this in a word", img_src: str = "none"):
    if(img_src == "none"):
        img_src = 'blank.jpg'
    image = Image.open(img_src).convert('RGB')
    prompt = "describe this in a word:" + prompt + "your answer must be one of the categories:[safe, violence, self-harm, nude, illegal-religion, animal-cruelty, schoolbullying, sex, terrorism, privacy]"
    msgs = [{'role': 'user', 'content': [image, prompt]}]

    res = model.chat(
        image=None,
        msgs=msgs,
        tokenizer=tokenizer
    )

    return res

In [None]:
# 需要检测的敏感词列表
sensitive_terms = ["violence", "self-harm", "nude", "illegal-religion",
              "animal-cruelty", "schoolbullying", "sex", "terrorism", "privacy"]
# 将敏感词转为一个正则表达式，忽略大小写
pattern = r'|'.join([re.escape(term) for term in sensitive_terms])
# 检测文本
def check_sensitive_content(text):
    text_lower = text.lower()  # 转小写进行不区分大小写的匹配
    matches = re.findall(pattern, text_lower, flags=re.IGNORECASE)
    if matches:
        return set(matches)  # 返回匹配到的敏感词
    else:
        return "safe"  # 没有匹配到

In [None]:
from PIL import Image

#多图检索
def multi_image_inference(image_paths, model, tokenizer, reference_tag = 0, unsafe_tag = 'safe', img_tag = 'safe'):
    """
    多图推理函数，支持自动构造提示语，描述最后一张图片。
    
    参数：
        image_paths (list[str]): 图像路径数组
        model: 加载好的模型
        tokenizer: 加载好的 tokenizer
        reference_tag: 参考图片的索引
        unsafe_tag: 参考图片的安全类别
        img_tag: 目标图片的疑似类别

    返回：
        answer (str): 推理生成的描述
    """
    assert len(image_paths) >= 1, "至少需要提供一张图片"
    
    # 打开所有图片
    images = [Image.open(p).convert("RGB") for p in image_paths]

    # 构造提示语
    if len(images) == 1:
        question = "Describe this image with one word."
        msgs = [{'role': 'user', 'content': [images[0], question]}]
    else:
        ref_images = images[:-1]
        target_image = images[-1]
        ref_str = ", ".join([f"image{i+1}" for i in range(len(ref_images))])
        
        # 判断是否加入初步预测强化
        if img_tag == unsafe_tag:
            tag_hint = (
                f"\nImage{reference_tag+1} is visually highly similar to image{len(images)}, "
                f"and both appear to contain content consistent with the '{img_tag}' category.\n"
                f"Based on this strong similarity, and initial prediction results, "
                f"it is highly likely that image{len(images)} also belongs to the '{img_tag}' category.\n"
                f"Use this information to guide your classification, especially if any content in image{len(images)} is ambiguous.\n"
                f"However, base your final classification on the actual visual content of image{len(images)}."
            )
        else:
            tag_hint = ""
        
        question = (
            f"Your task is to classify the content of image{len(images)} into one of the following categories:\n"
            f"[safe, violence, self-harm, nude, illegal-religion, animal-cruelty, schoolbullying, sex, terrorism, privacy].\n\n"
            f"You are given five images: {ref_str} and image{len(images)}.\n"
            f"Among image1 to image{len(images)-1} (reference images), only image{reference_tag+1} is highly relevant and should be used as the main reference.\n"
            f"image{reference_tag+1} belongs to the '{unsafe_tag}' category.\n"
            f"Ignore the content of the other reference images.\n\n"
            f"Use image{reference_tag+1} only to assist in understanding ambiguous or unclear regions in image{len(images)}.\n"
            f"Focus primarily on the actual content of image{len(images)} when making your classification decision.\n\n"
            f"{tag_hint}\n\n"
            
            f"Respond with exactly one word from the category list."
        )
        
        msgs = [{'role': 'user', 'content': ref_images + [target_image, question]}]
    # 推理生成答案
    answer = model.chat(
        image=None,
        msgs=msgs,
        tokenizer=tokenizer
    )
    
    return answer, question


In [None]:
# 多文本推理
def multi_text_inference(texts, model, tokenizer, reference_tag = 0, unsafe_tag = "safe", txt_tag="safe"):
    """
    多文本推理函数，支持自动构造提示语，描述最后一段文本。

    参数：
        texts (list[str]): 文本数组，前面的为参考，最后一个为要描述的对象。
        model: 加载好的模型
        tokenizer: 加载好的 tokenizer
        reference_tag: 参考文本的索引
        unsafe_tag: 参考文本的安全类别
        txt_tag: 目标文本的疑似类别

    返回：
        answer (str): 推理生成的描述
    """
    assert len(texts) >= 1, "至少需要提供一段文本"

    image_path = 'blank.jpg'
    img = Image.open(image_path).convert('RGB')
    
    if len(texts) == 1:
        question = "Describe the following text with one word:\n'" + texts[0] + "'"
        msgs = [{'role': 'user', 'content': [img, question]}]
    else:
        ref_texts = texts[:-1]
        target_text = texts[-1]
        ref_str = ", ".join([f'"{t}"' for t in ref_texts])
        ref_text = ref_texts[reference_tag] #相似度最高的参考文本
        
        if txt_tag == unsafe_tag:
            tag_hint = (
                f"\nThe reference text \"{ref_text}\" is highly similar to the target text and belongs to the '{unsafe_tag}' category.\n"
                f"Based on this strong similarity and prior indications, it is highly likely that the target text also falls into the '{txt_tag}' category.\n"
                f"Use this information to guide your classification, especially if the target text is ambiguous.\n"
                f"However, base your final decision strictly on the actual content of the target text."
            )
        else:
            tag_hint = ""
        
        question = (
            f"Your task is to classify the target text into one of the following categories:\n"
            f"[safe, violence, self-harm, nude, illegal-religion, animal-cruelty, schoolbullying, sex, terrorism, privacy].\n\n"
            f"You are given several reference texts: {ref_str}.\n"
            f"Among them, only the reference text \"{ref_text}\" is highly relevant and should be used as the main reference.\n"
            f"It belongs to the '{unsafe_tag}' category.\n"
            f"Ignore the content of the other reference texts.\n\n"
            f"Use the main reference text only to assist in understanding ambiguous or unclear parts of the target text.\n"
            f"Focus primarily on the actual content of the target text when making your classification decision."
            f"{tag_hint}\n\n"
            f"Target text: \"{target_text}\"\n"
            f"Respond with exactly one word from the category list."
        )
        
        msgs = [{'role': 'user', 'content': [img, question]}]
    # 推理生成答案
    answer = model.chat(
        image=None,
        msgs=msgs,
        tokenizer=tokenizer
    )
    return answer, question


In [None]:
def search_rag(query, top_k = 2):
    try:
        image = Image.open(query)
        unsafe_docs, unsafe_score = retriever_tag.search(img_src = query, tag = 'unsafe', top_k = top_k)
        safe_docs, safe_score = retriever_tag.search(img_src = query, tag = 'safe', top_k = top_k)
        docs = unsafe_docs + safe_docs
        
    except:
        unsafe_docs, unsafe_score = retriever_tag.search(query = query, tag = 'unsafe', top_k = top_k)
        safe_docs, safe_score = retriever_tag.search(query = query, tag = 'safe', top_k = top_k)
        docs = unsafe_docs + safe_docs
    
    descripe_list = [] # 检索出来相关内容的说明
    image_paths = [] # 检索出来的相关图片地址
    
    #  输出检索结果
    for i, doc in enumerate(docs):
        lines = doc.page_content.split("\n")
        result_discripe = lines[1]
        result_img_src = lines[3]
        descripe_list.append(result_discripe)
        image_paths.append(result_img_src)
    
    reference_tag = 0
    if(unsafe_score <= safe_score):
        citation = "# Unsafe example 1"
        unsafe_status = next(iter(check_sensitive_content(descripe_list[0])))
    else:
        citation = "# Safe example 1"
        unsafe_status = "safe"
        reference_tag = 2
        
    try:
        image = Image.open(image_paths[0])
        
        print(query)
        
        print("<BEGIN IMAGE EXAMPLES>")
        
        for idx, result_img_src in enumerate(image_paths):
            try:
                image = Image.open(result_img_src)  # 打开图片
                if(idx < 2):
                    print(f"# Unsafe example {idx+1}:\nimg_src: {result_img_src}")  # 打印参考图片的编号
                else:
                    print(f"# Safe example {idx+1}:\nimg_src: {result_img_src}")  # 打印参考图片的编号
                # image.show()  # 展示图片
                
                sensitive_word = check_sensitive_content(descripe_list[idx])
                
                if(sensitive_word != "safe"):
                    print(f"Assessment: Unsafe - {next(iter(sensitive_word))}\nExplanation: {descripe_list[idx]}\n\n")
                else:
                    print(f"Assessment: Safe\nExplanation: {descripe_list[idx]}\n\n")
                    
            except Exception as e:
                print(f"无法打开图片 {result_img_src}: {e}")   
        print("<END IMAGE EXAMPLES>")
        
        image_paths.append(query)
        
        img_tag = evaluate_safety(prompt='describe this in a word', img_src = query)
        
        response, _ = multi_image_inference(image_paths, model, tokenizer, reference_tag, unsafe_status, img_tag)
        print('============================================LLM-RAG============================================')
        print(f"Safe Status: {response}\nCitation: {citation}\n")
        print(f"- {descripe_list[reference_tag]}")
        print('==============================================END==============================================')
        
    #文本检索    
    except:
        print(f"User text:{query}\n\n")
        print("<BEGIN TEXT EXAMPLES>")
        
        for idx, result_img_src in enumerate(descripe_list):
            
            if(idx < 2):
                    print(f"# Unsafe example {idx+1}:\n")  
            else:
                print(f"# Safe example {idx+1}:\n")  
            
            sensitive_word = check_sensitive_content(descripe_list[idx])
            
            if(sensitive_word != 'safe'):
                print(f"Text content: {descripe_list[idx]}\nAssessment: Unsafe - {next(iter(sensitive_word))}\n\n")
            else:
                print(f"Text content: {descripe_list[idx]}\nAssessment: Safe\n\n")
                    
           
        print("<END TEXT EXAMPLES>")
        
        descripe_list.append(query)
        txt_tag = evaluate_safety(prompt=query)
            
        response, _ = multi_text_inference(descripe_list, model, tokenizer, reference_tag, unsafe_status, txt_tag)
        print('============================================LLM-RAG============================================')
        print(f"Safe Status: {response}\nCitation: {citation}\n") 
        print(f"- {descripe_list[reference_tag]}")
        print('==============================================END==============================================')
        pass
    

In [None]:
#minicpm图片推理测试
img_src = 'schoolbullying_10.jpg'
print("============================================需要识别的图片为============================================")
try:
    image = Image.open(img_src)  # 打开图片
    image.show()  # 展示图片
except Exception as e:
    print(f"无法打开图片 {img_src}: {e}")
    
print("===============================================RAG===============================================")
search_rag(img_src)

In [None]:
search_rag('He was arrested for being publicly nude.')