In [6]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchtext.data import Field, BucketIterator
import numpy as np

import spacy
import en_core_web_sm
import de_core_news_sm
spacy_en = en_core_web_sm.load()
spacy_de = de_core_news_sm.load()

from IPython import embed

In [9]:
class EncoderStack(nn.Module):
    """The encoder stack of NAT.
    """
    def __init__(self, d_embed=512, nhead=8, num_encoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(EncoderStack, self).__init__()
        
        encoder_layer = nn.TransformerEncoderLayer(d_embed, nhead, dim_feedforward, dropout)
        encoder_norm = nn.LayerNorm(d_embed)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
        
    def forward(self, src, mask=None, src_key_padding_mask=None):
        """
        Shape:
            src: [S, N, E]
            output: [S, N, E]
        """
        output = self.encoder(src, mask, src_key_padding_mask)

        return output

In [21]:
class PositionalEncoding(nn.Module):
    """Positional Encoding.
    
    Args:
        d_embed: int, the dimension of embedding.
        max_seq_len: int, the max length of what to be summed with positional encoding.
    
    """
    def  __init__(self, d_embed, max_seq_len):
        super(PositionalEncoding, self).__init__()
        
        self.d_embed = d_embed
        pe = torch.zeros(max_seq_len, d_embed) # pe: [T, E]
        for pos in range(max_seq_len):
            for i in range(0, d_embed, 2):
                pe[pos, i] = \
                math.sin(pos / (10000 ** ((2 * i)/d_embed)))
                pe[pos, i + 1] = \
                math.cos(pos / (10000 ** ((2 * (i + 1))/d_embed)))
                
        pe = pe.unsqueeze(0) # pe: [1, T, E]
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # make embeddings relatively larger
        x = x * math.sqrt(self.d_embed) # x: [T, N, E]
        #add constant to embedding
        seq_len = x.size(0)
        
        pe = self.pe[:, :seq_len].repeat(x.size(1),1,1).transpose(0,1) # pe: [T, N, E]
        pe = Variable(pe, requires_grad=False)
        x = x.cuda() if pe.is_cuda else x
        x = x + pe
        
        return x

In [22]:
class MultiHeadPositionalAttention(nn.Module):
    """Multi-Head Positional Attention sublayer."""
    def __init__(self, d_embed, num_heads, max_seq_len=100, dropout=0., bias=True, add_bias_kv=False,
                 add_zero_attn=False, kdim=None, vdim=None):
        super(MultiHeadPositionalAttention, self).__init__()
        
        self.positional_encoding = PositionalEncoding(d_embed, max_seq_len)
    
        self.multi_head_attn = nn.MultiheadAttention(d_embed, num_heads, dropout, bias, 
                                                     add_bias_kv, add_zero_attn, kdim, vdim)
        
    def forward(self, query, key, value, key_padding_mask=None, need_weights=True, 
                attn_mask=None):
        query = self.positional_encoding(query) # query: [L, N, E]
        key = self.positional_encoding(key) # key: [S, N, E]
        value = self.positional_encoding(value) # value: [S, N, E], L = S
        
        # attn_output: [L, N, E], attn_output_weights: [N, L, S]
        attn_output, _ = self.multi_head_attn(query, key, value, key_padding_mask, 
                                                                need_weights, attn_mask)
        
        return attn_output

In [23]:
class DecoderLayer(nn.Module):
    """The sublayer of DecoderStack.
    """
    def __init__(self, d_embed, nhead, max_seq_len=100, dim_feedforward=2048, dropout=0.1):
        super(DecoderLayer, self).__init__()
        
        self.multi_head_self_attn = nn.MultiheadAttention(d_embed, nhead, dropout)
        self.multi_head_pos_attn = MultiHeadPositionalAttention(d_embed, nhead, max_seq_len, dropout)
        self.multi_head_inter_attn = nn.MultiheadAttention(d_embed, nhead, dropout)
        
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_embed, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_embed)

        # Implementation of MultiHead-Attentions
        self.norm1 = nn.LayerNorm(d_embed)
        self.norm2 = nn.LayerNorm(d_embed)
        self.norm3 = nn.LayerNorm(d_embed)
        self.norm4 = nn.LayerNorm(d_embed)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        self.dropout4 = nn.Dropout(dropout)

        self.d_embed = d_embed
        
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        """DecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
        
        Shape:
            tgt: [T, N, E].
            memory: [S, N, E]
        """
        tgt_len, bsz = tgt.size(0), tgt.size(1)
        
        # define tgt_mask
        if tgt_mask is None:
            diag_ones = np.array([1]*tgt_len)
            tgt_mask = torch.from_numpy(np.diag(diag_ones)).bool()
            tgt_mask = tgt_mask.float().masked_fill(tgt_mask, float('-inf'))
        
        # multi-head self-attention
        tgt2 = self.multi_head_self_attn(tgt, tgt, tgt)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        
        # multi-head positional attention
        tgt2 = self.multi_head_pos_attn(tgt, tgt, tgt)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        
        # multi-head inter-attention
        tgt2 = self.multi_head_inter_attn(tgt, memory, memory)[0]
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        
        # position-wise feed forward layer
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm4(tgt)
        
        return tgt

In [24]:
class DecoderStack(nn.Module):
    """The decoder stack of NAT.
    """
    def __init__(self, d_embed=512, nhead=8, max_seq_len=1000, num_decoder_layers=6, 
                 dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(DecoderStack, self).__init__()
        
        decoder_layer = DecoderLayer(d_embed, nhead, max_seq_len, dim_feedforward, dropout)
        decoder_norm = nn.LayerNorm(d_embed)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
        
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None):
        """
        Shape:
            tgt: [T, N, E]
            memory: [S, N, E]
            output: [T, N, E]
        """
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)
        
        return output

