In [1]:
%reload_ext autoreload
%autoreload 2

from datasets import load_dataset, load_metric
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, default_data_collator, DataCollatorForSeq2Seq, T5ForConditionalGeneration
import numpy as np
from tqdm.auto import tqdm


from gsm8k import GSM8K
from prompt import math_word_problem_template
import utils


In [2]:
ds = load_dataset('gsm8k', 'main', split='test').shuffle(seed=42).select(range(32*2))

[2023-01-17 02:13:56,135] [datasets.builder] [builder.py:785] Found cached dataset gsm8k (/workspaces/seed/cache/hf_dataset/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba)
[2023-01-17 02:13:56,137] [datasets.arrow_dataset] [arrow_dataset.py:3930] Loading cached shuffled indices for dataset at /workspaces/seed/cache/hf_dataset/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba/cache-63d0d521e5d794ec.arrow


In [3]:
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xxl')

In [4]:
dds = ds.map(lambda example: GSM8K.regex_answer(example['answer']))\
        .map(math_word_problem_template)\
        .map(lambda examples: tokenizer(examples['prompt']), batched=True)\
        .remove_columns(['question', 'answer', 'prompt'])
dds

[2023-01-17 02:13:57,190] [datasets.arrow_dataset] [arrow_dataset.py:3047] Loading cached processed dataset at /workspaces/seed/cache/hf_dataset/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba/cache-452a9ac7250f13d6.arrow
[2023-01-17 02:13:57,191] [datasets.arrow_dataset] [arrow_dataset.py:3047] Loading cached processed dataset at /workspaces/seed/cache/hf_dataset/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba/cache-1c723ad759cb9077.arrow
[2023-01-17 02:13:57,196] [datasets.arrow_dataset] [arrow_dataset.py:3047] Loading cached processed dataset at /workspaces/seed/cache/hf_dataset/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba/cache-1dd7d3bf6917fcd8.arrow


Dataset({
    features: ['label', 'input_ids', 'attention_mask'],
    num_rows: 64
})

In [5]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer, 
    padding='longest', 
    max_length=1024, 
    pad_to_multiple_of=8,
    return_tensors='pt')

In [6]:
eval_dataloader = DataLoader(
    dds,
    collate_fn=data_collator,
    batch_size=16
)

In [7]:
# take a look at the first batch
batch = next(iter(eval_dataloader))

batch.keys()

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


dict_keys(['label', 'input_ids', 'attention_mask'])

In [8]:
model = T5ForConditionalGeneration.from_pretrained(
        'google/flan-t5-xxl', low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
    )
model.parallelize()

In [9]:
metric = load_metric('accuracy')

  metric = load_metric('accuracy')


In [10]:
progress_bar = tqdm(range(len(eval_dataloader)))

  0%|          | 0/4 [00:00<?, ?it/s]

In [11]:
for batch in eval_dataloader:
    with torch.no_grad():
        outputs = model.generate(
            input_ids=batch['input_ids'].cuda(),
            attention_mask=batch['attention_mask'].cuda(),
            max_length=256,
            temperature=0,
        )
        
        output_sequences = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions = torch.tensor(list(map(GSM8K.regex_predict, output_sequences)))

        metric.add_batch(predictions=predictions, references=batch['label'])
        
        progress_bar.update(1)

In [12]:
metric.compute()

{'accuracy': 0.21875}