In [1]:
import os
import json

os.chdir('..')
os.listdir('data/seeds')

['adversary-org.jsonl']

### Read Support Data

In [2]:
# Reading from jsonlines file
with open('data/seeds/adversary-org.jsonl', 'rb') as f:
    lines = f.readlines()
lines = [json.loads(line.decode('utf-8')) for line in lines]
line = lines[0]

import torch
from transformers import AdamW

from src.data import Data_Handler

handler = Data_Handler()
references = []
for line in lines:
    references.append( handler.process_spacy_annot(line) )
    
unique_texts = list(set([reference[0].doc.text for reference in references]))
references

[[<src.data.Doc_Tokens at 0x227a2441388>, [(0, 1), (76, 77)]],
 [<src.data.Doc_Tokens at 0x227a248ea48>, [(2, 2), (5, 8), (31, 31)]],
 [<src.data.Doc_Tokens at 0x227a520dec8>,
  [(12, 12), (45, 45), (89, 89), (72, 73)]]]

In [3]:
references

[[<src.data.Doc_Tokens at 0x227a2441388>, [(0, 1), (76, 77)]],
 [<src.data.Doc_Tokens at 0x227a248ea48>, [(2, 2), (5, 8), (31, 31)]],
 [<src.data.Doc_Tokens at 0x227a520dec8>,
  [(12, 12), (45, 45), (89, 89), (72, 73)]]]

### Support Class

In [4]:
class Support:
    """
    One Class support
    """
    def __init__(self, label, references, prototype_rep = None):
        self.label = label
        self.references = references
        self.prototype_rep = prototype_rep
        
sup = Support("adversary", references)

### Loading Support

In [5]:
from src.dygie_ent import Dygie_Ent
model = Dygie_Ent()

### 1. Calculate prototype

In [7]:
# def get_unique_spans(references):
#     unique_texts = list(set([reference.doc.text for reference in references]))
    
#     text_2_spans = []
#     for unique_text in unique_texts:
#         print(unique_text)
#         spans = [ref.span_tuple for ref in references if ref.doc.text == unique_text]
#         doc = handler.process_sentence(unique_text)
#         text_2_spans.append([doc, spans])
        
#     return text_2_spans
# doc_span_pairs = get_unique_spans(references)
# doc_span_pairs

In [6]:
def get_prototype(doc_span_pairs):
    """
    Sample input:
    ------------
    [[<src.data.Doc_Tokens at 0x2793506e4c8>,
      [(12, 12), (45, 45), (89, 89), (72, 73)]],
     [<src.data.Doc_Tokens at 0x27935047e48>, [(0, 1), (76, 77)]],
     [<src.data.Doc_Tokens at 0x27938cf2388>, [(2, 2), (5, 8), (31, 31)]]]
     
    :param doc_span_pairs: list of lists which contains a pair of docs and list of spans
    :return prototype_encoding: Torch tensor of shape (1 x embedding)
    """
    
    support_encodings = []
    for doc, span in doc_span_pairs:
        encodings = model.encode_spans(doc, span)
        support_encodings.append(encodings)
    prototype_encoding = torch.cat(support_encodings, axis=1)
    prototype_encoding = torch.mean(prototype_encoding, axis=1, keepdim=True)
    
    return prototype_encoding

get_prototype(references)

