In [1]:
import torch
import torch.nn as nn

In [None]:
class myRMSNorm(nn.Module):
    def __init__(self,eps,dim):
        super().__init__()
        self.eps = eps
        self.w = nn.Parameter(nn.ones(dim))
    def forward(self,x):
        y = x/(x.pow(2).mean(-1, keepdim = True)+self.eps).sqrt()
        return y*self.w

In [None]:
class myReLU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.maximum(x, torch.tensor(0, dtype=x.type,device=x.device))
    
class mySiLU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.sigmoid(x)*x
class mySwiGelu(nn.Module):
    def __init__(self, inputdim, embeddingdim):
        super().__init__()
        self.w1 = nn.Linear(inputdim,embeddingdim,bias=False)
        self.w2 = nn.Linear(inputdim,embeddingdim,bias=False)
        #self.w3 = nn.Linear(embeddingdim,inputdim,bias=False)
    def forward(self, x):
        #y = self.w3(nn.functional.gelu(self.w1(x))*self.w2(x))
        y = nn.functional.gelu(self.w1(x))*self.w2(x)
        return y

In [4]:
x = torch.rand([10])
eps = 1e-6
x/(x.pow(2).mean(-1, keepdim = True)+eps).sqrt()

tensor([0.5603, 0.6888, 1.4021, 0.7771, 1.1931, 0.2999, 0.8704, 0.2172, 1.4357,
        1.5042])

In [None]:
class myrnn(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()
        self.ln1 = nn.Linear(vocab_size, embedding_size)
        self.ln2 = nn.Linear(embedding_size,embedding_size)
        self.o1 = nn.Linear(embedding_size,vocab_size)
        self.relu = nn.ReLU()
    def forward(self, x, state):
        newstate = self.relu(self.ln1(x)+self.ln2(state))
        return self.o1(newstate), newstate

In [2]:
from fairscale.nn.model_parallel.layers import (
    ColumnParallelLinear,
    ParallelEmbedding,
    RowParallelLinear,
)

In [30]:
def sequence_mask(X,valid_lens,value = -1e6):
    shape = X.shape
    for i in range(shape[0]):
        X[i][valid_lens[i]:] = value
    return X

In [None]:
def sequence_mask(X, valid_len, value=-1e6):
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

In [31]:
def masked_softmax(X, valid_lens):
    #X bs*seq_len*embedding_size
    #valid_len bs*valid_len or bs
    #X bs*q_len*k_len
    #valid_len bs*valid_len or bs
    
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        X = sequence_mask(X.reshape(-1,shape[-1]), valid_lens)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
    

In [46]:
class additive_attention(nn.Module):
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(additive_attention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias = False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias = False)
        self.W_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)
    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        features = queries.unsqueeze(2)+keys.unsqueeze(1)
        scores = self.W_v(features).squeeze(-1)
        self.attention = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention), values)
        

In [155]:
class DotProductAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        scores = torch.bmm(queries,keys.transpose(1,2))/math.sqrt(d)
        self.attention = masked_softmax(scores,valid_lens)
        return torch.bmm(self.dropout(self.attention), values)        

In [89]:
def transpose_qkv(X, num_heads):
    X = X.reshape(X.shape[0],X.shape[1],num_heads, -1)
    #print(X)
    X = X.permute(0,2,1,3)
    X = X.reshape(-1,X.shape[2],X.shape[3])
    return X
def transpose_output(X, num_heads):
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1],-1)

In [96]:
class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
    def forward(self, queries, keys, values, valid_lens):
        #print(transpose_qkv(self.W_q(queries), self.num_heads))
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        #print(queries)
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats = self.num_heads, dim=0)
        attention = self.attention(queries, keys, values, valid_lens)
        output = transpose_output(attention, self.num_heads)
        return self.W_o(output)
        

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class dotproductattention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = dropout
    def forward(self, quiry, key, value, mask):
        d = quiry.shape[-1]
        score = torch.matmul(quiry, key.transpose(-2,-1))/math.sqrt(d)
        if mask:
            score = score+mask
        return torch.matmul(self.dropout(F.softmax(score, dim=-1)),value)
class multiheadattention(nn.Module):
    def __init__(self, nquiry,nkey,nvalue,embedding_size,nheads,dropout,bias=False):
        super().__init__()
        self.wq = nn.Linear(nquiry,embedding_size,bias=bias)
        self.wk = nn.Linear(nkey,embedding_size,bias=bias)
        self.wv = nn.Linear(nvalue,embedding_size,bias=bias)
        self.wo = nn.Linear(embedding_size,embedding_size,bias=bias)
        self.attention = dotproductattention(dropout)
        self.nhead = nheads
    def forward(self, quirys, keys, values, mask):
        quirys, keys, values = self.wq(quirys), self.wk(keys), self.wv(values)
        quirys = quirys.view(quirys.shape[0],quirys.shape[1],self.nheads,-1)
        keys = keys.view(keys.shape[0],keys.shape[1],self.nheads,-1)
        values = values.view(values.shape[0],values.shape[1],self.nheads,-1)
        score = self.attention(quirys,keys,values,mask)
        score = score.view(score.shape[0],score.shape[1],-1)
        return self.wo(score)
        

