In [1]:
import torch
from torch import nn
from models.utils import param, reverse_packed_sequence, Namespace, packCharsWithMask, pad_sequence, pad_packed_sequence, pack_padded_sequence
from torch.nn.utils.rnn import PackedSequence
from models.networks import DeepTransitionRNN, CNNEmbedding, SequenceLabelingEncoder

In [2]:
wordEmbedding = 20 # 300
charEmbedding = 20 # 128
contextOutputUnits = outputUnits = 70
numWords = 10
numChars = 10
numTags = 17
contextTransitionNumber = transitionNumber = 3
maxLen = 5
encoderUnits = decoderUnits = 50

In [3]:
#max chars in word is 6
s1 = [
    [1, 2, 3],
    [4, 3],
    [5, 2, 2, 3, 7, 6],
]


s2 = [
    [4, 3, 1, 6, 2],
    [1, 2, 3],
]


# the packed sequence removes word axis padding, custom mask handles char axis padding
x, mask = packCharsWithMask([s1, s2])

w = [
    [2, 3, 1],
    [6, 2]
]

l = [len(s) for s in w]
w = [torch.LongTensor(s) for s in w]
w = pad_sequence(w)
w = pack_padded_sequence(w, l, enforce_sorted=False)
print(w.data.shape, x.data.shape, mask.shape)
print("lengths are: ", l)
print(w)

torch.Size([5]) torch.Size([5, 6]) torch.Size([5, 6, 1])
lengths are:  [3, 2]
PackedSequence(data=tensor([2, 6, 3, 2, 1]), batch_sizes=tensor([2, 2, 1]), sorted_indices=tensor([0, 1]), unsorted_indices=tensor([0, 1]))


In [4]:
class GlobalContextualEncoder(nn.Module):
    def __init__(self, numChars, charEmbedding, numWords, wordEmbedding, outputUnits, transitionNumber):
        super().__init__()
        self.cnn = CNNEmbedding(numChars, charEmbedding)
        self.glove = nn.Embedding(numWords, wordEmbedding)
        
        encoderInputUnits = wordEmbedding + charEmbedding
        self.forwardEncoder  = DeepTransitionRNN(encoderInputUnits, outputUnits, transitionNumber)
        self.backwardEncoder = DeepTransitionRNN(encoderInputUnits, outputUnits, transitionNumber)
        
    def forward(self, words, chars, charMask):
        _, *args = words
        
        w = self.glove(words.data)
        c = self.cnn(chars, charMask)
        
        # word and char concat, pass through encoder and we get directional global context
        wc = torch.cat([w, c.data], dim=-1)
        forwardInput  = PackedSequence( wc, *args )
        forwardG  = self.forwardEncoder(forwardInput)
        
        backwardInput = reverse_packed_sequence(forwardInput)      
        backwardG = self.backwardEncoder(backwardInput)
        backwardG = reverse_packed_sequence(backwardG)
        
        nonDirectionalG = torch.cat([forwardG.data, backwardG.data], dim=-1)
        
        # mean pooling is done by padding with zeros, taking timewise sum and dividing by lengths
        nonDirectionalG = PackedSequence(nonDirectionalG, *args)
        nonDirectionalG, lens = pad_packed_sequence(nonDirectionalG)
        lens = torch.unsqueeze(torch.unsqueeze(lens, -1), 0)
        nonDirectionalG_sum = nonDirectionalG.sum(dim=0, keepdim=True)
        g = nonDirectionalG_sum / lens
        
        # need to broadcast g and concat with wc
        new_shape = [nonDirectionalG.data.shape[0] // g.shape[0]] + [-1] * (len(g.shape) - 1)
        g = pack_padded_sequence(g.expand(*new_shape), lens[0,:, 0])
        
        wcg = torch.cat([g.data, wc], dim=-1)
        wcg = PackedSequence(wcg, *args)
        return wcg

    
model = GlobalContextualEncoder(numChars, charEmbedding, numWords, wordEmbedding, outputUnits, transitionNumber)
g = model(w, x, mask)
print(g.data.shape)

torch.Size([5, 180])


In [5]:
class GlobalContextualDeepTransition(nn.Module):
    def __init__(self, numChars, charEmbedding, numWords,
                     wordEmbedding, contextOutputUnits, contextTransitionNumber,
                        encoderUnits, decoderUnits, transitionNumber, numTags):
        super().__init__()
        self.contextEncoder = GlobalContextualEncoder(numChars, charEmbedding, numWords,
                                                          wordEmbedding, contextOutputUnits, contextTransitionNumber)
        self.labellerInput = wordEmbedding + charEmbedding + 2 * contextOutputUnits # units in g
        self.sequenceLabeller = SequenceLabelingEncoder(self.labellerInput, encoderUnits, decoderUnits, transitionNumber, numTags)
        
    def forward(self, words, chars, charMask):
        wcg = self.contextEncoder(words, chars, charMask)
        
        print(wcg.data.shape[-1], self.labellerInput)
        out = self.sequenceLabeller(wcg, reverse_packed_sequence(wcg))
        return out

model = GlobalContextualDeepTransition(numChars, charEmbedding, numWords,
                     wordEmbedding, contextOutputUnits, contextTransitionNumber,
                        encoderUnits, decoderUnits, transitionNumber, numTags)
y = model(w, x, mask)
data, *args = y
y, lens = pad_packed_sequence(y)
print(y.shape, lens)

180 180
torch.Size([3, 2, 17]) tensor([3, 2])
