In [1]:
import torch 
from torch import nn
import pandas as pd
import numpy as np
import math
from torch.cuda import amp
from transformers import get_cosine_schedule_with_warmup
from collections import Counter
import collections

## Processing Dataset
Notice that in Machine translation task, each token in both source language and target language should be create a dictionary mapping.

In [2]:
class Build_vocabulary(object):
    '''
    Here we need to bulid a vocabulary for mapping
    '''
    def __init__(self, tokens = None, min_freq = 0, special_tokens = None):
        if tokens is None:
            tokens = []
        if special_tokens is None:
            special_tokens = []
        tokens = [token for line in tokens for token in line]
        counter = Counter(tokens)
        # sort by frequency
        self.freq = sorted(counter.items(), key = lambda x: x[1], reverse=True)
        # set special token
        self.idx_to_token = ["<unk>"] + special_tokens
        self.token_to_id = {token: ids for ids, token in enumerate(self.idx_to_token)}
        for token, freq in self.freq:
            if freq < min_freq:
                break
            if token not in self.token_to_id:
                self.idx_to_token.append(token)
                self.token_to_id[token] = len(self.idx_to_token) - 1
                
    #build some internal property
    def __len__(self):
        return len(self.idx_to_token)
    
    def __getitem__(self, token):
        '''
        Return the token:ids for each input token in dict
        '''
        if not isinstance(token, (list, tuple)):
            return self.token_to_id.get(token, 0)
        return [self.__getitem__(token_) for token_ in token]
    
    def indices_to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]
    

Process the French - english translation dataset:

In [3]:
def read_dataset():
    with open("./fra-eng/fra.txt", "r", encoding="utf-8") as f:
        text = f.read()
        text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
    def no_space(char, prev_char):
        return char in set(',.!?') and prev_char != " "
    out = []
    for i, char in enumerate(text):
        if i>0 and no_space(char, text[i-1]):
            out.append(' '+char)
        else:
            out.append(char)
    text = "".join(out)
    # Tokenization
    english_word, french_word = [], []
    for i, sentence in enumerate(text.split("\n")):
        # split by \t
        result = sentence.split("\t")
        if len(result) == 2:
            english_word.append(result[0].split(" "))
            french_word.append(result[1].split(" "))
    return english_word, french_word

In [4]:
english_word, french_word = read_dataset()

Then, also add special tokens: \<pad>, \<bos>, \<eos>, and create the mapping

Padding and truncation, Notice that padding is add \<pad> here:

In [5]:
def padding_truncation(tokens, max_lens, padding_token):
    if len(tokens) > max_lens:
        return tokens[:max_lens]
    return tokens + [padding_token]*(max_lens - len(tokens))

In [6]:
def build_array(tokens, dic, max_length):
    '''
    This function build the array of each token
    '''
    tokens_mapping = [dic[token] for token in tokens]
    tokens_mapping = [token + [dic["<eos>"]] for token in tokens_mapping]
    # add padding, truncation
    tensor = torch.tensor([padding_truncation(token, max_length, dic["<pad>"]) for token in tokens_mapping])
    valid_len = (tensor != dic["<pad>"]).type(torch.int32).sum(1)
    return tensor, valid_len

In [7]:
def processing_french_english_dataset(batch_size, max_length):
    english_word, french_word = read_dataset()
    english_mapping = Build_vocabulary(english_word, min_freq=2, 
                                   special_tokens=["<pad>", "<bos>", "<eos>"])
    french_mapping = Build_vocabulary(french_word, min_freq=2, 
                                   special_tokens=["<pad>", "<bos>", "<eos>"])
    english_array, english_valid_len = build_array(english_word, english_mapping, max_length)
    french_array, french_valid_len = build_array(french_word, french_mapping, max_length)
    dataset = torch.utils.data.TensorDataset(*(english_array, english_valid_len, french_array, french_valid_len))
    data_iter = torch.utils.data.DataLoader(dataset,batch_size = batch_size, shuffle = True)
    return data_iter, english_mapping, french_mapping

## Transformer Architecture
We totally have 3 parts:
1. Multi-head self-attention
2. Position-wise fully connected layer
3. Add & norm
4. Positional encoding

## Multi-head Attention
Given query: $\mathbf{q} \in \mathbb{R}^{d_q}$,
key$\mathbf{k} \in \mathbb{R}^{d_k}$, 
value$\mathbf{v} \in \mathbb{R}^{d_v}$, the calculation method of each attention head is:$$\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},$$

