In [34]:
k = v = torch.zeros(1, 77, 512)
q = torch.rand(1, 1000, 512)

torch.Size([1, 77])

In [36]:
import torch
import torch.nn as nn


def create_padding_mask(seq): #(batch_size, seq_len)
    mask = (seq.sum(dim=-1) == 0)
    return mask

class CrossAttentionBlock(nn.Module):
    def __init__(self, n_features=512, n_heads=8, n_hidden=512, dropout=0.1):
        super(CrossAttentionBlock, self).__init__()
        self.MHselfA = nn.MultiheadAttention(embed_dim=n_features, num_heads=n_heads, dropout=dropout, batch_first=True)
        self.preLN1 = nn.LayerNorm(n_features)
        
        self.MHcrossA = nn.MultiheadAttention(embed_dim=n_features, num_heads=n_heads, dropout=dropout, batch_first=True)
        self.preLN2 = nn.LayerNorm(n_features)
        self.FF = nn.Sequential(
            nn.Linear(in_features=n_features, out_features=n_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(in_features=n_hidden, out_features=n_features)
        )
        self.postLN = nn.LayerNorm(n_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, kv, src_mask):
        tgt = query
        memory = kv
        
        # Self Attention
        tgt2 = self.MHselfA(tgt, tgt, tgt, key_padding_mask=None)[0]
        tgt = tgt + self.dropout(tgt2)
        tgt = self.preLN1(tgt)

        # Cross Attention
        tgt2 = self.MHcrossA(tgt, memory, memory, key_padding_mask=src_mask)[0]
        tgt = tgt + self.dropout(tgt2)
        tgt = self.preLN2(tgt)
        
        # Feedforward
        tgt2 = self.FF(tgt)
        tgt = tgt + self.dropout(tgt2)
        tgt = self.postLN(tgt)
        
        return tgt

# 示例代码
batch_size = 1
seq_len = 77
n_features = 512

# 生成一个随机输入，包含两个 padding token
k = v = torch.zeros(batch_size, seq_len, n_features)
k[0, 75:, :] = 0  # 最后两个 token 是 padding

# 生成 query，形状与 key 和 value 匹配
q = torch.rand(batch_size, 1000, n_features)

# 生成 padding mask
def create_padding_mask(seq):
    return (seq.sum(dim=-1) == 0)

src_mask = create_padding_mask(k)

# 创建 CrossAttentionBlock 实例
cross_attention_block = CrossAttentionBlock(n_features=n_features)

# 调用 forward 方法
output = cross_attention_block(query=q, kv=k, src_mask=src_mask)

print(output.shape)  # 应该输出: torch.Size([1, 77, 512])


torch.Size([1, 1000, 512])
