In [90]:
import torch
import torch.nn as nn
import math
import numpy as np

In [91]:
def get_len_mask(b: int, max_len: int, feat_lens: torch.Tensor, device: torch.device) -> torch.Tensor:
    attn_mask = torch.ones((b, max_len, max_len), device=device)
    for i in range(b):
        attn_mask[i, :, :feat_lens[i]] = 0
    return attn_mask.to(torch.bool)
def get_subsequent_mask(b: int, max_len: int, device: torch.device) -> torch.Tensor:
    """
    Args:
        b: batch-size.
        max_len: the length of the whole seqeunce.
        device: cuda or cpu.
    """
    return torch.triu(torch.ones((b, max_len, max_len), device=device), diagonal=1).to(torch.bool)
def get_enc_dec_mask(
    b: int, max_feat_len: int, feat_lens: torch.Tensor, max_label_len: int, device: torch.device
) -> torch.Tensor:
    attn_mask = torch.zeros((b, max_label_len, max_feat_len), device=device)       # (b, seq_q, seq_k)
    for i in range(b):
        attn_mask[i, :, feat_lens[i]:] = 1
    return attn_mask.to(torch.bool)

In [92]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_k, d_v, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v

        self.W_Q = nn.Linear(d_model, n_heads * d_k)
        self.W_K = nn.Linear(d_model, n_heads * d_k)
        self.W_V = nn.Linear(d_model, n_heads * d_v)
        self.output = nn.Linear(n_heads * d_v, d_model)

    def forward(self, query, key, value, attn_mask):
        batch_size = query.size(0)
        src_len = query.size(1)
        k_len = key.size(1)
        d_k = self.d_k
        d_v = self.d_v
        n_heads = self.n_heads

        query = self.W_Q(query).reshape(batch_size, -1, n_heads, d_k).transpose(1, 2)
        key = self.W_K(key).reshape(batch_size, -1, n_heads, d_k).transpose(1, 2)
        value = self.W_V(value).reshape(batch_size, -1, n_heads, d_v).transpose(1, 2)

        if attn_mask is not None:
            assert attn_mask.size() == (batch_size, src_len, k_len)
            attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
            attn_mask = attn_mask.bool()

        scores = torch.matmul(query, key.transpose(-1, -2)) / np.sqrt(d_k)

        if attn_mask is not None:
            scores.masked_fill_(attn_mask, -1e4)

        attns = torch.softmax(scores, dim=-1)
        output = torch.matmul(attns, value)
        output = output.transpose(1, 2).contiguous().reshape(batch_size, -1, d_v * n_heads)
        output = self.output(output)

        return output

In [93]:
MHA = MultiHeadAttention(  
    d_k=64,        # 每个头的维度  
    d_v=64,
    d_model=512,   # 模型总维度  
    n_heads=8,     # 8个注意力头  
    
)  
# 使用示例  
batch_size = 32  
seq_length = 100  
query = torch.randn(batch_size, seq_length, 512)  
key = torch.randn(batch_size, seq_length, 512)  
value = torch.randn(batch_size, seq_length, 512)  

# 创建注意力遮罩(可选)  
attn_mask = torch.zeros(batch_size, seq_length, seq_length).bool()  

# 前向传播  
output = MHA(query, key, value, attn_mask)  
print(query)
print(output)

