In [2]:
import torch
from torch import nn 
import torch.functional as F
import math

In [2]:
X = torch.randn(16,64,512)
print(X.shape)

torch.Size([16, 64, 512])


In [4]:
d_model = 512
n_head = 8

In [4]:
# init
# forward
# self attention
class multi_head_attention(nn.Module):
    def __init__(self, d_model, n_head):
        
        super(multi_head_attention, self).__init__()
        
        self.n_head = n_head
        self.d_model = d_model
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim = -1)
    
    def forward(self, q, k, v):
        B, T, D = q.shape
        n_d = self.d_model // self.n_head
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        
        # split
        q = q.view(B, T, self.n_head, n_d).transpose(1, 2)
        k = k.view(B, T, self.n_head, n_d).transpose(1, 2)
        v = v.view(B, T, self.n_head, n_d).transpose(1, 2)
        #print(q.shape)
        
        # scaled dot prodction 
        score = q @ k.transpose(2,3) / math.sqrt(n_d)
        mask = torch.tril(torch.ones(T,T,dtype=bool))
        score = score.masked_fill(mask == 0, -10000) #why -1000
        score = self.softmax(score)
        score = score @ v # "@" is multiple matrix
        
        #print(score.shape)
        
        # concate
        x_concate = score.transpose(1,2).contiguous().view(B,T, self.d_model)
        x_output = self.w_o(x_concate)
        #print(x_output.shape)
        return x_output
        
attn = multi_head_attention(d_model, n_head)
Y = attn(X,X,X)
print(Y.shape)

torch.Size([16, 64, 512])


In [10]:
# layer norm
class layer_norm(nn.Module):
    def __init__(self, d_model, eps = 1e-12):
        super(layer_norm, self).__init__()
        
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim = True)
        var = x.var(-1, unbiased=False, keepdim = True)
        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out
    
d_model = 512
X = torch.randn(2,5,512) # 2句话, 5个token，词向量512
ln = layer_norm(d_model)
print("d_model: ", d_model)
print(f"ln gamma: {ln.gamma.shape}")
print(f"ln beta: {ln.beta.shape}")
Y_ln = ln(X)
print(Y_ln.shape)

d_model:  512
ln gamma: torch.Size([512])
ln beta: torch.Size([512])
torch.Size([2, 5, 512])
