In [19]:
import torch
import clip
import json
from PIL import Image
import os
from tqdm import tqdm

# 定义数据和文件路径
img_dir = "/root/autodl-tmp/hate_speech_dataset/img_resized"
json_file = "/root/autodl-tmp/hate_speech_dataset/MMHS150K_GT.json"
img_txt_dir = "/root/autodl-tmp/hate_speech_dataset/img_txt"
train_ids_file = "splits/train_ids.txt"
val_ids_file = "splits/val_ids.txt"
test_ids_file = "splits/test_ids.txt"
img_save_dir = "/root/autodl-tmp/hate_speech_dataset/processed/img"
txt_save_dir = "/root/autodl-tmp/hate_speech_dataset/processed/txt"
label_save_dir = "/root/autodl-tmp/hate_speech_dataset/processed/label"

# 加载预训练的CLIP模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 函数：加载并预处理图像
def load_and_preprocess_image(image_path):
    image = Image.open(image_path)
    return preprocess(image).unsqueeze(0).to(device)

# 函数：提取图像特征
def extract_image_features(image_path):
    image_input = load_and_preprocess_image(image_path)
    with torch.no_grad():
        image_features = model.encode_image(image_input)
    return image_features

# 函数：提取文本特征
def split_text(text, max_length=77):
    """
    将文本分割成不超过最大长度的片段。
    尝试在空格处分割以保持单词的完整性。
    """
    words = text.split()
    segments = []
    current_segment = ""
    
    for word in words:
        if len(current_segment) + len(word) + 1 <= max_length:
            current_segment += " " + word if current_segment else word
        else:
            segments.append(current_segment)
            current_segment = word
    if current_segment:
        segments.append(current_segment)
    
    return segments

def extract_text_features(model, text, device):
    """
    对长文本进行分割并提取特征。
    """
    segments = split_text(text)
    features_list = []
    
    for segment in segments:
        text_input = clip.tokenize([segment],truncate=True).to(device)
        with torch.no_grad():
            features = model.encode_text(text_input)
            features_list.append(features)
    
    # 将所有片段的特征进行平均，作为整个文本的特征表示
    features = torch.mean(torch.stack(features_list), dim=0)
    return features

# 加载 JSON 数据
with open(json_file, 'r') as f:
    data = json.load(f)

# 函数：遍历图像和文本，提取特征
def process_dataset():
    for tweet_id, details in tqdm(data.items()):
        img_path = os.path.join(img_dir, tweet_id + ".jpg")
        txt_path = os.path.join(img_txt_dir, tweet_id + ".json")
        
        # 检查图像文件是否存在
        if os.path.exists(img_path):
            image_features = extract_image_features(img_path)
            #print(f"Image features for {tweet_id} extracted.{image_features.shape}")
        else:
            print(f"Image file for {tweet_id} not found.")
        
        # 检查文本文件是否存在
        if os.path.exists(txt_path):
            with open(txt_path, 'r') as f:
                d = json.load(f)
                text = d["img_text"]
                text_features = extract_text_features(model,text,device)
                #print(f"Text features for {tweet_id} extracted.{text_features.shape}")
        else:
            #print(f"Text file for {tweet_id} not foundu, use gt.")
            text = details["tweet_text"]
            text_features = extract_text_features(model,text,device)
            #print(f"Text features for {tweet_id} extracted.{text_features.shape}")

        # 保存label
        
            
        # 保存特征
        torch.save(image_features, os.path.join(img_save_dir, tweet_id + "_image.pt"))
        torch.save(text_features, os.path.join(txt_save_dir, tweet_id + "_txt.pt"))
        torch.save(details["labels"],os.path.join(label_save_dir, tweet_id + "_label.pt"))

            

# 运行数据处理
process_dataset()

  1%|          | 1704/149823 [00:47<1:09:03, 35.74it/s]


KeyboardInterrupt: 

In [13]:
import clip
from PIL import Image
import torch

def safe_clip_tokenize(texts, clip_model, max_length=77):
    """
    安全地对文本进行tokenize，确保每个文本片段不超过CLIP模型的最大长度限制。
    """
    tokenized_texts = []
    for text in texts:
        # 使用CLIP的tokenize方法预处理文本，获取token长度
        tokens = clip.tokenize([text])  # tokenize方法期望一个列表输入
        # 如果token长度超过最大长度，尝试分割文本
        if tokens.shape[1] > max_length:
            # 简单策略：均匀分割文本
            split_size = len(text) * max_length // tokens.shape[1]
            splits = [text[i:i+split_size] for i in range(0, len(text), split_size)]
            for split in splits:
                split_tokens = clip.tokenize([split])
                if split_tokens.shape[1] <= max_length:
                    tokenized_texts.append(split_tokens)
                else:
                    print(f"Split text is still too long: {split}")
        else:
            tokenized_texts.append(tokens)
    # 返回分割和tokenize后的文本列表
    return torch.cat(tokenized_texts, dim=0)  # 沿batch维度拼接

# 示例使用
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 假设有一段超长文本
long_texts = ["这是一个非常长的文本，需要被正确地分割以确保每个部分都不超过CLIP模型的最大长度限制。" * 10]

# 安全地对长文本进行tokenize
tokenized_texts = safe_clip_tokenize(long_texts, model)

print(tokenized_texts.shape)

AttributeError: 'function' object has no attribute 'tokenize'