Then, for each $h_i$, given a concatenation of them and finally passing a learnable variable $W_o$ to project the final result

In [8]:
def masked_attention_softmax(attention_score, valid_len=None):
    '''
    This function provide the masked attention caculation result when passing
    the softmax normalization.
    valid len is a tensor, where means each
    '''
    if valid_len is None:
        return nn.functional.softmax(attention_score, dim = -1)
    else:
        # take a multiplication to mask
        if valid_len.dim() == 1:
            # if 1 dimension
            # repeat for each dimension
            valid_len = torch.repeat_interleave(valid_len, attention_score.shape[1])
        else:
            valid_len = valid_len.reshape(-1)
        #print("valid length", valid_len[:,None])
        max_length = attention_score.shape[-1]
        tmp_att_score = attention_score.reshape(-1, attention_score.shape[-1])
        # create masking, here None expand dimension on dim 0 or -1
#         print(torch.arange((max_length), dtype = torch.float32, 
#                             device = attention_score.device)[None,:])
        # use 1*4, 4*1 to broadcast result 
        mask = torch.arange((max_length), dtype = torch.float32, 
                            device = attention_score.device)[None,:] < valid_len[:,None]
        #print(mask)
        tmp_att_score[~mask] = -1e6
        return nn.functional.softmax(tmp_att_score.reshape(attention_score.shape), dim = -1)

In [9]:
class Dot_product_attention(nn.Module):
    def __init__(self, dropout_rate, **kwargs):
        super(Dot_product_attention,self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout_rate)
        self.attention_weight = None
    def forward(self, query, key, value, valid_len = None):
        '''
        Query * key.T/ sqrt d * value
        '''
        mult = torch.bmm(query, key.transpose(1,2)) / math.sqrt(query.shape[-1])
        self.attention_weight = masked_attention_softmax(mult, valid_len)
        return torch.bmm(self.dropout(self.attention_weight), value)

In [10]:
def parallel_calculate_multihead(input_, num_heads, reverse = False):
    '''
    This function provide a parallel computation of multiple attention heads
    '''
    if not reverse:
        # (bs, num_key,value pair, hidden_size)
        # change to (bs, num_key,value pair, num_heads, hidden_size/num_heads)
        input_ = input_.reshape(input_.shape[0], input_.shape[1], num_heads, -1)
        # change shape to(bs, num_Head, num_key,value pair, hidden/num_heads)
        input_ = input_.permute(0,2,1,3)
        # return shape: (bs * num_head, num_key, value pair, hidden/num_head)
        return input_.reshape(-1, input_.shape[2], input_.shape[3])
    else:
        # change back to (bs, num_heads, num_key, value pair, hidden/num_heads)
        input_ = input_.reshape(-1, num_heads, input_.shape[1], input_.shape[2])
        input_ = input_.permute(0,2,1,3)
        return input_.reshape(input_.shape[0], input_.shape[1], -1)

In [11]:
class Multi_head_attention(nn.Module):
    def __init__(self, key_size, value_size, query_size, num_hidden_size, num_heads,
                dropout_rate, bias = False, **kwargs):
        super(Multi_head_attention,self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = Dot_product_attention(dropout_rate)
        #---------------- 3 learnable variables --------------------------
        self.W_query = nn.Linear(query_size, num_hidden_size, bias=bias)
        self.W_key = nn.Linear(key_size, num_hidden_size, bias=bias)
        self.W_value = nn.Linear(value_size, num_hidden_size, bias=bias)
        # ------------------- Projection of final hiddens
        self.W_output = nn.Linear(num_hidden_size, num_hidden_size, bias=bias)
    
    def forward(self, query, key, value, valid_len):
        query = parallel_calculate_multihead(self.W_query(query), self.num_heads)
        key = parallel_calculate_multihead(self.W_key(key), self.num_heads)
        value = parallel_calculate_multihead(self.W_query(value), self.num_heads)
        if valid_len is not None:
            valid_len = torch.repeat_interleave(valid_len, repeats = self.num_heads,
                                               dim = 0)
        output = self.attention(query, key, value, valid_len)
        concat = parallel_calculate_multihead(output, self.num_heads, reverse = True)
        # (bs, num_query, num_hidden)
        return self.W_output(concat)

## FeedForward Network
just a basic 2-layer MLP with relu activation function

In [12]:
class FFN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, **kwargs):
        super(FFN, self).__init__(**kwargs)
        self.ln1 = nn.Linear(input_size, hidden_size)
        self.act = nn.ReLU()
        self.ln2 = nn.Linear(hidden_size, output_size)
    def forward(self, input_):
        return self.ln2(self.act(self.ln1(input_)))