In [None]:
class positionencoding(nn.Module):
    def __init__(self,seq_len, numhiddens, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        x = torch.outer(torch.arange(seqlen,dtype=torch.float32),1.0/torch.pow(10000.0, torch.arange(0,numhiddens,2.dtype=torch.float32)/numhiddens))
        self.P = torch.torch.zeros((1, max_len, num_hiddens))
        

In [99]:
class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

In [100]:
class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input,ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens,ffn_num_outputs)
    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

In [116]:
class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)
    def forward(self,X,Y):
        return self.ln(self.dropout(Y)+X)

In [123]:
class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **args):
        super(EncoderBlock,self).__init__(**args)
        self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
        self.ffn = PositionWiseFFN(ffn_num_input,ffn_num_hiddens,ffn_num_input)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
    def forward(self, X, valid_lens):
        Y=self.addnorm1(X, self.attention(X,X,X,valid_lens))
        return self.addnorm2(Y,self.ffn(Y))
        

In [124]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **args):
        super(TransformerEncoder,self).__init__(**args)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),EncoderBlock(key_size, query_size, value_size, num_hiddens,norm_shape, ffn_num_input, ffn_num_hiddens,num_heads, dropout, use_bias))
    def forward(self, X, valid_lens, *args):
        X = self.pos_encoding(self.embedding(X)*math.sqrt(self.num_hiddens))
        self.attention_weights = [None] * len(self.blks)
        for i,blk in enumerate(self.blks):
            X = blk(X,valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention
        return X
        

In [132]:
class DecoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, use_bias=False, **args):
        super(DecoderBlock, self).__init__(**args)
        self.i = i
        #层数
        self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input,ffn_num_hiddens,ffn_num_input)
        self.addnorm3 = AddNorm(norm_shape, dropout)
    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None   
        X2 = self.attention1(X, key_values, key_values,dec_valid_lens)
        Y = self.addnorm1(X,X2)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs)
        Z = self.addnorm2(Y,Z)
        return self.addnorm3(Z,self.ffn(Z)), state

In [171]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def _norm(self, x):
        return x*torch.rsqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)
    def forward(self,x):
        output=self._norm(x.float()).type_as(x)
        return output*self.weight

In [175]:
def precompute_freq_cis(dim, end, theta = 10000.0):
    freq = 1.0/(theta**(torch.arange(0,dim,2,dtype = torch.float)/dim))
    t = torch.arange(end, device = freq.device)
    freqs = torch.outer(t,freq).float()
    freps_cis = torch.polar(torch.ones_like(freqs),freqs)
    return freqs_cis

In [176]:
def reshape_for_broadcast(freq_cis, x):
    return freq_cis.view(1,freq_cis.shape[0],1,freq_cis.shape[1])

In [177]:
def apply_rotary_emb(xq, xk, freq_cis):
    xq_ = torch.view_as_complex(xq.view(*xq.shape[:-1],-1,2))
    xk_ = torch.view_as_complex(xk.view(*xk.shape[:-1],-1,2))
    freq_cis = reshape_for_broadcast(freq_cis)
    xq_out = torch.view_as_real(xq_*freq_cis).flatten(3)
    xk_out = torch.view_as_real(xk_*freq_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [179]:
class attention(nn.Module):
    def __init__(self, args):
        super().__init___()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = ColumnParallelLinear(args.dim, args.n_heads*args.head_dim,bias=False, gather_output = False,init_method=lambda x: x)
        #why no xavier? torch.empty()????/
        self.wk = ColumnParallelLinear(args.dim, args.n_kv_heads *args.head_dim, bias=False, gather_output = False,init_method=lambda x: x)
        self.wv = ColumnParallelLinear(args.dim, args.n_kv_heads *args.head_dim, bias=False, gather_output = False,init_method=lambda x: x)
        self.wo = ColumnParallelLinear(args.n_heads*args.head_dim, args.dim, bias=False, gather_output = False,init_method=lambda x: x)
        self.cache_k = torch.zeros(args.max_batch_size, args.max_seq_lens, args.n_local_kv_heads, args.head_dim).cuda()
        self.cache_v = torch.zeros(args.max_batch_size, args.max_seq_lens, args.n_local_kv_heads, args.head_dim).cuda()
    def forward(self,x,start_pos,freq_cis,mask):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)
        

In [None]:
import torch
from torch.autograd import Function

