This notebook generates a prompt prefix for every task. Prompt prefix consists of instruction and demonstrations.

The instruction is obtained from [PromptSource](https://github.com/bigscience-workshop/promptsource) using the notebook `get_instructions.ipynb`, and then manually filtered.

The demonstrations are sampled from the `dev` split.

# See statistics of each tasks

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from transformers import T5Tokenizer

CLF_TASKS = [
    "emo", "emotion", "tweet_eval-emoji", "tweet_eval-emotion", "tweet_eval-hate", "tweet_eval-irony", "tweet_eval-offensive", 
    "tweet_eval-sentiment", "tweet_eval-stance_abortion", "tweet_eval-stance_atheism", "tweet_eval-stance_climate", 
    "tweet_eval-stance_feminist", "tweet_eval-stance_hillary", "climate_fever", "health_fact", "kilt_fever", "liar", "tab_fact", 
    "ethos-directed_vs_generalized", "ethos-disability", "ethos-gender", "ethos-national_origin", "ethos-race", "ethos-religion", 
    "ethos-sexual_orientation", "hate_speech_offensive", "hate_speech18", "hatexplain", "anli", "glue-mnli", "glue-qnli", 
    "glue-rte", "glue-wnli", "scitail", "sick", "superglue-cb", "superglue-rte", "ade_corpus_v2-classification", "circa", 
    "discovery", "glue-cola", "google_wellformed_query", "onestop_english", "scicite", "sms_spam", "superglue-wic", "superglue-wsc", 
    "trec", "trec-finegrained", "wiki_auto", "wiki_qa", "glue-mrpc", "glue-qqp", "medical_questions_pairs", "paws", 
    "amazon_polarity", "financial_phrasebank", "glue-sst2", "imdb", "poem_sentiment", "rotten_tomatoes", "yelp_polarity", "ag_news", 
    "dbpedia_14", "yahoo_answers_topics"]  # All 65 classification tasks.
QA_TASKS = [
    "boolq", "mc_taco", "freebase_qa", "jeopardy", "kilt_hotpotqa", "kilt_nq", "kilt_trex", "kilt_zsre", "lama-conceptnet", 
    "lama-google_re", "lama-squad", "lama-trex", "numer_sense", "search_qa", "squad-no_context", "web_questions", "eli5-askh", 
    "eli5-asks", "eli5-eli5", "adversarialqa", "biomrc", "duorc", "hotpot_qa", "quoref", "ropes", "squad-with_context", 
    "superglue-record", "tweet_qa", "ai2_arc", "aqua_rat", "codah", "commonsense_qa", "cosmos_qa", "dream", "hellaswag", "math_qa", 
    "openbookqa", "qasc", "quail", "quarel", "quartz-no_knowledge", "quartz-with_knowledge", "race-high", "race-middle", "sciq", 
    "social_i_qa", "superglue-copa", "superglue-multirc", "swag", "wino_grande", "wiqa"]  # All 51 QA tasks.
CG_TASKS = [
    "empathetic_dialogues", "kilt_wow", "spider", "wiki_bio", "wiki_split", "wikisql", "aeslc", "gigaword", "multi_news", 
    "reddit_tifu-title", "reddit_tifu-tldr", "samsum", "xsum"]  # All 13 CG tasks.
OTHER_TASKS = [
    "acronym_identification", "art", "common_gen", "crawl_domain", "crows_pairs", "definite_pronoun_resolution", "e2e_nlg_cleaned",
    "limit", "piqa", "proto_qa", "qa_srl", "cos_e", "blimp-anaphor_gender_agreement", "blimp-anaphor_number_agreement",
    "blimp-determiner_noun_agreement_with_adj_irregular_1", "blimp-ellipsis_n_bar_1", "blimp-ellipsis_n_bar_2",
    "blimp-existential_there_quantifiers_1", "blimp-irregular_past_participle_adjectives",
    "blimp-sentential_negation_npi_licensor_present", "blimp-sentential_negation_npi_scope", "blimp-wh_questions_object_gap",
    "app_reviews", "mocha", "yelp_review_full", "ade_corpus_v2-dosage", "ade_corpus_v2-effect"
]  # 27 other tasks, 4 tasks are omitted (has no instructions)
TASK_NAMES = CG_TASKS
T5_MODEL = "t5-base"
MAX_INPUT_LEN = 1024


def get_task_prefixes(data_path: str, task_name: str) -> list:
    """Returns all task prefixes (e.g., adversarialqa_32_13) of a task."""
    files = sorted(os.listdir(os.path.join(data_path, task_name)))
    prefixes = []
    for filename in files:
        if not filename.endswith(".tsv"):
            continue
        prefix = "_".join(filename.split("_")[:-1])
        if prefix not in prefixes:
            prefixes.append(prefix)
    return prefixes

def get_all_examples(task_name: str) -> list:
    examples = []
    count = {}
    prefix = get_task_prefixes("data/crossfit", task_name)[0]
    for split in ["train", "dev", "test"]:
        suffix = "_" + split + ".tsv"
        with open(os.path.join("data/crossfit", task_name, prefix + suffix)) as fin:
            lines = fin.readlines()
        for line in lines:
            d = line.strip().split("\t")
            examples.append([d[0], d[1:]])
        count[split] = len(lines)
    return examples, count

tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, model_max_length=MAX_INPUT_LEN)

