In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size

        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        Q = self.query(x) # (batch_size, seq_length, embed_size)
        K = self.key(x)
        V = self.value(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_size ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output


In [2]:
# 测试self attention
batch_size = 5
seq_len = 10
embed_size = 128

# 创建输入数据
x = torch.randn(batch_size, seq_len, embed_size)

# 初始化自注意力层
self_attention = SelfAttention(embed_size=embed_size)

output = self_attention(x)

print(f'输入形状：', x.shape)
print(f'输出形状:', output.shape)

输入形状： torch.Size([5, 10, 128])
输出形状: torch.Size([5, 10, 128])


In [3]:
# 多头注意力机制
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads
        
        # 线性层变化 
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)

        self.combine_heads = nn.Linear(embed_size, embed_size)
    
    def forward(self, q, k, v):
        batch_size, seq_len, embed_size = q.shape
        head_dim = self.head_dim
        q, k, v = self.query(q), self.key(k), self.value(v)
        
        q = q.view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, head_dim).transpose(1, 2)

        # 计算注意力权重
        attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_output = torch.matmul(attention_weights, v) # attention_output @ v
        
        # 合并多头
        attention_output = attention_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, embed_size)
        output = self.combine_heads(attention_output)
        return output


In [4]:
# 测试多头注意力
batch_size = 5
seq_len = 10
embed_size = 128
num_heads = 4

x = torch.randn(batch_size, seq_len, embed_size)

multi_head_attention = MultiHeadAttention(embed_size=embed_size, num_heads=num_heads)

output = multi_head_attention(x, x, x)

print(f'输入形状：', x.shape)
print(f'输出形状:', output.shape)


输入形状： torch.Size([5, 10, 128])
输出形状: torch.Size([5, 10, 128])


In [None]:
# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 计算位置编码
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_size)) # 对应论文缩减因子 1/10000^(2i/d)
        pe[:, 0::2] = torch.sin(position * div_term) # 对应论文公式 sin(pos/10000^(2i/d))
        pe[:, 1::2] = torch.cos(position * div_term) # 对应论文公式 cos(pos/10000^(2i/d))
        pe = pe.unsqueeze(0) # [1, max_len, embed_size]
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# 测试位置编码
embed_size = 128
dropout = 0.1
max_len = 100
batch_size = 2
seq_len = 10

x = torch.randn(batch_size, seq_len, embed_size)

positional_encoding = PositionalEncoding(embed_size=embed_size, dropout=dropout, max_len=max_len)

output = positional_encoding(x)

print(f'输入形状：', x.shape)
print(f'输出形状:', output.shape)


输入形状： torch.Size([1, 10, 128])
输出形状: torch.Size([1, 10, 128])


In [12]:
# 旋转位置编码代码实现(RoPE)
from typing import Tuple
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # 计算词向量元素两两之间，每组元素对应的旋转角度
    # 1. 计算旋转角度 theta_j = 1/(10000)^(2(j-1) / d)
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 
    # 2. 生成token序列索引t = [0, 1, ..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs).float()

    # 
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(
        xq:torch.Tensor,
        xk:torch.Tensor,
        freqs_cis:torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
    # 转为复数域
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)
    
    # 应用旋转操作，然后将结果转回实数域
    # xq_out.shape = [batch_size, seq_len, dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
    return xq_out.type_as(xq), xk_out.type_as(xk)
    
