contrastive decoding paper - https://arxiv.org/abs/2210.15097

In [90]:
import torch
import transformers
import torch.nn.functional as F

tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
amateur_lm = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
expert_lm = transformers.AutoModelForCausalLM.from_pretrained('gpt2-large')

encoding = tokenizer("The future of AI is", return_tensors='pt')
logits = expert_lm(**encoding).logits

In [89]:
print(logits.shape)

torch.Size([1, 5, 50257])


In [102]:
#contrastive decoding without adaptive plausability
def generate_next_token(encoding, expert_model, amateur_model, num_tokens):
  generated_tokens = encoding['input_ids'][0].tolist()

  for _ in range(num_tokens):
    with torch.no_grad():
      expert_logits = expert_model(**encoding).logits[0, -1, :]
      amateur_logits = amateur_model(**encoding).logits[0, -1, :]
      
      expert_probs = F.softmax(expert_logits, dim=-1)
      amateur_probs = F.softmax(amateur_logits, dim=-1)
      
      contrastive_logits = torch.log(expert_probs) - torch.log(amateur_probs)

      topk_logits, topk_indices = torch.topk(contrastive_logits, k=50, dim=-1)
      topk_probs = F.softmax(topk_logits, dim=-1)
      sampled_index = torch.multinomial(topk_probs, num_samples=1)
      next_token = topk_indices[sampled_index]

      generated_tokens.append(next_token.item())
      text = tokenizer.decode(generated_tokens)
      encoding = tokenizer(text, return_tensors='pt')

  return tokenizer.decode(generated_tokens)

encoding = tokenizer("Hello world", return_tensors='pt')
print(generate_next_token(encoding, expert_lm, amateur_lm, 50))

Hello world REPL=> HankhallanotationsrebHyp Archdemon Dukeemate conflic Mannyaepernickaband Manzielvana MANウanglerburghback Manzielburgh mustacheburgh Mannyoxide SymphonyPUTbowsburghburghburghapyもorthern Struggle Kaepernickogy migriphBILITYymph BuddyCANergy"}],"ruitsreating eleph


In [103]:
#contrastive decoding with adaptive plausability
def calculate_vhead(expert_probs, alpha):
  mask = expert_probs >= (alpha * torch.max(expert_probs))
  return mask

def generate_next_token(encoding, expert_model, amateur_model, num_tokens, alpha = 0.1):
  generated_tokens = encoding['input_ids'][0].tolist()

  for _ in range(num_tokens):
    with torch.no_grad():
      expert_logits = expert_model(**encoding).logits[0, -1, :]
      amateur_logits = amateur_model(**encoding).logits[0, -1, :]

      expert_probs = F.softmax(expert_logits, dim=-1)
      amateur_probs = F.softmax(amateur_logits, dim=-1)

      plausible_mask = calculate_vhead(expert_probs, alpha) #binary mask for plausible tokens
      contrastive_logits = torch.log(expert_probs)
      amateur_penalty = torch.log(amateur_probs)
      
      #CD score
      contrastive_logits[plausible_mask] -= amateur_penalty[plausible_mask]
      contrastive_logits[~plausible_mask] = float('-inf')

      topk_logits, topk_indices = torch.topk(contrastive_logits, k=50, dim=-1)
      topk_probs = F.softmax(topk_logits, dim=-1)
      sampled_index = torch.multinomial(topk_probs, num_samples=1)
      next_token = topk_indices[sampled_index]

      generated_tokens.append(next_token.item())
      text = tokenizer.decode(generated_tokens)
      encoding = tokenizer(text, return_tensors='pt')

  return tokenizer.decode(generated_tokens)

encoding = tokenizer("The future of AI is", return_tensors='pt')
print(generate_next_token(encoding, expert_lm, amateur_lm, 50))

The future of AI is going to be really exciting. But I do think the risk is that people who have money invested in these companies won't know what is happening and don't care. The most important thing to know about AI is that it's a very fast and effective
