In [1]:
import torch  # 导入PyTorch库
import torch.nn as nn  # 导入PyTorch神经网络模块
from transformers import BertModel, BertTokenizer, BertForMaskedLM  # 导入Hugging Face的BERT模型和分词器
from torchvision.models import resnet50  # 导入ResNet50预训练模型
import random  # 导入随机模块，用于MLM任务的掩码操作

class ImageEncoder(nn.Module):
    """图像编码器（使用 ResNet50）"""
    def __init__(self, embed_dim=256):  # 初始化函数，设置嵌入维度默认为256
        super(ImageEncoder, self).__init__()  # 调用父类初始化
        self.backbone = resnet50(pretrained=True)  # 加载预训练的ResNet50模型作为骨干网络
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, embed_dim)  # 替换最后的全连接层，输出指定维度

    def forward(self, images):  # 前向传播函数
        return self.backbone(images)  # 通过ResNet50处理图像并返回特征


class TextEncoder(nn.Module):
    """文本编码器（使用 BERT）"""
    def __init__(self, embed_dim=256):  # 初始化函数，设置嵌入维度默认为256
        super(TextEncoder, self).__init__()  # 调用父类初始化
        self.bert = BertModel.from_pretrained('bert-base-uncased')  # 加载预训练的BERT模型
        self.proj = nn.Linear(self.bert.config.hidden_size, embed_dim)  # 投影层，将BERT输出映射到指定维度

    def forward(self, input_ids, attention_mask):  # 前向传播函数
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)  # 通过BERT处理文本
        last_hidden_state = outputs.last_hidden_state  # 获取最后一层的隐藏状态
        cls_token = last_hidden_state[:, 0, :]  # 取 [CLS] 对应的向量作为文本表示
        return self.proj(cls_token), last_hidden_state  # 返回投影后的CLS向量和完整的隐藏状态