tensor([[[-4.2775e-01,  1.1009e+00, -5.5461e-01,  ...,  7.1175e-01,
          -1.2172e+00, -4.7280e-01],
         [-7.3395e-02,  1.7238e+00, -1.5250e+00,  ...,  9.3528e-01,
           2.0412e-01,  3.5094e-01],
         [ 1.4875e+00,  6.2181e-02, -1.9177e-01,  ...,  2.3646e-01,
           1.6656e-01, -7.1866e-01],
         ...,
         [-6.5702e-02, -4.2091e-01,  9.0087e-01,  ...,  4.3691e-01,
           3.1007e-01, -6.0949e-01],
         [ 5.8539e-02, -8.0803e-01,  7.5208e-01,  ..., -3.4936e-01,
           7.7019e-01,  4.4988e-01],
         [-6.9764e-02,  3.9726e-02, -4.1676e-02,  ...,  2.1024e+00,
           3.1992e-01,  1.8180e+00]],

        [[ 8.6516e-01,  1.8501e-01,  7.2429e-01,  ..., -1.0912e+00,
          -8.1465e-01,  1.4580e+00],
         [-4.2022e-01,  2.3325e-01, -5.5296e-01,  ...,  1.9702e-01,
           1.6790e+00,  6.5472e-01],
         [ 1.2830e+00, -1.4826e+00, -2.6093e-01,  ..., -3.1383e-02,
           6.1435e-01, -7.1637e-01],
         ...,
         [-9.5878e-01, -4

In [94]:
def position_embedding(src_len, d_model):
    embedding_table = torch.zeros((src_len, d_model))
    for i in range(d_model):
        if i % 2 == 0:
            f = torch.sin
        else:
            f = torch.cos
        embedding_table[::, i] = f(torch.arange(0, src_len) / np.power(10000, 2 * (i // 2) / d_model))
    return embedding_table.float()

In [95]:
class PoswiseFFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.conv1 = nn.Conv1d(d_model, d_ff, 1, 1, 0)
        self.conv2 = nn.Conv1d(d_ff, d_model, 1, 1, 0)
        self.relu = nn.ReLU()

    def forward(self, x):
        # output of MHA (batch_size, src_len, d_model)
        out = self.conv1(x.transpose(1, 2))
        out = self.relu(out)
        out = self.conv2(out).transpose(1, 2)
        return out

In [96]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, dff):

        """
        Args:
            d_model: input dimension
            n_heads: number of attention heads
            dff: dimension f PosFFN (Positional FeedForward)
        """
        
        super().__init__()
        d_k = d_model // n_heads
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.multi_head_attn = MultiHeadAttention(d_k, d_k, d_model, n_heads)

        self.poswise_ffn = PoswiseFFN(d_model, dff)

    def forward(self, enc_in, attn_mask):
        residual = enc_in

        out = self.multi_head_attn(enc_in, enc_in, enc_in, attn_mask)
        out = self.norm1(residual + out)
        residual = out
        out = self.poswise_ffn(out)
        out = self.norm2(residual + out)

        return out

In [97]:
class Encoder(nn.Module):
    def __init__(self, num_layers, enc_dim, n_heads, dff, tgt_len):
        
        """
        Args:
            num_layers: numer of encoder layers
            enc_dim: input dimension of encoder
            n_heads: number of attention heads
            dff: dimension of PosFFN
            tgt_len: the maximum length of sequences
        """

        super().__init__()

        self.tgt_len = tgt_len
        self.pos_emb = nn.Embedding.from_pretrained(position_embedding(tgt_len, enc_dim), freeze=True)
        self.layers = nn.ModuleList(
            [EncoderLayer(enc_dim, n_heads, dff) for _ in range(num_layers)]
        )

    def forward(self, x, x_lens, mask=None):
        batch_size, seq_len, d_model = x.shape
        out = x + self.pos_emb(torch.arange(seq_len, device=x.device))
        for layer in self.layers:
            out = layer(out, mask)
        return out

In [98]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, dff):
        
        """
        Args:
            d_model: input dimension
            n_heads: number of attention heads
            dff: dimension f PosFFN (Positional FeedForward)
        """
        
        super().__init__()
        d_k = d_model // n_heads
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.poswise_ffn = PoswiseFFN(d_model, dff)

        self.dec_attn = MultiHeadAttention(d_k, d_k, d_model, n_heads)
        self.enc_dec_attn = MultiHeadAttention(d_k, d_k, d_model, n_heads)

    def forward(self, dec_in, enc_out, dec_mask, dec_enc_mask):
        residual = dec_in
        context = self.dec_attn(dec_in, dec_in, dec_in, dec_mask)
        dec_out = self.norm1(residual + context)
        residual = dec_out
        context = self.enc_dec_attn(dec_out, enc_out, enc_out, dec_enc_mask)
        dec_out = self.norm2(residual + context)
        residual = dec_out
        out = self.poswise_ffn(dec_out)
        dec_out = self.norm3(residual + out)
        return dec_out


In [99]:
class Decoder(nn.Module):
    def __init__(self, num_layers, dec_dim, n_heads, dff, tgt_len, tgt_vocab_size):

        """
        Args:
            num_layers: number of encoder layers
            dec_dim: input dimension of decoder
            num_heads: number of attention heads
            dff: dimensionf of PosFFN
            tgt_len: the target length to be embedded.
            tgt_vocab_size: the target vocabulary size.
        """

        super(Decoder, self).__init__()

        self.tgt_emb = nn.Embedding(tgt_vocab_size, dec_dim)

        self.pos_emb = nn.Embedding.from_pretrained(position_embedding(tgt_len, dec_dim), freeze=True)

        self.layers = nn.ModuleList(
            [
                DecoderLayer(dec_dim, n_heads, dff) for _ in range(num_layers)
            ]
        )

    def forward(self,  labels, enc_out, dec_mask, dec_enc_mask):
        tgt_emb = self.tgt_emb(labels)
        pos_emb = self.pos_emb(torch.arange(labels.size(1), device=labels.device))
        dec_out = tgt_emb + pos_emb
        for layer in self.layers:
            dec_out = layer(dec_out, enc_out, dec_mask, dec_enc_mask)
        return dec_out

In [100]:
class Transformer(nn.Module):
    def __init__(
            self, frontend: nn.Module, encoder: nn.Module, decoder: nn.Module,
            dec_out_dim: int, vocab: int,
    ) -> None:
        super().__init__()
        self.frontend = frontend     # feature extractor
        self.encoder = encoder
        self.decoder = decoder
        self.linear = nn.Linear(dec_out_dim, vocab)

    def forward(self, X: torch.Tensor, X_lens: torch.Tensor, labels: torch.Tensor):
        X_lens, labels = X_lens.long(), labels.long()
        b = X.size(0)
        device = X.device
        # frontend
        out = self.frontend(X)
        max_feat_len = out.size(1)                            # compute after frontend because of optional subsampling
        max_label_len = labels.size(1)
        # encoder
        enc_mask = get_len_mask(b, max_feat_len, X_lens, device)
        enc_out = self.encoder(out, X_lens, enc_mask)
        # decoder
        dec_mask = get_subsequent_mask(b, max_label_len, device)
        dec_enc_mask = get_enc_dec_mask(b, max_feat_len, X_lens, max_label_len, device)
        dec_out = self.decoder(labels, enc_out, dec_mask, dec_enc_mask)
        logits = self.linear(dec_out)

        return logits

In [101]:

# constants
batch_size = 16                 # batch size
max_feat_len = 100              # the maximum length of input sequence
max_label_len = 50              # the maximum length of output sequence
fbank_dim = 80                  # the dimension of input feature
hidden_dim = 512                # the dimension of hidden layer
vocab_size = 26                 # the size of vocabulary

# dummy data
fbank_feature = torch.randn(batch_size, max_feat_len, fbank_dim)        # input sequence
feat_lens = torch.randint(1, max_feat_len, (batch_size,))               # the length of each input sequence in the batch
labels = torch.randint(0, vocab_size, (batch_size, max_label_len))      # output sequence
label_lens = torch.randint(1, max_label_len, (batch_size,))             # the length of each output sequence in the batch

# model
feature_extractor = nn.Linear(fbank_dim, hidden_dim)                    # alinear layer to simulate the audio feature extractor
encoder = Encoder(
    num_layers=6, enc_dim=hidden_dim, n_heads=8, dff=2048, tgt_len=2048
)
decoder = Decoder(
    num_layers=6, dec_dim=hidden_dim, n_heads=8, dff=2048, tgt_len=2048, tgt_vocab_size=vocab_size
)
transformer = Transformer(feature_extractor, encoder, decoder, hidden_dim, vocab_size)

# forward check
logits = transformer(fbank_feature, feat_lens, labels)
print(f"logits: {logits.shape}")     # (batch_size, max_label_len, vocab_size)

# output msg
# logits: torch.Size([16, 100, 26])

logits: torch.Size([16, 50, 26])
