假设输入的是两个不同的特征张量（模拟查询、键值对的情况），展示如何通过CrossAttention模块进行信息融合。

In [3]:
import torch
from torch import nn
import torch.nn.functional as F
import math


# 定义CrossAttention类
class CrossAttention(nn.Module):
    def __init__(
            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
            proj_drop=0., window_size=None, attn_head_dim=None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        if attn_head_dim is not None:
            head_dim = attn_head_dim
        all_head_dim = head_dim * self.num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, all_head_dim, bias=False)
        self.k = nn.Linear(dim, all_head_dim, bias=False)
        self.v = nn.Linear(dim, all_head_dim, bias=False)

        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
        else:
            self.q_bias = None
            self.k_bias = None
            self.v_bias = None

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, bool_masked_pos=None, k=None, v=None):
        B, N, C = x.shape
        N_k = k.shape[1]
        N_v = v.shape[1]

        q_bias, k_bias, v_bias = None, None, None
        if self.q_bias is not None:
            q_bias = self.q_bias
            k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
            v_bias = self.v_bias

        q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
        q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)    # (B, N_head, N_q, dim)

        k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
        k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)

        v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
        v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))      # (B, N_head, N_q, N_k)

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x


# 设置相关的维度参数和输入张量示例
batch_size = 2  # 批次大小
dim = 64  # 特征维度
num_heads = 4  # 头的数量
seq_len_query = 10  # 查询序列长度
seq_len_key_value = 8  # 键值对序列长度

# 随机生成输入张量，模拟查询、键、值
query = torch.rand(batch_size, seq_len_query, dim)
key = torch.rand(batch_size, seq_len_key_value, dim)
value = torch.rand(batch_size, seq_len_key_value, dim)

# 实例化CrossAttention模块
cross_attention_module = CrossAttention(dim=dim, num_heads=num_heads)

# 进行前向传播计算
output = cross_attention_module(query, k=key, v=value)

print("输出结果的形状:", output.shape)

输出结果的形状: torch.Size([2, 10, 64])


首先定义了CrossAttention类，包括线性变换、维度调整、注意力权重计算以及最终的信息融合和输出投影等操作。

然后设置了一些示例参数，比如批次大小、特征维度、头的数量以及查询和键值对的序列长度等，并随机生成了对应的输入张量（模拟查询、键、值）。

接着实例化了CrossAttention模块，并使用生成的输入张量进行前向传播计算，最后打印出输出结果的形状。