In [1]:
import re
import json
from tqdm.notebook import tqdm
import torch
import pytorch_lightning as pl
from data.dataset import NERDataset
from models.utils import Namespace, getSignal
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence, PackedSequence
from models.networks import GlobalContextualDeepTransition
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader 

In [2]:
sourceName = 'data/conll03/eng.train.src'
targetName = 'data/conll03/eng.train.trg'
gloveFile = 'data/conll03/trimmed.300d.Cased.txt'
symbFile = 'data/conll03/sym.glove'
testSrc = 'data/conll03/eng.testb.src'
testTrg = 'data/conll03/eng.testb.trg'

data = NERDataset(sourceName, targetName, gloveFile, symbFile)
data.readTestFile(testSrc, testTrg)
loader = data.getLoader(1024, shuffle=False)
# loader = DataLoader(data, collate_fn=data.pack_collate)

In [3]:
prevCheckpointPath = 'lightning_logs/ckpt-epoch=109-train_loss=0.63.ckpt'

with open('config.json', 'r') as file:
    kwargs = json.load(file)
    
model = GlobalContextualDeepTransition.load_from_checkpoint(prevCheckpointPath, **kwargs)
model = model.eval()

In [4]:
def logitsToLogProbs(logits):
    return logits - torch.logsumexp(logits, dim=1, keepdim=True)

In [5]:
with torch.no_grad():
    words, chars, charMask, targets = batch = next(iter(loader))
    encoded, initHiddenState, initPrevTarget = model.encode(words, chars, charMask)

In [6]:
def pathSum(values, logProbs):
    """
        Adds the prev sum to current logProbs
        to get the effective logprob
    """
    # values is batch, beam
    values = torch.unsqueeze(values, -1) # batch, beam, 1
    values = values.repeat(1, 1, tags)  # batch, beam, units
    values = values.permute(1,0,2)       # beam, batch, units
    values = torch.cat(list(values), -1) # batch, units * beam
    
    # logprobs is [batch, units*beam]
    ps = logProbs + values
    
    # ps is [batch, units*beam]
    return ps

In [7]:
batchSize = words.batch_sizes[0].item() # batchSize
beamSize  = 4 # beamsize
units = model.sequenceLabeller.decoderUnits
tags = model.numTags

print(f"batchSize={batchSize}, beamSize={beamSize}, units={units}, tags={tags}")

lengths = torch.zeros(batchSize, dtype=torch.int)

for x in words.batch_sizes:
    lengths[:x] += 1

"""
    values[i, j] contains the heuristic beam of the ith example. j in range(beamSize)
    paths [i, j] contains the corresponding paths
"""
values = torch.zeros(batchSize, beamSize) # we maintain a queue like tensor, each example has a queue of size beamSize
paths  = [ [list() for _ in range(beamSize)] for _ in range(batchSize) ] # one node in each queue as root of tree
"""Init nodes in a matrix for each beam and """

# encoded pages
start = 0
encodedPages = []
for pageLen in words.batch_sizes:
    if start == 0:
        page = encoded[start:start+pageLen] # first page is not repeated
    else:
        page = encoded[start:start+pageLen].repeat(beamSize, 1)
    encodedPages.append(page) # [e1, e2, e3, e1, e2, e3.. etc, repeated beamSize times]
    start += pageLen

live = list(range(batchSize))
dead = []

# initial values are not repeated
hiddenState = initHiddenState
prevTarget = initPrevTarget


for t, b in enumerate(words.batch_sizes):
    """Get the previous target and make the forward pass"""
    if t == 0:
        actualSize = b
    else:
        actualSize = b * beamSize
    prevTarget = prevTarget[:actualSize] + getSignal(1, units, t, model.device)
    with torch.no_grad():
        hiddenState, logits = model.sequenceLabeller.decode_once(
            encodedPages[t],
            prevTarget,
            hiddenState
        )
    logProbs = logitsToLogProbs(logits) # [b*beamSize, units] ie [l1, l2, l3, l1, l2, l3 ... numTag d vectors repeated]

    """Add the logProbs to the current paths to get newPathSums"""
    if t == 0:
        ps = logProbs
    else:
        logProbs = logProbs.reshape((beamSize, b, tags)) # now becomes [[l1, l2, l3], [l1, l2, l3], [l1, l2, l3]...]
        logProbs = torch.cat(list(logProbs), dim=-1) # [l1l1l1..., l2l2l2..., l3l3l3...]
        ps = pathSum(values[:b], logProbs)

    """Filter the top beam pathsums and extend the paths"""
    values[:b], indices = ps.topk(dim=-1, k=beamSize) # values is [batch, beam]

    # indices represent max over arrays of size units * beam
    # Their parent must be at idx/units in the queue.
    parents = indices // tags

    # the child is the actual index
    children = indices % tags

    """Extend paths using new values"""
    numFinished = 0
    for qidx, valBeam, childbeam, parentBeam in zip(live, values, children, parents):
        """
            Narrow our sight to each example:
                At qidx, extend the path of parentBeam[i] with childBeam[i]
                You will get beam no. of new paths.
                This is your new path beam.
        """
        newQueue = []
        for v, c, p in zip(valBeam, childbeam, parentBeam):
            oldPath = paths[qidx][p]
            newPath = oldPath + [c.item()]
            newQueue.append(newPath)
        paths[qidx] = newQueue

        """Mark completed if the lenghts of the paths match the word count"""
        if len(newQueue[0]) == lengths[qidx]:
            numFinished += 1
            dead.append(qidx)
    
#     print(*paths, sep='\n')
    """If an example is done, it has to be at the end of the live array"""
    for _ in range(numFinished):
        live.pop()

    """
        Rearrange the prevTarget and hiddenState using indices
        Note that values = torch.gather(ps, -1, indices) is a way to go from [batch, units*beam] and [batch, beam]

        * ps -> [batch, units*beam] and values -> [batch, beam]
        use the same flow as logProbs -> ps -> values 
    """
    # remove dead sequences
    children = children[:b-numFinished]
    parents = parents[:b-numFinished]
    prevTarget = model.sequenceLabeller.targetEmbedding(children.T.reshape(-1))
    
    if t == 0:
        hiddenState = torch.unsqueeze(hiddenState, 0)
    else:
        hiddenState = hiddenState.reshape((beamSize, b, -1))
    unfolded = parents.T.reshape(-1)
    runner = torch.arange(b-numFinished).repeat(beamSize)
    hiddenState = hiddenState[unfolded, runner]

#     print("Hidden State shape", hiddenState.shape)
#     print("prevTarget shape", prevTarget.shape)
# print(*[page.shape for page in encodedPages], sep='\n')

batchSize=47, beamSize=4, units=256, tags=17


In [8]:
# decode results
results = [x[0] for x in paths]
unsorted = []
for i in words.unsorted_indices:
    unsorted.append([data.tags[j] for j in results[i]] )

In [9]:
idx = 7
actual = [data.tags[j] for j in data[idx][2]] 
predicted = unsorted[idx]

print(actual)
print(predicted)

['O', 'O', 'S-MISC', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'E-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['O', 'O', 'S-MISC', 'O', 'O', 'O', 'O', 'O', 'B-MISC', 'E-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
