In [28]:
import math

import torch
from torch import nn, einsum
from torch.nn import Module
import torch.nn.functional as F

import time
from einops import rearrange, repeat, pack, unpack


In [29]:
# 示例输入
batch_size = 2
seq_len = 5
d_k = 4

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

In [37]:
def insert_zero_rows(tensor, lengths, target_lengths):
    assert len(lengths) == len(target_lengths), "Lengths and target lengths must be of the same length."
    
    # 计算每个位置需要插入的零行数
    zero_rows = [target - length for length, target in zip(lengths, target_lengths)]
    
    # 初始化结果列表
    parts = []
    mask_parts = []
    start = 0
    
    for i, length in enumerate(lengths):
        end = start + length
        
        # 原始张量部分
        parts.append(tensor[:, start:end, :])
        mask_parts.append(torch.ones(tensor.size(0), length, dtype=torch.bool, device=tensor.device))
        
        # 插入零行
        if zero_rows[i] > 0:
            zero_padding = torch.zeros(tensor.size(0), zero_rows[i], tensor.size(2), device=tensor.device)
            mask_padding = torch.zeros(tensor.size(0), zero_rows[i], dtype=torch.bool, device=tensor.device)
            parts.append(zero_padding)
            mask_parts.append(mask_padding)
        
        start = end
    
    # 拼接所有部分
    padded_tensor = torch.cat(parts, dim=1)
    mask = torch.cat(mask_parts, dim=1)
    
    return padded_tensor, mask


def round_up_to_nearest_k(lst, k):
    return [(x + k - 1) // k * k for x in lst]

# 示例输入
tensor = torch.randn(7, 181 + 9 + 10, 128)
lengths = [21, 9, 10]
target_lengths = round_up_to_nearest_k(lengths, 10)

# 调用函数
start = time.time()
padded_tensor, mask = insert_zero_rows(tensor, lengths, target_lengths)
end = time.time()

print(f"Time taken: {end - start:.6f}")
print("Padded Tensor shape:", padded_tensor.shape)
print("Mask shape:", mask.shape)



Time taken: 0.000192
Padded Tensor shape: torch.Size([7, 50, 128])
Mask shape: torch.Size([7, 50])


In [16]:
import torch
import torch.nn.functional as F

def self_attention(Q, K, mask=None):
    # Q, K, V: [batch_size, seq_len, d_k]
    d_k = Q.size(-1)
    
    # 计算注意力得分
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # 应用掩码
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # 计算注意力权重
    attn_weights = F.softmax(scores, dim=-1)
    
    return attn_weights

# 填充掩码示例
padding_mask = torch.tensor([[1, 1, 1, 0, 0], [1, 1, 0, 0, 1]]).unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, seq_len]


# 应用填充掩码
attn_weights = self_attention(Q, K, mask=padding_mask)
print("Output with padding mask:", attn_weights)


Output with padding mask: tensor([[[[0.4903, 0.3744, 0.1353, 0.0000, 0.0000],
          [0.1181, 0.3770, 0.5049, 0.0000, 0.0000],
          [0.0804, 0.5052, 0.4144, 0.0000, 0.0000],
          [0.0820, 0.4937, 0.4243, 0.0000, 0.0000],
          [0.0754, 0.6691, 0.2555, 0.0000, 0.0000]],

         [[0.0193, 0.9334, 0.0473, 0.0000, 0.0000],
          [0.1245, 0.4938, 0.3817, 0.0000, 0.0000],
          [0.4794, 0.1304, 0.3902, 0.0000, 0.0000],
          [0.3361, 0.4615, 0.2024, 0.0000, 0.0000],
          [0.3814, 0.0232, 0.5954, 0.0000, 0.0000]]],


        [[[0.4670, 0.3567, 0.0000, 0.0000, 0.1763],
          [0.0617, 0.1968, 0.0000, 0.0000, 0.7416],
          [0.0243, 0.1528, 0.0000, 0.0000, 0.8229],
          [0.0832, 0.5011, 0.0000, 0.0000, 0.4158],
          [0.0649, 0.5762, 0.0000, 0.0000, 0.3589]],

         [[0.0186, 0.8987, 0.0000, 0.0000, 0.0827],
          [0.1487, 0.5898, 0.0000, 0.0000, 0.2615],
          [0.5016, 0.1364, 0.0000, 0.0000, 0.3620],
          [0.2452, 0.3366, 0.0