In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import storage_dir, hf_cache_dir
from datasets import load_dataset
from difflib import SequenceMatcher

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
olmo = AutoModelForCausalLM.from_pretrained(
    "allenai/OLMo-2-1124-7B", cache_dir=hf_cache_dir,
    device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-7B")
message = ["Language modeling is "]

inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
response = olmo.generate(**inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [6]:
ds = load_dataset("allenai/dolmino-mix-1124", "flan", cache_dir=hf_cache_dir)

Downloading data: 100%|██████████| 209/209 [09:53<00:00,  2.84s/files]
Downloading data: 100%|██████████| 209/209 [09:53<00:00,  2.84s/files]
Generating train split: 57264867 examples [10:20, 92244.38 examples/s] 



In [None]:
proportions = []
prefix_len = 200
generation_len = 100
n_examples = 100
for paragraph in ds['train']['text'][:n_examples]:
    tokenized_paragraph = tokenizer(paragraph, return_tensors='pt', return_token_type_ids=False)
    n_paragraph_tokens = tokenized_paragraph['input_ids'].shape[1]
    if n_paragraph_tokens < prefix_len+generation_len:
        continue
    paragraph_prefix = {
        'input_ids': tokenized_paragraph['input_ids'][:, :prefix_len],
        'attention_mask': tokenized_paragraph['attention_mask'][:, :prefix_len]}

    # Generate the next GENERATION_LEN tokens using greedy sampling
    response = olmo.generate(**paragraph_prefix, max_new_tokens=generation_len, do_sample=False)
    n_tokens_actually_generated = response.shape[1] - prefix_len

    # Decode the generated tokens and the actual tokens
    generated_tokens = tokenizer.convert_ids_to_tokens(response[0, prefix_len:])
    actual_tokens = tokenizer.convert_ids_to_tokens(
        tokenized_paragraph['input_ids'][0, prefix_len:prefix_len+n_tokens_actually_generated])

    # Find the longest matching subsequence on a token-by-token level
    matcher = SequenceMatcher(None, generated_tokens, actual_tokens)
    match = matcher.find_longest_match(0, len(generated_tokens), 0, len(actual_tokens))
    correct_predictions = match.size
    total_predictions = len(actual_tokens)

    # Calculate the proportion of correct predictions for this example
    proportion = correct_predictions / total_predictions if total_predictions > 0 else 0
    proportions.append(proportion)
    print(f"Proportion of correct predictions for this example: {proportion:.2f}")
    if proportion >= 0.5:
        decoded_generated = tokenizer.decode(response[0, prefix_len:])
        decoded_actual = tokenizer.decode(
            tokenized_paragraph['input_ids'][0, prefix_len:prefix_len+n_tokens_actually_generated])
        print(decoded_generated)
        print(decoded_actual)

# Optionally, print all proportions
print("All proportions:", proportions)

Proportion of correct predictions for this example: 0.30
Proportion of correct predictions for this example: 0.30
Proportion of correct predictions for this example: 0.93
 entailed by the premise?
Options:
- yes
- no
- it is not possible to tell
Teacher's response: no<|endoftext|>
 entailed by the premise?
Options:
- yes
- no
- it is not possible to tell
Teacher's response: Let's
Proportion of correct predictions for this example: 0.18
Proportion of correct predictions for this example: 0.83
 + $25 / month = $525 / month.<|endoftext|>
 + $25 / month = $525 / month.
The
Proportion of correct predictions for this example: 0.56
 35 books are returned. On Thursday, another 15 books are withdrawn from the library. How many books are now in the library?
There are 250 - 120 = 130 books in the library after Tuesday. There are 130 +
 35 books are returned. On Thursday, another 15 books are withdrawn from the library. How many books are now in the library?
On Tuesday, 120 books were taken out 25