In [1]:
import os
import pandas as pd

os.chdir("../../../")
from scripts.llm import get_completion
from scripts.utils import read_obj_from_pickle
from scripts.data import make_prompts_for_clf

pd.set_option("display.max_colwidth", 150)

In [2]:
# simplifiy the index selection tasks by simply replacing each text with its gold label
clf_task = "AGNews"
database = read_obj_from_pickle(f"data/databases/text classification/{clf_task}.pkl")
database["testData"]["texts"] = database["testData"]["labels"]
database["devData"]["texts"] = database["devData"]["labels"]

database["promptTemplates"] = dict()
database["promptTemplates"]["0-shot"] = dict()

Read object from data/databases/text classification/AGNews.pkl


In [6]:
SelectOne = "In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n" \
            "Go over the $num lines of text below and list the index numbers of the lines that contains the word '$category' according to the following instructions:\n" \
            "If none of the texts contain the word '$category', write 'None.'\n" \
            "If all the texts contain the word '$category', write 'All.'\n" \
            "Otherwise, provide the index numbers of the texts that contain the word '$category'.\n\n" \
            "Output your responses in JSON format with the key '$category'.\nA formatted example output is provided below.\n" \
            "{'$category': [None/All or index numbers of the texts containing '$category']}\n\n" \
            "Texts, one per line:\n\n$texts\n\n" \
            "JSON output:\n"

SelectAll = "In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\n" \
            "Go over the $num lines of text below and list the index numbers of the lines that contains each word according to the following instructions:\n" \
            "If none of the texts contain a particular word, write 'None.'\n" \
            "If all the texts contain a particular word, write 'All.'\n" \
            "Otherwise, provide the index numbers of the texts that contain each word.\n" \
            "Output your responses in JSON format with the following keys: 'business,' 'sports,' 'world,' and 'sci/tech.'\n" \
            "A formatted example output is provided below.\n" \
            "{'business': [None/All or index numbers of texts containing 'Business'], 'sports': [None/All or index numbers of texts containing 'Sports'], " \
            "'world': [None/All or index numbers of texts containing 'World'], 'sci/tech': [None/All or index numbers of texts containing 'Sci/Tech']}\n\n" \
            "Texts, one per line:\n\n$texts\n\n" \
            "JSON output:\n" 

tasks = ["SelectOne", "SelectAll"]
promptTemplates = [SelectOne, SelectAll]

for task, tmp in zip(tasks, promptTemplates):
    database["promptTemplates"]["0-shot"][task] = tmp

#### Test Prompts

- The main purpose is to check if LLMs can output the desired formats given the prompts 

In [7]:
dev = []
num_instance = 2

taskSizes = [5, 10]
for propmtMode in ["0-shot"]:
    for task in tasks:
        if task == "SingleClf":
            dev.append(make_prompts_for_clf(database, task, "dev", propmtMode)[:num_instance])
            continue

        for taskSize in taskSizes:  
            dev.append(make_prompts_for_clf(database, task, "dev", propmtMode, taskSize, attr="category", 
                                            label_attr_converter=lambda t: t, num_instance=num_instance))

dev = pd.concat(dev).reset_index(drop=True)

In [8]:
for p in dev[(dev["taskSize"] <= 5) & (dev["taskIndex"] == 1)].prompt:
    print(p)
    print("-"*50)
    print()

In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'

Go over the 5 lines of text below and list the index numbers of the lines that contains the word 'Business' according to the following instructions:
If none of the texts contain the word 'Business', write 'None.'
If all the texts contain the word 'Business', write 'All.'
Otherwise, provide the index numbers of the texts that contain the word 'Business'.

Output your responses in JSON format with the key 'Business'.
A formatted example output is provided below.
{'Business': [None/All or index numbers of the texts containing 'Business']}

Texts, one per line:

1. World
2. Sci/Tech
3. Sci/Tech
4. Sports
5. World

JSON output:

--------------------------------------------------

In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'

Go over the 5 lines of text below and list the index numbers of the lines that contains the word 'Sci/

In [9]:
dev["preds"] = dev.prompt.apply(get_completion)
dev

Unnamed: 0,taskIndex,prompt,answer,targetLabel,task,#shot,CoT,taskSize,preds
0,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...",{None},Business,SelectOne,0,False,5,"{\n ""Business"": ""None""\n}"
1,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...","{2, 3}",Sci/Tech,SelectOne,0,False,5,"{\n ""Sci/Tech"": [2, 3]\n}"
2,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...",{4},Sports,SelectOne,0,False,5,{'Sports': [4]}
3,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...","{1, 5}",World,SelectOne,0,False,5,"{'World': [1, 5]}"
4,2,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...","{3, 4, 5}",Business,SelectOne,0,False,5,"{\n ""Business"": [3, 4, 5]\n}"
5,2,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...",{1},Sci/Tech,SelectOne,0,False,5,"{\n ""Sci/Tech"": [1]\n}"
6,2,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...",{None},Sports,SelectOne,0,False,5,{'Sports': None}
7,2,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...",{2},World,SelectOne,0,False,5,{'World': [2]}
8,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 10 lines of text below an...",{None},Business,SelectOne,0,False,10,{'Business': None}
9,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 10 lines of text below an...","{2, 3, 6, 7}",Sci/Tech,SelectOne,0,False,10,"{\n ""Sci/Tech"": [2, 3, 6, 7]\n}"


### Make prompts

In [10]:
out = []
num_instance = 100

taskSizes = [5, 10, 20, 50, 100]
for propmtMode in ["0-shot"]:
    for task in tasks:

        if task == "SingleClf":
            out.append(make_prompts_for_clf(database, task, "test", propmtMode))
            continue

        for taskSize in taskSizes:
            out.append(make_prompts_for_clf(database, task, "test", propmtMode, taskSize, attr="category", 
                                            label_attr_converter=lambda t: t, num_instance=num_instance))

out = pd.concat(out)
out.reset_index(drop=True, inplace=True)
os.makedirs("results/text classification/", exist_ok=True)
out.to_json(f"results/text classification/{clf_task}-simplified_index_selection_only.json", orient="records", lines=True)                                             

In [11]:
out.task.value_counts()

SelectOne    2000
SelectAll     500
Name: task, dtype: int64

In [12]:
out.copy()[(out.taskIndex == 1) & (out.taskSize <= 5)]

Unnamed: 0,taskIndex,prompt,answer,targetLabel,task,#shot,CoT,taskSize
0,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...",{3},Business,SelectOne,0,False,5
1,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...",{None},Sci/Tech,SelectOne,0,False,5
2,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...","{2, 4, 5}",Sports,SelectOne,0,False,5
3,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...",{1},World,SelectOne,0,False,5
2000,1,"In this task, each line of text contains one of four words 'Business,' 'Sports,' 'World,' and 'Sci/Tech.'\n\nGo over the 5 lines of text below and...","{'business': {3}, 'sci/tech': {'None'}, 'sports': {2, 4, 5}, 'world': {1}}",,SelectAll,0,False,5