data = []
for task_name in TASK_NAMES:
    examples, count = get_all_examples(task_name)
    tokenized_input = tokenizer([ex[0] for ex in examples])
    tokenized_target = tokenizer([x for ex in examples for x in ex[1]])
    lengths = [len(x) for x in tokenized_input["input_ids"]]
    max_target_len = np.max([len(x) for x in tokenized_target["input_ids"]])

    data.append([
        task_name, len(examples), count["train"], count["dev"], count["test"], max_target_len, np.min(lengths), np.max(lengths),
        np.percentile(lengths, 25), np.percentile(lengths, 50), np.percentile(lengths, 75), lengths
    ])

stats_df = pd.DataFrame(data,
                        columns=["task_name", "n_examples", "n_train", "n_dev", "n_test", "max_target_len",
                                 "min_len", "max_len", "percentile25", "percentile50", "percentile75", "all_lengths"])

with pd.option_context("display.max_rows", None, "display.max_columns", None):
    display(stats_df[[
        "task_name", "n_examples", "n_train", "n_dev", "n_test", "max_target_len",
        "min_len", "max_len", "percentile25", "percentile50", "percentile75"]])

Tasks which are removed:
* 7 classification tasks:
  * `amazon_polarity`, `yahoo_answers_topics`, `yelp_polarity` (too long for k=8)
  * `tab_fact`, `onestop_english`, `imdb` (too long even for k=4)
  * `tweet_eval-emoji` (T5 cannot recognize emojis)
* 6 QA tasks:
  * `biomrc`, `duorc`, `quoref`, `quail`, `race-high`, `superglue-multirc` (too long for k=3)
* 4 CG tasks:
  * `multi_news`, `reddit_tifu-title`, `reddit_tifu-tldr`, `xsum` (too long for k=3)
* 4 other tasks:
  * `aslg_pc12`, `break-QDMR`, `break-QDMR-high-level`, `kilt_ay2` (no instructions & don't understand the tasks)

In [None]:
stats_df[(stats_df.percentile75 > 120)][[
        "task_name", "n_examples", "n_train", "n_dev", "n_test", "max_target_len", "min_len", "max_len",
        "percentile25", "percentile50", "percentile75"]]

In [None]:
def plot_hist(task_name, n_bins=40):
    n, bins, patches = plt.hist(stats_df[stats_df.task_name == task_name].iloc[0]["all_lengths"], n_bins)
    plt.show()

plot_hist("dbpedia_14")

# Generate prompt prefix

k = 3 for QA tasks, `mocha` and `yelp_review_full`

k = 8 for the rest of the tasks

In [None]:
TASK_NAMES = ["wiki_bio"]
T5_MODEL = "t5-base"
MAX_INPUT_LEN = 1024
K = 3  # Number of demonstrations.
INSTRUCTIONS_FILE = "data/prompt/instructions_iosep.tsv"
OUTPUT_FILE = "data/prompt/prompt2.tsv"


# Read instructions data.
INSTRUCTIONS_DICT = {}
with open(INSTRUCTIONS_FILE) as fin:
    lines = fin.readlines()
for line in lines:
    splits = line.strip().split("\t")  # Splits into (task_name, instruction, input_output_separator).
    INSTRUCTIONS_DICT[splits[0]] = splits[1], splits[2]

In [None]:
import random
random.seed(0)

data = []
for task_name in TASK_NAMES:
    prefixes = get_task_prefixes("data/crossfit", task_name)
    for prefix in prefixes:
        # Get dev examples
        dev_examples = []
        with open(os.path.join("data/crossfit", task_name, prefix + "_dev.tsv")) as fin:
            lines = fin.readlines()
        for line in lines:
            d = line.strip().split("\t")
            dev_examples.append([d[0], d[1:]])

        # Construct prompt with demos and instructions
        demos = random.sample(dev_examples, K)
        instructions, iosep = INSTRUCTIONS_DICT[task_name]
        demos_text = " ".join(["{} {} {}".format(ex[0], iosep, random.choice(ex[1])) for ex in demos])
        prompt = instructions + " " + demos_text

        data.append([
            task_name, prefix, prompt, len(tokenizer(prompt)["input_ids"]), iosep
        ])

prompt_df = pd.DataFrame(
    data, columns=["task_name", "task_prefix", "prompt", "prompt_len", "io_sep"])
with pd.option_context("display.max_rows", None, "display.max_columns", None):
    display(prompt_df)

In [None]:
# We want at least 75% of the examples have length < 1024.
# Let's check the prompts which are too long. We can retry generating prompt for these tasks.
def f(row):
    max_input_len = stats_df[stats_df.task_name == row.task_name].iloc[0]["percentile75"]
    return row.prompt_len + max_input_len > MAX_INPUT_LEN

prompt_df.loc[prompt_df.apply(f, axis=1)]

In [None]:
# Save results.
prompt_df.to_csv(OUTPUT_FILE, index=False, sep="\t", header=None)