# Phi2 - 0-shot classification

In [1]:
import torch
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    StoppingCriteria, 
    StoppingCriteriaList, 
    TextIteratorStreamer
    )

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from samples import samples

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Your device is", device)

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/phi-2", 
    device_map="auto", 
    torch_dtype="auto" if device == "cuda" else torch.float, 
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/phi-2", trust_remote_code=True)

Your device is cuda


Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.76it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
class StopOnTokens(StoppingCriteria):
    """Stops the model if it produces an 'end of text' token"""
    def __call__(self, input_ids: torch.LongTensor, 
                 scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [50256, 198] # <|endoftext|> and EOL
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

stop = StopOnTokens()

def run_prompt(prompt):
    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False).to(device)
        outputs = model.generate(**inputs, max_length=250, stopping_criteria=[stop])
        text = tokenizer.batch_decode(outputs)[0]
    return text

In [5]:
prompt = f"""
Instruct: Classify the following text in one of the following categories: ["support", "sales", "joke"]
Text: {samples['support'][0]}
Output:
""".strip()

In [6]:
print(run_prompt(prompt))

Instruct: Classify the following text in one of the following categories: ["support", "sales", "joke"]
Text: Hi, I have an issue with my order. It hasn't arrived, and the delivery date has passed.
Output: sales



In [7]:
template = """
Instruct: Classify the following text in one of the following categories: ["support", "sales", "joke"]. Output only the name of the category.
+ "support" for customer support texts
+ "sales" for sales and commercial texts
+ "joke" for jokes, funny or comedy like texts
Text: {}
Output:
""".strip()

In [8]:
candidate_labels = ["customer support", "sales and comercial", "joke"]
labels_short = ["support", "sales", "joke"]

def long_to_short(label):
    return labels_short[candidate_labels.index(label)]

In [11]:
for k in samples.keys():
    d = {k: 0 for k in labels_short}
    for s in samples[k]:
        text = template.format(s)
        label = run_prompt(text).split()[-1]
        d[label.lower()] += 1
    print(d)

{'support': 21, 'sales': 4, 'joke': 0}
{'support': 3, 'sales': 22, 'joke': 0}
{'support': 0, 'sales': 0, 'joke': 25}
