### Ingest data

In [1]:
import os
import json

os.chdir('..')
os.listdir('data')

['single_entity.jsonl', 'v4.coref-separated.jsonl']

In [2]:
# Reading from jsonlines file
with open('data/single_entity.jsonl', 'rb') as f:
    lines = f.readlines()
lines = [json.loads(line.decode('utf-8')) for line in lines]
line = lines[0]
text = line['text']
line

{'text': 'We believe TEMP.Mana, a Chinese cyber espionage group, is linked to infrastructure spoofing domains of at least two U.S. chemical manufacturers. Similar activity suspected of being tied to TEMP.Mana reinforces the risk to the chemical sector and related industries.',
 'meta': {'source': 'fireeye-0-0'},
 'spans': [{'start': 68,
   'end': 99,
   'label': 'ENT',
   'token_start': 13,
   'token_end': 16,
   '_private_string': 'infrastructure spoofing domains'}]}

### Import Sparta ent model

In [3]:
from allennlp.data.dataset_readers.dataset_utils.span_utils import enumerate_spans
from src.model import Sparta_Ent
from src.data import Data_Handler

model = Sparta_Ent()
handler = Data_Handler()
doc = handler.process_sentence(text)

span_tuples = enumerate_spans(doc.doc, max_span_width=7)
span_encodings = model.encode_spans(doc, span_tuples)

In [4]:
span_encoding = model.encode_span(doc, (0,4), verbose=True)
print(span_encoding.shape)

['▁we', '▁believe', '▁', 'temp', '.', 'man', 'a', '▁', ',', '▁a']
torch.Size([10, 768])


Note that the `Doc_Span` and `Doc_Token` follow AllenNLP's inclusive spans instead of spacy's   exclusive spans

In [5]:
doc[0:4]

We believe TEMP.Mana, a

In [6]:
doc.doc[0:5]

We believe TEMP.Mana, a

### Get reference encoding

In [7]:
ref_text = "The APT10 group (also known as Red Tears) is responsible and is linked to the North Korean nexus, and is charged with discharging the CryptoMix virus, retrieving money, phone numbers, details and credentials"
ref_doc = handler.process_sentence(ref_text)
reference_encoding = model.encode_span(ref_doc, (1,1))

### Score the span candidates

In [8]:
# def rank(model, )

In [9]:
span_scores = {span_tuple:model.encoding_sim_score(query_encoding, reference_encoding) 
               for span_tuple, query_encoding in zip(span_tuples,span_encodings)}
sorted_span_scores = {k: v for k, v in 
                      sorted(span_scores.items(), key=lambda item: item[1], reverse=True)}

In [10]:
sorted_span_scores