In [25]:
class FertilityPredictor(nn.Module):
    """The fertility predictor of NAT.
    
    Args:
        d_embed: the dimension of EncoderStack's output.
        L: the number of the classes used to represent fertility.
    """
    def __init__(self, d_embed, L):
        super(FertilityPredictor, self).__init__()
        
        self.fc_layer = nn.Linear(d_embed, L)
        self.relu = nn.ReLU()
        
    def forward(self, encoder_output):
        """Using EncoderStack's output to predict fertility list.
        
        Shape:
            encoder_output: [S, N, E]
            fertility_list: [S, N]
        """
        fertility_list = F.softmax(self.relu(self.fc_layer(encoder_output)), dim=-1) # fertility_list: [S, N, L]
        fertility_list = torch.argmax(fertility_list, dim=-1) # fertility_list: [S, N]
        
        return fertility_list

In [26]:
class TranslationPredictor(nn.Module):
    """The translation predictor of NAT.

    Args:
        d_embed: the dimension of EncoderStack's output.
        vocab: the number of the classes used to represent fertility.
    """
    def __init__(self, d_embed, vocab):
        super(TranslationPredictor, self).__init__()
        
        self.fc_layer = nn.Linear(d_embed, vocab)
        self.relu = nn.ReLU()
        
    def forward(self, decoder_output):
        """Using DecoderStack's output to predict translation.
        
        Shape:
            decoder_output: [S, N, E]
            translation_output: [S, N, vocab]
        """
        translation_output = F.softmax(self.relu(self.fc_layer(decoder_output)), dim=-1)
        
        return translation_output

In [27]:
class NAT(nn.Module):
    """Non-Autoregressive Transformer.
    
    Args:
        vocab_src: int, the size of source vocabulary.
        vocab_tgt: int, the size of target vocabulary.
        d_embed: int, the dimension of embedded input.
        S: int, the length of source input sentence.
        L: int, the number of the classes used to represent fertility.
        
    Shape:
        input: LongTensor, [N, S]
        output: FloatTensor, [T, N, vocab_tgt], T=S*L    
    """
    def __init__(self, vocab_src, vocab_tgt, S, d_embed=512, L=50, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(NAT, self).__init__()
        
        self.d_embed = d_embed
        self.S = S
        self.L = L
        self.T = S*L
        max_seq_len = S*L
        
        self.embedding_input = nn.Embedding(vocab_src, d_embed)
        self.position_encoder_en = PositionalEncoding(d_embed, S)
        self.position_encoder_de = PositionalEncoding(d_embed, S*L)
        self.encoder = EncoderStack(d_embed, nhead, num_encoder_layers, dim_feedforward, dropout, activation)
        self.fertility_predictor = FertilityPredictor(d_embed, L)
        self.decoder = DecoderStack(d_embed, nhead, max_seq_len, num_decoder_layers, dim_feedforward, dropout, activation)
        self.translation_predictor = TranslationPredictor(d_embed, vocab_tgt)
        
    def forward(self, input):
        
        input_e = self.embedding_input(input.transpose(0,1)) # input_e: [S, N, E]
        input_pe = self.position_encoder_en(input_e) # input_pe: [S, N, E]
        encoder_output = self.encoder(input_e) # encoder_output: [S, N, E] ---------------
        fertility_list = self.fertility_predictor(encoder_output) # fertility_list: [S, N]
        copied_embedding = self.copy_fertility(input_e, fertility_list, self.L) # copied_embedding: [T, N, E]
        copied_embedding_pe = self.position_encoder_de(copied_embedding) # copied_embedding_pe: [T, N, E]
        memory = encoder_output
        decoder_output = self.decoder(copied_embedding_pe, memory) # decoder_output: [T, N, E]
        output = self.translation_predictor(decoder_output) # output: [T, N, E]
        
        return output
        
        
    def copy_fertility(self, input_e, fertility_list, L):
        """Copy the input embedding as the number at corresponding index.
        
        Args:
            input_e: [S, N, E].
            fertility_list: [S, N].
            L: int, the number of the classes used to represent fertility.
        
        Output:
            copied_embedding: [T, N, E]
        """
        # copy as fertitlity list
        [S, N, E] = input_e.shape
        copied_embedding = torch.zeros(N, S*L, E) # copied_embedding: [N, S*L, E]
        input_e_permute = input_e.permute(1,0,2) # input_e_permute: [N, S, E]
        fertility_list_permute = fertility_list.transpose(0,1) # fertility_list_permuet: [N, S]
        
        # use fertility list and embedded input to get decoder's input.
        for i, fertility_batch in enumerate(fertility_list_permute):
            pos = 0
            for j, fertility_j in enumerate(fertility_batch):
                if fertility_j == 0:
                    continue
                copied_embedding[i,pos:pos+int(fertility_j),:] = input_e_permute[i,j,:].repeat(1,int(fertility_j),1)
                pos += int(fertility_j)
        copied_embedding = copied_embedding.transpose(0,1)
        
        return copied_embedding

In [35]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='3'

In [36]:
input = torch.randint(0, 1000, (16, 20)).long().cuda() # input: [N, S]
model = NAT(vocab_src=1000, vocab_tgt=500, S=20).cuda()

In [37]:
output = model(input)

RuntimeError: CUDA out of memory. Tried to allocate 490.00 MiB (GPU 0; 11.91 GiB total capacity; 1.21 GiB already allocated; 31.25 MiB free; 24.85 MiB cached)

In [13]:
output.size()

torch.Size([1000, 16, 500])