## Eliciting scores from LMs

Language models are trained to predict token probabilities, given some input context. This allows us to explore a number of different scoring methods.

* token-scoring: the natural ability of LMs -- assigning probabilities to tokens given context
* word-scoring: going from tokens (which could be sub-words) to word scores
* sequence-scoring: going from tokens/words to full, multi-word sequences 
* conditional-scoring: computing conditional probabilities of sequences given some input

For all these methods, we will consider a range of different scores: probabilities, log-probabilities, surprisals. In the context of sequence probabilities, we will look at differences between summing log-probabilities (equivalent to multiplying probabilities) vs. looking at log-probability per token, to account for the effect of length.

In [1]:
import torch

from minicons import scorer
from nltk.tokenize import TweetTokenizer
from torch.utils.data import DataLoader

### Different types of LMs

Autoregressive LMs: `lm.IncrementalLMScorer`

Masked LMs: `lm.MaskedLMScorer`

In [2]:
model_name = "HuggingFaceTB/SmolLM2-135M"
# model_name = "gpt2"
# model_name = "facebook/opt-125m"

# many models do not automatically insert a beggining of 
# sentence tokens when tokenizing a sequence, even though
# they were trained to do so...

if "gpt2" in model_name or "pythia" in model_name or "SmolLM" in model_name:
    BOS = True
else:
    BOS = False

lm = scorer.IncrementalLMScorer(model_name)

In [44]:
prefixes = [
    "He caught the pass and scored another touchdown. There was nothing he enjoyed more than a good game of",
    "The firefighters wanted to have a mascot to live with them at the firehouse. Naturally, they decided it would have to be a"
]

dist = lm.next_word_distribution(prefixes)

# batch-wise logprobs over the next word
dist

tensor([[-12.3408, -26.3428, -26.3727,  ..., -23.4791, -17.3325, -22.6918],
        [-11.9606, -23.8860, -23.8423,  ..., -20.0405, -12.6999, -20.7964]])

In [45]:
entropy = (-dist * dist.exp()).sum(1)
entropy

tensor([4.2262, 6.0277])

In [53]:
# probs and ranks of query words

lm.query(
    dist, 
    [["football", "baseball", "monopoly"], 
     ["dog", "bear", "zebra"]]
)

([[0.3232649564743042, 0.021983535960316658, 5.035623689764179e-05],
  [0.061017148196697235, 0.04948189854621887, 0.0008991201757453382]],
 [[1, 6, 844], [1, 2, 160]])

### Token Scoring

**Input:** I know what the lion devoured at sunrise.

**Outputs:** 
* Probabilities: $p(w_i | w_1, w_2, \dots, w_{i-1})$
* log-probabilities: $\log p(w_i | w_1, w_2, \dots, w_{i-1})$
* Surprisals: $-\log p(w_i | w_1, w_2, \dots, w_{i-1})$


In [48]:
sequences = [
    "I know what the lion devoured at sunrise.", 
    "I know that the lion devoured at sunrise."
]

Probabilities:

In [49]:
lm.token_score(
    sequences, 
    bos_token=BOS,
    prob=False,
    surprisal=True,
    bow_correction=True
)

[[('<|endoftext|>', 0.0),
  ('I', 6.193064212799072),
  ('Ġknow', 4.064828872680664),
  ('Ġwhat', 2.9297256469726562),
  ('Ġthe', 2.5732367038726807),
  ('Ġlion', 9.888418197631836),
  ('Ġdev', 11.566431045532227),
  ('oured', 2.2843103408813477),
  ('Ġat', 3.8329315185546875),
  ('Ġsunrise', 7.398899078369141),
  ('.', 1.7576873302459717)],
 [('<|endoftext|>', 0.0),
  ('I', 6.193064212799072),
  ('Ġknow', 4.064828872680664),
  ('Ġthat', 1.2324095964431763),
  ('Ġthe', 2.0393288135528564),
  ('Ġlion', 9.423365592956543),
  ('Ġdev', 9.49761962890625),
  ('oured', 1.1308623552322388),
  ('Ġat', 7.301660060882568),
  ('Ġsunrise', 9.36835765838623),
  ('.', 2.6524205207824707)]]

log probabilities/surprisals:

In [11]:
lm.token_score(
    sequences, 
    bos_token=BOS,
    surprisal=True
)

