In [1]:
import sys

module_path = "../src"

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
# Load dataset
from dataset import get_dataset
dataset = get_dataset()

In [3]:
# Load libraries
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForSeq2SeqLM, AutoTokenizer
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

model_id = "google/flan-t5-xl"

cuda


In [4]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Shuffle and pick subset from dataset
subset = dataset['train'].shuffle(seed=442333+424714).select(range(5000))

In [5]:
# Setup model
model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
# Define prompting func
def ask(text, max_new_tokens = 20):
    inputs = tokenizer(text, return_tensors="pt").to(device)

    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [7]:
text = "I hate this film! Is this a positive or negative review?"
ask(text)

'negative'

In [8]:
text = "I think she didn't send her homework! Will girl pass?"
ask(text)

'no'

In [9]:
text = """
When patient had stroke, its code 22.
If patient had low red blood cells, it's code 27.
When patient had heart failure, it is code 23.

Patient description: "Man, 30 years old, exhibits signs of cardiac insufficiency, such as shortness of breath, fatigue, and leg swelling. Test results showed below normal count of erythrocytes. Additional assessment and medical history are necessary for a thorough diagnosis and tailored treatment approach."
Which multiple codes should be used for this description?
"""
ask(text)

'23 and 27'

In [12]:
categories = ["World", "Sports", "Business", "Sci/Tech"]

from tqdm.notebook import tqdm

correct = 0
total = 0
for example in tqdm(subset):
    text = example['text']
    category = example['label']
    
    context = f"You have {len(categories)} categories: {', '. join(categories)}. Use exactly those categories. Decide which category the following text belongs to: "

    prompt = f"{context}'{text}'."

    answer = ask(prompt)

    # model could respond with unknown string, but we trust
    index = categories.index(answer)

    if index == category:
        correct += 1
    total += 1

print(f"Accuracy: {correct*100/total}%")

  0%|          | 0/5000 [00:00<?, ?it/s]

Accuracy: 93.32%
