### Load Transformer and spacy

In [2]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('albert-base-v2', use_fast=False)
model = AutoModel.from_pretrained('albert-base-v2')

In [21]:
import spacy 
nlp = spacy.load('en_core_web_sm')

text = "The APT10 group (also known as Temp.Mana) is responsible and is linked to the North Korean nexus, and is charged with discharging the CryptoMix virus, retrieving money, phone numbers, details and credentials"

doc = nlp(text)
doc_tokens = [_.text for _ in doc]

### Process the text

In [12]:
from allennlp.data.dataset_readers.dataset_utils.span_utils import enumerate_spans
span_width_tuples = enumerate_spans(doc, max_span_width=5)

In [138]:
fullword_tokens = [tokenizer.tokenize(text) for text in doc_tokens]
subword_tokens = [__ for _ in fullword_tokens for __ in _]
subword_idx, k = [], 0
for subword in fullword_tokens:
    k+=len(subword)
    subword_idx.append(k)
subword_idx = ([0] + subword_idx)[:-1]
display( list(zip(subword_idx, fullword_tokens))[:11] )

[(0, ['▁the']),
 (1, ['▁a', 'pt', '10']),
 (4, ['▁group']),
 (5, ['▁', '(']),
 (7, ['▁also']),
 (8, ['▁known']),
 (9, ['▁as']),
 (10, ['▁', 'temp']),
 (12, ['▁', '.']),
 (14, ['▁man', 'a']),
 (16, ['▁', ')'])]

### Encode specific span

In [103]:
span_width_tuple = (1,2)
span_start, span_end = span_width_tuple
span_start, span_end = subword_idx[span_start], subword_idx[span_end]+1

In [109]:
piece_id = tokenizer.encode(subword_tokens, return_tensors='pt')
encodings = model(piece_id)['last_hidden_state']
span_encoding = encodings[0, span_start+1: span_end+1, :]

### Sparta class

![image.png](attachment:image.png)

In [306]:
from allennlp.data.dataset_readers.dataset_utils.span_utils import enumerate_spans

In [6]:
import spacy 
from torch import nn
import torch.nn.functional as F
import torch
from transformers import AutoTokenizer, AutoModel
from allennlp.data.dataset_readers.dataset_utils.span_utils import enumerate_spans

class Doc_Tokens:
    def __init__(self, doc, fullword_tokens, subword_tokens, subword_idx):
        self.doc = doc
        self.fullword_tokens = fullword_tokens
        self.subword_tokens = subword_tokens
        self.subword_idx = subword_idx
        
def sim_matrix(a, b, eps=1e-8):
    """
    added eps for numerical stability
    https://stackoverflow.com/questions/50411191/how-to-compute-the-cosine-similarity-in-pytorch-for-all-rows-in-a-matrix-with-re
    """
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt
        
class Sparta_Ent:
    
    def __init__(self, bert_name = 'albert-base-v2', spacy_name = 'en_core_web_sm'):
        self.tokenizer = AutoTokenizer.from_pretrained(bert_name, use_fast=False)
        self.model = AutoModel.from_pretrained(bert_name)
        self.nlp = spacy.load(spacy_name)
        self.threshold_bias = nn.Parameter(torch.zeros(1))
        self.act = nn.ReLU()
        
    def process_sentence(self, text):
        """
        :param text: text sentence
        :return Doc_Tokens: Doc_Tokens objects for encoding
        """
        # spacy tokenize
        doc = self.nlp(text)
        doc_tokens = [_.text for _ in doc]
        
        # transformers tokenize
        fullword_tokens = [tokenizer.tokenize(text) for text in doc_tokens]
        subword_tokens = [__ for _ in fullword_tokens for __ in _]
        subword_idx, k = [], 0
        for subword in fullword_tokens:
            k+=len(subword)
            subword_idx.append(k)
        subword_idx = ([0] + subword_idx)[:-1]
        
        return Doc_Tokens(doc, fullword_tokens, subword_tokens, subword_idx)
    
    def encode_span(self, doc_tokens: Doc_Tokens, span_width_tuple: tuple):
        """
        :param doc_tokens: Doc_Tokens object that contains the spans
        :param span_width_tuple: tuple indicating start and end of spacy token spans
        :return span_encoding: (span length x embed dimensions) torch array
        """
        # get span 
        span_start, span_end = span_width_tuple
        span_start, span_end = doc_tokens.subword_idx[span_start], doc_tokens.subword_idx[span_end+1]

        # encode the pieces and slice according to span
        # span starts and end +1 to account for inserted CLS token
        piece_id = self.tokenizer.encode(doc_tokens.subword_tokens, return_tensors='pt')
        encodings = model(piece_id)['last_hidden_state']
        span_encoding = encodings[0, span_start+1: span_end+1, :]
        
        return span_encoding
    
    
    def encode_spans(self, doc_tokens: Doc_Tokens, span_width_tuples: list):
        """
        :param doc_tokens: Doc_Tokens object that contains the spans
        :param span_width_tuples: list of tuple indicating start and end of spacy token spans
        :return span_encoding: (span length x embed dimensions) torch array
        """
        # encode the pieces and slice according to span
        # span starts and end +1 to account for inserted CLS token
        piece_id = self.tokenizer.encode(doc_tokens.subword_tokens, return_tensors='pt')
        encodings = model(piece_id)['last_hidden_state']
        
        # get spans
        span_encodings = []
        for span_width_tuple in span_width_tuples:
            span_start, span_end = span_width_tuple
            span_start, span_end = doc_tokens.subword_idx[span_start], doc_tokens.subword_idx[span_end+1]
            span_encoding = encodings[0, span_start+1: span_end+1, :]
            span_encodings.append(span_encoding)
        
        return span_encodings
    
    def encoding_sim_score(self, query_encoding: torch.Tensor, reference_encoding: torch.Tensor):
        sim = sim_matrix(query_encoding, reference_encoding)
        max_vals, max_ind = torch.max( sim, axis=1 )
        max_vals += self.threshold_bias    
        score = torch.sum( torch.log( self.act(max_vals) + 1 ) )
        return score

