In [1]:
import os
import pandas as pd

os.chdir("../../")
from scripts.llm import get_num_of_tokens, get_completion
from scripts.utils import save_obj_as_pickle, read_obj_from_pickle
from scripts.data import make_database, make_prompts_for_clf

pd.set_option("display.max_colwidth", 150)

In [2]:
def loadData(fp, sample_size=1000):
    df = pd.read_csv(fp)
    df.rename(columns={"Class Index": "label", "Description": "text"}, inplace=True)
    df.label = df.label.map({1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tech"})
    # we only select short texts to avoid going beyond some LLM's token limit when constructing 100-problem prompts
    df = df[df.text.apply(get_num_of_tokens) <= 40].sample(sample_size, random_state=234)
    df = df[["text", "label"]].reset_index(drop=True)
    return df

clf_task = "AGNews"
test_fp = f"data/raw/text classification/{clf_task}/test.csv"
test_df = loadData(test_fp)

dev_fp = f"data/raw/text classification/{clf_task}/train.csv"
dev_df = loadData(dev_fp)

test_df.head()

Unnamed: 0,text,label
0,The New York Jets and quarterback Chad Pennington are looking to finalize a contract extension by next Wednesday.,Sports
1,TEHRAN: Iran added one more missile to its military arsenal and the defense minister said Saturday his country was ready to confront any external ...,World
2,AFP - The second major airlift of Vietnamese Montagnards who fled to Cambodia's remote jungles after April anti-government protests will begin at ...,World
3,"Out of money, out of patience, out of time, and for the foreseeable future, out of business.",Sports
4,"NEW YORK, September 3 (New Ratings) - The European Union has reportedly made significant progress in settling its prolonged antitrust case against...",Business


### Database

Sources for the prompt data

In [3]:
# num_instance: Number of instances to compose multi-problem prompts. Each instance contains multiple problems
# max_instance_size: Maximum number of problems sampled from the benchmark dataset to compose an instance 
# dev_df can be used for purposes such as testing the prompts or for generating exemplars 
database = make_database(test_df, dev_df, num_instance=100, max_instance_size=100)
database.keys()

dict_keys(['num_instance', 'max_instance_size', 'labels', 'testData', 'testInstances', 'devData', 'devInstances'])

#### 0-shot

In [4]:
database["promptTemplates"] = dict()
database["promptTemplates"]["0-shot"] = dict()

SingleClf = "Classify which news category the following line of text belongs to among the following four categories: " \
            "'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n" \
            "Text: $text\nNews category:"

BatchClf = "Classify which news category each of the $num following lines of text belongs to " \
           "among the following four categories: 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n" \
           "Texts, one per line:\n\n$texts\n\n" \
           "News categories for each of the $num lines of text, one per line:\n"

SelectOne = "This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n" \
            "Go over the $num lines of text below and list the index numbers of the lines that can be classified as $category according to the following instructions:\n" \
            "If none of the texts can be classified as $category, write 'None.'\n" \
            "If all the texts can be classified as $category, write 'All.'\n" \
            "Otherwise, provide the index numbers of the texts that can be classified as $category.\n\n" \
            "Output your responses in JSON format with the key '$category'.\nA formatted example output is provided below.\n" \
            "{'$category': [None/All or index numbers of the texts that can be classified as $category]}\n\n" \
            "Texts, one per line:\n\n$texts\n\n" \
            "JSON output:\n"

SelectAll = "This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n" \
            "Go over the $num lines of text below and list the index numbers of the lines that belong to each category according to the following instructions:\n" \
            "If none of the texts can be classified as a particular category, write 'None.'\n" \
            "If all the texts can be classified as a particular category, write 'All.'\n" \
            "Otherwise, provide the index numbers of the texts that can be classified as the category.\n" \
            "Output your responses in JSON format with the following keys: 'business,' 'sports,' 'world,' and 'sci/tech.'\n" \
            "A formatted example output is provided below.\n" \
            "{'business': [None/All or index numbers of texts in 'business' category], 'sports': [None/All or index numbers of texts in 'sports' category], " \
            "'world': [None/All or index numbers of texts in 'world' category], 'sci/tech': [None/All or index numbers of texts in sci/tech category]}\n\n" \
            "Texts, one per line:\n\n$texts\n\n" \
            "JSON output:\n" 

tasks = ["SingleClf", "BatchClf", "SelectOne", "SelectAll"]
promptTemplates = [SingleClf, BatchClf, SelectOne, SelectAll]

for task, tmp in zip(tasks, promptTemplates):
    database["promptTemplates"]["0-shot"][task] = tmp

In [5]:
os.makedirs("data/databases/text classification/", exist_ok=True)
save_obj_as_pickle(database, f"data/databases/text classification/{clf_task}.pkl")

Saved object to data/databases/text classification/AGNews.pkl


#### Test Prompts

- The main purpose is to check if LLMs can output the desired formats given the prompts 

In [6]:
dev = []
num_instance = 2

taskSizes = [5, 10]
for propmtMode in ["0-shot"]:
    for task in tasks:
        if task == "SingleClf":
            dev.append(make_prompts_for_clf(database, task, "dev", propmtMode)[:num_instance])
            continue

        for taskSize in taskSizes:  
            dev.append(make_prompts_for_clf(database, task, "dev", propmtMode, taskSize, attr="category", 
                                            label_attr_converter=None, num_instance=num_instance))

dev = pd.concat(dev).reset_index(drop=True)

In [7]:
for p in dev[(dev["taskSize"] <= 5) & (dev["taskIndex"] == 1)].prompt:
    print(p)
    print("-"*50)
    print()

Classify which news category the following line of text belongs to among the following four categories: 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'

Text: The last time they saw each other, John Kerry was sitting in the Red Sox owner #39;s box on the eve of the Democratic National Convention and the front 
News category:
--------------------------------------------------

Classify which news category each of the 5 following lines of text belongs to among the following four categories: 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'

Texts, one per line:

1. PAUL Bigley, brother of the Iraqi hostage Ken, last night appealed to Tony Blair to  quot;pick up the bloody phone quot; to save his brother #39;s life.
2.  quot;The PIC is being sold to consumers through service providers, like VSNL, which is a different approach for computing devices, quot; said Shane Rau of IDC.
3. Global IT services powerhouse EDS extended it contract with Opsware (Quote, Chart) for data center automation soft

In [8]:
dev["preds"] = dev.prompt.apply(get_completion)
dev

Unnamed: 0,taskIndex,prompt,answer,targetLabel,task,#shot,CoT,taskSize,preds
0,1,"Classify which news category the following line of text belongs to among the following four categories: 'Business,' 'Sports,' 'World,' and 'Sci/Te...",Sports,,SingleClf,0,False,1,Sports
1,2,"Classify which news category the following line of text belongs to among the following four categories: 'Business,' 'Sports,' 'World,' and 'Sci/Te...",Sports,,SingleClf,0,False,1,Sports
2,1,"Classify which news category each of the 5 following lines of text belongs to among the following four categories: 'Business,' 'Sports,' 'World,' ...","[World, Sci/Tech, Sci/Tech, Sports, World]",,BatchClf,0,False,5,1. World\n2. Sci/Tech\n3. Business\n4. Sports\n5. World
3,2,"Classify which news category each of the 5 following lines of text belongs to among the following four categories: 'Business,' 'Sports,' 'World,' ...","[Sci/Tech, World, Business, Business, Business]",,BatchClf,0,False,5,1. World\n2. World\n3. Business\n4. Business\n5. Business
4,1,"Classify which news category each of the 10 following lines of text belongs to among the following four categories: 'Business,' 'Sports,' 'World,'...","[World, Sci/Tech, Sci/Tech, Sports, World, Sci/Tech, Sci/Tech, World, World, Sports]",,BatchClf,0,False,10,1. World\n2. Sci/Tech\n3. Business\n4. Sports\n5. World\n6. World\n7. Sci/Tech\n8. World\n9. World\n10. Sports
5,2,"Classify which news category each of the 10 following lines of text belongs to among the following four categories: 'Business,' 'Sports,' 'World,'...","[Sci/Tech, World, Business, Business, Business, World, Sci/Tech, Sci/Tech, World, Business]",,BatchClf,0,False,10,1. World\n2. World\n3. Business\n4. Business\n5. Business\n6. World\n7. Sci/Tech\n8. Sci/Tech\n9. World\n10. Business
6,1,"This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n...",{None},Business,SelectOne,0,False,5,"{'business': [2, 3]}"
7,1,"This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n...","{2, 3}",Sci/Tech,SelectOne,0,False,5,"{'sci/tech': [2, 3]}"
8,1,"This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n...",{4},Sports,SelectOne,0,False,5,{'sports': [None]}
9,1,"This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n...","{1, 5}",World,SelectOne,0,False,5,"{\n ""world"": [5]\n}"


### Make prompts

In [9]:
database = read_obj_from_pickle(f"data/databases/text classification/{clf_task}.pkl")

Read object from data/databases/text classification/AGNews.pkl


In [10]:
out = []
num_instance = 100

taskSizes = [5, 10, 20, 50, 100]
for propmtMode in ["0-shot"]:
    for task in tasks:

        if task == "SingleClf":
            out.append(make_prompts_for_clf(database, task, "test", propmtMode))
            continue

        for taskSize in taskSizes:
            out.append(make_prompts_for_clf(database, task, "test", propmtMode, taskSize, attr="category", 
                                            label_attr_converter=None, num_instance=num_instance))

out = pd.concat(out)
out.reset_index(drop=True, inplace=True)

os.makedirs("results/text classification/", exist_ok=True)
out.to_json(f"results/text classification/{clf_task}.json", orient="records", lines=True)                                             

In [11]:
out.task.value_counts()

SelectOne    2000
SingleClf    1000
BatchClf      500
SelectAll     500
Name: task, dtype: int64

In [12]:
out.copy()[(out.taskIndex == 1) & (out.taskSize <= 5)]

Unnamed: 0,taskIndex,prompt,answer,targetLabel,task,#shot,CoT,taskSize
0,1,"Classify which news category the following line of text belongs to among the following four categories: 'Business,' 'Sports,' 'World,' and 'Sci/Te...",Sports,,SingleClf,0,False,1
1000,1,"Classify which news category each of the 5 following lines of text belongs to among the following four categories: 'Business,' 'Sports,' 'World,' ...","[World, Sports, Business, Sports, Sports]",,BatchClf,0,False,5
1500,1,"This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n...",{3},Business,SelectOne,0,False,5
1501,1,"This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n...",{None},Sci/Tech,SelectOne,0,False,5
1502,1,"This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n...","{2, 4, 5}",Sports,SelectOne,0,False,5
1503,1,"This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n...",{1},World,SelectOne,0,False,5
3500,1,"This is a news classification task in which each line of text belongs to one of four categories 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n...","{'business': {3}, 'sci/tech': {'None'}, 'sports': {2, 4, 5}, 'world': {1}}",,SelectAll,0,False,5
