In [None]:
# 必要なものを用意
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2-xl"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

max_length = 128
input_txt = """In a shocking finding, scientist discovered \
a herd of unicorns living in a remote, previously unexplored \
valley, in the Andes Mountains. Even more surprising to the \
researchers was the fact that the unicorns spoke perfect English.\n\n"""
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
output_greedy = model.generate(input_ids, max_length=max_length,
                               do_sample=False)

In [None]:
# 条件付き確率の積
0.5 ** 1024 # e-309というとてつもなく小さい数

In [None]:
# 対数確率の積
import numpy as np

sum([np.log(0.5)] * 1024) # -709.…なので全然扱える

In [None]:
# トークンの対数確率計算
import torch.nn.functional as F

def log_probs_from_logits(logits, labels):
  logp = F.log_softmax(logits, dim=-1)
  logp_label = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
  return logp_label

In [None]:
# 系列の総対数確率
def sequence_logprob(model, labels, input_len=0):
  with torch.no_grad():
    output = model(labels)
    log_probs = log_probs_from_logits(
        output.logits[:, :-1, :], labels[:, 1:]) # 最終状態のロジットは不要、最初のラベルのロジットは計算されることがない
    seq_log_prob = torch.sum(log_probs[:, input_len:]) # 入力分は無視
  return seq_log_prob.cpu().numpy()

In [None]:
# 貪欲法の対数確率計算
logp = sequence_logprob(model, output_greedy, input_len=len(input_ids[0]))
print(tokenizer.decode(output_greedy[0]))
print(f"\nlog-prob: {logp:.2f}")

In [None]:
# ビームサーチの対数確率計算
output_beam = model.generate(input_ids, max_length=max_length, num_beams=5,
                             do_sample=False)
logp = sequence_logprob(model, output_beam, input_len=len(input_ids[0]))
print(tokenizer.decode(output_beam[0]))
print(f"\nlog-prob: {logp:.2f}")

In [None]:
# nグラムペナルティ
output_beam = model.generate(input_ids, max_length=max_length, num_beams=5,
                             do_sample=False, no_repeat_ngram_size=2)
logp = sequence_logprob(model, output_beam, input_len=len(input_ids[0]))
print(tokenizer.decode(output_beam[0]))
print(f"\nlog-prob: {logp:.2f}")