In [1]:
import numpy as np
import torch
import transformers
import pickle as pkl

In [2]:
ce_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
softmax_fn = torch.nn.Softmax(dim=-1)

In [4]:
encodings = pkl.load(open("encodings.pkl", 'rb'))
observer_logits = pkl.load(open("ob_logits.pkl", 'rb'))
performer_logits = pkl.load(open("pf_logits.pkl", 'rb'))
pad_token = pkl.load(open("pad_token_id.pkl", 'rb'))

In [5]:
np.set_printoptions(formatter={'float_kind':'{:f}'.format})

In [10]:
def perplexity(encoding: transformers.BatchEncoding,
               logits: torch.Tensor,
               median: bool = False,
               temperature: float = 1.0):
    shifted_logits = logits[..., :-1, :].contiguous() / temperature
    # print(shifted_logits)
    shifted_labels = encoding.input_ids[..., 1:].contiguous()
    # print(shifted_labels)
    shifted_attention_mask = encoding.attention_mask[..., 1:].contiguous()
    # print(shifted_attention_mask)
    if median:
        ce_nan = (ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels).
                  masked_fill(~shifted_attention_mask.bool(), float("nan")))
        # print(ce_nan)
        walk = ce_nan.to("cpu").float().numpy()
        # print(walk)
        ppl = np.nanmedian(ce_nan.cpu().float().numpy(), 1)

    else:
        ppl = (ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels) *
               shifted_attention_mask).sum(1) / shifted_attention_mask.sum(1)
        ppl = ppl.to("cpu").float().numpy()

    return ppl, walk

In [20]:
# print(perplexity(encodings, observer_logits, median=True))
# print(perplexity(encodings, performer_logits, median=True))
# ob_ppl, ob_walk = perplexity(encodings, observer_logits, median=True)
pf_ppl, pf_walk = perplexity(encodings, performer_logits, median=True)
sample_string = '''Dr. Capy Cosmos, a capybara unlike any other, astounded the scientific community with his
groundbreaking research in astrophysics. With his keen sense of observation and unparalleled ability to interpret
cosmic data, he uncovered new insights into the mysteries of black holes and the origins of the universe. As he
peered through telescopes with his large, round eyes, fellow researchers often remarked that it seemed as if the
stars themselves whispered their secrets directly to him. Dr. Cosmos not only became a beacon of inspiration to
aspiring scientists but also proved that intellect and innovation can be found in the most unexpected of creatures.'''
# ob_ppl_list = [(word, ob_walk[0][w]) for w, word in enumerate(sample_string.split())]
pf_ppl_list = [(word, pf_walk[0][w]) for w, word in enumerate(sample_string.split())]
# print(ob_ppl_list)
print(pf_ppl_list)
print(len(pf_ppl_list))

