In [2]:
import math

import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self,embeddingDim,dropout,maxWordLen=5000):
        positionalEncoding = torch.zeros(maxWordLen,embeddingDim)
        # keep dim 1 array a singleton
        position = torch.arange(0,maxWordLen).unsqueeze(1)
        # positional encoding is defined to be 
        # PE(pos,2i)=sin(pos/1e4^(2i/embedding_dim))
        # PE(pos,2i+1)=cos(pos/1e4^(2i/embedding_dim))
        exponentTerm = torch.arange(0,dim,2, dtype=torch.float)\
                *(-math.log(1e4*1.0)/embeddingDim)
        divisionTerm = torch.exp(exponentTerm)
        # all even indices
        positionalEncoding[:,0::2] = torch.sin(position.float()*divisionTerm)
        # all odd indices
        positionalEncoding[:,1::2] = torch.cos(position.float()*divisionTerm)
        # keep dim 0 array size 1 --> a single array
        positionalEncoding = positionalEncoding.unsqueeze(0)
        super(PositionalEncoding, self).__init__()
        self.register_buffer('positionalEncoding', positionalEncoding)
        #dropout is probability of dropout
        self.dropout = nn.Dropout(p=dropout)
        self.embeddingDim = embeddingDim
        
    def forward(self,embedding):
        # optional: add step -- not sure what it does
        # for dropout?
        embedding = embedding * math.sqrt(self.embeddingDim)
        embedding = embedding + self.positionalEncoding[:,:embeddingDim.size(1)]
        embedding = self.dropout(embedding)
        return embedding
    def getPositonalEncoding(self,embedding):
        return self.positionalEncoding[:,:embedding.size(1)]

In [3]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, numHeads, modelDim, feedForwardDim, dropout):
        super(TransformerEncoderLayer, self).__init__()
        self.selfAttention = MultiHeadedAttention(numHeads,modelDim,dropout=dropout)
        self.feedForward = PositionwiseFeedForward(modelDim,feedForwardDim,dropout)
        # output = (input - Expectation[input])/sqrt(variance(input)+eps)*gamma+beta
        # gamma and beta learnable 
        # normalizing over the last dim
        self.layerNorm = nn.LayerNorm(modelDim,eps=1e-6)
        #dropout is probability of dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,iteration,query,inputs,mask):
        # if it is not the first iteration, then normalize the layer 
        inputsNorm = self.layerNorm(inputs) if iteration != 0 else inputs 
        # keep dim 1 a singleton
        mask=mask.unsqueeze(1)
        # why?
        contextEncoding = self.selfAttention(inputsNorm,inputsNorm,inputsNorm,mask=mask)
        # why do we add input back?
        contextEncodingWDropout = self.dropout(contextEncoding) + inputs 
        return self.feedForward(contextEncodingWDropout)

In [4]:
#inter-sentence transformer focuses on learning relationship between sentences to produce a document level summary
# the input to this transformer is output of Bert, which can be viewed as contextual encoding. 
# modelDim is the output of the Bert's hidden size, or Bert's transformer's model dimension size 
# (model dimension size is the word used in the self hidden is all you need)
class InterSentencesTransformerEncoderLayer(nn.Module):
    def __init__(self,numHeads, modelDim, feedforwardDim, dropout,numInterSentencesLayers=0):
        super(InterSentencesTransformerEncoderLayer, self).__init__()
        self.modelDim = modelDim
        self.numInterSentencesLayers =  numInterSentencesLayers
        self.positionalEmbedding = PositionalEncoding(dimModel,dropout)
        self.interSentencesTransformers = nn.ModuleList(
            [TransformerEncoderLayer(numHeads, modelDim, feedForwardDim, dropout)
             for _ in range(numTransformers)])
        self.dropout = nn.Dropout(dropout)
        self.layerNorm = nn.LayerNorm(modelDim, eps=1e-6)
        self.linearLayer = nn.Linear(modelDim, 1, bias=True)
        self.sigmoid = nn.Sigmoid()
    # top_vector in the original code: top vectors? topic vectors?
    # topicVectors is the output of Bert, some sort of contextual embedding
    # we will use contextualEncoding instead of top_vector
    def forward(self,contextualEncoding,mask):
        batchSize, nSentences = contextualEncoding.size(0), contextualEncoding.size(1)
        positionalEmbedding = self.positionalEmbedding.positionalEncoding[:,:nSentences]
        # mask takes [:,:,None] to account for batches?
        # x will be the contextualEncoding undergoing transformer operations
        x = contextualEncoding * mask[:,:,None].float() + positionalEmbedding
        for iteration in range(self.numInterSentencesLayers):
            x = self.interSentencesTransformers[i](i,x,x,1-mask)
        x = self.layerNorm(x)
        sentencesScores = self.sigmoid(self.linearLayer(x))
        sentencesScores = sentencesScores.squeeze(-1)*mask.float()
        return sentencesScores