## Add&Norm
By applying residual(short cut) here

In [13]:
class Add_and_norm(nn.Module):
    def __init__(self, norm_shape, dropout_rate, **kwargs):
        super(Add_and_norm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = nn.LayerNorm(norm_shape)
        
    def forward(self, X, Y):
        return self.layer_norm(self.dropout(Y) + X)

## Positional Encoding

In [14]:
class PositionalEncoding(nn.Module):
    """Positional encoding."""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # Create a long enough `P`
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

## Transformer Encoder
Transformer encoder needs multihead attention, AddNorm, FeedForward, AddNorm

In [15]:
class Encoder(nn.Module):
    def __init__(self, query_size, key_size, value_size, num_heads, num_hiddens, 
                norm_shape, num_input_size, num_hidden_size, num_output_size, 
                dropout_rate, bias=False, **kwargs):
        '''
        input params:
            query_size: query input dim
            key_size: key input dim
            value_size: value input dim
            num_heads: number of attention heads
            num_hidden: number of hidden size per each attention projection
            norm_shape: normalize shape of layer normalization
            num_input_size: input dim of ffn
            num_hidden_size: hidden dim of ffn
            num_output_size: output dim of ffn
            dropout_rate: drop out rate
            bias: whether use bias
        '''
        assert num_input_size == num_hiddens, "Input FFN size should be equal to hidden size of Attention Output"
        assert num_output_size == num_hiddens, "Output FFN size should be equal to hidden size of Attention Output"
        assert num_hiddens % num_heads == 0, "Hidden size should be totally divided by num of heads"
        super(Encoder, self).__init__(**kwargs)
        self.attention = Multi_head_attention(key_size, value_size, query_size,
                                             num_hiddens,num_heads, dropout_rate,bias)
        self.add_norm = Add_and_norm(norm_shape, dropout_rate)
        self.ffn = FFN(num_input_size, num_hidden_size, num_output_size)
        self.add_norm2 = Add_and_norm(norm_shape, dropout_rate)
        
    def forward(self, input_, valid_len):
        # self attention
        #print(self.attention(input_, input_, input_, valid_len).shape)
        Res1 = self.add_norm(input_, self.attention(input_, input_, input_, valid_len))
        Res2 = self.add_norm2(Res1, self.ffn(Res1))
        return Res2

In [25]:
class TransformerEncoder(nn.Module):
    '''
    Here is the encoder of Transformer
    '''
    def __init__(self, num_words,query_size, key_size, value_size, num_heads, num_hiddens, 
                norm_shape, num_input_size, num_hidden_size, num_output_size, num_layers,
                dropout_rate, bias=False, **kwargs):
        '''
        input params:
            num_Words: number of words in dic
            query_size: query input dim
            key_size: key input dim
            value_size: value input dim
            num_heads: number of attention heads
            num_hidden: number of hidden size per each attention projection
            norm_shape: normalize shape of layer normalization
            num_input_size: input dim of ffn
            num_hidden_size: hidden dim of ffn
            num_output_size: output dim of ffn
            num_layers: How many encoder blocks
            dropout_rate: drop out rate
            bias: whether use bias
        '''
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        # token embedding
        # to save
        self.attention_weight = None
        self.embedding_matrix = nn.Embedding(num_words, num_hiddens)
        # positional encoding
        self.positional_encoding = PositionalEncoding(num_hiddens, dropout_rate)
        self.encoder_blk = nn.Sequential()
        for i in range(num_layers):
            self.encoder_blk.add_module(
            "Encoder_block "+ str(i), Encoder(query_size, key_size, value_size, num_heads,
                                             num_hiddens, norm_shape, num_input_size, num_hidden_size,
                                             num_output_size,dropout_rate, bias))
            
    def forward(self, input_, valid_len):
        # rescaled by multiplying sqrt hidden dimension
        X = self.positional_encoding(self.embedding_matrix(input_)*math.sqrt(self.num_hiddens))
        self.attention_weight = [None] * len(self.encoder_blk)
        # pass each blk
        for i, blk in enumerate(self.encoder_blk):
            X = blk(X,valid_len)
            # record attention weight on Dot product attention classes
            self.attention_weight[i] = blk.attention.attention.attention_weight
        # output shape: bs, num_k,q,v, per batch, num_hiddens
        return X

## Transformer Decoder
Different from Transformer Encoder, append a masked attention before the Encoder layer:

In [17]:
class Decoder(nn.Module):
    def __init__(self,query_size, key_size, value_size, num_heads, num_hiddens, 
                norm_shape, num_input_size, num_hidden_size, num_output_size,
                dropout_rate, ith, bias=False,**kwargs):
        '''
        input params:
            query_size: query input dim
            key_size: key input dim
            value_size: value input dim
            num_heads: number of attention heads
            num_hidden: number of hidden size per each attention projection
            norm_shape: normalize shape of layer normalization
            num_input_size: input dim of ffn
            num_hidden_size: hidden dim of ffn
            num_output_size: output dim of ffn
            num_layers: How many encoder blocks
            dropout_rate: drop out rate
            ith: ith layer of encoder
            bias: whether use bias
        '''
        super(Decoder,self).__init__(**kwargs)
        assert num_input_size == num_hiddens, "Input FFN size should be equal to hidden size of Attention Output"
        assert num_output_size == num_hiddens, "Output FFN size should be equal to hidden size of Attention Output"
        assert num_hiddens % num_heads == 0, "Hidden size should be totally divided by num of heads"
        self.ith = ith
        self.attention_mask = Multi_head_attention(key_size, value_size, query_size,
                                             num_hiddens,num_heads, dropout_rate,bias)
        self.add_norm = Add_and_norm(norm_shape, dropout_rate)
        self.attention2 = Multi_head_attention(key_size, value_size, query_size,
                                             num_hiddens,num_heads, dropout_rate,bias)
        self.add_norm2 = Add_and_norm(norm_shape, dropout_rate)
        self.ffn = FFN(num_input_size, num_hidden_size, num_output_size)
        self.add_norm3 = Add_and_norm(norm_shape, dropout_rate)
        self.training = True
    def forward(self, input_, ith_state, valid_len, state_ls):
        '''
        Take self attention as k,q,v masked attention,
        Take ith_state as k,v, with query as output of masked attention output
        Notice that only training seq2seq, we know the time step. When prediction,
        We only predict output sequence token by token, which means only generated tokens 
        can be used in decoder self-attention
        '''
        # means training
        if state_ls[self.ith] is None:
            # means here is the initial state
            # record key and values of encoder output
            keys, values = input_, input_
        # means validation
        else:
            # concat on time steps, which means the previous nth steps 
            keys, values = torch.cat([state_ls[self.ith], input_], axis = 1),torch.cat([state_ls[self.ith], input_], axis = 1)
            #print("After concatenation", keys.shape)
        # record to the state ls for next decoder block 
        state_ls[self.ith] = keys
        if self.training:
            #num kqv means the total len of token, also the step
            bs, num_kqv, _ = input_.shape
            # make sure only known the calculated output instead all the token
            decoder_valid_lens = torch.arange(1, num_kqv+1, device = input_.device).repeat(bs, 1)
        else:
            decoder_valid_lens = None
        # add ith_state attention with output of encoder
        # All query, key, value comes from last decoder output
        out1 = self.attention_mask(input_,keys,values, decoder_valid_lens)
        out1 = self.add_norm(input_, out1)
        # encoder, decoder attention
        # (bs, num_steps, num_hidden)
        out2 = self.attention2(out1, ith_state, ith_state, valid_len)
        out2 = self.add_norm2(out1,out2)
        return self.add_norm3(out2, self.ffn(out2)), ith_state, valid_len, state_ls

In [18]:
class TransformerDecoder(nn.Module):
    def __init__(self, num_words,query_size, key_size, value_size, num_heads, num_hiddens, 
                norm_shape, num_input_size, num_hidden_size, num_output_size, num_layers,
                dropout_rate, bias=False, **kwargs):
        '''
        input params:
            num_Words: number of words in dic
            query_size: query input dim
            key_size: key input dim
            value_size: value input dim
            num_heads: number of attention heads
            num_hidden: number of hidden size per each attention projection
            norm_shape: normalize shape of layer normalization
            num_input_size: input dim of ffn
            num_hidden_size: hidden dim of ffn
            num_output_size: output dim of ffn
            num_layers: How many encoder blocks
            dropout_rate: drop out rate
            bias: whether use bias
        '''
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        # token embedding
        # to save
        self.attention_weight = None
        self.embedding_matrix = nn.Embedding(num_words, num_hiddens)
        # positional encoding
        self.positional_encoding = PositionalEncoding(num_hiddens, dropout_rate)
        self.decoder_blk = nn.Sequential()
        for i in range(num_layers):
            self.decoder_blk.add_module(
            "Decoderblock" + str(i), Decoder(query_size, key_size, value_size, num_heads,
                                             num_hiddens, norm_shape, num_input_size, num_hidden_size,
                                             num_output_size,dropout_rate, i,bias))
        # calculate all the possible output (as n classification)
        self.ln = nn.Linear(num_hiddens, num_words)
    def init_state(self, encoder_output, encoder_valid_lens):
        return encoder_output, encoder_valid_lens, [None]*self.num_layers
    
    def forward(self, input_, state, valid_len, state_ls):
        input_ = self.positional_encoding(self.embedding_matrix(input_)*math.sqrt(self.num_hiddens))
        # record two multi head attention weight for each layer 
        self.attention_weight = [[None] * len(self.decoder_blk) for _ in range(2)]
        for i, blk in enumerate(self.decoder_blk):
            input_,state, valid_len, state_ls = blk(input_, state, valid_len, state_ls) 
            # record decoder masked self attention weight
            self.attention_weight[0][i] = blk.attention_mask.attention.attention_weight
            # Record encoder-decoder attention weight
            self.attention_weight[1][i] = blk.attention2.attention.attention_weight
        return self.ln(input_), state, valid_len, state_ls
    
    def return_attention_weight(self):
        return self.attention_weight

## Combine to fully transformer

In [19]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder, **kwargs):
        super(Transformer, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, encoder_input, encoder_valid_lens,decoder_input):
        encoder_output = self.encoder(encoder_input,encoder_valid_lens)
        decoder_init_state, valid_len, state_ls = self.decoder.init_state(encoder_output,
                                                                         encoder_valid_lens)
        return self.decoder(decoder_input,decoder_init_state, valid_len, state_ls)