class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, lambd):
        """
        前向传播：不做任何改变，直接返回输入。
        ctx 用于保存lambda参数，以便反向传播时使用。
        """
        ctx.lambd = lambd
        return x

    @staticmethod
    def backward(ctx, grad_output):
        """
        反向传播：将梯度乘以负的lambda，反转梯度方向。
        """
        grad_input = grad_output.neg() * ctx.lambd
        return grad_input, None  # 返回对输入x的梯度，和对lambda的梯度（None）

# 使用时定义一个函数来调用
class GradientReversalLayer(torch.nn.Module):
    def __init__(self, lambd=1.0):
        super(GradientReversalLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambd)

# 示例：如何使用 GRL
grl = GradientReversalLayer(lambd=1.0)  # 设置lambda参数
x = torch.randn(5, requires_grad=True)   # 输入张量
output = grl(x)  # 前向传播：输出和输入相同

# 模拟一个损失并进行反向传播
loss = output.sum()
loss.backward()

print(x.grad)

In [None]:
import heapq

class BeamSearchNode:
    def __init__(self, sequence, score, state):
        self.sequence = sequence
        self.score = score
        self.state = state

    def __lt__(self, other):
        return self.score < other.score

def beam_search(decoder, start_token, beam_width, max_len):
    # 初始化beam中的节点
    start_node = BeamSearchNode(sequence=[start_token], score=0, state=None)
    beam = [start_node]

    # 遍历每个时间步
    for _ in range(max_len):
        new_beam = []

        # 对beam中的每个节点进行扩展
        for node in beam:
            # 使用解码器生成下一个时间步的候选项
            candidates, next_state = decoder(node.sequence, node.state)
            for candidate, score in candidates:
                new_sequence = node.sequence + [candidate]
                new_score = node.score + score  # 累加分数
                new_node = BeamSearchNode(sequence=new_sequence, score=new_score, state=next_state)
                new_beam.append(new_node)

        # 选择得分最高的beam_width个节点继续搜索
        beam = heapq.nlargest(beam_width, new_beam)

    # 返回得分最高的序列
    best_node = max(beam, key=lambda x: x.score)
    return best_node.sequence

# 示例解码器函数
def example_decoder(sequence, state):
    # 这里我们使用一个简单的示例，实际情况中应使用实际的模型和解码逻辑
    next_candidates = [(0, -1.0), (1, -2.0), (2, -0.5)]  # 示例候选项和得分
    return next_candidates, state

# 使用Beam Search进行搜索
start_token = 0
beam_width = 3
max_len = 5
result = beam_search(example_decoder, start_token, beam_width, max_len)

print("Best sequence:", result)


In [170]:
def bleu(pred_seq, label_seq, k):
    pred_tokens, label_tokens = pred_seq.split(), label_seq.split()
    len_predict, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0,1-len_label/len_predict))
    for n in range(1,k+1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label-n+1):
            label_subs[' '.join(label_tokens[i:i+n])]+=1
        for i in range(len_predict-n+1):
            seq = ' '.join(pred_tokens[i:i+n])
            if seq in label_subs and label_subs[seq]>0:
                num_matches+=1
                label_subs[seq]-=1
        score*=math.pow(num_matches/(len_predict-n+1),math.pow(0.5,n))
    return score

In [169]:
torch.repeat_interleave(torch.arange(1, 10).repeat(2,1), repeats = 2, dim=0)

tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9],
        [1, 2, 3, 4, 5, 6, 7, 8, 9],
        [1, 2, 3, 4, 5, 6, 7, 8, 9],
        [1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [125]:
encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape

torch.Size([2, 100, 24])

In [106]:
add_norm = AddNorm([3,4], 0.5)
add_norm.eval()
add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4)))

tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]], grad_fn=<NativeLayerNormBackward0>)

In [101]:
ffn = PositionWiseFFN(4, 4, 8)
ffn.eval()
ffn(torch.ones((2, 3, 4)))[0]

tensor([[-0.0948,  0.0187,  0.4682, -0.3720, -0.7445,  0.1480,  0.8489,  0.5328],
        [-0.0948,  0.0187,  0.4682, -0.3720, -0.7445,  0.1480,  0.8489,  0.5328],
        [-0.0948,  0.0187,  0.4682, -0.3720, -0.7445,  0.1480,  0.8489,  0.5328]],
       grad_fn=<SelectBackward0>)

In [97]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()

MultiHeadAttention(
  (attention): DotProductAttention(
    (Dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)

In [98]:
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])

In [59]:
valid_lens = torch.tensor([[2,3],[1,4]])
torch.repeat_interleave(valid_lens, repeats = 4, dim=0)

tensor([[2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [1, 4],
        [1, 4],
        [1, 4],
        [1, 4]])

In [54]:
queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)

tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]])

In [50]:
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
 # values的小批量，两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
 2, 1, 1)
valid_lens = torch.tensor([2, 6])
attention = additive_attention(key_size=2, query_size=20, num_hiddens=8,
dropout=0.1)

In [51]:
attention.eval()
attention(queries, keys, values, valid_lens)

tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)