In [1]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
from transformers import pipeline

import numpy as np
from tqdm.autonotebook import tqdm
from typing import Dict, Any, List, Tuple, Optional
from dataclasses import dataclass

from src.settings import MODELS_DIR

In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

In [3]:
MODEL = "gpt2"
DATASET = "Rowan/hellaswag"

In [4]:
dataset = load_dataset(DATASET, split="test")
dataset

Dataset({
    features: ['ind', 'activity_label', 'ctx_a', 'ctx_b', 'ctx', 'endings', 'source_id', 'split', 'split_type', 'label'],
    num_rows: 10003
})

In [5]:
dataset[0]

{'ind': 14,
 'activity_label': 'Wakeboarding',
 'ctx_a': 'A man is being pulled on a water ski as he floats in the water casually.',
 'ctx_b': 'he',
 'ctx': 'A man is being pulled on a water ski as he floats in the water casually. he',
 'endings': ['mounts the water ski and tears through the water at fast speeds.',
  'goes over several speeds, trying to stay upright.',
  'struggles a little bit as he talks about it.',
  'is seated in a boat with three other people.'],
 'source_id': 'activitynet~v_-5KAycAQlC4',
 'split': 'test',
 'split_type': 'indomain',
 'label': ''}

In [17]:
model = AutoModelForCausalLM.from_pretrained(MODEL).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

In [18]:
def preprocess_hellaswag_batch(examples: Dict[str, Any]) -> Tuple[List[str], List[str]]:
    inputs = examples["ctx"]
    targets = examples["endings"]
    return inputs, targets


@dataclass
class HellaSwagInputsEncoder:
    tokenizer: AutoTokenizer
    max_seq_length: int

    def convert_to_features_train(
        self,
        example_batch: Dict[str, Any],
        indices: Optional[List[int]] = None
    ) -> Any:
        inputs, text_target = preprocess_hellaswag_batch(example_batch)

        model_inputs = self.tokenizer(
            # inputs, text_target=text_target, max_length=self.max_seq_length, truncation=True
            inputs, max_length=self.max_seq_length, truncation=True
        )
        return model_inputs

    def __call__(
        self,
        example_batch: Dict[str, Any],
        indices: Optional[List[int]] = None
    ) -> Any:
        return self.convert_to_features_train(
            example_batch=example_batch, indices=indices
        )

In [19]:
loader_columns = [
    'datasets_idx',
    'input_ids',
    'token_type_ids',
    'attention_mask',
    'start_positions',
    'end_positions',
    'labels'
]
columns_to_ignore = [c for c in dataset.column_names if c not in loader_columns]
columns_to_ignore

['ind',
 'activity_label',
 'ctx_a',
 'ctx_b',
 'ctx',
 'endings',
 'source_id',
 'split',
 'split_type',
 'label']

In [20]:
encoder = HellaSwagInputsEncoder(tokenizer=tokenizer, max_seq_length=384)

dataset_transformed = dataset.map(
    encoder,
    batched=True,
    remove_columns=columns_to_ignore,
)

Map:   0%|          | 0/10003 [00:00<?, ? examples/s]

Exception ignored in: <function tqdm.__del__ at 0x7f02af35e440>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/tqdm/std.py", line 1162, in __del__
    self.close()
  File "/opt/conda/lib/python3.10/site-packages/tqdm/notebook.py", line 288, in close
    self.disp(bar_style='danger', check_delay=False)
AttributeError: 'tqdm_notebook' object has no attribute 'disp'


In [21]:
BATCH_SIZE = 32

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)
test_dl = DataLoader(
    dataset_transformed, batch_size=BATCH_SIZE, shuffle=False, collate_fn=data_collator, pin_memory=True
)
test_dl_src = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

In [22]:
next(iter(test_dl)).keys()

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


dict_keys(['input_ids', 'attention_mask', 'labels'])

In [32]:
with torch.no_grad():
    for batch, batch_src in tqdm(zip(test_dl, test_dl_src)):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        predictions = model.generate(
            input_ids=batch["input_ids"],
            max_length=100,
            attention_mask=batch["attention_mask"],
        )
        decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        for i, (pred, src_ctx, src_endings) in enumerate(zip(decoded_predictions, batch_src["ctx"], batch_src["endings"])):
            if i == 4:
                break
            print(f'Text: "{src_ctx}"')
            print(f'Predicted ending:\n{pred}')
            print(f"\nCorrect endings:\n\t - " + "\n\t - ".join(src_endings))
            print("\n", "-" * 80, "\n")

        break

0it [00:00, ?it/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Text: "A man is being pulled on a water ski as he floats in the water casually. he"
Predicted ending:
A man is being pulled on a water ski as he floats in the water casually. he is wearing a ski mask and a ski mask with a ski mask on. He is also carrying a black backpack with him when he is not in the water. He is also carrying a black backpack with him when he is not in the water. Photo: Supplied

A man

Correct endings:
	 - mounts the water ski and tears through the water at fast speeds.
	 - are water boarding in a river.
	 - run out to where the javelin lands again.
	 - do the same action but in different locations.
	 - , another man does not throw his javelin.
	 - # 1, but drops his javelin.
	 - in windsurfer gear sits off to the side of the table talking.
	 - puts a bronze medal on the third thrower.
	 - picks up the ingredients and puts them on the baking sheet.
	 - then adds toasted graham bears to a pot filled with water.
	 - then pours the rest of it in a bowl, setting it onto