## Masked Softmax

In [20]:
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    """带遮蔽的softmax交叉熵损失函数"""
    # pred的形状：(batch_size,num_steps,vocab_size)
    # label的形状：(batch_size,num_steps)
    # valid_len的形状：(batch_size,)
    def forward(self, pred, label, valid_len):
        #print(pred.shape, label.shape, valid_len.shape)
        weights = torch.ones_like(label)
        max_len = weights.size(1)
        mask = torch.arange((max_len), dtype = torch.float32,
                           device = weights.device)[None,:] < valid_len[:,None]
        weights[~mask] = 0
        self.reduction='none'
        unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(
            pred.permute(0, 2, 1), label)
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss

## Gradient clipping

In [21]:
def clip_grad(net, theta):
    params = [p for p in net.parameters() if p.requires_grad]
    norm = torch.sqrt(sum(torch.sum((p.grad**2)) for p in params))
    if norm > theta:
        for para in params:
            para.grad[:] *= theta/norm

Test Layer

In [26]:
# #print(masked_attention_softmax(torch.rand(2,2,4), torch.tensor([[1,3],[2,4]])))
# query = torch.normal(0,1,(2,1,2))
# keys = torch.ones((2,10,2))
# #第一个 mask 2， 第二个 mask 6
# valid_lens = torch.tensor([2, 6])
# values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
#     2, 1, 1)
# attention = Dot_product_attention(dropout_rate= 0.5)
# attention.eval()
# #print(attention(query, keys, values, valid_lens))
# num_hiddens, num_heads = 100,5
# attention = Multi_head_attention(num_hiddens, num_hiddens, num_hiddens, num_hiddens,
#                                 num_heads, 0,5)
# attention.eval()
# batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
# X = torch.ones((batch_size, num_queries, num_hiddens))
# Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
# attention(X, Y, Y, valid_lens).shape
# ffn = FFN(4, 4, 8)
# ffn.eval()
# ffn(torch.ones((2, 3, 4)))
# add_norm = Add_and_norm([3, 4], 0.5) # Normalized_shape is input.size()[1:]
# add_norm.eval()
# add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
# x = torch.ones((2,100,24))
# valid_len = torch.tensor([3,2])
# encoder_blk = Encoder(24, 24, 24, 8, 24,[100,24], 24,48,24,0.5)
# encoder_blk.eval()
# encoder_blk(x, valid_len).shape
# encoder = TransformerEncoder(200, 24, 24,24,8,24,[100,24], 24,48,24,2,0.5)
# encoder.eval()
# encoder(torch.ones((2, 100), dtype=torch.long), valid_len).shape
# decoderblk = Decoder(24, 24, 24, 8, 24,[100,24], 24,48,24,0.5,0)
# decoderblk.eval()
# X = torch.ones((2,100,24))
# ith_state = encoder_blk(X, valid_lens)
# decoderblk(X, ith_state, valid_lens,[None])[0].shape

