In [1]:
import sys

module_path = "../src"

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
# Load dataset
from dataset import get_dataset

dataset = get_dataset()

In [3]:
# Load libraries
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSeq2SeqLM, AutoTokenizer
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

model_id = "google/flan-t5-xl"

cuda


In [4]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Shuffle and pick subset from dataset
subset = dataset["test"].shuffle(seed=442333 + 424714)

In [5]:
# Setup model
model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
# Define prompting func
def ask(text, max_new_tokens=20):
    inputs = tokenizer(text, return_tensors="pt").to(device)

    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [9]:
categories = ["World", "Sports", "Business", "Sci/Tech"]

from tqdm.notebook import tqdm

correct = 0
total = 0
for example in tqdm(subset):
    text = example["text"]
    category = example["label"]

    context = f"You have {len(categories)} categories: {', '. join(categories)}. Use exactly those categories. Decide which category the following text belongs to: "

    prompt = f"{context}'{text}'."

    answer = ask(prompt)

    if total < 3:
        print("Task for model:")
        print(prompt)
        print(f"Answer: {answer}, expected answer: {categories[category]}")
        print()

    # model could respond with unknown string, but we trust
    index = categories.index(answer)

    if index == category:
        correct += 1
    total += 1

print(f"Accuracy: {correct*100/total}%")

  0%|          | 0/7600 [00:00<?, ?it/s]

Task for model:
You have 4 categories: World, Sports, Business, Sci/Tech. Use exactly those categories. Decide which category the following text belongs to: 'Trapeze Software Eases Services Delivery (Ziff Davis) Ziff Davis - Trapeze Networks this week will announce upgrades to its wireless LAN switch software.'.
Answer: Sci/Tech, expected answer: Sci/Tech

Task for model:
You have 4 categories: World, Sports, Business, Sci/Tech. Use exactly those categories. Decide which category the following text belongs to: 'Update 4: Belo to Cut 250 Jobs, Mostly in Dallas Media owner Belo Corp. said Wednesday that it would cut 250 jobs, more than half of them at its flagship newspaper, The Dallas Morning News, and that an internal investigation into circulation overstatements '.
Answer: Business, expected answer: Business

Task for model:
You have 4 categories: World, Sports, Business, Sci/Tech. Use exactly those categories. Decide which category the following text belongs to: 'China's Lenovo in ta