In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer, BertConfig
from torchvision.models import resnet50
import random
import numpy as np
import math


class ImageEncoder(nn.Module):
    """图像编码器（使用 ResNet50）"""
    def __init__(self, embed_dim=768):  # 改为768以匹配BERT
        super(ImageEncoder, self).__init__()
        self.backbone = resnet50(pretrained=True)
        # 移除最后的分类层，保留特征提取部分
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # 添加自适应池化和投影层
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))  # 输出 7x7 特征图
        self.proj = nn.Linear(2048, embed_dim)  # ResNet50的输出通道数是2048
        
    def forward(self, images):
        # images: (batch_size, 3, 224, 224)
        features = self.backbone(images)  # (batch_size, 2048, 7, 7)
        features = self.avgpool(features)  # (batch_size, 2048, 7, 7)
        
        # 将空间维度展平，保留patch信息
        batch_size, channels, h, w = features.shape
        features = features.view(batch_size, channels, h * w).transpose(1, 2)  # (batch_size, 49, 2048)
        
        # 投影到BERT维度
        features = self.proj(features)  # (batch_size, 49, 768)
        
        return features


class TextEncoder(nn.Module):
    """文本编码器（使用 BERT）"""
    def __init__(self, embed_dim=768):
        super(TextEncoder, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.embed_dim = embed_dim

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state  # (batch_size, seq_len, 768)


class MultiHeadCrossAttention(nn.Module):
    """多头交叉注意力机制"""
    def __init__(self, embed_dim=768, num_heads=12):
        super(MultiHeadCrossAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, query, key, value, attention_mask=None):
        batch_size = query.size(0)
        
        # 线性变换并重塑为多头格式
        Q = self.q_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力
        attention = self._attention(Q, K, V, attention_mask)
        
        # 合并多头
        attention = attention.transpose(1, 2).contiguous().view(
            batch_size, -1, self.embed_dim)
        
        return self.out(attention)
    
    def _attention(self, Q, K, V, attention_mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            # 扩展mask到多头维度
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
            scores = scores.masked_fill(attention_mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        attention = torch.matmul(attention_weights, V)
        
        return attention


class MultimodalEncoderLayer(nn.Module):
    """多模态编码器层 - 包含自注意力和交叉注意力"""
    def __init__(self, embed_dim=768, num_heads=12, ff_dim=3072, dropout=0.1):
        super(MultimodalEncoderLayer, self).__init__()
        
        # 自注意力层
        self.self_attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        
        # 交叉注意力层（文本查询图像）
        self.cross_attention = MultiHeadCrossAttention(embed_dim, num_heads)
        
        # 前馈网络
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
        # Layer Normalization
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text_features, image_features, text_attention_mask=None):
        # 1. 文本自注意力
        residual = text_features
        text_features = self.norm1(text_features)
        
        # 创建文本的key_padding_mask
        if text_attention_mask is not None:
            key_padding_mask = (text_attention_mask == 0)
        else:
            key_padding_mask = None
            
        text_attn_out, _ = self.self_attention(
            text_features, text_features, text_features,
            key_padding_mask=key_padding_mask
        )
        text_features = residual + self.dropout(text_attn_out)
        
        # 2. 交叉注意力（文本查询图像）
        residual = text_features
        text_features = self.norm2(text_features)
        
        cross_attn_out = self.cross_attention(
            query=text_features,
            key=image_features,
            value=image_features
        )
        text_features = residual + self.dropout(cross_attn_out)
        
        # 3. 前馈网络
        residual = text_features
        text_features = self.norm3(text_features)
        ff_out = self.feed_forward(text_features)
        text_features = residual + ff_out
        
        return text_features


class MultimodalEncoder(nn.Module):
    """多模态编码器 - 包含多层交叉注意力"""
    def __init__(self, num_layers=6, embed_dim=768, num_heads=12, ff_dim=3072, dropout=0.1):
        super(MultimodalEncoder, self).__init__()
        self.layers = nn.ModuleList([
            MultimodalEncoderLayer(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        
    def forward(self, text_features, image_features, text_attention_mask=None):
        # 通过多层多模态编码器
        for layer in self.layers:
            text_features = layer(text_features, image_features, text_attention_mask)
        
        return text_features


class ALBEF(nn.Module):
    """ALBEF模型 - 包含交叉注意力的完整实现"""
    def __init__(self, embed_dim=768, num_multimodal_layers=6):
        super(ALBEF, self).__init__()
        
        # 单模态编码器
        self.image_encoder = ImageEncoder(embed_dim)
        self.text_encoder = TextEncoder(embed_dim)
        
        # 多模态编码器（关键组件）
        self.multimodal_encoder = MultimodalEncoder(
            num_layers=num_multimodal_layers,
            embed_dim=embed_dim
        )
        
        # 用于对比学习的投影头
        self.image_proj = nn.Linear(embed_dim, embed_dim)
        self.text_proj = nn.Linear(embed_dim, embed_dim)
        
        # 温度参数
        self.temperature = nn.Parameter(torch.ones([]) * 0.07)
        
        # ITM任务头
        self.itm_head = nn.Linear(embed_dim, 2)
        
        # MLM任务头
        self.mlm_head = nn.Linear(embed_dim, 30522)  # BERT词汇表大小
        
    def forward(self, images, input_ids, attention_mask, task="contrastive", **kwargs):
        if task == "contrastive":
            return self.contrastive_forward(images, input_ids, attention_mask)
        elif task == "itm":
            return self.itm_forward(images, input_ids, attention_mask, **kwargs)
        elif task == "mlm":
            return self.mlm_forward(images, input_ids, attention_mask, **kwargs)
        else:
            raise ValueError(f"不支持的任务类型: {task}")
    
    def encode_unimodal(self, images, input_ids, attention_mask):
        """编码单模态特征"""
        # 图像编码
        image_features = self.image_encoder(images)  # (batch_size, 49, 768)
        
        # 文本编码  
        text_features = self.text_encoder(input_ids, attention_mask)  # (batch_size, seq_len, 768)
        
        return image_features, text_features
    
    def encode_multimodal(self, images, input_ids, attention_mask):
        """编码多模态特征（通过交叉注意力）"""
        # 获取单模态特征
        image_features, text_features = self.encode_unimodal(images, input_ids, attention_mask)
        
        # 通过多模态编码器进行交叉注意力融合
        multimodal_text_features = self.multimodal_encoder(
            text_features, image_features, attention_mask
        )
        
        return multimodal_text_features, image_features
    
    def contrastive_forward(self, images, input_ids, attention_mask):
        """对比学习前向传播"""
        # 使用单模态编码器进行对比学习
        image_features, text_features = self.encode_unimodal(images, input_ids, attention_mask)
        
        # 池化操作获取全局特征,做到2维
        image_embeds = image_features.mean(dim=1)  # (batch_size, 768)
        text_embeds = text_features[:, 0, :]  # 使用[CLS] token
        
        # 投影到对比学习空间
        image_embeds = self.image_proj(image_embeds)
        text_embeds = self.text_proj(text_embeds)
        
        # 归一化
        image_embeds = F.normalize(image_embeds, dim=-1)
        text_embeds = F.normalize(text_embeds, dim=-1)
        
        # 计算对比损失
        logits = torch.matmul(image_embeds, text_embeds.t()) / self.temperature
        labels = torch.arange(logits.size(0), device=logits.device)
        
        loss_i = F.cross_entropy(logits, labels)
        loss_t = F.cross_entropy(logits.t(), labels)
        loss = (loss_i + loss_t) / 2
        
        return loss
    
    def itm_forward(self, images, input_ids, attention_mask, labels=None):
        """图像-文本匹配前向传播"""
        # 使用多模态编码器
        multimodal_features, _ = self.encode_multimodal(images, input_ids, attention_mask)
        
        # 使用[CLS] token进行分类
        cls_features = multimodal_features[:, 0, :]  # (batch_size, 768)
        itm_logits = self.itm_head(cls_features)  # (batch_size, 2)
        
        if labels is not None:
            loss = F.cross_entropy(itm_logits, labels)
            return loss
        else:
            return itm_logits
    
    def mlm_forward(self, images, input_ids, attention_mask, masked_input_ids=None, mlm_labels=None):
        """掩码语言模型前向传播"""
        if masked_input_ids is None and mlm_labels is None:
            masked_input_ids, mlm_labels = self.mask_tokens(input_ids)
        
        # 使用多模态编码器处理掩码文本
        multimodal_features, _ = self.encode_multimodal(images, masked_input_ids, attention_mask)
        
        # MLM预测
        mlm_logits = self.mlm_head(multimodal_features)  # (batch_size, seq_len, vocab_size)
        
        if mlm_labels is not None:
            loss = F.cross_entropy(
                mlm_logits.view(-1, mlm_logits.size(-1)),
                mlm_labels.view(-1),
                ignore_index=-100
            )
            return loss
        else:
            return mlm_logits
    
    def mask_tokens(self, input_ids, mlm_probability=0.15):
        """创建用于MLM任务的掩码输入和标签"""
        device = input_ids.device
        labels = input_ids.clone()
        
        # 创建掩码概率矩阵
        probability_matrix = torch.full(labels.shape, mlm_probability, device=device)
        
        # 特殊token不应该被掩码
        special_tokens_mask = torch.zeros_like(input_ids, dtype=torch.bool, device=device)
        for special_id in [0, 101, 102]:  # [PAD], [CLS], [SEP]
            special_tokens_mask = special_tokens_mask | (input_ids == special_id)
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        
        # 创建掩码
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100
        
        # 80%的情况下用[MASK]替换
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8, device=device)).bool() & masked_indices
        input_ids[indices_replaced] = 103  # [MASK] token ID
        
        # 10%的情况下用随机token替换
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5, device=device)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(0, 30522, labels.shape, device=device)
        input_ids[indices_random] = random_words[indices_random]
        
        return input_ids, labels


def demo_albef_with_cross_attention():
    """演示包含交叉注意力的ALBEF模型"""
    print("=" * 80)
    print("ALBEF模型 - 包含交叉注意力机制")
    print("=" * 80)
    
    # 初始化模型
    model = ALBEF(embed_dim=768, num_multimodal_layers=6)
    model.eval()
    
    # 准备示例数据
    from transformers import BertTokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    texts = [
        "A cat sitting on a chair",
        "A dog running in the park", 
        "A bird flying in the sky",
        "A fish swimming in water"
    ]
    
    tokenized = tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=32)
    images = torch.randn(4, 3, 224, 224)
    
    print(f"输入数据形状:")
    print(f"  图像: {images.shape}")
    print(f"  文本ID: {tokenized['input_ids'].shape}")
    print(f"  注意力掩码: {tokenized['attention_mask'].shape}")
    print()
    
    with torch.no_grad():
        # 1. 单模态编码
        print("1. 单模态编码")
        print("-" * 40)
        image_features, text_features = model.encode_unimodal(
            images, tokenized['input_ids'], tokenized['attention_mask']
        )
        print(f"  图像特征形状: {image_features.shape}")  # (4, 49, 768)
        print(f"  文本特征形状: {text_features.shape}")   # (4, seq_len, 768)
        print()
        
        # 2. 多模态编码（交叉注意力）
        print("2. 多模态编码（交叉注意力）")
        print("-" * 40)
        multimodal_features, _ = model.encode_multimodal(
            images, tokenized['input_ids'], tokenized['attention_mask']
        )
        print(f"  多模态特征形状: {multimodal_features.shape}")  # (4, seq_len, 768)
        print("  ✓ 文本特征已通过交叉注意力与图像特征融合")
        print()
        
        # 3. 对比学习任务
        print("3. 对比学习任务")
        print("-" * 40)
        contrastive_loss = model(images, tokenized['input_ids'], tokenized['attention_mask'], task="contrastive")
        print(f"  对比学习损失: {contrastive_loss.item():.4f}")
        print()
        
        # 4. ITM任务（使用多模态特征）
        print("4. 图像-文本匹配任务（基于交叉注意力）")
        print("-" * 40)
        itm_labels = torch.tensor([1, 1, 0, 0])
        itm_loss = model(images, tokenized['input_ids'], tokenized['attention_mask'], 
                        task="itm", labels=itm_labels)
        print(f"  ITM损失: {itm_loss.item():.4f}")
        
        itm_logits = model.itm_forward(images, tokenized['input_ids'], tokenized['attention_mask'])
        itm_predictions = torch.softmax(itm_logits, dim=-1)
        print(f"  ITM预测概率:")
        for i, pred in enumerate(itm_predictions):
            print(f"    样本{i+1}: 不匹配={pred[0]:.3f}, 匹配={pred[1]:.3f}")
        print()
        
        # 5. MLM任务（使用多模态特征）
        print("5. 掩码语言模型任务（基于交叉注意力）")
        print("-" * 40)
        mlm_loss = model(images, tokenized['input_ids'], tokenized['attention_mask'], task="mlm")
        print(f"  MLM损失: {mlm_loss.item():.4f}")
        print("  ✓ MLM预测基于图像-文本交叉注意力特征")
        print()
        
        # 6. 模型架构对比
        print("6. 模型架构特点")
        print("-" * 40)
        total_params = sum(p.numel() for p in model.parameters())
        multimodal_params = sum(p.numel() for p in model.multimodal_encoder.parameters())
        
        print(f"  总参数量: {total_params:,}")
        print(f"  多模态编码器参数: {multimodal_params:,}")
        print(f"  交叉注意力层数: {len(model.multimodal_encoder.layers)}")
        print(f"  ✓ 包含图像-文本交叉注意力机制")
        print(f"  ✓ 符合ALBEF论文架构")
        
    print("=" * 80)
    print("演示完成！现在的实现包含了论文中的交叉注意力机制")
    print("=" * 80)


if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    demo_albef_with_cross_attention()

ALBEF模型 - 包含交叉注意力机制
输入数据形状:
  图像: torch.Size([4, 3, 224, 224])
  文本ID: torch.Size([4, 8])
  注意力掩码: torch.Size([4, 8])

1. 单模态编码
----------------------------------------
  图像特征形状: torch.Size([4, 49, 768])
  文本特征形状: torch.Size([4, 8, 768])

2. 多模态编码（交叉注意力）
----------------------------------------
  多模态特征形状: torch.Size([4, 8, 768])
  ✓ 文本特征已通过交叉注意力与图像特征融合

3. 对比学习任务
----------------------------------------
  对比学习损失: 1.3920

4. 图像-文本匹配任务（基于交叉注意力）
----------------------------------------
  ITM损失: 0.6441
  ITM预测概率:
    样本1: 不匹配=0.360, 匹配=0.640
    样本2: 不匹配=0.223, 匹配=0.777
    样本3: 不匹配=0.326, 匹配=0.674
    样本4: 不匹配=0.469, 匹配=0.531

5. 掩码语言模型任务（基于交叉注意力）
----------------------------------------
  MLM损失: 9.7902
  ✓ MLM预测基于图像-文本交叉注意力特征

6. 模型架构特点
----------------------------------------
  总参数量: 215,928,701
  多模态编码器参数: 56,710,656
  交叉注意力层数: 6
  ✓ 包含图像-文本交叉注意力机制
  ✓ 符合ALBEF论文架构
演示完成！现在的实现包含了论文中的交叉注意力机制


In [2]:
#transformers库中没有ALBEF模型