class ALBEF(nn.Module):
    """简化版 ALBEF 模型"""
    def __init__(self, embed_dim=256):  # 初始化函数，设置嵌入维度默认为256
        super(ALBEF, self).__init__()  # 调用父类初始化
        self.image_encoder = ImageEncoder(embed_dim)  # 初始化图像编码器
        self.text_encoder = TextEncoder(embed_dim)  # 初始化文本编码器
        self.temperature = nn.Parameter(torch.ones([]) * 0.07)  # 温度参数，用于调整相似度计算的尺度
        
        # 图像-文本匹配（ITM）任务的分类器
        self.itm_head = nn.Linear(self.text_encoder.bert.config.hidden_size, 2)  # 二分类：匹配/不匹配
        
        # 掩码语言模型（MLM）任务
        self.mlm_head = nn.Linear(self.text_encoder.bert.config.hidden_size, self.text_encoder.bert.config.vocab_size)

    def forward(self, images, input_ids, attention_mask, task="contrastive"):  # 前向传播函数，增加任务类型参数
        """
        支持三种任务：
        - contrastive: 图像-文本对比学习
        - itm: 图像-文本匹配
        - mlm: 掩码语言模型
        """
        if task == "contrastive":  # 图像-文本对比学习
            return self.contrastive_forward(images, input_ids, attention_mask)
        elif task == "itm":  # 图像-文本匹配
            return self.itm_forward(images, input_ids, attention_mask)
        elif task == "mlm":  # 掩码语言模型
            return self.mlm_forward(images, input_ids, attention_mask)
        else:
            raise ValueError(f"不支持的任务类型: {task}")

    def contrastive_forward(self, images, input_ids, attention_mask):
        # 编码图像和文本
        image_features = self.image_encoder(images)  # 获取图像特征 (batch_size, embed_dim)
        text_features, _ = self.text_encoder(input_ids, attention_mask)  # 获取文本特征 (batch_size, embed_dim)

        # 归一化特征
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)  # L2归一化图像特征
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)  # L2归一化文本特征

        # 计算对比损失
        logits = (image_features @ text_features.t()) / self.temperature  # 计算余弦相似度并除以温度参数
        labels = torch.arange(logits.size(0), device=logits.device)  # 创建标签（对角线匹配）
        loss_i = nn.CrossEntropyLoss()(logits, labels)  # 计算图像到文本的损失
        loss_t = nn.CrossEntropyLoss()(logits.t(), labels)  # 计算文本到图像的损失
        loss = (loss_i + loss_t) / 2  # 取两个方向损失的平均值

        return loss  # 返回总损失
    
    def itm_forward(self, images, input_ids, attention_mask, labels=None):
        """图像-文本匹配（ITM）任务的前向传播"""
        # 编码图像
        image_features = self.image_encoder(images)  # (batch_size, embed_dim)
        
        # 编码文本，获取完整的隐藏状态
        _, text_hidden = self.text_encoder(input_ids, attention_mask)  # (batch_size, seq_len, hidden_size)
        
        # 使用[CLS]标记的隐藏状态进行分类
        cls_hidden = text_hidden[:, 0, :]  # (batch_size, hidden_size)
        
        # 图像特征与文本特征融合（简单拼接或相加）
        # 这里我们简单地将图像特征与CLS隐藏状态相加
        fused_features = cls_hidden + image_features  # (batch_size, hidden_size)
        
        # 通过ITM头部预测匹配/不匹配
        itm_logits = self.itm_head(fused_features)  # (batch_size, 2)
        
        # 如果提供了标签，计算损失
        if labels is not None:
            itm_loss = nn.CrossEntropyLoss()(itm_logits, labels)
            return itm_loss
        else:
            return itm_logits
    
    def mlm_forward(self, images, input_ids, attention_mask, masked_input_ids=None, mlm_labels=None):
        """掩码语言模型（MLM）任务的前向传播"""
        # 编码图像
        image_features = self.image_encoder(images)  # (batch_size, embed_dim)
        
        # 如果没有提供掩码输入，则创建掩码
        if masked_input_ids is None and mlm_labels is None:
            masked_input_ids, mlm_labels = self.mask_tokens(input_ids)
        
        # 编码文本，获取完整的隐藏状态
        _, text_hidden = self.text_encoder(masked_input_ids, attention_mask)  # (batch_size, seq_len, hidden_size)
        
        # 将图像特征与每个文本token的特征融合
        # 这里我们简单地将图像特征扩展并加到文本特征上，相当于把图像特征加到每个文本token上,实际albef这里用的是融合特征，而非简单拼接或相加
        #详见课件的说明：https://github.com/salesforce/ALBEF/blob/main/modeling_albef.py#L107-L110
        image_features_expanded = image_features.unsqueeze(1).expand(-1, text_hidden.size(1), -1)
        fused_features = text_hidden + image_features_expanded  # (batch_size, seq_len, hidden_size)
        
        # 通过MLM头部预测掩码token
        mlm_logits = self.mlm_head(fused_features)  # (batch_size, seq_len, vocab_size)
        
        # 如果提供了标签，计算损失
        if mlm_labels is not None:
            mlm_loss = nn.CrossEntropyLoss()(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1))
            return mlm_loss
        else:
            return mlm_logits
    
    def mask_tokens(self, input_ids, mlm_probability=0.15):
        """创建用于MLM任务的掩码输入和标签"""
        device = input_ids.device
        labels = input_ids.clone()
        
        # 创建掩码概率矩阵
        probability_matrix = torch.full(labels.shape, mlm_probability, device=device)
        
        # 特殊token不应该被掩码
        special_tokens_mask = torch.zeros_like(input_ids, dtype=torch.bool, device=device)
        for special_id in [0, 101, 102]:  # [PAD], [CLS], [SEP]的ID
            special_tokens_mask = special_tokens_mask | (input_ids == special_id)
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        
        # 创建掩码，bernoulli功能是以给定概率生成掩码
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # 只计算被掩码token的损失，-100是无效token的标签
        
        # 80%的情况下用[MASK]替换
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8, device=device)).bool() & masked_indices
        input_ids[indices_replaced] = 103  # [MASK]的ID
        
        # 10%的情况下用随机token替换
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5, device=device)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(0, self.text_encoder.bert.config.vocab_size, labels.shape, device=device)
        input_ids[indices_random] = random_words[indices_random]
        
        # 剩下10%保持不变
        
        return input_ids, labels


# 初始化模型
model = ALBEF()  # 创建ALBEF模型实例

# 打印模型结构
print(model)  # 输出模型的结构信息



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

ALBEF(
  (image_encoder): ImageEncoder(
    (backbone): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReL