In [1]:
import utils

from minicons import scorer
from torch.utils.data import DataLoader
from tqdm import tqdm



In [2]:
model_name = "EleutherAI/pythia-2.8b-deduped"

if "gpt2" in model_name or "pythia" in model_name:
    bos_token = True
else:
    bos_token = False

lm = scorer.IncrementalLMScorer(model_name, device="cuda:0")

In [3]:
parameters = lm.model.num_parameters()

In [7]:
fk1999 = utils.read_csv_dict("../data/fk1999-final.csv")

In [8]:
def compute_conditional_scores(prefixes, continuations):
    scores = lm.conditional_score(
        prefixes,
        continuations,
        bos_token=bos_token,
        bow_correction=True,
        reduction=lambda x: -x.sum().item(), # surprisal
    )

    return scores

In [9]:
def run_batches(stimuli, batch_size=16):
    batches = DataLoader(stimuli, batch_size=batch_size)

    results = []
    for batch in tqdm(batches):
        prefix = batch["prefix"]
        idx = batch["item"]

        expected = batch["expected"]
        within_category = batch["within_category"]
        between_category = batch["between_category"]

        dist = lm.next_word_distribution(prefix, bos_token=bos_token).detach().cpu()
        entropies = (-1.0 * (dist * dist.exp()).sum(1)).tolist()

        expected_scores = compute_conditional_scores(prefix, expected)
        within_scores = compute_conditional_scores(prefix, within_category)
        between_scores = compute_conditional_scores(prefix, between_category)

        for i, entropy, e, w, b in zip(
            idx, entropies, expected_scores, within_scores, between_scores
        ):
            results.append((i, entropy, e, w, b, parameters))

    return results

In [10]:
results_1 = run_batches(fk1999, 1)

100%|██████████| 132/132 [00:10<00:00, 12.51it/s]


In [None]:
results_64 = run_batches(fk1999, 64)

100%|██████████| 3/3 [00:04<00:00,  1.65s/it]


In [None]:
results_64[:3]

[('1',
  4.438488960266113,
  2.5140509605407715,
  9.913602828979492,
  7.563915252685547,
  2775208960),
 ('2',
  4.929327964782715,
  7.043475151062012,
  5.667770862579346,
  9.051733016967773,
  2775208960),
 ('3',
  3.6238369941711426,
  1.4791144132614136,
  7.839367866516113,
  9.950811386108398,
  2775208960)]