In [5]:
class MultiHeadedAttention(nn.Module):
    def __init__(self,numHeads, modelDim, dropout=0.1, isFinalLinear=True):
        assert modelDim%numHeads == 0
        self.dimPerHead = model.dim//numHeads 
        self.modelDim = modelDim
        
        super(MultiHeadedAttention,self).__init__()
        self.numHeads = numHeads
        self.linearKeys = nn.Linear(modelDim,numHeads*self.dimPerHead)
        self.linearValues = nn.Linear(modelDim,numHeads*self.dimPerHead)
        self.linearQuery = nn.Linear(modelDim,numHeads*self.dimPerHead)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.isFinalLinear = isFinalLinear
        if self.isFinalLinear:
            self.finalLayer = nn.Linear(modelDim,modelDim)
    def forward(self,key,value,query,mask = None,
               layerCache = None, types = None):
        batchSize = key.size(0) # size(0) is batch size for all three vectors
        dimPerHead = self.dimPerHead
        numHeads = self.numHeads

        
        shape = lambda x: x.view(batchSize,-1,numHeads,dimPerHead).transpose(1,2)
        # why contiguous: https://discuss.pytorch.org/t/when-and-why-do-we-use-contiguous/47588
        # we might not need it
        # apparently in the old version, transpose only changes the view of data, but not data itself
        # so to force it to change the data, use contiguous 
        unshape = lambda x: x.transpose(1,2).contiguous().view(batchSize,-1,numHeads*dimPerHead)
        
        # get key, value, and query
        if layerCache is None:
            key = self.linearKeys(key)
            value = self.linearValues(value)
            query = self.linearQuery(query)
            key = shape(key)
            value = shape(value)
            
        else:
            # Note: this is different from the original code. the original code has 
            # if statement that is already tested, and else statements that will 
            # never get use
            
            # concatenate to variable key" and "value" to their respective caches. 
            if type == "self":
                key = self.linearKeys(key)
                value = self.linearValues(value)
                query = self.linearQuery(query)
                key = shape(key)
                value = shape(value)
                
                device = key.device
                
                itemPairsToUpdate = [(key,"selfKeys"),(value,"selfValues")]
                
                
                for variable, variableName in itemPairsToUpdate:
                    if layerCache[variableName] is not None:
                        variable = torch.cat((layer_cache[variableName].to(device),variable), dim=2)
                    layerCache[variableName] = variable
                
            elif type == "context":
                # if no cache, create the cache, 
                # else copy the cache to the variables 
                query = self.linearQuery(query)
                if layerCache["memoryKeys"] is None:
                    key = self.linearKeys(key)
                    value = self.linearValues(value)
                    key = shape(key)
                    value = shape(value)
                    layerCache["memoryKeys"] = key
                    layerCache["memoryValues"] = value
                else:
                    key, value = layerCache["memoryKeys"], layerCache["memoryValues"]
        query = shape(query)
        
        '''
        # possibly for debugging purpose
        keyLength = key.size(2)
        queryLength = query.size(2)
        '''
        
        # compute and scale the scores
        
        # why sqrt?
        query = query / math.sqrt(dimPerHead)
        scores = torch.matmul(query,key.transpose(2,3))
        
        if mask is not None:
            mask = mask.unsqueeze(1).expand_as(scores)
            scores = scores.masked_fill(mask,-1e18) # negative infinity 
            
        # apply attention dropout and compute context vectors 
        attention = self.softmax(scores)
        attentionDropout = self.dropout(attention)
        if self.isFinalLinear:
            context = unshape(torch.matmult(attentionDropout,value))
            output = self.finalLayer(context)
            return output
        else:
            context = torch.matmul(attentionDropout,value)
            return context

In [6]:
class PositionwiseFeedForward(nn.Module):
    '''
    A two-layer Feed-Forward-Network with residual layer norm.
    '''
    
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(modelDim, feedforwardDim)
        self.linear2 = nn.Linear(feedforwardDim, modelDim)
        self.layerNorm = nn.LayerNorm(modelDim, eps=1e-6)
        # activation function
        self.gelu = lambda x: \
                0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        
        hidden = self.dropout1(self.gelu(self.linear1(self.layerNorm(x))))
        output = self.dropout2(self.linear12(hidden))
        return output + x