torch.Size([2, 100, 24])


torch.Size([2, 100, 24])

## Set parameters

In [23]:
# here num of step means the fixed length of each input sentences
num_hiddens, num_layers, dropout_rate, bs, num_step = 64, 4, 0.1, 128, 10
lr, num_epochs, device = 2e-4, 200, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
num_input, num_hidden, num_output, num_heads = 64, 128, 64, 4
key_size, value_size, query_size = 64,64,64
norm_shape = [64]
train_iter, english_mapping, french_mapping =processing_french_english_dataset(bs,num_step)
encoder = TransformerEncoder(len(english_mapping),query_size,key_size, value_size,
                            num_heads, num_hiddens, norm_shape, num_input, num_hidden, 
                             num_output,num_layers, dropout_rate, True)
decoder = TransformerDecoder(len(french_mapping),query_size,key_size, value_size,
                            num_heads, num_hiddens, norm_shape, num_input, num_hidden, 
                             num_output,num_layers, dropout_rate, True)
transformer = Transformer(encoder, decoder)
criterion = MaskedSoftmaxCELoss()
optimizer = torch.optim.Adam(transformer.parameters(), lr = lr)
scheduler = get_cosine_schedule_with_warmup(optimizer= optimizer, num_warmup_steps = 0, 
                                                num_training_steps= len(train_iter), num_cycles = 0.5)
