# 从零实现BLIP模型

# 1.安装依赖

In [None]:
!pip install -q torch torchvision matplotlib pillow

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA是否可用: {torch.cuda.is_available()}")


# 2. 实现Patch Embedding

In [None]:
class PatchEmbedding(nn.Module):
    """
    将图像分割成patch并进行嵌入
    
    输入: (B, 3, 224, 224)
    输出: (B, 196, 768)  # 14x14 = 196 patches
    """

    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2  # 14x14 = 196
        
        # 使用卷积将patch转换为向量
        self.proj = nn.Conv2d(
            in_channels, 
            embed_dim, 
            kernel_size=patch_size, 
            stride=patch_size
        )
    
    def forward(self, x):
        # x: (B, 3, 224, 224)
        x = self.proj(x)  # (B, 768, 14, 14)
        x = x.flatten(2)  # (B, 768, 196)
        x = x.transpose(1, 2)  # (B, 196, 768)
        return x

# 测试
patch_embed = PatchEmbedding()
test_img = torch.randn(2, 3, 224, 224)
output = patch_embed(test_img)
print(f"Patch Embedding输出形状: {output.shape}")  # (2, 196, 768)


# 3. 实现Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    """
    多头注意力机制

    支持三种模式：
    1.Self-Attention: Q = K = V
    2.Cross-Attention: Q来自一个序列，K和V来自另一个序列
    3.Masked Self-Attention: 用于解码器
    """

    def __init__(self, embed_dim=768, num_heads=12, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        # Q, K, V投影
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key=None, value=None, attention_mask=None):
        """
        Args:
            query: (B, N, C) 查询序列
            key: (B, M, C) 键序列 (None则使用query)
            value: (B, M, C) 值序列 (None则使用query)
            attention_mask: (B, 1, N, M) 注意力掩码
        """
        if key is None:
            key = query
        if value is None:
            value = query
        
        B, N, C = query.shape
        
        # 线形投影并重塑为多头形式
        q = self.q_proj(query).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(key).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(value).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力分数
        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        # 应用注意力掩码 (如果有)
        if attention_mask is not None:
            attn = attn + attention_mask
        
        # softmax归一化
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        # 加权求和
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.out_proj(x)
        return x

# 测试
attn = MultiHeadAttention()
test_q = torch.randn(2, 196, 768)
output = attn(test_q)
print(f"Multi-Head Attention输出形状: {output.shape}")


# 4. 实现Transformer Block

In [None]:
class MLP(nn.Module):
    """前馈神经网络"""
    
    def __init__(self, embed_dim=768, hidden_dim=3072, dropout=0.0):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.GELU()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    """Transformer编码器块"""
    
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, attention_mask=None):
        # Pre-Norm架构
        x = x + self.dropout(self.attn(self.norm1(x), attention_mask=attention_mask))
        x = x + self.mlp(self.norm2(x))
        return x

# 5. 实现Vision Encoder (ViT)

In [None]:
class VisionEncoder(nn.Module):
    """
    视觉编码器（ViT架构）
    输入: (B, 3, 224, 224)
    输出: 
        - cls_output: (B, 768) 全局特征
        - patch_output: (B, 196, 768) 局部特征
    """

    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio = 4.0,dropout=0.0):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        # cls token 和位置编码
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)

        #  transformer encoder
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        # 初始化权重
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, 196, 768)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, 768)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, 197, 768)
        x = self.pos_drop(x + self.pos_embed)  # (B, 197, 768)
        for block in self.blocks:
            x = block(x)  # (B, 197, 768)
        x = self.norm(x)  # (B, 197, 768)
        cls_output = x[:, 0]  # (B, 768)
        patch_output = x[:, 1:]  # (B, 196, 768)
        return cls_output, patch_output

# 测试
print("测试Vision Encoder...")
vision_encoder = VisionEncoder(depth=4)  # 使用较少层数快速测试
test_img = torch.randn(2, 3, 224, 224)
cls_out, patch_out = vision_encoder(test_img)
print(f"CLS输出形状: {cls_out.shape}")    # (2, 768)
print(f"Patch输出形状: {patch_out.shape}")  # (2, 196, 768)



# 6. 实现Text Encoder (BERT-style)

In [None]:
class TextEmbedding(nn.Module):
    """文本嵌入层"""
    
    def __init__(self, vocab_size=30524, embed_dim=768, max_position=512, dropout=0.0):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.position_embed = nn.Embedding(max_position, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        nn.init.normal_(self.token_embed.weight, std=0.02)
        nn.init.normal_(self.position_embed.weight, std=0.02)
    
    def forward(self, input_ids, position_ids=None):
        seq_length = input_ids.size(1)
        
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        
        token_embeds = self.token_embed(input_ids)
        position_embeds = self.position_embed(position_ids)
        
        embeddings = token_embeds + position_embeds
        embeddings = self.dropout(embeddings)
        
        return embeddings


class TextEncoder(nn.Module):
    """
    文本编码器 (BERT架构)
    
    输入: (B, seq_len) token IDs
    输出: (B, seq_len, 768) 文本特征
    """
    
    def __init__(self, vocab_size=30524, embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4.0, max_position=512, dropout=0.0):
        super().__init__()
        
        self.embeddings = TextEmbedding(vocab_size, embed_dim, max_position, dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, input_ids, attention_mask=None):
        x = self.embeddings(input_ids)
        
        # 创建注意力掩码
        if attention_mask is not None:
            extended_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            extended_mask = (1.0 - extended_mask) * -10000.0
        else:
            extended_mask = None
        
        for block in self.blocks:
            x = block(x, extended_mask)
        
        x = self.norm(x)
        
        return x

# 测试
print("测试Text Encoder...")
text_encoder = TextEncoder(depth=4)
test_text = torch.randint(0, 30524, (2, 32))
text_out = text_encoder(test_text)
print(f"Text输出形状: {text_out.shape}")  # (2, 32, 768)

# 7. 实现Text Decoder (带Cross-Attention)

In [None]:
class CrossAttentionBlock(nn.Module):
    """
    交叉注意力块 (用于解码器)
    
    包含:
    1. Self-Attention (带因果掩码)
    2. Cross-Attention (接收编码器输出)
    3. MLP
    """
    
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        
        self.norm2 = nn.LayerNorm(embed_dim)
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        
        self.norm3 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_hidden_states, self_attention_mask=None, cross_attention_mask=None):
        # Self-Attention
        x = x + self.dropout(self.self_attn(self.norm1(x), attention_mask=self_attention_mask))
        
        # Cross-Attention
        x = x + self.dropout(self.cross_attn(
            self.norm2(x), 
            key=encoder_hidden_states, 
            value=encoder_hidden_states,
            attention_mask=cross_attention_mask
        ))
        
        # MLP
        x = x + self.mlp(self.norm3(x))
        
        return x


class TextDecoder(nn.Module):
    """
    文本解码器 (用于生成任务)
    
    通过Cross-Attention接收图像特征
    自回归生成文本
    """
    
    def __init__(self, vocab_size=30524, embed_dim=768, depth=12, 
                 num_heads=12, mlp_ratio=4.0, max_position=512, dropout=0.0):
        super().__init__()
        
        self.embeddings = TextEmbedding(vocab_size, embed_dim, max_position, dropout)
        
        self.blocks = nn.ModuleList([
            CrossAttentionBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
    
    def _generate_causal_mask(self, seq_length, device):
        """生成因果掩码 (下三角矩阵)"""
        mask = torch.triu(torch.ones(seq_length, seq_length, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask.unsqueeze(0).unsqueeze(0)
    
    def forward(self, input_ids, encoder_hidden_states, attention_mask=None):
        seq_length = input_ids.size(1)
        
        x = self.embeddings(input_ids)
        
        # 因果掩码
        causal_mask = self._generate_causal_mask(seq_length, input_ids.device)
        
        if attention_mask is not None:
            extended_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            extended_mask = (1.0 - extended_mask) * -10000.0
            causal_mask = causal_mask + extended_mask
        
        for block in self.blocks:
            x = block(x, encoder_hidden_states, causal_mask, None)
        
        x = self.norm(x)
        logits = self.lm_head(x)
        
        return logits

# 8. 完整BLIP模型

In [None]:
class BLIPModel(nn.Module):
    """
    完整的BLIP模型
    
    支持:
    - 图文检索 (Image-Text Retrieval)
    - 图像描述生成
    - 视觉问答 (VQA)
    """
    
    def __init__(self, vocab_size=30524, img_size=224, patch_size=16, 
                 embed_dim=768, vision_depth=12, text_depth=12, num_heads=12, 
                 mlp_ratio=4.0, max_position=512, dropout=0.0):
        super().__init__()
        
        # 视觉编码器
        self.vision_encoder = VisionEncoder(
            img_size, patch_size, 3, embed_dim, vision_depth, 
            num_heads, mlp_ratio, dropout
        )
        
        # 文本编码器
        self.text_encoder = TextEncoder(
            vocab_size, embed_dim, text_depth, num_heads, 
            mlp_ratio, max_position, dropout
        )
        
        # 文本解码器
        self.text_decoder = TextDecoder(
            vocab_size, embed_dim, text_depth, num_heads, 
            mlp_ratio, max_position, dropout
        )
        
        # 投影层
        self.vision_proj = nn.Linear(embed_dim, embed_dim)
        self.text_proj = nn.Linear(embed_dim, embed_dim)
        
        # 图文匹配头
        self.itm_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, 2)
        )
        
        # 温度参数
        self.temp = nn.Parameter(torch.ones(1) * 0.07)
    
    def encode_image(self, image):
        """编码图像"""
        cls_output, _ = self.vision_encoder(image)
        image_embed = self.vision_proj(cls_output)
        image_embed = F.normalize(image_embed, dim=-1)
        return image_embed
    
    def encode_text(self, input_ids, attention_mask=None):
        """编码文本"""
        text_output = self.text_encoder(input_ids, attention_mask)
        text_embed = self.text_proj(text_output[:, 0])  # CLS token
        text_embed = F.normalize(text_embed, dim=-1)
        return text_embed
    
    def compute_similarity(self, image, input_ids, attention_mask=None):
        """计算图文相似度"""
        image_embed = self.encode_image(image)
        text_embed = self.encode_text(input_ids, attention_mask)
        similarity = image_embed @ text_embed.T / self.temp.exp()
        return similarity
    
    def generate_caption(self, image, max_length=50, bos_token_id=101, eos_token_id=102):
        """生成图像描述"""
        _, patch_output = self.vision_encoder(image)
        
        batch_size = image.size(0)
        generated = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=image.device)
        
        for _ in range(max_length):
            logits = self.text_decoder(generated, patch_output)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            generated = torch.cat([generated, next_token], dim=1)
            
            if (next_token == eos_token_id).all():
                break
        
        return generated

# 9.测试完整模型

In [None]:
# 创建模型 (使用较小的配置快速测试)
print("创建BLIP模型...")
model = BLIPModel(
    vocab_size=30524,
    img_size=224,
    patch_size=16,
    embed_dim=768,
    vision_depth=6,   # 减少层数
    text_depth=6,
    num_heads=12,
    mlp_ratio=4.0
)

# 统计参数
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n模型参数统计:")
print(f"  总参数: {total_params:,}")
print(f"  可训练参数: {trainable_params:,}")

# 测试前向传播
print("测试前向传播...")

batch_size = 2
image = torch.randn(batch_size, 3, 224, 224)
input_ids = torch.randint(0, 30524, (batch_size, 32))
attention_mask = torch.ones(batch_size, 32)

with torch.no_grad():
    # 图像编码
    image_embed = model.encode_image(image)
    print(f"图像特征: {image_embed.shape}")
    
    # 文本编码
    text_embed = model.encode_text(input_ids, attention_mask)
    print(f"文本特征: {text_embed.shape}")
    
    # 相似度计算
    similarity = model.compute_similarity(image, input_ids, attention_mask)
    print(f"相似度矩阵: {similarity.shape}")
    
    # 图像描述生成
    caption = model.generate_caption(image, max_length=10)
    print(f"生成的描述token: {caption.shape}")