In [1]:
from model import Transformer

model = Transformer('gpt2')

In [2]:
from loader import load_hellaswag
from experiment import Experiment

dataset, features = load_hellaswag()
experiment = Experiment(dataset=dataset, name='HellaSwag Test', features=features)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [5]:
sampled_experiment = experiment.sample(num=100)

## Efficiency

In [6]:
# Can we speed up computation a little? Batching? (Use some timing here to compare)
import numpy as np
import torch
import torch.nn.functional as F


def current_method(model, experiment):
    results = model.compute_option_log_likelihoods(items=experiment.dataset.items)
    return np.array(results)

def batch_computation(
        model,
        items,
        add_whitespace: bool=True,
    ):
    input_texts = []
    for item in items:
        input_texts.extend([item.prompt + (" " if add_whitespace else "") + option for option in item.options])

    # Tokenize the combined texts
    encoded_inputs = model.tokenize_texts(input_texts)

    # Get logits from the model
    logits = model.get_logits(encoded_inputs)

    #TODO To get the correct result, we would need to consider prompt lengths too,
    # but first check if this approach is even faster
    
    # Ignore the log probs at the beginning
    prompt_length = 1 #len(model.tokenizer(item.prompt, add_special_tokens=False)['input_ids']) #- 1

    # Calculate log probabilities from the logits for each token
    log_probs = F.log_softmax(logits, dim=-1)[:, prompt_length - 1:-1]
    # Need to offset by one since last position contains prediction for current token

    input_ids = encoded_inputs['input_ids'][:, prompt_length:]
    attention_mask = encoded_inputs['attention_mask'][:, prompt_length:]

    # Get log probabilities of actual tokens
    token_log_probs = log_probs.gather(2, input_ids.unsqueeze(-1)).squeeze(-1)

    # Set irrelevant entries at the end (from padding) to zero
    masked_log_probs = token_log_probs * attention_mask

    #TODO Note that we would need to pull this apart again so that items have separate entries
    return torch.sum(masked_log_probs, dim=1).numpy()


def new_method(model, experiment, batch_size=10):
    results = []
    for pos in range(0, len(experiment.dataset.items), batch_size):
        items = experiment.dataset.items[pos:pos+batch_size]
        results.extend(batch_computation(model, items=items))
    return results

%time lp1 = current_method(model, sampled_experiment)
for bs in [2, 5]: #, 10, 20, 50, 100]:
    %time lp2 = new_method(model, sampled_experiment, batch_size=bs)

CPU times: user 2min 7s, sys: 30.7 s, total: 2min 38s
Wall time: 20.5 s
CPU times: user 2min 14s, sys: 32 s, total: 2min 46s
Wall time: 21.3 s
CPU times: user 2min 36s, sys: 39.2 s, total: 3min 15s
Wall time: 25.9 s
