In [None]:
import json
import random

import torch
import torch.nn.functional as F

from time import perf_counter

from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

from IPython.display import clear_output

from LogitWrappers import *


In [None]:
tokenizer = AutoTokenizer.from_pretrained("concedo/OPT-19M-ChatSalad")
model = AutoModelForCausalLM.from_pretrained("concedo/OPT-19M-ChatSalad", return_dict_in_generate=True)


In [5]:
def summary(xs, titles):
    for x, title in zip(xs, titles):
        print(' ' * (15 - (len(title) + 1)), f'{title}:', end=' ')
        # for fn in (torch.max, torch.mean, torch.median, len):
            # print(f'{fn.__name__}\t{fn(x.float()):.2f}')
        print(x.item())


In [6]:
def best_k_p(logits, golden):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    probs = sorted_logits.softmax(dim=-1)

    probs_max = probs[..., 0, None]
    prob_golden = probs[sorted_indices == golden]

    print(' ' * (15 - len('prob max:')), 'prob max:', probs_max.item())
    print(' ' * (15 - len('prob golden:')), 'prob golden:', prob_golden.item())
    print(' ' * (15 - len('cumulative_probs:')), 'cumulative probs:', cumulative_probs.shape)
    
    print()

    temp = (prob_golden *  probs_max) / 2
    tk = (sorted_indices == golden).nonzero()[:, 1]
    tp = cumulative_probs[sorted_indices == golden]

    return temp, tk, tp


In [7]:
prompt = "Echo Sohma sat at her desk,"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids

start_time = perf_counter()
outputs = model(input_ids)
print(perf_counter() - start_time, 'sec to get logits', end='\n\n')

next_token_logits = outputs[0][:, -1, :]

summary(best_k_p(next_token_logits, torch.tensor([19311]).unsqueeze(1)), ['temp', 'k', 'p'])


0.10483854499761946 sec to get logits

       prob max: 0.07534100860357285
    prob golden: 4.79663722217083e-05
 cumulative probs: torch.Size([1, 50266])

           temp: 1.806917452995549e-06
              k: 1604
              p: 0.8144427537918091


In [68]:
class AdvancedRepetitionPenaltyLogitsProcessor():
    def __init__(self, penalty: int, penalty_range: int, penalty_slope: int):
        self.penalty = penalty
        self.penalty_range = int(penalty_range)
        self.penalty_slope = penalty_slope

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        clipped_penalty_range = min(input_ids.shape[-1], self.penalty_range)

        if self.penalty != 1.0:
            if self.penalty_range > 0:
                if clipped_penalty_range < input_ids.shape[1]:
                    input_ids = input_ids[..., -clipped_penalty_range:]

                if self.penalty_slope != 0:
                    _penalty = (torch.arange(self.penalty_range, dtype=scores.dtype, device=scores.device)/(self.penalty_range - 1)) * 2. - 1
                    _penalty = (self.penalty_slope * _penalty) / (1 + torch.abs(_penalty) * (self.penalty_slope - 1))
                    _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (self.penalty - 1)
                    self.penalty = _penalty[..., -clipped_penalty_range]

            score = torch.gather(scores, 1, input_ids)
            score = torch.where(score <= 0, score * self.penalty, score / self.penalty)
            scores.scatter_(1, input_ids, score)

        return scores


In [71]:
sample_order = [
    ['Temperature', TemperatureLogitsWarper(temperature=0.9)],
    ['Repetition Penalty', AdvancedRepetitionPenaltyLogitsProcessor(penalty_range=1024, penalty=1.1, penalty_slope=0.7)],
    ['Top P', TopPLogitsWarper(top_p=0.9)],
    ['Top K', TopKLogitsWarper(top_k=15)],
    ['Top A', TopALogitsWarper(top_a=0.7)],
    ['Typical Sampling', TypicalLogitsWarper(typical=1.0)],
    ['Temperature', TemperatureLogitsWarper(temperature=5.0)],
    ['Top P', TopPLogitsWarper(top_p=0.6)],
    ['Tail Free', TailFreeLogitsWarper(tfs=1.0)],
    ['Min Length', MinLengthLogitsProcessor(min_length=10, eos_token_id=tokenizer.eos_token_id)]
]

prompt = "this model\n"


In [None]:
print(prompt, end='')

for _ in range(100):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    with torch.no_grad():
        start_time = perf_counter()
        outputs = model(input_ids)
        print(perf_counter() - start_time, 'sec to get logits', end='\n\n')

        next_token_logits = outputs[0][:, -1, :]

        for [sampler_name, sampler] in [['None', lambda input_ids, scores: scores], *sample_order]:
            next_token_logits = sampler(input_ids=input_ids, scores=next_token_logits)

            probs_for_sampler = F.softmax(next_token_logits, dim=-1).squeeze(0)

            prob_list_for_sampler = []

            for (prb, idx) in zip(probs_for_sampler, range(probs_for_sampler.shape[0])):
                if (prb.item() == 0):
                    continue

                prob_list_for_sampler.append(((idx, tokenizer.decode(idx)), prb.item()))

            print(f"{sampler_name}:{' ' * (20 - len(sampler_name))}", end='')

            tokens = []

            for ((id, token), prob) in sorted(prob_list_for_sampler, key=lambda x: x[1], reverse=True)[:7]:
                tokens.append(f"{json.dumps(token)} ({'{:.4f}'.format(round(prob, 4))})")

            print(', '.join(tokens), f'...{len(prob_list_for_sampler) - 7} more' if len(prob_list_for_sampler) > 10 else '')

        print()

        print(f"Final probabilities:{' ' * (20 - len('Final probablities:'))}", end='')

        for ((id, token), prob) in sorted(prob_list_for_sampler, key=lambda x: x[1], reverse=True):
            print(f"{json.dumps(token)} ({'{:.4f}'.format(round(prob, 4))})", end=', ')

        print()
        print()

        sampled_token_org = tokenizer.decode(torch.multinomial(F.softmax(outputs[0][:, -1, :], dim=-1).squeeze(0), 1)[0])
        sampled_token = tokenizer.decode(torch.multinomial(F.softmax(next_token_logits, dim=-1), 1)[0])

        rand = random.random()

        prompt += sampled_token

        print('Sampled (Before samplers):', json.dumps(sampled_token_org))
        print('Sampled (After samplers) :', json.dumps(sampled_token))

        print('Taken:', sampled_token if rand > 0.5 else sampled_token_org)
