In [1]:
import utils

from minicons import scorer
from torch.utils.data import DataLoader
from tqdm import tqdm



In [2]:
model_name = "EleutherAI/pythia-2.8b-deduped"

if "gpt2" in model_name or "pythia" in model_name:
    bos_token = True
else:
    bos_token = False

lm = scorer.IncrementalLMScorer(model_name, device="cuda:0")

In [3]:
parameters = lm.model.num_parameters()

In [4]:
fk1999 = utils.read_csv_dict("../data/fk1999-final.csv")

In [5]:
fk1999[:3]

[{'item': '1',
  'prefix': 'Ann wanted to treat her foreign guests to an all-American pie. She went out in the back yard and picked some',
  'expected': 'apples',
  'within_category': 'oranges',
  'between_category': 'carrots',
  'cloze_expected': '0.818',
  'constraint': 'high'},
 {'item': '2',
  'prefix': 'Every morning, Jack makes himself a glass of fresh-squeezed juice. He keeps his refrigerator stocked with',
  'expected': 'oranges',
  'within_category': 'apples',
  'between_category': 'tomatoes',
  'cloze_expected': '0.6779999999999999',
  'constraint': 'low'},
 {'item': '3',
  'prefix': 'Sheila loves the taste of home-made spaghetti sauce. She decided to start growing her own',
  'expected': 'tomatoes',
  'within_category': 'carrots',
  'between_category': 'apples',
  'cloze_expected': '0.782',
  'constraint': 'low'}]

In [6]:
def compute_conditional_scores(prefixes, continuations):
    scores = lm.conditional_score(
        prefixes,
        continuations,
        bos_token=bos_token,
        bow_correction=True,
        reduction=lambda x: -x.sum().item(), # surprisal
    )

    return scores

In [7]:
def run_batches(stimuli, batch_size=16, debug=False):
    batches = DataLoader(stimuli, batch_size=batch_size)

    results = []
    for batch in tqdm(batches):
        prefix = batch["prefix"]
        idx = batch["item"]

        expected = batch["expected"]
        within_category = batch["within_category"]
        between_category = batch["between_category"]

        dist = lm.next_word_distribution(prefix, bos_token=bos_token).detach().cpu()
        entropies = (-1.0 * (dist * dist.exp()).sum(1)).tolist()

        expected_scores = compute_conditional_scores(prefix, expected)
        within_scores = compute_conditional_scores(prefix, within_category)
        between_scores = compute_conditional_scores(prefix, between_category)

        for i, entropy, e, w, b in zip(
            idx, entropies, expected_scores, within_scores, between_scores
        ):
            results.append((i, entropy, e, w, b, parameters))

    if debug:
        return results, batch
    else:
        return results

In [8]:
results_1 = run_batches(fk1999, 1)

100%|██████████| 132/132 [00:10<00:00, 12.58it/s]


In [10]:
# passing debug = True for online demo -- to see whats in a batch
results_64, batch = run_batches(fk1999, 64, debug=True)

100%|██████████| 3/3 [00:04<00:00,  1.59s/it]


In [11]:
results_64[:3]

[('1',
  4.438488960266113,
  2.5140507221221924,
  9.913602828979492,
  7.563915252685547,
  2775208960),
 ('2',
  4.929327964782715,
  7.043475151062012,
  5.667770862579346,
  9.051733016967773,
  2775208960),
 ('3',
  3.6238369941711426,
  1.4791144132614136,
  7.839367866516113,
  9.950811386108398,
  2775208960)]

In [12]:
batch

{'item': ['129', '130', '131', '132'],
 'prefix': ['Exploring a little cave near his house, Jacob suddenly found himself in the dark. Something seemed to have gone wrong with his',
  'The gold and crystals scattered the light around the diningroom, and Walter stared up in awe. He wished that he owned a',
  'He put the icecream and the milk in for his shake, but he forgot to put on the cover. Disaster struck when he turned on the',
  'The slices of bread never popped up, and they burned to a crisp every time. She sighed, realizing she would have to buy a new'],
 'expected': ['flashlight', 'chandelier', 'blender', 'toaster'],
 'within_category': ['chandelier', 'flashlight', 'toaster', 'blender'],
 'between_category': ['toaster', 'blender', 'flashlight', 'chandelier'],
 'cloze_expected': ['0.765', '0.345', '0.96', '0.961'],
 'constraint': ['low', 'low', 'high', 'high']}