{(1, 7): tensor(4.8290, grad_fn=<SumBackward0>),
 (0, 6): tensor(4.7705, grad_fn=<SumBackward0>),
 (2, 8): tensor(4.7536, grad_fn=<SumBackward0>),
 (2, 7): tensor(4.4192, grad_fn=<SumBackward0>),
 (0, 5): tensor(4.4024, grad_fn=<SumBackward0>),
 (1, 6): tensor(4.3962, grad_fn=<SumBackward0>),
 (31, 37): tensor(4.1353, grad_fn=<SumBackward0>),
 (29, 35): tensor(4.0843, grad_fn=<SumBackward0>),
 (30, 36): tensor(4.0653, grad_fn=<SumBackward0>),
 (0, 4): tensor(4.0287, grad_fn=<SumBackward0>),
 (1, 5): tensor(4.0281, grad_fn=<SumBackward0>),
 (2, 6): tensor(3.9864, grad_fn=<SumBackward0>),
 (28, 34): tensor(3.9664, grad_fn=<SumBackward0>),
 (27, 33): tensor(3.8204, grad_fn=<SumBackward0>),
 (29, 34): tensor(3.7729, grad_fn=<SumBackward0>),
 (31, 36): tensor(3.7723, grad_fn=<SumBackward0>),
 (30, 35): tensor(3.7508, grad_fn=<SumBackward0>),
 (26, 32): tensor(3.7170, grad_fn=<SumBackward0>),
 (1, 4): tensor(3.6543, grad_fn=<SumBackward0>),
 (19, 25): tensor(3.6371, grad_fn=<SumBackward0>),


### Spacy line to doc

In [11]:
class Doc_Tokens:
    def __init__(self, doc, fullword_tokens, subword_tokens, subword_idx):
        self.doc = doc
        self.fullword_tokens = fullword_tokens
        self.subword_tokens = subword_tokens
        self.subword_idx = subword_idx

    def __getitem__(self, val):
        """
        Slice the doc. This is wrt to AllenNLP's inclusive spans.
        Meaning that it is not compatible with Spacy's exclusive spans
        
        For example, the input doc[0:4] in allennlp's inclusive span 
        will return the equivalent of self.doc[0:5] in spacy's terms

        reference for __getitem__:
        https://stackoverflow.com/questions/2936863/implementing-slicing-in-getitem
        """
        if isinstance(val, slice):
            # +1 because allennlp spans are inclusive
            return self.doc[val.start: val.stop+1]
            if val.stop == val.start:
                return self.doc[val.start: val.stop+1]
            else:
                return self.doc[val.start: val.stop]
        else:
            return self.doc[val]
        
class Doc_Span(Doc_Tokens):
    def __init__(self, doc_tokens:Doc_Tokens, span_tuple:tuple):
        self.doc = doc_tokens.doc
        self.fullword_tokens = doc_tokens.fullword_tokens
        self.subword_tokens = doc_tokens.subword_tokens
        self.subword_idx = doc_tokens.subword_idx
        self.span_tuple = span_tuple
        self.span_text = self.doc[self.span_tuple[0]:self.span_tuple[1]+1].text

doc_span = Doc_Span(doc, (2,8))
doc_span.span_text

'TEMP.Mana, a Chinese cyber espionage group'

In [12]:
doc_span = Doc_Span(doc, (1,7))
doc_span.span_text

'believe TEMP.Mana, a Chinese cyber espionage'

### Sparta loss
input: Query, reference, Negatives

In [32]:
import torch

def brute_sum_tensors(list_of_tensors):
    return torch.sum( torch.stack( list_of_tensors ) )

def xentropy_l2r(model, query_span: Doc_Span, references: list, negatives: list):
    """
    Cross-entropy learning to rank loss as defined in:
    https://arxiv.org/pdf/2009.13013.pdf
    
    A stronger reference for this loss is eq (3) in:
    https://papers.nips.cc/paper/2009/file/2f55707d4193dc27118a0f19a1985716-Paper.pdf
    
    # TODO: this should get batched
    """
    # perform encoding
    query_encoding = model.encode_span(query_span, query_span.span_tuple)
    reference_encodings = [model.encode_span(reference_span, reference_span.span_tuple) 
                           for reference_span in references]  if type(references[0]) != torch.Tensor else references
    negative_encodings = [model.encode_span(neg_span, neg_span.span_tuple) 
                         for neg_span in negatives] if type(negatives[0]) != torch.Tensor else negatives

    # calc loss
    pos_sim = [model.encoding_sim_score(query_encoding, reference_encoding) 
               for reference_encoding in reference_encodings]
    neg_sim = [torch.exp( model.encoding_sim_score(query_encoding, negative_encoding) ) 
               for negative_encoding in negative_encodings]
    loss = torch.log( brute_sum_tensors(neg_sim) ) - brute_sum_tensors(pos_sim)
    return loss

%time loss = xentropy_l2r(model, Doc_Span(doc, (2,8)), [Doc_Span(ref_doc, (1,2))]*2 ,[Doc_Span(doc, (1,7))]*2 )
# loss

Wall time: 470 ms


### Intra document loss
1 query and label against support

In [55]:
references_spans = [Doc_Span( handler.process_sentence(ref_text), (1,2)) ]
answer_spans = [Doc_Span(doc, (2,8))]
print(references_spans[0].span_text)
print(answer_spans[0].span_text)

APT10 group
TEMP.Mana, a Chinese cyber espionage group


In [47]:
def intra_doc_loss(answer_spans: list, references_spans: list):
    
    # genereate correct encodings
    answer_spans_encodings = model.encode_spans( answer_spans[0], [answer_span.span_tuple for answer_span in answer_spans] )
    
    # generate wrong spans
    all_possible_spans = enumerate_spans( answer_spans[0].doc, max_span_width = 7)
    all_wrong_spans = [span for span in all_possible_spans 
                       if span not in [answer_span.span_tuple for answer_span in answer_spans]]
    all_wrong_span_encodings = model.encode_spans( answer_spans[0], all_wrong_spans )
    
    # iterate through references_spans & sum loss
    losses = []
    for reference in references_spans:
        loss_ = xentropy_l2r(model, reference, answer_spans_encodings, all_wrong_span_encodings)
        losses.append(loss_)
        
    return brute_sum_tensors(losses)

intra_doc_loss(answer_spans, references_spans)

tensor(5.0098, grad_fn=<SumBackward0>)

### Interdocument loss
support against support

In [58]:
import random
import numpy as np
query = np.random.choice(references_spans)

In [60]:
positives = [references_span for references_span in references_spans if references_span != query]

In [61]:
positives

[]