In [1]:
import numpy as np
import torch
import transformers
import pickle as pkl

  _torch_pytree._register_pytree_node(


In [46]:
ce_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
softmax_fn = torch.nn.Softmax(dim=-1)

In [2]:
encodings = pkl.load(open("encodings.pkl", 'rb'))
observer_logits = pkl.load(open("ob_logits.pkl", 'rb'))
performer_logits = pkl.load(open("pf_logits.pkl", 'rb'))
pad_token = pkl.load(open("pad_token_id.pkl", 'rb'))

In [48]:
np.set_printoptions(formatter={'float_kind':'{:f}'.format})

In [49]:
def perplexity(encoding: transformers.BatchEncoding,
               logits: torch.Tensor,
               median: bool = False,
               temperature: float = 1.0):
    shifted_logits = logits[..., :-1, :].contiguous() / temperature
    # print(shifted_logits)
    shifted_labels = encoding.input_ids[..., 1:].contiguous()
    # print(shifted_labels)
    shifted_attention_mask = encoding.attention_mask[..., 1:].contiguous()
    # print(shifted_attention_mask)
    if median:
        ce_nan = (ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels).
                  masked_fill(~shifted_attention_mask.bool(), float("nan")))
        # print(ce_nan)
        walk = ce_nan.to("cpu").float().numpy()
        # print(walk)
        ppl = np.nanmedian(ce_nan.cpu().float().numpy(), 1)

    else:
        ppl = (ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels) *
               shifted_attention_mask).sum(1) / shifted_attention_mask.sum(1)
        ppl = ppl.to("cpu").float().numpy()

    return ppl, walk

In [50]:
# print(perplexity(encodings, observer_logits, median=True))
# print(perplexity(encodings, performer_logits, median=True))
# ob_ppl, ob_walk = perplexity(encodings, observer_logits, median=True)
pf_ppl, pf_walk = perplexity(encodings, performer_logits, median=True)
sample_string = '''Dr. Capy Cosmos, a capybara unlike any other, astounded the scientific community with his
groundbreaking research in astrophysics. With his keen sense of observation and unparalleled ability to interpret
cosmic data, he uncovered new insights into the mysteries of black holes and the origins of the universe. As he
peered through telescopes with his large, round eyes, fellow researchers often remarked that it seemed as if the
stars themselves whispered their secrets directly to him. Dr. Cosmos not only became a beacon of inspiration to
aspiring scientists but also proved that intellect and innovation can be found in the most unexpected of creatures.'''
# ob_ppl_list = [(word, ob_walk[0][w]) for w, word in enumerate(sample_string.split())]
pf_ppl_list = [(word, pf_walk[0][w]) for w, word in enumerate(sample_string.split())]
print(ob_ppl_list)
print(pf_ppl_list)

[('Dr.', 0.43164062), ('Capy', 9.75), ('Cosmos,', 7.3125), ('a', 9.1875), ('capybara', 1.15625), ('unlike', 2.515625), ('any', 2.375), ('other,', 7.65625), ('astounded', 0.029052734), ('the', 0.038330078), ('scientific', 10.8125), ('community', 0.21679688), ('with', 0.14746094), ('his', 0.33789062), ('groundbreaking', 9.875), ('research', 1.9375), ('in', 0.9765625), ('astrophysics.', 2.265625), ('With', 0.31054688), ('his', 1.34375), ('keen', 0.44335938), ('sense', 7.6875), ('of', 4.34375), ('observation', 0.39453125), ('and', 1.109375), ('unparalleled', 2.34375), ('ability', 3.453125), ('to', 0.515625), ('interpret', 0.65234375), ('cosmic', 4.09375), ('data,', 1.3515625), ('he', 5.09375), ('uncovered', 2.109375), ('new', 0.09423828), ('insights', 2.21875), ('into', 0.73046875), ('the', 5.09375), ('mysteries', 3.1875), ('of', 0.28320312), ('black', 4.96875), ('holes', 3.78125), ('and', 4.5), ('the', 0.18261719), ('origins', 1.796875), ('of', 0.10986328), ('the', 1.4140625), ('universe.

In [None]:
# original
# ppl = perplexity(encodings, performer_logits)
# x_ppl = entropy(observer_logits.to(DEVICE_1), performer_logits.to(DEVICE_1),
#                 encodings.to(DEVICE_1), self.tokenizer.pad_token_id)
x_ppl, en_walk = entropy(observer_logits, performer_logits, encodings, pad_token)
pkl.dump(open('en_walk.pkl', 'wb'))


In [41]:
def entropy(p_logits: torch.Tensor,
            q_logits: torch.Tensor,
            encoding: transformers.BatchEncoding,
            pad_token_id: int,
            median: bool = False,
            sample_p: bool = False,
            temperature: float = 1.0):
    vocab_size = p_logits.shape[-1]
    total_tokens_available = q_logits.shape[-2]
    p_scores, q_scores = p_logits / temperature, q_logits / temperature

    p_proba = softmax_fn(p_scores).view(-1, vocab_size)

    if sample_p:
        p_proba = torch.multinomial(p_proba.view(-1, vocab_size), replacement=True, num_samples=1).view(-1)

    q_scores = q_scores.view(-1, vocab_size)

    ce = ce_loss_fn(input=q_scores, target=p_proba).view(-1, total_tokens_available)
    padding_mask = (encoding.input_ids != pad_token_id).type(torch.uint8)

    if median:
        ce_nan = ce.masked_fill(~padding_mask.bool(), float("nan"))
        en_walk = ce_nan.to("cpu").float().numpy()
        agg_ce = np.nanmedian(ce_nan.cpu().float().numpy(), 1)
    else:
        agg_ce = (((ce * padding_mask).sum(1) / padding_mask.sum(1)).to("cpu").float().numpy())

    return agg_ce, en_walk

In [None]:
bino_walk = (pf_walk / en_walk)
pkl.dump(open("bino_walk", 'wb'))