<a href="https://colab.research.google.com/github/hao1zhao/Model/blob/main/transformer_keycode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

**Position embedding**
$$\begin{aligned} p_{i, 2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right),\\p_{i, 2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right).\end{aligned}$$
**Attention(Q,K,V)**
$$ \mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.$$


**Encoder**

In [4]:
#structure train data
batch_size = 2
#words size
max_num_src_words = 8
max_num_tgt_words = 8
#domel dimension
model_dim = 8
#max length
max_scr_seq_len=5
max_tgt_seq_len=5
max_position_len =5

# src_len = torch.randint(2,5,(batch_size,))
# tgt_len = torch.randint(2,5,(batch_size,)) 
src_len = torch.Tensor([2,4]).to(torch.int32) #language squence length
tgt_len = torch.Tensor([4,3]).to(torch.int32) #target sqence length
#batch and padding
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words , (L,)),(0,max(src_len)-L)),0) \
                     for L in src_len]) #word index
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words , (L,)),(0,max(tgt_len)-L)),0) \
                     for L in tgt_len]) #word index
#word embedding
src_embedding_table = nn.Embedding(max_num_src_words+1,model_dim)#first row is padding(0)
tgt_embedding_table = nn.Embedding(max_num_src_words+1,model_dim)
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

#position embedding
pos_mat = torch.arange(max_position_len).reshape((-1,1))
i_mat = torch.pow(10000,torch.arange(0,8,2).reshape((1,-1))/model_dim)
pe_embedding_table = torch.zeros(max_position_len,model_dim)
pe_embedding_table[:,0::2] = torch.sin(pos_mat/i_mat) #even column
pe_embedding_table[:,1::2] = torch.cos(pos_mat/i_mat) #odd column

pe_embedding = nn.Embedding(max_position_len,model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table,requires_grad=False)
#position index
src_pos = torch.cat([torch.unsqueeze(torch.arange(max(src_len)),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max(tgt_len)),0) for _ in tgt_len]).to(torch.int32)
src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)

#self_attention mask:Relationship Matrix
#mask:[batch_size,max_src_len,max_src_len],value = 1 or -inf
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len)-L)),0) \
                               for L in src_len]),2)
valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos,valid_encoder_pos.transpose(1,2))
invalid_encoder_pos_matrix = 1 - valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)

score = torch.randn(batch_size,max(src_len),max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention,-1e9) #True:-1e9
prob = F.softmax(masked_score,-1)
print(score)
print(masked_score)
print(prob)
#softmax why we need to scale
# alpha1 = 0.1
# alpha2 = 10
# score = torch.rand(5) #result of similarity Q*K
# prob = F.softmax(score*alpha1,-1)
# prob2 = F.softmax(score*alpha2,-1)
# def softmax_func(score):
#   return F.softmax(score)
# jaco_mat1 = torch.autograd.functional.jacobian(softmax_func,score*alpha1)
# jaco_mat2 = torch.autograd.functional.jacobian(softmax_func,score*alpha2)



tensor([[[-0.0859, -1.1043, -0.7031,  0.2001],
         [-0.4731,  0.2562,  1.3384,  1.5176],
         [ 0.0924, -0.7039, -0.4194,  0.2126],
         [ 1.4768,  0.3347,  0.0357,  1.2195]],

        [[-0.3293,  0.8625, -0.6781,  0.3781],
         [-0.5138,  0.1044, -0.2862,  0.0664],
         [ 2.1876,  0.9374, -1.7762,  0.8907],
         [ 1.9163,  2.6244, -1.1028,  0.4590]]])
tensor([[[-8.5921e-02, -1.1043e+00, -1.0000e+09, -1.0000e+09],
         [-4.7315e-01,  2.5619e-01, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[-3.2929e-01,  8.6252e-01, -6.7808e-01,  3.7811e-01],
         [-5.1383e-01,  1.0437e-01, -2.8624e-01,  6.6426e-02],
         [ 2.1876e+00,  9.3739e-01, -1.7762e+00,  8.9074e-01],
         [ 1.9163e+00,  2.6244e+00, -1.1028e+00,  4.5903e-01]]])
tensor([[[0.7347, 0.2653, 0.0000, 0.0000],
         [0.3253, 0.6747, 0.0000, 0.0000],
         [0.2500, 0.2500, 

**Decoder**

In [32]:
# intra_attention mask
# Q @ K^T shape: [batch_size, tgt_seq_len, src_seq_len]
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len)-L)),0) \
                               for L in src_len]),2)
valid_decoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(tgt_len)-L)),0) \
                               for L in tgt_len]),2)

valid_cross_pos_matrix = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1,2))
invalid_cross_pos_matrix = 1 - valid_cross_pos_matrix
mask_cross_attention = invalid_cross_pos_matrix.to(torch.bool)

#decoder self-attention mask
valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones((L,L))),\
                        (0,max(tgt_len)-L,0,max(tgt_len)-L)),0) for L in tgt_len])
invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
invalid_decoder_tri_matrix = invalid_decoder_tri_matrix.to(torch.bool)
score = torch.randn(batch_size,max(tgt_len),max(tgt_len))
masked_de_score = score.masked_fill(invalid_decoder_tri_matrix,-1e9) #True:-1e9
prob_de = F.softmax(masked_de_score,-1)

#scaled self_attemtion
def scaled_dot_porduct_attention(Q,K,V,attn_mask):
  #Q,K,V:(batch_size*num_head,seq_len,model_dim/num_head)
  torch.bmm(Q, K.transpose(-2,-1))/torch.sqrt(model_dim)
  masked_score = score.masked_fill(attn_mask,-1e9)
  prob = F.softmax(masked_score,-1)
  context = torch.bmm(prob,V)
  return context


tensor([4, 3], dtype=torch.int32)
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.2903, 0.7097, 0.0000, 0.0000],
         [0.2045, 0.1580, 0.6375, 0.0000],
         [0.2276, 0.6026, 0.1447, 0.0251]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5549, 0.4451, 0.0000, 0.0000],
         [0.4484, 0.0998, 0.4519, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500]]])
