contrastive decoding paper - https://arxiv.org/abs/2210.15097

In [None]:
import torch
import transformers

tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
amateur_lm = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
expert_lm = transformers.AutoModelForCausalLM.from_pretrained('gpt2-medium')

encoding = tokenizer("Hello world", return_tensors='pt')
logits = expert_lm(**encoding).logits

In [None]:
print(logits.shape)

torch.Size([1, 2, 50257])


In [None]:
encoding

{'input_ids': tensor([[15496,   995]]), 'attention_mask': tensor([[1, 1]])}

In [None]:
import torch.nn.functional as F

def generate_next_token(encoding, expert_model, amateur_model, num_tokens):
  generated_tokens = encoding['input_ids'][0].tolist()

  for _ in range(num_tokens):
    with torch.no_grad():
      expert_logits = expert_model(**encoding).logits[0, -1, :]
      amateur_logits = amateur_model(**encoding).logits[0, -1, :]

      expert_probs = F.softmax(expert_logits, dim=-1)
      amateur_probs = F.softmax(amateur_logits, dim=-1)

      #calculate contrastive decoding scores - NO adaptive plausability yet
      contrastive_logits = torch.log(expert_probs) - torch.log(amateur_probs)

      #sampling
      next_token_probs = F.softmax(contrastive_logits, dim=-1)
      next_token = torch.argmax(contrastive_logits, dim=-1)
      print(next_token.item())

      generated_tokens.append(next_token.item())
      text = tokenizer.decode(generated_tokens)
      encoding = tokenizer(text, return_tensors='pt')

  return tokenizer.decode(generated_tokens)

In [None]:
print(generate_next_token(encoding, expert_lm, amateur_lm, 5))

35343
10108
26534
14818
35266
Hello world/** simultane�� Sud�


In [None]:
import torch.nn.functional as F

def calculate_vhead(expert_probs, alpha):
  mask = expert_probs >= (alpha * torch.max(expert_probs))
  return mask

def generate_next_token(encoding, expert_model, amateur_model, num_tokens, alpha = 0.1):
  generated_tokens = encoding['input_ids'][0].tolist()

  for _ in range(num_tokens):
    with torch.no_grad():
      expert_logits = expert_model(**encoding).logits[0, -1, :]
      amateur_logits = amateur_model(**encoding).logits[0, -1, :]

      expert_probs = F.softmax(expert_logits, dim=-1)
      amateur_probs = F.softmax(amateur_logits, dim=-1)

      #binary mask for plausible tokens
      plausible_mask = calculate_vhead(expert_probs, alpha)
      contrastive_logits = torch.log(expert_probs)
      amateur_penalty = torch.log(amateur_probs)

      #CD score
      contrastive_logits[plausible_mask] -= amateur_penalty[plausible_mask]
      contrastive_logits[~plausible_mask] = float('-inf')

      #sampling
      next_token_probs = F.softmax(contrastive_logits, dim=-1)
      next_token = torch.argmax(contrastive_logits, dim=-1)
      print(next_token.item())

      generated_tokens.append(next_token.item())
      text = tokenizer.decode(generated_tokens)
      encoding = tokenizer(text, return_tensors='pt')

  return tokenizer.decode(generated_tokens)


In [None]:
print(generate_next_token(encoding, expert_lm, amateur_lm, 5))

6
379
262
886
13
Hello world' at the end.
