# Analyzing Bert probabilities

In [3]:
import torch, logging
from analysis_utils import *

tokenizer, model, lm_model = prepare_models()

In [2]:
def bert_token_remap(src, tgt):
    src_list = src.split()
    print(src_list)
    src_i = 0
    cont = ''
    res = []
    for tgt_i, tok in enumerate(tgt):
        if tok.startswith('##') and len(tok) > 2:
            tok = tok.lstrip('##')
        curr = cont + tok
        assert src_list[src_i].startswith(curr), 'Mismatch in src and tgt! curr={} src={}'\
                .format(curr, src_list[src_i])
        if curr == src_list[src_i]:
            if curr == '[CLS]' or curr == '[SEP]':
                res.append(None)
            else:
                res.append(src_i - 1)
            src_i += 1
            cont = ''
        else:
            res.append(src_i - 1)
            cont = curr
    return res

In [3]:
src = ['turkmen president gurbanguly berdymukhammedov will begin a two-day visit to russia , his country \'s main energy partner , on monday for trade talks , the kremlin press office said .']
in_tensor, str_toks, attn_mask, scrm_idxs = batch_to_idx_tensor(tokenizer, src)

with torch.no_grad():
    lm_unmasked_output = lm_model(in_tensor, attention_mask=attn_mask)
unmasked_scores = lm_unmasked_output[0].squeeze(0)[1:-1]
# scores are to be fed into softmax to get prob. scores are already in log space.
# use the following to get prob in log space
unmasked_log_probs = unmasked_scores - torch.logsumexp(unmasked_scores, dim=1, keepdim=True)

# now mask each token in src: make batch with same size as src len
mask_in_tensor = in_tensor.repeat_interleave(in_tensor.shape[1] - 2, dim=0)
attn_mask = attn_mask.repeat_interleave(in_tensor.shape[1] - 2, dim=0)
mask_token = tokenizer.mask_token
mask_id = tokenizer.convert_tokens_to_ids([mask_token])[0]
r = torch.arange(mask_in_tensor.shape[0])
idxs = r + 1
mask_in_tensor[r, idxs] = mask_id

with torch.no_grad():
    lm_masked_output = lm_model(mask_in_tensor, attention_mask=attn_mask)
    
masked_scores = lm_masked_output[0][r, idxs]
masked_log_probs = masked_scores - torch.logsumexp(masked_scores, dim=1, keepdim=True)


### Exploring the relationship between surprisal,  KL-divergence, entropy

For each position in a sentence $X = x_1x_2...x_N$, I get two probabilities. 

One is the 'unmasked probability', $P_i(w_i \mid X)$ which is the distribution over words at position $i \in [1, N]$, given all words in the sentence, including $x_i$ itself.

Another is the 'masked probability', $Q_i(w_i \mid X_{-i})$, which is the distribution over words at position $i$, given all words in the sentence except that the word $x_i$ is masked.

The surprisal of the word is calculated as $S = -\log Q_i(x_i \mid X_{-i})$.

For each position $i$, the KL-divergence between the masked and unmasked probabilities is calculated: $KL(P_i, Q_i) = \sum_{w \in \mathcal{V}}P_i(w)\log\left(\frac{P_i(w)}{Q_i(w)}\right)$. This is interpreted as a measure of information loss before and after masking the word $x_i$. 

In [4]:
def kl_divergence(a, b):
    """Calculate KL divergence for a, b in logspace"""
    return (a.exp() * (a - b)).sum(dim=-1)

def entropy(a):
    """Calculate entropy given a with shape [len, vocab_size] for a in logspace"""
    return (-a).sum(dim=-1)

In [5]:

kl = kl_divergence(unmasked_log_probs, masked_log_probs)
unmasked_entropy = entropy(unmasked_log_probs)
masked_entropy = entropy(masked_log_probs)

word_ids = in_tensor[0, 1:-1]
masked_surprisal = -masked_log_probs[r, word_ids]
unmasked_surprisal = -unmasked_log_probs[r, word_ids]

word_strs = str_toks[0][1:-1]
print('Word  mS  umS  KL  umH  mH')
for word, ms, ums, k, ume, me in zip(word_strs, masked_surprisal, unmasked_surprisal, kl, unmasked_entropy, masked_entropy):
    print('{}  {:.4}  {:.4}  {:.4}  {:.4}  {:.4}'.format(word, ms, ums, k, ume, me))

