In [None]:
import torch
import spacy
import transformers
import itertools
import heapq
import matplotlib.pyplot as plt

In [None]:
tok = transformers.AutoTokenizer.from_pretrained('KB/bert-base-swedish-cased')
model = transformers.AutoModel.from_pretrained('KB/bert-base-swedish-cased').eval()
nlp = spacy.load('../data/sv_model_xpos/sv_model0/sv_model0-0.0.0/')

In [None]:
for parameter in model.parameters():
    parameter.requires_grad=False

In [None]:
TXT="""Bob Dylan föddes som Robert Zimmerman i staden Duluth, Minnesota men strax innan han fyllde sex år och efter att hans far fått polio flyttade familjen till den närliggande staden Hibbing, Minnesota där han sedan växte upp."""
"""
Familjen Zimmerman var judisk och deras förfäder hade utvandrat från Ryssland, Ukraina, Litauen och Turkiet. Morfar och mormor - Benjamin och Liba Edelstein (senare Stein och Stone) - var litauiska judar som emigrerade till USA 1902.
När Bob Dylan var åtta-nio år började han spela på familjens piano. Därefter lärde han sig att spela munspel och gitarr.[3] Mycket av hans ungdomstid gick åt till att lyssna på radio där han tog in stationer som sände blues, country och tidig rock'n'roll. Han började uppträda i mitten av 1950-talet och var medlem i ett flertal band under sin tid i high school.

1959 började han studera på universitetet i Minneapolis. I samma veva tog hans intresse för folkmusik fart. Det var också nu han började presentera sig som Bob Dylan. Var han fått namnet ifrån finns det flera historier om. Vissa menar att det är inspirerat av poeten Dylan Thomas. År 2004 skrev han själv om hur han valde namnet i sin bok Memoarer, första delen"""


In [None]:
def parse_sentence(sentence, nlp, tok):
    
    doc = nlp(sentence)
    length = 0
    input_ids = []
    word_start = []
    
    
    
    ixs, tokens = zip(*[(ix, token) for (ix, token) in enumerate(doc) if not token.is_space])
    
    wordpieces_for_token = tok(
            [token.string for token in tokens],
            add_special_tokens=False, 
            padding=False, 
            return_token_type_ids=False, 
            return_attention_mask=False)['input_ids']
    
    for wordpieces in wordpieces_for_token:
        assert len(wordpieces) > 0, "Empty token makes program sad"
        word_start.append(length)
        input_ids += wordpieces
        length += len(wordpieces)
    
    nouns = []
    spans = []
    ptr = 0
    
    chunkfix = {i:j for j,i in enumerate(ixs)}
    
    for chunk in doc.noun_chunks:
        start = chunkfix[chunk.start]
        end = chunkfix[chunk.end]
        #Add the next chunk to spans        
        for i in range(ptr, start):
            #Add all non noun chunks to the span
            nouns.append(False)
            spans.append(word_start[i])
            
        #Add the chunk to the spans
        nouns.append(True)
        spans.append(word_start[start])
        ptr = end
        
    #Add trailing (non noun) chunks to the span
    nouns.extend([False for start in word_start[ptr:]])
    spans.extend([start for start in word_start[ptr:]])
    
    # Add cls token
    input_ids = torch.LongTensor([tok.cls_token_id, *input_ids])
    nouns = [False] +  nouns
    spans = [0] + [start + 1 for start in spans]
    
    spans = list(zip(spans, spans[1:] + [len(input_ids)]))
    
    return input_ids, nouns, spans

In [None]:
def test():
    those = [noun.string.strip() for noun in nlp(TXT).noun_chunks]
    these = []
    input_ids, nouns, spans = parse_sentence(TXT, nlp, tok)
    for noun, (start, stop) in zip(nouns, spans):
        if noun: 
            these.append(tok.decode(input_ids[start:stop]))
    assert those == these, "Spacy nouns does not match our nouns"
test()

