In [None]:
import torch  # 导入PyTorch库
import torch.nn as nn  # 导入PyTorch神经网络模块
from transformers import BertModel, ViTModel, BertTokenizer  # 从transformers库导入预训练模型和分词器

class BLIP_MED(nn.Module):  # 定义BLIP多模态编码解码模型
    def __init__(self, config):  # 初始化函数，接收配置参数
        super().__init__()  # 调用父类初始化
        self.image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224")  # 加载预训练的ViT图像编码器
        self.text_encoder = BertModel.from_pretrained("bert-base-uncased")  # 加载预训练的BERT文本编码器
        self.itc_head = nn.Sequential(  # 定义图像-文本对比(ITC)头部
            nn.Linear(config.hidden_size, config.hidden_size),  # 线性层
            nn.LayerNorm(config.hidden_size)  # 层归一化
        )
        self.itm_head = nn.Linear(config.hidden_size, 2)  # 图像-文本匹配(ITM)头部
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)  # 语言模型头部，用于词汇预测
        self.text_decoder = ImageGroundedTextDecoder(config)  # 初始化基于图像的文本解码器

    def forward(self, image, input_ids, attention_mask, decoder_input_ids=None):  # 前向传播函数
        # ViT image features
        image_outputs = self.image_encoder(pixel_values=image)  # 通过ViT编码图像
        image_feat = image_outputs.last_hidden_state  # [B, num_patches+1, C]  # 获取图像特征

        # Text encoding
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)  # 通过BERT编码文本
        text_feat = text_outputs.last_hidden_state  # 获取文本特征
        cls_feat = text_feat[:, 0]  # [CLS] token  # 获取CLS标记的特征

        # ITC: use cls_feat and image_feat[CLS]
        image_cls = image_feat[:, 0]  # 获取图像的CLS特征
        itc_score = torch.cosine_similarity(self.itc_head(cls_feat), self.itc_head(image_cls))  # 计算文本和图像特征的余弦相似度

        # ITM
        itm_logits = self.itm_head(cls_feat)  # 计算图像-文本匹配分数

        # LM decoding if decoder_input_ids provided
        if decoder_input_ids is not None:  # 如果提供了解码器输入
            lm_outputs = self.text_decoder(decoder_input_ids, image_feat)  # 使用文本解码器生成输出
            lm_logits = self.lm_head(lm_outputs)  # 计算语言模型的词汇预测
        else:
            lm_logits = None  # 否则设置为None

        return {  # 返回结果字典
            "itc_score": itc_score,  # 图像-文本对比分数
            "itm_logits": itm_logits,  # 图像-文本匹配分数
            "lm_logits": lm_logits  # 语言模型预测结果
        }

class ImageGroundedTextDecoder(nn.Module):  # 定义基于图像的文本解码器
    def __init__(self, config):  # 初始化函数
        super().__init__()  # 调用父类初始化
        self.layers = nn.ModuleList([  # 创建Transformer解码器层列表
            TransformerDecoderBlock(config) for _ in range(config.num_hidden_layers)  # 根据配置创建多个解码器块
        ])
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)  # 词嵌入层
        self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)  # 位置嵌入层
        self.layernorm = nn.LayerNorm(config.hidden_size)  # 层归一化

    def forward(self, decoder_input_ids, image_feats):  # 前向传播函数
        bsz, seq_len = decoder_input_ids.shape  # 获取批次大小和序列长度
        pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=decoder_input_ids.device)  # 创建位置ID
        pos_ids = pos_ids.unsqueeze(0).expand_as(decoder_input_ids)  # 扩展位置ID到与输入相同的形状

        x = self.embedding(decoder_input_ids) + self.position_embedding(pos_ids)  # 词嵌入加位置嵌入
        x = self.layernorm(x)  # 应用层归一化

        for layer in self.layers:  # 遍历所有解码器层
            x = layer(x, image_feats)  # 通过每一层处理

        return x  # 返回解码结果

class TransformerDecoderBlock(nn.Module):  # 定义Transformer解码器块
    def __init__(self, config):  # 初始化函数
        super().__init__()  # 调用父类初始化
        self.self_attn = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)  # 自注意力机制
        self.cross_attn = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)  # 交叉注意力机制
        self.feed_forward = nn.Sequential(  # 前馈神经网络
            nn.Linear(config.hidden_size, config.intermediate_size),  # 第一个线性层
            nn.GELU(),  # GELU激活函数
            nn.Linear(config.intermediate_size, config.hidden_size)  # 第二个线性层
        )
        self.norm1 = nn.LayerNorm(config.hidden_size)  # 第一个层归一化
        self.norm2 = nn.LayerNorm(config.hidden_size)  # 第二个层归一化
        self.norm3 = nn.LayerNorm(config.hidden_size)  # 第三个层归一化

    def forward(self, x, visual_feats):  # 前向传播函数
        # Causal mask
        seq_len = x.size(1)  # 获取序列长度
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()  # 创建因果掩码

        # 修复自注意力机制
        x2, _ = self.self_attn(x, x, x, attn_mask=attn_mask)  # 应用自注意力
        x = self.norm1(x + x2)  # 残差连接和层归一化

        x2, _ = self.cross_attn(query=x,  # 应用交叉注意力
                               key=visual_feats,  # 使用视觉特征作为键
                               value=visual_feats)  # 使用视觉特征作为值
        x = self.norm2(x + x2)  # 残差连接和层归一化

        x = self.norm3(x + self.feed_forward(x))  # 应用前馈网络，残差连接和层归一化
        return x  # 返回处理后的特征

# ---------- 推理示例 ----------
if __name__ == '__main__':  # 主程序入口
    from transformers import BertConfig  # 导入BERT配置
    config = BertConfig()  # 创建配置对象
    model = BLIP_MED(config)  # 初始化模型
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")  # 加载分词器

    # 构造输入
    dummy_image = torch.randn(2, 3, 224, 224)  # batch of 2  # 创建随机图像张量
    text = ["a girl holding a kitten", "a man riding a horse"]  # 示例文本
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)  # 对文本进行分词处理

    decoder_input = tokenizer(["a little","a small"], return_tensors="pt", padding=True).input_ids  # 创建解码器输入

    # 前向推理
    outputs = model(dummy_image, inputs['input_ids'], inputs['attention_mask'], decoder_input)  # 执行模型推理
    print("ITC score:", outputs['itc_score'].shape)  # 打印ITC分数形状
    print("ITM logits:", outputs['itm_logits'].shape)  # 打印ITM逻辑形状
    print("LM logits:", outputs['lm_logits'].shape if outputs['lm_logits'] is not None else None)  # 打印LM逻辑形状


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

ITC score: torch.Size([2])
ITM logits: torch.Size([2, 2])
LM logits: torch.Size([2, 4, 30522])