num_gpu = 1
max_grad_norm = 1000

In [24]:
def train(net, train_iter,criterion, optimizer, epochs, scheduler, gradient_accumulate_step, max_grad_norm ,num_gpu,
        target_vocab):
    net.train()   
    # instantiate a scalar object 
    ls          = []
    #device_ids  = [try_gpu(i) for i in range(num_gpu)]
    device  = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print("\ntrain on %s\n"%str(device))
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
    net.apply(xavier_init_weights)
    net.to(device)
    for epoch in range(epochs):
        net.train()
        for idx, value in enumerate(train_iter):
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in value]
            # when forward process, use amp
            bos = torch.tensor([target_vocab["<bos>"]]*Y.shape[0], device = device).reshape(-1,1)
            # this called teacher forcing
            decoder_input = torch.cat([bos, Y[:,:-1]], 1)
            output,_,_,_  = net(X, X_valid_len, decoder_input)
            # calculate masked loss
            loss        = criterion(output, Y, Y_valid_len)
            # prevent gradient to 0
            if gradient_accumulate_step > 1:
                # 如果显存不足，通过 gradient_accumulate 来解决
                loss    = loss/gradient_accumulate_step

            loss.sum().backward()
            # do the gradient clip
            gradient_norm = nn.utils.clip_grad_norm_(net.parameters(),max_grad_norm)
            clip_grad(net, 1)
            if (idx + 1) % gradient_accumulate_step == 0:
                # 多少 step 更新一次梯度
                optimizer.step()
                scheduler.step()
                #print("done 1 train")
            # 每1000次计算 print 出一次loss
            if idx % 30 == 0 or idx == len(train_iter) -1:
                with torch.no_grad():
                    print("==============Epochs "+ str(epoch) + " ======================")
                    print("loss: " + str(loss.mean()) + "; grad_norm: " + str(gradient_norm))
                    ls.append(loss.mean().item())
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': net.state_dict(),
                    'loss': ls
                },"./checkpoint.params")

In [25]:
# train(transformer, train_iter, criterion,optimizer, num_epochs, scheduler,1,max_grad_norm,
#      num_gpu,french_mapping)