[[('<|endoftext|>', 0.0),
  ('I', 5.686635971069336),
  ('Ġknow', 4.515291213989258),
  ('Ġwhat', 2.960165023803711),
  ('Ġthe', 2.5910110473632812),
  ('Ġlion', 9.441681861877441),
  ('Ġdev', 12.020919799804688),
  ('oured', 1.3213729858398438),
  ('Ġat', 4.788764953613281),
  ('Ġsunrise', 7.406002998352051),
  ('.', 1.361612319946289)],
 [('<|endoftext|>', 0.0),
  ('I', 5.686635971069336),
  ('Ġknow', 4.515291213989258),
  ('Ġthat', 1.2616539001464844),
  ('Ġthe', 2.0563831329345703),
  ('Ġlion', 9.177278518676758),
  ('Ġdev', 9.753374099731445),
  ('oured', 1.1199226379394531),
  ('Ġat', 7.276799201965332),
  ('Ġsunrise', 9.404158592224121),
  ('.', 2.3709983825683594)]]

### Word scoring

Same metrics, but logprobs for words that are split into tokens are summed---e.g., `devoured` is split into `dev + oured`. However, here you have to provide the word tokenizer yourself. We will use `nltk`'s `TweetTokenizer()` as an example

In [12]:
word_tokenizer = TweetTokenizer().tokenize

In [13]:
lm.word_score_tokenized(
    sequences, 
    bos_token=BOS, 
    tokenize_function=word_tokenizer,
    surprisal=True,
    base_two=True
)

[[('I', 8.204081535339355),
  ('know', 6.514188289642334),
  ('what', 4.270615577697754),
  ('the', 3.7380387783050537),
  ('lion', 13.621467590332031),
  ('devoured', 19.248859405517578),
  ('at', 6.908727645874023),
  ('sunrise', 10.684603691101074),
  ('.', 1.9643913507461548)],
 [('I', 8.204081535339355),
  ('know', 6.514188289642334),
  ('that', 1.8201818466186523),
  ('the', 2.966733694076538),
  ('lion', 13.24001407623291),
  ('devoured', 15.686850547790527),
  ('at', 10.498202323913574),
  ('sunrise', 13.567333221435547),
  ('.', 3.4206275939941406)]]

### Sequence scoring

**Input:** batch of sentences

**Outputs:** scores indicating how likely each sequence is. There are multiple methods for doing this though:

* summed log-probs (equivalent to joint probability, computed using the product rule)
* log-prob per token

In [14]:
sequences = [
    "The keys to the cabinet are on the table.",
    "The keys to the cabinet is on the table."
]

log-prob per token (default behavior):

In [15]:
lm.sequence_score(sequences, bos_token=BOS)

[-3.6746392250061035, -4.0424675941467285]

summed log-probs:

summing is done by using the `reduction` argument, which takes a function

In [27]:
lm.sequence_score(
    sequences, 
    bos_token=BOS, 
    reduction=lambda x: x.sum().item()
)

[-36.74639129638672, -40.42467498779297]

Here, the lambda function is a concise way of defining a function, here this is equivalent to taking the torch tensor consisting of the model elicited log-probabilities and reduces it row-wise by summing, and extracting the item (as opposed to keeping it as a `tensor`). For example:

In [49]:
x = torch.tensor([0.223234, 0.443257, 0.364343], dtype=torch.double)
sum_func = lambda x: x.sum().item()

sum_func(x)

1.030834

### log-prob of full sequence (summing) vs. log-prob per token (avg)

Usually, the two metrics show similar qualitative trends, especially for minimal pair comparisons. However there are certain cases where log-prob per token is a better metric. This is because the summed log prob metric for a sentence might be lower simply because it is longer (contain more tokens)--since it involves a more number of multiplications between word-probabilities, each of which is a number lower than 1.

The following pair illustrates this issue:

1. These casseroles disgust Mrs. O'leary
2. *These casseroles disgusts Kayla

In [42]:
stimuli = [
    "These casseroles disgust Mrs. O'leary", # longer but grammatical
    "These casseroles disgusts Kayla" # shorter but ungrammatical
]

In [43]:
# sum
lm.sequence_score(
    stimuli, 
    bos_token=BOS, 
    reduction=lambda x: x.sum().item()
)

[-61.54802703857422, -56.2628173828125]

In [44]:
lm.sequence_score(stimuli, bos_token=BOS)

[-6.1548027992248535, -8.037545204162598]

### Conditional LM scoring

This follows the same principle as sequence scoring, but allows you to separate the prefix and the continuation. Like sequence scoring, this method also allows for different reduction methods

In [50]:
# Lake and Murphy (2023) / Murphy (1988)
# "are cooked in a pie" is an emergent property of sliced apples

prefix = ["Sliced apples", "Apples", "Sliced things"]
continuation = ["are cooked in a pie."] * 3

In [51]:
lm.conditional_score(prefix, continuation, bos_token=BOS)

[-3.236760377883911, -3.915409803390503, -3.4075047969818115]