In [1]:
import re
import json
from tqdm.notebook import tqdm
import torch
import pytorch_lightning as pl
from data.dataset import NERDataset
from models.utils import Namespace, getSignal
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence, PackedSequence
from models.networks import GlobalContextualDeepTransition
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader 

In [2]:
sourceName = 'data/conll03/eng.train.src'
targetName = 'data/conll03/eng.train.trg'
gloveFile = 'data/conll03/trimmed.300d.Cased.txt'
symbFile = 'data/conll03/sym.glove'
testSrc = 'data/conll03/eng.testb.src'
testTrg = 'data/conll03/eng.testb.trg'

data = NERDataset(sourceName, targetName, gloveFile, symbFile)
data.readTestFile(testSrc, testTrg)
loader = data.getLoader(1024, shuffle=False)
# loader = DataLoader(data, collate_fn=data.pack_collate)

In [3]:
prevCheckpointPath = 'lightning_logs/checkpoint-v0.ckpt'

with open('config.json', 'r') as file:
    kwargs = json.load(file)
    
model = GlobalContextualDeepTransition.load_from_checkpoint(prevCheckpointPath, **kwargs)
model = model.eval()

In [4]:
def logitsToLogProbs(logits):
    return logits - torch.logsumexp(logits, dim=1, keepdim=True)

In [5]:
with torch.no_grad():
    words, chars, charMask, targets = batch = next(iter(loader))
    encoded, initHiddenState, initPrevTarget = model.encode(words, chars, charMask)

In [97]:
def pathSum(values, logProbs):
    """
        Adds the prev sum to current logProbs
        to get the effective logprob
    """
    # values is batch, beam
    values = torch.unsqueeze(values, -1) # batch, beam, 1
    values = values.repeat(1, 1, units)  # batch, beam, units
    values = values.permute(1,0,2)       # beam, batch, units
    values = torch.cat(list(values), -1) # batch, units * beam
    
    # logprobs is [batch, units*beam]
    ps = logProbs + values
    
    # ps is [batch, units*beam]
    return ps

In [101]:
batchSize = words.batch_sizes[0].item() # batchSize
beamSize  = 4 # beamsize
units = model.sequenceLabeller.decoderUnits
print(f"batchSize={batchSize}, beamSize={beamSize}, units={units}")

lengths = torch.zeros(batchSize, dtype=torch.int)
live = list(range(batchSize))
dead = []

for x in words.batch_sizes:
    lengths[:x] += 1

"""
    values[i, j] contains the heuristic beam of the ith example. j in range(beamSize)
    paths [i, j] contains the corresponding paths
"""
values = torch.zeros(batchSize, beamSize) # we maintain a queue like tensor, each example has a queue of size beamSize
paths  = [ [list() for _ in range(beamSize)] for _ in range(batchSize) ] # one node in each queue as root of tree

batchSize=47, beamSize=4, units=256


In [96]:
"""Init nodes in a matrix for each beam and """

# encoded pages
start = 0
encodedPages = []
for pageLen in words.batch_sizes:
    if start == 0:
        page = encoded[start:start+pageLen] # first page is not repeated
    else:
        page = encoded[start:start+pageLen].repeat(beamSize, 1)
    encodedPages.append(page) # [e1, e2, e3, e1, e2, e3.. etc, repeated beamSize times]
    start += pageLen



In [119]:
# initial values are not repeated
hiddenState = initHiddenState
prevTarget = initPrevTarget