## Prediction of Seq2Seq model
Notice that, each decoder time step will take the prediction token from the previous time step and fed into the decoder as an input. Also, the initial time step is \<bos>, when \<eos> is predicted, then complete

In [26]:
transformer.load_state_dict(torch.load("../input/french-english-transformer-weights/checkpoint.params")["model_state_dict"])

In [27]:
def predict(net, source_sentence, source_dic_mapping, target_dic_mapping, num_steps,
           save_attention_weight = False):
    net.eval()
    device  = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    net.to(device)
    # first process the source sentence into input token part
    source_tokens = source_dic_mapping[source_sentence.lower().split(" ")] + [source_dic_mapping["<eos>"]]
    encoder_valid_len = torch.tensor([len(source_tokens)], device = device)
    # padding 
    source_tokens = padding_truncation(source_tokens, num_steps, source_dic_mapping["<pad>"])
    # create batch dimension of input
    encoder_input = torch.unsqueeze(torch.tensor(source_tokens, dtype = torch.long, device = device), dim = 0)
    encoder_output = net.encoder(encoder_input, encoder_valid_len)
    state, valid_len, state_ls = net.decoder.init_state(encoder_output,encoder_valid_len)
    # create decoder input, first just the <bos> special token
    decoder_input = torch.unsqueeze(torch.tensor([target_dic_mapping["<bos>"]], dtype = torch.long, device = device), dim = 0)
    #print(decoder_input.shape)
    # for loop on time steps dimension
    output_seq, att_weight_seq = [], []
    for i in range(num_steps):
        decoder_output, state, valid_len, state_ls = net.decoder(decoder_input, state, valid_len, state_ls)
        # in the next time step, use prediction result as the next input state
        decoder_input = decoder_output.argmax(dim = 2)
        # reduce batch size
        pred = decoder_input.squeeze(dim = 0).type(torch.int32).item()
        if pred == target_dic_mapping["<eos>"]:
            break
        output_seq.append(pred)
    return " ".join(target_dic_mapping.indices_to_tokens(output_seq)), att_weight_seq

## The metric of evaluation of machine-translation problem
In general, a metric called `BLEU` (Bilingual Evaluation Understudy) is used for this task, defined as following:
1. Denote $p_n$, precision of n-grams, which is the ration of the number of matched n-grams in the predicted and label sequences to the number of n-grams in the predicted sequence, 即其为两个数量的比值，第一个为预测序列与标签序列中匹配的n-gram数量，第二个是预测序列中n-gram数量的比率， 比如 label: A, B, C, D, E,F; Predict: A,B,B,C,D,对于：
    1. 1-gram, 共有5组，有四个是对的，即 \<BOS>->A, A->B, B->B, B->C, C->D, 其中和label匹配的有\<BOS>->A, A->B, B->C, C->D, 所以此处的precision为 4/5
    2. 2-gram, 一共4组，其中三个是对的，precision为3/4，以此类推
    

The Formula here is:
$exp(min(0,1-\frac{len(label)}{len(pred)}))\prod_{n=1}^{k}p_n^{1/2^n}$

Base on definition of BLEU, it's obvious to find that the longer the n-gram, the more difficult to match. When pn is fixed, the longer n, the larger the value. Also, the shorter, the larger score, so, we use a exp to punish this situation, which will punish the shorter sequence output.

In [31]:
def BLEU(pred, label, n_gram):
    assert isinstance(pred, str)
    assert isinstance(label, str)
    pred_tokens, labels_tokens = pred.split(" "), label.split(" ")
    len_prediction, len_label = len(pred_tokens), len(labels_tokens)
    score = math.exp(min(0, 1 - len(label)/len(pred)))
    for n in range(1,n_gram +1):
        num_matches, label = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label[" ".join(labels_tokens[i:i+n])] += 1
        for i in range(len_prediction - n + 1):
            if label[" ".join(pred_tokens[i:i+n])] > 0:
                num_matches += 1
                label[" ".join(pred_tokens[i:i+n])]-=1
        score *= math.pow(num_matches/(len_prediction - n + 1), math.pow(0.5,n))
    return score

In [32]:
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = predict(
        transformer, eng, english_mapping, french_mapping, num_step)
    print(str(eng) +"===>"+str(translation))
    print("BLEU_SCORE: ", BLEU(translation, fra, 2))