In [3]:
tuple

tuple

In [3]:
pip install allennlp

Collecting allennlp
  Using cached allennlp-2.4.0-py3-none-any.whl (625 kB)
Collecting scipy
  Using cached scipy-1.6.3-cp37-cp37m-win_amd64.whl (32.6 MB)
Collecting sentencepiece
  Using cached sentencepiece-0.1.95-cp37-cp37m-win_amd64.whl (1.2 MB)
Collecting scikit-learn
  Using cached scikit_learn-0.24.2-cp37-cp37m-win_amd64.whl (6.8 MB)
Collecting lmdb
  Using cached lmdb-1.2.1-cp37-cp37m-win_amd64.whl (105 kB)
Collecting h5py
  Using cached h5py-3.2.1-cp37-cp37m-win_amd64.whl (2.7 MB)
Collecting huggingface-hub>=0.0.8
  Using cached huggingface_hub-0.0.8-py3-none-any.whl (34 kB)
Collecting more-itertools
  Using cached more_itertools-8.7.0-py3-none-any.whl (48 kB)
Collecting tensorboardX>=1.2
  Using cached tensorboardX-2.2-py2.py3-none-any.whl (120 kB)
Collecting wandb<0.11.0,>=0.10.0
  Using cached wandb-0.10.30-py2.py3-none-any.whl (1.8 MB)
Collecting pytest
  Using cached pytest-6.2.4-py3-none-any.whl (280 kB)
Collecting boto3<2.0,>=1.14
  Using cached boto3-1.17.72-py2.py3-no

In [None]:
pip install spacy==2.3.5

In [302]:
text = "The APT10 group (also known as Temp.Mana) is responsible and is linked to the North Korean nexus, and is charged with discharging the CryptoMix virus, retrieving money, phone numbers, details and credentials"
display(text)

sparta_ent = Sparta_Ent()
doc = sparta_ent.process_sentence(text)

'The APT10 group (also known as Temp.Mana) is responsible and is linked to the North Korean nexus, and is charged with discharging the CryptoMix virus, retrieving money, phone numbers, details and credentials'

In [303]:
ent1 = sparta_ent.encode_span(doc, (1,2))
ent1

tensor([[ 0.9899, -0.3053,  0.7081,  ..., -0.5901,  3.1945,  0.0744],
        [ 0.6409, -0.1234,  0.9752,  ..., -1.6425,  1.1848,  0.1874],
        [-0.2175, -0.3960,  1.2751,  ..., -1.9441,  0.2071, -0.7991],
        [-1.2339, -1.2780,  1.5548,  ...,  0.2599,  2.3435, -0.3532]],
       grad_fn=<SliceBackward>)

In [304]:
ent2 = sparta_ent.encode_span(doc, (7,9))
# ent2 = sparta_ent.encode_span(doc, (3,5))
ent2

tensor([[ 1.0878,  0.3942, -0.4331,  ..., -0.3236,  2.5309,  0.1388],
        [ 0.3049, -0.1178,  0.8762,  ..., -0.9927,  0.9251, -1.0933],
        [ 0.2504, -1.3833,  1.3662,  ..., -0.5650,  1.6812,  0.1072],
        [ 1.5854,  0.4066,  0.3833,  ..., -0.5505,  1.3191, -0.8333],
        [-0.1886, -0.8965,  2.6719,  ..., -0.0414,  0.4988, -1.3243],
        [ 0.4528, -1.0257,  1.8590,  ..., -0.1739,  1.5228, -1.5650]],
       grad_fn=<SliceBackward>)

In [305]:
sparta_ent.encoding_sim_score(ent1, ent2)

tensor(1.7900, grad_fn=<SumBackward0>)

In [203]:
for _ in range(11,17):
    print( tokenizer.decode(int(piece_id[0,_])) )


temp

.
man
a


### Processing

![image.png](attachment:image.png)

In [300]:
def encoding_sim_score(query_encoding, reference_encoding):
    sim = sim_matrix(query_encoding, reference_encoding)
    max_vals, max_ind = torch.max(sim, axis=1)
    max_vals += bias    
    score = torch.sum( torch.log( nn.ReLU()(max_vals) + 1 ) )
    return score

In [283]:
sim = sim_matrix(ent1, ent2)
sim

tensor([[0.4774, 0.3713, 0.3173, 0.3345, 0.3832, 0.4750],
        [0.5126, 0.6466, 0.4511, 0.5007, 0.5570, 0.5628],
        [0.4309, 0.5817, 0.4216, 0.4409, 0.5171, 0.6188],
        [0.4458, 0.4765, 0.4404, 0.4084, 0.4906, 0.5209]],
       grad_fn=<MmBackward>)

In [293]:
max_vals, max_ind = torch.max(sim, axis=1)
max_vals

tensor([0.4774, 0.6466, 0.6188, 0.5209], grad_fn=<MaxBackward0>)

In [294]:
bias = nn.Parameter(torch.zeros(1))
max_vals += bias
max_vals

tensor([0.4774, 0.6466, 0.6188, 0.5209], grad_fn=<AddBackward0>)

In [298]:
torch.sum( torch.log( nn.ReLU()(max_vals) + 1 ) )

tensor(1.7900, grad_fn=<SumBackward0>)