tensor([[[ 2.5397e-01,  2.7272e-01, -6.6942e-02,  4.9511e-02,  5.1519e-02,
          -1.8112e-01,  2.1526e-01,  2.7567e-03, -9.9313e-02,  2.9337e-01,
          -3.7707e-01, -8.6256e-03,  3.6425e-01,  5.1284e-02,  7.7859e-02,
          -2.1698e-02, -2.0512e-01,  1.0793e-01, -7.7208e-02,  6.0131e-03,
          -4.5704e-01, -2.5524e-01,  2.0992e-01, -2.6664e-01, -2.3550e-01,
          -7.9316e-02,  3.2734e-02,  5.4942e-02,  8.9804e-02,  8.2831e-02,
           5.7981e-02, -4.7596e-01,  1.8200e-02,  2.6689e-01, -4.7884e-01,
          -2.2027e-03,  3.6676e-02,  1.7224e-01,  1.3287e-01, -3.1933e-01,
          -1.5459e-01, -1.5890e-01,  1.4062e-02, -1.6998e-02, -1.3967e-01,
          -8.6211e-02,  4.0028e-01, -8.3599e-02, -4.7588e-02, -4.1036e-01,
          -3.3558e-01,  1.9723e-01, -5.1036e-01, -1.6826e-01,  5.8366e-02,
           5.7200e-01, -4.7079e-01,  1.0036e-01, -4.6616e-02,  1.9194e-01,
           4.1204e-01,  1.7132e-01, -4.5358e-01,  1.7432e-02, -1.5045e-01,
           2.5946e-01, -1

### Forward function

Create input doc

In [7]:
text = "The activity of the advanced hacker group the researchers call Silence has increased significantly over the past year. Victims in the financial sector are scattered across more than 30 countries and financial losses have quintupled.\n The group started timidly in 2016, learning the ropes by following the path beaten by other hackers. Since then, it managed to steal at least $4.2 million, initially from banks in the former Soviet Union, then from victims in Europe, Latin America, Africa, and Asia.\n Researchers at Group-IB, Singapore-based cybersecurity company specializing in attack prevention, tracked Silence early on and judged its members to be familiar with white-hat security activity.\n A report last year\xa0details the roles of Silence hackers, their skills, failures, and successful bank heists"
text

'The activity of the advanced hacker group the researchers call Silence has increased significantly over the past year. Victims in the financial sector are scattered across more than 30 countries and financial losses have quintupled.\n The group started timidly in 2016, learning the ropes by following the path beaten by other hackers. Since then, it managed to steal at least $4.2 million, initially from banks in the former Soviet Union, then from victims in Europe, Latin America, Africa, and Asia.\n Researchers at Group-IB, Singapore-based cybersecurity company specializing in attack prevention, tracked Silence early on and judged its members to be familiar with white-hat security activity.\n A report last year\xa0details the roles of Silence hackers, their skills, failures, and successful bank heists'

In [8]:
doc = handler.process_sentence(text)
# for tok in doc.doc:
#     print(tok.i, tok)
ent_spans = [(10,10), (38,39), (111, 111), (138,139)]

Init model

In [9]:
from src.dygie_ent import Dygie_Ent
model = Dygie_Ent()

Create spans of the document

In [47]:
from allennlp.data.dataset_readers.dataset_utils.span_utils import enumerate_spans
all_spans = enumerate_spans(doc.doc, max_span_width=4)
labels = [ int(span in ent_spans) for span in all_spans]

# convert to torch

In [71]:
model.parameters()

<generator object Module.parameters at 0x00000227A8111E48>

In [48]:
len(labels)

594

Encode `all_spans`  
Running for 594 spans  
282 ms ± 7.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [27]:
# %timeit encodings = model.encode_spans(doc, all_spans)
encodings = model.encode_spans(doc, all_spans)
encodings.shape

torch.Size([1, 594, 768])

Get Prototype representation

In [14]:
model.references['Adversary:Actor'] = references

In [15]:
references

[[<src.data.Doc_Tokens at 0x227a2441388>, [(0, 1), (76, 77)]],
 [<src.data.Doc_Tokens at 0x227a248ea48>, [(2, 2), (5, 8), (31, 31)]],
 [<src.data.Doc_Tokens at 0x227a520dec8>,
  [(12, 12), (45, 45), (89, 89), (72, 73)]]]

In [17]:
label_strings, prototype_encodings = model.get_prototypes()

Calculate the distances
1. The `prototype_encodings` need to be concatenated with zero vectors. The resulting shape should be (n_classes, 2, hidden_size)  
2. `torch.cdist`

In [18]:
prototype_encodings.shape

torch.Size([1, 1, 768])

In [26]:
prototype_encodings = torch.cat( [prototype_encodings, torch.zeros(prototype_encodings.shape)], axis=1 )
prototype_encodings.shape

torch.Size([1, 2, 768])

In [30]:
distances = torch.cdist(encodings, prototype_encodings)
distances.shape

torch.Size([1, 594, 2])

In [39]:
prob_class = torch.nn.LogSoftmax(dim=2)(distances)

In [44]:
prob_class.shape

torch.Size([1, 594, 2])

In [63]:
from torch import nn
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)

In [42]:
input.shape

torch.Size([3, 5])

In [43]:
target.shape

torch.Size([3])

In [59]:
nn.CrossEntropyLoss?

In [67]:
prob_class.squeeze(0).shape

torch.Size([594, 2])

In [68]:
loss = nn.CrossEntropyLoss()
loss(prob_class.squeeze(0), torch.LongTensor(labels))

tensor(1.2614, grad_fn=<NllLossBackward>)

In [57]:
torch.LongTensor(labels).unsqueeze(dim=0)

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0

### Forward function

In [None]:
def forward(doc, max_span_width=4):
    all_spans = enumerate_spans(doc.doc, max_span_width=max_span_width)
    encodings = model.encode_spans(doc, all_spans)
    
    support_encodings = torch.cat( [self.prototype_encodings, torch.zeros(self.prototype_encodings.shape)], axis=1 )
    distances = torch.cdist(encodings, support_encodings)
    prob_class = torch.nn.LogSoftmax(dim=2)(distances)
    
    return prob_class
    