In [None]:
def compress_attention(attention, spans):
    csatt = attention.cumsum(0).cumsum(1)

    starts, ends = zip(*spans)
    starts = torch.LongTensor(starts) - 1
    ends = torch.LongTensor(ends) - 1
    
    ret = csatt[starts, :][:, starts] - csatt[starts, :][:, ends] - csatt[ends, :][:, starts] + csatt[ends, :][:, ends] 
    ret[0,:] = csatt[0,ends] - csatt[0,starts]
    ret[:,0] = csatt[ends,0] - csatt[starts,0]
    ret[0,0] = attention[0,0]
    ret /= (ends - starts)[:, None]
    return ret

In [None]:
def get_provenance(model, input_ids, spans):
     
    attentions = model(input_ids.unsqueeze(0), output_attentions=True)['attentions']
    
    N = len(spans)
    
    ret = torch.eye(N)
    rets = []
    
    for tmp in map(lambda att: compress_attention(att.mean(1).squeeze(0), spans), attentions):
        ret = (ret + torch.eye(N)) @ tmp
        ret /= 2
        rets.append(ret)
        
    return torch.stack(rets).mean(0)

In [None]:
def get_attention(model, input_ids, spans):
    # Average attention over heads in the last layer
    # (Using cumulative sum)
    att = model(input_ids.unsqueeze(0), output_attentions=True)['attentions'][-1].mean(1).squeeze(0)
    return compress_attention(att, spans)

In [None]:
x = torch.randn(3,3)
x

In [None]:
x[1, 2]

In [None]:
def get_triplets(sentence, model, nlp, tok):
    input_ids, nouns, spans = parse_sentence(sentence, nlp, tok)
    attention = get_attention(model, input_ids, spans)
    
    noun_set = set([i for i, noun in enumerate(nouns) if noun])
    rel_ixs = [i for i, noun in enumerate(nouns) if not noun]
    att = attention
    fwd_cache = attention.diag(1).cumsum(0)
    bwd_cache = attention.diag(-1).cumsum(0)
    
    def lemmatized(start,stop=None):
        lb = spans[start][0]
        ub = spans[stop if stop else start][1]
        txt = tok.decode(input_ids[lb:ub])
        return ' '.join([token.lemma_ for token in nlp(txt)])
        
    def get_scores(head, tail):
        scores = []
        for start in range(head+1, tail):
            if start in noun_set: continue
            for stop in range(start, tail):
                if stop in noun_set: break
                
                ### Forward attention (head reads from x, x reads from tail)
                # Calculate internal part
                fwd = fwd_cache[stop-1] - fwd_cache[start-1]
                # Calculate ends
                fwd += att[head, start] + att[stop, tail]
                
                ### Backward version (head writes to x, x writes to tail)
                # Calculate internal part
                bwd = bwd_cache[stop-1] - fwd_cache[start-1]
                # Calculate ends
                bwd += att[start, head] + att[tail, stop]
        
                score = max(fwd, bwd)
                scores.append((score, start, stop))
        
        return scores
        
    for head, tail in itertools.product(noun_set, noun_set):
        rels = heapq.nlargest(5, get_scores(head, tail))
        for (score, start, stop) in rels:
            yield (score, lemmatized(head), lemmatized(tail), lemmatized(start, stop))

            
for score, head, tail, relation in get_triplets(TXT, model, nlp, tok):
    print('{:.3f} {} -- {} -- {}'.format(score, head, relation, tail))

In [None]:
black_list_relation = set([ token2id[n]  for n in noun_chunks ])
all_relation_pairs = []
id2token = { value: key for key, value in token2id.items()}
with Pool(10) as pool:
    params = [  ( pair[0], pair[1], attn_graph, max(tokenid2word_mapping), black_list_relation, ) for pair in tail_head_pairs]
    for output in pool.imap_unordered(bfs, params):
    if len(output):
        all_relation_pairs += [ (o, id2token) for o in output ]
        
triplet_text = []
with Pool(10, global_initializer, (nlp,)) as pool:
    for triplet in pool.imap_unordered(filter_relation_sets, all_relation_pairs):
        if len(triplet) > 0:
            triplet_text.append(triplet)
return triplet_text