Word  mS  umS  KL  umH  mH
turk  0.002079  3.338  11.49  4.377e+05  6.775e+05
##men  0.04684  0.924  4.896  5.049e+05  6.351e+05
president  0.113  0.004902  0.09585  7.32e+05  5.815e+05
gu  0.2113  0.0002594  0.2096  8.256e+05  6.758e+05
##rba  0.1048  0.01257  0.1196  6.323e+05  5.456e+05
##ng  0.4632  0.1022  0.8444  6.098e+05  6.474e+05
##ul  0.207  0.001814  0.2031  7.347e+05  5.528e+05
##y  0.01636  0.004824  0.01823  7.039e+05  6.489e+05
be  2.118  0.007103  2.108  7.196e+05  5.601e+05
##rdy  7.206  1.483  7.132  5.238e+05  5.907e+05
##mu  3.437  0.5322  2.364  5.481e+05  4.507e+05
##kha  3.781  0.06216  3.519  6.344e+05  4.818e+05
##mme  5.122  1.032  3.891  5.05e+05  5.316e+05
##do  0.211  0.0001869  0.2105  8.672e+05  6.227e+05
##v  0.0167  0.02242  0.08604  6.9e+05  6.602e+05
will  0.8513  0.08677  0.5482  7.614e+05  5.678e+05
begin  6.972  0.8514  5.233  5.197e+05  5.574e+05
a  0.1463  0.0001774  0.145  8.95e+05  6.56e+05
two  1.419  0.00124  1.41  8.643e+05  5.547e+05
-  0.

## Finding efficient methods to retrieve words with high surprisal

In [6]:
from collections import defaultdict as DD

def precision_and_recall(tgt, hyp):
    tgt_dict = DD(int)
    tgt_size = len(tgt) if isinstance(tgt, list) else tgt.shape[0]
    for y in tgt:
        tgt_dict[y] += 1
    hyp_size = len(hyp) if isinstance(hyp, list) else hyp.shape[0]
    overlap = 0
    for x in hyp:
        if x in tgt_dict:
            overlap += 1
    precision = overlap / hyp_size
    recall = overlap / tgt_size
    
    return precision, recall

_, tgt_top_idxs = masked_surprisal.topk(10)
tgt_tok_strs = [word_strs[i] for i in tgt_top_idxs]
print(tgt_tok_strs)

_, hyp_top_idxs = unmasked_surprisal.topk(10)
hyp_tok_strs = [word_strs[i] for i in hyp_top_idxs]
print(hyp_tok_strs)

prec, recl = precision_and_recall(tgt_tok_strs, hyp_tok_strs)

print('Precision {:.4}, recall {:.4}'.format(prec, recl))
    

['energy', '##rdy', 'begin', '##mme', 'his', 'monday', 'trade', '##kha', '##mu', 'press']
['##lin', 'turk', '##rdy', '##mme', '##men', 'begin', '##mu', 'main', 'talks', '##ng']
Precision 0.4, recall 0.4


In [7]:
in_tensor, str_toks, attn_mask, scrm_idxs = batch_to_idx_tensor(tokenizer, src)

with torch.no_grad():
    model_res = model(in_tensor, attention_mask=attn_mask)
    
attns = model_res[2]
layer0_attns = attns[0]
print(layer0_attns.shape)
summed = layer0_attns.sum(dim=2).sum(dim=1)
summed = (summed / summed.sum(dim=1)).squeeze(0)[1:-1]
print(summed.shape)
_, hyp_top_idxs = summed.topk(10)
hyp_tok_strs = [word_strs[i] for i in hyp_top_idxs]
print(hyp_tok_strs)

prec, recl = precision_and_recall(tgt_tok_strs, hyp_tok_strs)

print('Precision {:.4}, recall {:.4}'.format(prec, recl))

torch.Size([1, 12, 49, 49])
torch.Size([47])
['turk', 'monday', 'russia', '##kha', 'talks', 'visit', 'partner', 'begin', 'president', '##rdy']
Precision 0.4, recall 0.4
