In [9]:
# 导入必要的库
import torch
from torch import nn
from transformers import BertModel, BertTokenizer
from torchvision.models import resnet50
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

# 定义文本图像融合数据集
class TextImageDataset(Dataset):
    def __init__(self, texts, image_paths, labels):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.texts = texts
        self.image_paths = image_paths
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # 处理文本
        text = self.texts[idx]
        inputs = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt"
        )
        input_ids = inputs['input_ids'].squeeze(0)
        attention_mask = inputs['attention_mask'].squeeze(0)

        # 处理图像
        image = Image.open(self.image_paths[idx]).convert("RGB")
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        image = transform(image)

        label = torch.tensor(self.labels[idx])

        return input_ids, attention_mask, image, label

In [10]:
# 定义融合模型
class FusionModel(nn.Module):
    def __init__(self):
        
        super(FusionModel, self).__init__()
        
        self.text_model = BertModel.from_pretrained('bert-base-uncased')
        self.image_model = resnet50(pretrained=True)
        
        image_features_dim = self.image_model.fc.in_features
        self.image_model.fc = nn.Identity()
        
        fusion_dim = self.text_model.config.hidden_size + image_features_dim
        self.fusion_layer = nn.Linear(fusion_dim, 512)
        self.output_layer = nn.Linear(512, 1)

    def forward(self, input_ids, attention_mask, images):
        text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        image_features = self.image_model(images)
        
        fusion = torch.cat((text_features, image_features), dim=1)
        
        fusion = self.fusion_layer(fusion)
        fusion = torch.relu(fusion)
        output = self.output_layer(fusion)
        return output

In [11]:
# 示例文本和图像路径
texts = ['This is a spiderman', 'This is actually a spiderman']
image_paths = ['cube.jpeg', 'Spiderman1.png']
# 0 fake, 1 real
labels = [0, 1]

# 创建数据集和数据加载器
dataset = TextImageDataset(texts, image_paths, labels)
dataloader = DataLoader(dataset, batch_size=2)

# 初始化模型并进行预测
model = FusionModel()

In [12]:
for batch in dataloader:
    input_ids, attention_mask, images, labels = batch
    outputs = model(input_ids, attention_mask, images)
    
    probabilities = torch.sigmoid(outputs)
    
    print(probabilities)

tensor([[0.4978],
        [0.5409]], grad_fn=<SigmoidBackward0>)
