In [11]:
import os
import json
import torch
import tiktoken
from urllib.request import urlretrieve

In [3]:
hellaswags = {
    "train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl",
    "val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl",
    "test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl",
}

In [4]:
os.makedirs("../data/hellaswag", exist_ok=True)

# download validation set
for split, url in hellaswags.items():
    out_path = os.path.join("../data/hellaswag", f"hellaswag_{split}.jsonl")
    if not os.path.exists(out_path):
        urlretrieve(url, out_path)
        print(f"Downloaded {out_path}")
    else:
        print(f"File {out_path} already exists, skipping download.")

File ../data/hellaswag/hellaswag_train.jsonl already exists, skipping download.
File ../data/hellaswag/hellaswag_val.jsonl already exists, skipping download.
File ../data/hellaswag/hellaswag_test.jsonl already exists, skipping download.


In [93]:
# Read all lines from the validation set
with open("../data/hellaswag/hellaswag_val.jsonl", "r") as f:
    lines = f.readlines()
    print(f"Read {len(lines)} lines from validation set.")
examples = [json.loads(line) for line in lines]
print(f"Loaded {len(examples)} examples from validation set.")

Read 10042 lines from validation set.
Loaded 10042 examples from validation set.


In [None]:
class HellaSwagRenderer:
    def __init__(self):
        self.tok = tiktoken.get_encoding("gpt2")
    
    def render_example(self, example):
        ctx, ends, label = example["ctx"], example["endings"], example["label"]
        ctx_ids = self.tok.encode(ctx)
        ending_ids = [self.tok.encode(" " + ending) for ending in ends]  # " " because gpt2 tokenizer
        result_length = max(len(eids) for eids in ending_ids) + len(ctx_ids)
        tokens = torch.zeros((len(ends), result_length), dtype=torch.long)
        mask = torch.zeros((len(ends), result_length), dtype=torch.long)
        for i, eids in enumerate(ending_ids):
            tokens[i, :len(ctx_ids)] = torch.tensor(ctx_ids, dtype=torch.long)
            tokens[i, len(ctx_ids) : len(ctx_ids)+len(eids)] = torch.tensor(eids, dtype=torch.long)
            mask[i, len(ctx_ids) : len(ctx_ids)+len(eids)] = 1
        return tokens, mask, label

    def print_example(self, example):
        tokens, mask, label = self.render_example(example)
        ctx = example["ctx"]
        endings = example["endings"]
        label = example["label"]
        
        print(f"Context: |{ctx}|")
        print(f"Label: {label}")
        for b in range(len(endings)):
            print(f"Ending {b}: |{endings[b]}|")

            for i in range(tokens.size(1)):
                tt = tokens[b,i].item()
                mm = mask[b,i].item()
                word = self.tok.decode([tt])
                print(f"i={i:2}: tok={tt:6} mask={mm} dec=|{word}|")


In [None]:
hella_swag_renderer = HellaSwagRenderer()

In [106]:
hella_swag_renderer.print_example(examples[0])  # print first example

Context: |A man is sitting on a roof. he|
Label: 3
Ending 0: |is using wrap to wrap a pair of skis.|
i= 0: tok=    32 mask=0 dec=|A|
i= 1: tok=   582 mask=0 dec=| man|
i= 2: tok=   318 mask=0 dec=| is|
i= 3: tok=  5586 mask=0 dec=| sitting|
i= 4: tok=   319 mask=0 dec=| on|
i= 5: tok=   257 mask=0 dec=| a|
i= 6: tok=  9753 mask=0 dec=| roof|
i= 7: tok=    13 mask=0 dec=|.|
i= 8: tok=   339 mask=0 dec=| he|
i= 9: tok=   318 mask=1 dec=| is|
i=10: tok=  1262 mask=1 dec=| using|
i=11: tok= 14441 mask=1 dec=| wrap|
i=12: tok=   284 mask=1 dec=| to|
i=13: tok= 14441 mask=1 dec=| wrap|
i=14: tok=   257 mask=1 dec=| a|
i=15: tok=  5166 mask=1 dec=| pair|
i=16: tok=   286 mask=1 dec=| of|
i=17: tok=  1341 mask=1 dec=| sk|
i=18: tok=   271 mask=1 dec=|is|
i=19: tok=    13 mask=1 dec=|.|
Ending 1: |is ripping level tiles off.|
i= 0: tok=    32 mask=0 dec=|A|
i= 1: tok=   582 mask=0 dec=| man|
i= 2: tok=   318 mask=0 dec=| is|
i= 3: tok=  5586 mask=0 dec=| sitting|
i= 4: tok=   319 mask=0 dec=| o

# Evaluate Huggingface GPT2

In [None]:
import torch.nn.functional as F
from transformers import GPT2LMHeadModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_hf = GPT2LMHeadModel.from_pretrained("gpt2")
model_hf.to(device)
model_hf.eval()


model_labels = []
num_total = 0
num_correct = 0

for i in range(len(examples)):
    example = examples[i]

    tokens, mask, label = hella_swag_renderer.render_example(example)
    tokens = tokens.to(device)
    mask = mask.to(device)
    with torch.no_grad():
        logits = model_hf(tokens).logits
        logits_shifted = logits[:, :-1, :].contiguous()
        tokens_shifted = tokens[:, 1:].contiguous()
        mask_shifted = mask[:, 1:].contiguous()
        B, T_1, V = logits_shifted.shape
        losses_shifted = F.cross_entropy(
            logits_shifted.view(B*T_1, V),
            tokens_shifted.view(B*T_1),
            reduction='none'
        ).view(B, T_1)
        losses_shifted_masked = losses_shifted * mask_shifted
        losses_avg = losses_shifted_masked.sum(dim=1) / mask_shifted.sum(dim=1)
        model_label = torch.argmin(losses_avg).item()
        model_labels.append(model_label)
        num_total += 1
        if model_label == label:
            num_correct += 1

In [105]:
print("Evaluation results:")
accuracy = num_correct / num_total if num_total > 0 else 0
print(f"Accuracy: {accuracy:.2%} ({num_correct}/{num_total})")

Evaluation results:
Accuracy: 29.55% (2967/10042)