for t, b in enumerate(words.batch_sizes):
    """Get the previous target and make the forward pass"""
    if t == 0:
        actualSize = b.item()
    else:
        actualSize = b.item() * beamSize
    prevTarget = prevTarget[:actualSize] + getSignal(1, units, t, model.device)
    with torch.no_grad():
        hiddenState, logits = model.sequenceLabeller.decode_once(
            encodedPages[t],
            prevTarget,
            hiddenState
        )
    logProbs = logitsToLogProbs(logits) # [b*beamSize, units] ie [l1, l2, l3, l1, l2, l3 ... numTag d vectors repeated]

    """Add the logProbs to the current paths to get newPathSums"""
    if t == 0:
        ps = logProbs
    else:
        logProbs = logProbs.reshape((beamSize, b.item(), units)) # now becomes [[l1, l2, l3], [l1, l2, l3], [l1, l2, l3]...]
        logProbs = torch.cat(list(logProbs), dim=-1) # [l1l1l1..., l2l2l2..., l3l3l3...]
        ps = pathSum(values, logProbs)

    """Filter the top beam pathsums and extend the paths"""
    values, indices = ps.topk(dim=-1, k=beamSize) # values is [batch, beam]

    # indices represent max over arrays of size units * beam
    # Their parent must be at idx/units in the queue.
    parents = indices // units

    # the child is the actual index
    children = indices % units

    """Extend paths using new values"""
    numFinished = 0
    for qidx, valBeam, childbeam, parentBeam in zip(live, values, children, parents):
        """
            Narrow our sight to each example:
                At qidx, extend the path of parentBeam[i] with childBeam[i]
                You will get beam no. of new paths.
                This is your new path beam.
        """
        newQueue = []
        for v, c, p in zip(valBeam, childbeam, parentBeam):
            oldPath = paths[qidx][p]
            newPath = oldPath + [c.item()]
            newQueue.append(newPath)
        paths[qidx] = newQueue

        """Mark completed if the lenghts of the paths match the word count"""
        if len(newQueue[0]) == lengths[qidx]:
            numFinished += 1
            dead.append(qidx)

    """If an example is done, it has to be at the end of the live array"""
    for _ in range(numFinished):
        live.pop()

    """
        Rearrange the prevTarget and hiddenState using indices
        Note that values = torch.gather(ps, -1, indices) is a way to go from [batch, units*beam] and [batch, beam]

        * ps -> [batch, units*beam] and values -> [batch, beam]
    """
    print(indices) 
    prevTarget = model.sequenceLabeller.targetEmbedding(indices.reshape(-1))
#     hiddenState = torch.gather(hiddenState, -1, indices)
    print(prevTarget.shape)
    print(hiddenState.shape)
    break
# print(*[page.shape for page in encodedPages], sep='\n')

tensor([[12, 13, 15, 14],
        [12, 13, 15, 14],
        [12, 13, 15, 14],
        [12, 13, 15, 14],
        [12,  7, 14, 13],
        [12, 13, 14,  7],
        [12, 13, 15, 14],
        [12,  7, 14,  3],
        [12,  7, 14,  3],
        [12, 13, 15, 14],
        [12,  7, 14, 13],
        [12,  7, 14,  3],
        [12,  9,  0,  7],
        [12, 13, 15, 14],
        [12,  0,  9,  7],
        [12,  7, 14,  0],
        [12, 13, 15, 14],
        [12,  7, 14,  3],
        [12, 13, 14, 15],
        [12,  7, 14,  0],
        [12, 13, 15, 14],
        [12, 13, 15, 14],
        [12,  9,  0, 14],
        [12,  7,  9,  0],
        [12, 13, 15, 14],
        [12,  7, 14,  3],
        [12, 13, 15, 14],
        [12,  9,  0, 14],
        [12,  7, 14,  3],
        [12, 13, 14, 15],
        [12,  7,  9,  0],
        [12, 13, 15, 16],
        [12,  7, 14,  3],
        [12, 14,  7,  3],
        [12,  0,  9, 14],
        [12, 13, 15, 16],
        [12,  9,  0,  7],
        [12, 13, 14, 15],
        [12,

In [115]:
torch.Tensor([
    [1,2,3],
    [4, 5, 6]
]).reshape(-1)

tensor([1., 2., 3., 4., 5., 6.])

In [None]:
batchSize = 5
beamSize = 4

hs = torch.randn(batchSize, 256)