[('Dr.', 0.6953125), ('Capy', 9.8125), ('Cosmos,', 6.03125), ('a', 9.75), ('capybara', 1.84375), ('unlike', 2.984375), ('any', 1.921875), ('other,', 6.9375), ('astounded', 0.040771484), ('the', 0.026123047), ('scientific', 11.5), ('community', 0.27734375), ('with', 0.140625), ('his', 0.3046875), ('groundbreaking', 9.0625), ('research', 1.765625), ('in', 1.1171875), ('astrophysics.', 2.4375), ('With', 0.71484375), ('his', 1.0859375), ('keen', 1.4296875), ('sense', 8.125), ('of', 5.90625), ('observation', 0.5859375), ('and', 1.703125), ('unparalleled', 2.34375), ('ability', 1.84375), ('to', 0.51171875), ('interpret', 0.31640625), ('cosmic', 3.34375), ('data,', 1.4609375), ('he', 3.03125), ('uncovered', 2.46875), ('new', 0.048828125), ('insights', 3.328125), ('into', 1.1640625), ('the', 4.4375), ('mysteries', 4.03125), ('of', 0.30859375), ('black', 5.59375), ('holes', 2.984375), ('and', 3.515625), ('the', 0.04296875), ('origins', 2.4375), ('of', 0.08886719), ('the', 1.6328125), ('universe

In [13]:
def entropy(p_logits: torch.Tensor,
            q_logits: torch.Tensor,
            encoding: transformers.BatchEncoding,
            pad_token_id: int,
            median: bool = False,
            sample_p: bool = False,
            temperature: float = 1.0):
    vocab_size = p_logits.shape[-1]
    total_tokens_available = q_logits.shape[-2]
    p_scores, q_scores = p_logits / temperature, q_logits / temperature

    p_proba = softmax_fn(p_scores).view(-1, vocab_size)

    if sample_p:
        p_proba = torch.multinomial(p_proba.view(-1, vocab_size), replacement=True, num_samples=1).view(-1)

    q_scores = q_scores.view(-1, vocab_size)

    ce = ce_loss_fn(input=q_scores, target=p_proba).view(-1, total_tokens_available)
    padding_mask = (encoding.input_ids != pad_token_id).type(torch.uint8)

    if median:
        ce_nan = ce.masked_fill(~padding_mask.bool(), float("nan"))
        en_walk = ce_nan.to("cpu").float().numpy()
        agg_ce = np.nanmedian(ce_nan.cpu().float().numpy(), 1)
        return agg_ce, en_walk
    else:
        agg_ce = (((ce * padding_mask).sum(1) / padding_mask.sum(1)).to("cpu").float().numpy())

    return agg_ce

In [18]:
# original
# ppl = perplexity(encodings, performer_logits)
# x_ppl = entropy(observer_logits.to(DEVICE_1), performer_logits.to(DEVICE_1),
#                 encodings.to(DEVICE_1), self.tokenizer.pad_token_id)
x_ppl, en_walk = entropy(observer_logits, performer_logits, encodings, pad_token, median=True)
pkl.dump(en_walk, open('en_walk.pkl', 'wb'))
print(en_walk)

[[3.125000 7.343750 5.187500 7.468750 3.937500 5.031250 6.218750 7.125000
  0.281250 0.247070 4.468750 1.210938 0.914062 1.679688 4.468750 1.828125
  4.062500 4.687500 1.015625 2.765625 1.953125 6.031250 6.312500 0.882812
  4.125000 2.812500 4.843750 1.648438 1.875000 4.625000 3.593750 7.562500
  4.093750 0.488281 5.281250 1.226562 6.031250 4.093750 0.875000 6.062500
  4.125000 4.562500 0.671875 4.281250 0.824219 2.015625 4.656250 3.953125
  4.562500 1.656250 1.914062 4.625000 0.277344 1.789062 0.285156 1.382812
  4.375000 5.500000 0.077148 1.250000 0.863281 0.792969 2.875000 2.796875
  3.812500 5.437500 3.171875 2.031250 1.453125 0.225586 2.734375 3.812500
  7.187500 2.546875 6.000000 1.875000 0.992188 2.859375 2.296875 5.125000
  5.343750 2.359375 2.531250 1.804688 2.265625 1.062500 2.234375 1.546875
  5.281250 2.421875 1.781250 3.437500 1.210938 1.546875 1.156250 1.429688
  0.609375 3.406250 0.036133 0.691406 0.019409 4.500000 0.117188 6.062500
  2.484375 4.875000 0.964844 3.531250 

In [21]:
bino_walk = (pf_walk / en_walk[..., 1:])
pkl.dump(bino_walk, open("bino_walk", 'wb'))

In [23]:
print(bino_walk)

[[0.094681 1.891566 0.807531 2.476191 0.366460 0.479899 0.269737
  24.666666 0.165020 0.005846 9.496774 0.303419 0.083721 0.068182
  4.957265 0.434615 0.238333 2.400000 0.258475 0.556000 0.237047 1.287129
  6.690266 0.142045 0.605556 0.483871 1.118483 0.272917 0.068412 0.930435
  0.193182 0.740458 5.056000 0.009246 2.713376 0.193005 1.083969 4.607143
  0.050902 1.356061 0.654110 5.232558 0.010036 2.957346 0.044089 0.350671
  0.968379 0.972603 0.632075 0.353061 0.043708 12.507042 0.030158
  12.438356 0.006841 0.127679 0.127841 26.936708 0.008984 0.217195
  0.043411 0.095109 1.324022 0.688525 0.080460 2.295567 0.125000 0.801075
  16.692640 0.071786 0.922131 0.123370 1.288344 0.081055 1.283333
  0.096457 0.031762 3.809524 0.445122 0.830409 1.039735 0.549383 1.437229
  0.820690 0.937500 0.111014 1.287879 0.047707 0.606452 0.350877 1.636364
  1.883871 0.016651 3.945946 0.092896 0.439103 0.011468 73.513512
  0.014919 10.867925 0.001119 53.066666 0.001591 1.226415 0.169071
  4.307693 0.054204