In [122]:
import datasets
import os
import openai
import numpy as np
with open(os.path.expanduser('~/.openai_api_key'), 'r') as file:
    openai.api_key = file.read().replace('\n', '')

import adatest
import re
import json
import jsonlines
import seqio
import os
os.environ['CURL_CA_BUNDLE'] = "/etc/ssl/certs/ca-bundle.crt"
from bigbench.bbseqio import tasks
vocabulary=seqio.SentencePieceVocabulary("/gscratch/zlab/bparan/projects/cascades/models/t5-spiece.model")
from sklearn.metrics import accuracy_score
from typing import List
# from utils.constants import OPENAI_API_KEY

import tqdm

with open(os.path.expanduser('~/.openai_api_key'), 'r') as file:
    openai.api_key = file.read().replace('\n', '')
print(openai.api_key)

cache_dir = '/gscratch/zlab/bparan/projects/cascades/data'

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


### GPT-3 Model for prompting

In [2]:
class OpenAIModel(adatest.Model):
    def __init__(self, model="text-davinci-002", quote="\"", temperature=0.7, top_p=1, max_length=30, n=1):
        self.model = model
        self.api_key = openai.api_key
        self.quote = quote
        self.temperature = temperature
        self.top_p = top_p
        self.max_length = max_length
        self.n = n
    def __call__(self, strings):
        resp = openai.Completion.create(
            model=self.model,
            prompt=strings,
            max_tokens=self.max_length,
            temperature=self.temperature,
            top_p=self.top_p,
            n=self.n,
            stop=self.quote,
        )
        return [x["text"] for x in resp['choices']]

gpt3 = OpenAIModel(model="text-davinci-002",  max_length=200, quote='', n=1)


### Prompt to propose an instruction

In [4]:
def propose_decomposition(decomp_prompt, io_pairs, n=20):
    gpt3 = OpenAIModel(model="text-davinci-002",  max_length=400, quote='---', n=n)
    prompt = '''%s. Here are examples of input-output pairs for the task I'm trying to break down.
----
%s
----
Steps:
1.'''%(decomp_prompt, io_pairs)
    return gpt3(prompt)

In [3]:
def propose_instruction(instruct_prompt, io_pairs, n=20):
    gpt3 = OpenAIModel(model="text-davinci-002",  max_length=400, quote='---', n=n)
    prompt = '''%s. Here are examples of input-output pairs for this task.
----
%s
----
I can do this task by'''%(instruct_prompt, io_pairs)
    return gpt3(prompt)

### Automatic Decomposition Helper functions

In [15]:
def chunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i + n]

In [16]:
def get_subset(inputs, labels, n=100):
    idxs = np.random.choice(len(inputs), n, replace=False)
    labs = np.array([labels[i] for i in idxs])
    subset = [inputs[i] for i in idxs]
    return labs, subset

# Tasks 

For each tasks, we compute:
* Best human decomposition performance over N runs: Known decomps or ones that we come up with. A further variant of this is (a) Decompositing into individual GPT-3 calls with few-shot prompting (decompositional prompting) and (b) Making and integrating external affordance calls when needed.
* Automatic instruction generation (APE): Reporting on top-K instructions. APE reports average over top-10 for 200 instructions. They also have an efficient score estimation technique whereby promising candidates (evaluated based on a small subset) are given more compute resource. 
* Automatic decomposition generation, followed by zero-shot application to downstream task. Reporting average performance over top-k decompositions
* 

Things to keep track of:
* Evaluation metric computation
* Generated sequence length 

#### Anachronisms

In [15]:
# Get data
d = datasets.load_dataset('bigbench', 'anachronisms')
inputs = d['train']['inputs'] + d['validation']['inputs']
inputs = [x.split('\n')[0] for x in inputs]
labels = np.array([int(x[0] == 'Yes') for x in d['train']['targets'] + d['validation']['targets']])



  0%|          | 0/3 [00:00<?, ?it/s]

In [16]:
# Human Decomp 
def anachronism(x):
    gpt3 = OpenAIModel(model="text-davinci-002",  max_length=200, quote='---', n=1)
    prompt = '''Given a sentence and the time periods of each entity in it, tell me if it could have happened or not.
Sentence: I wrote about shakespeare
Entities and dates:
I -> 21st century
Shakespeare -> 16th century
Could the sentence be true based on the dates alone: Yes
----
Sentence: Shakespeare wrote about me

Entities and dates:
Shakespeare -> 16th century
I -> 21st century

Could the sentence be true based on the dates alone: No
----
Sentence: %s''' % x
    return gpt3(prompt)

perf_array = []
runs = 2
for run in range(runs): 
    answers = []
    for x in inputs:
        answers.append(anachronism(x))
    preds = np.array([int(x[0].endswith('No')) for x in answers])
    perf_array.append((preds == labels).mean())
print("Human Performance:")
print("Mean", np.mean(perf_array))
print("Std. Dev", np.mean(perf_array))

Human Performance:
Mean 0.7826086956521738
Std. Dev 0.7826086956521738


In [25]:
# Automatic instruction runs.

instruct_prompt = 'I want to figure out whether a sentence contains anachronisms or not. An anachronism is a mistake in chronology, or a person, thing, or event that is out of its proper time.'
io_pairs = """Input: George Washington fought in the American Civil War.
Output: No
Input: The Mongolian horse rider used his bow to hunt the velociraptor.
Output: Yes
Input: Beats from the MPC3000 helped inspire many original blues artists.
Output: No
Input: Attila the Hun acted in the live-action remake of Mulan.
Output: Yes
Input: Kurt Cobain starred in the 1990 television show "Twin Peaks".
Output: Yes"""

instructions = propose_instruction(instruct_prompt, io_pairs, 50)

def get_anachronism_ape_fn(instruction, batch_size=10):
#     decomposition = '1.'+ decomposition
#     last_n = int(re.findall(r'(\d+)\.', decomposition)[-1])
#     decomposition += '\n%s. Output YES if there is an anachronism, and NO otherwise' % (last_n + 1)
    instruction = instruction.strip()
    def decomposition_ape_fn(sentences):
        gpt3 = OpenAIModel(model="text-davinci-002",  max_length=400, quote='---', n=1)
        out = []
        for chunk in chunks(sentences, batch_size):
            prompts = ['''An anachronism is a mistake in chronology, or a person, thing, or event that is out of its proper time. Figure out whether a sentence contains anachronisms or not, using this instruction.
Instruction:
%s
----
Sentence: %s
Is this an Anachronism? Output YES if there is an anachronism, and NO otherwise.''' % (instruction, x) for x in chunk]
            out.extend(gpt3(prompts))
        return out
    return decomposition_ape_fn

labs, subset = get_subset(inputs, labels, n=100)
all_preds = []
pps = []
accs = []
for z, instruction in enumerate(instructions):
    print('Instruction', z)
    fn = get_anachronism_ape_fn(instruction, batch_size=20)
    this_preds = fn(subset)
    pp = np.array([1 if 'yes' in x.lower() else 0 for x in this_preds])
    all_preds.append(this_preds)
    pps.append(pp)
    accs.append((pp==labs).mean())
    print((pp==labs).mean())
    


Instruction 0
0.63
Instruction 1
0.61
Instruction 2
0.5
Instruction 3
0.58
Instruction 4
0.63
Instruction 5
0.62
Instruction 6
0.62
Instruction 7
0.58
Instruction 8
0.66
Instruction 9
0.68


In [None]:
all_preds[0]

In [28]:
# Automatic decomposition runs

decomp_prompt = 'I want to break down the task of figuring out whether a sentence contains anachronisms or not, into individual steps. An anachronism is a mistake in chronology, or a person, thing, or event that is out of its proper time.'
decompositions = propose_decomposition(decomp_prompt, io_pairs, 10)

def get_anachronism_fn(decomposition, batch_size=10):
    decomposition = '1.'+ decomposition
    last_n = int(re.findall(r'(\d+)\.', decomposition)[-1])
#     decomposition += '\n%s. Output YES if there is an anachronism, and NO otherwise' % (last_n + 1)
    def decomposition_fn(sentences):
        gpt3 = OpenAIModel(model="text-davinci-002",  max_length=400, quote='---', n=1)
        out = []
        for chunk in chunks(sentences, batch_size):
            prompts = ['''Figure out whether a sentence contains anachronisms or not, using the following steps
Steps:
%s
----
Sentence: %s
Is this an Anachronism? Show me how you arrived at this answer step-wise. Output YES if there is an anachronism, and NO otherwise.''' % (decomposition, x) for x in chunk]
            out.extend(gpt3(prompts))
        return out
    return decomposition_fn


labs, subset = get_subset(inputs, labels, n=100)
preds = []
pps = []
accs = []
all_preds = []
for z, decomposition in enumerate(decompositions):
    print('Decomposition', z)
    fn = get_anachronism_fn(decomposition, batch_size=20)
    this_preds = fn(subset)
#     pp = np.array([1 if 'contains an anachronism' in x.lower() else 0 for x in this_preds])
    pp = np.array([1 if 'yes' in x.lower() else 0 for x in this_preds])
    preds.append(this_preds)
    pps.append(pp)
    accs.append((pp==labs).mean())
    print((pp==labs).mean())
    all_preds.append(this_preds)

Decomposition 0
0.55
Decomposition 1
0.66
Decomposition 2
0.59
Decomposition 3
0.57
Decomposition 4
0.54
Decomposition 5
0.64
Decomposition 6
0.6
Decomposition 7
0.59
Decomposition 8
0.66
Decomposition 9
0.59


#### Dataset from decomposed prompting (K'th letter concatenation)

In [41]:
# load data 
import urllib.request
url = 'https://raw.githubusercontent.com/allenai/DecomP/main/datasets/letter_cat/n5_eg100_pos2_space.json'
response = urllib.request.urlopen(url)
data = json.loads(response.read())
inputs = [d['question'] for d in data['1']['qa_pairs']]
labels = [d['answer']['spans'][0] for d in data['1']['qa_pairs']]
len(data['1']['qa_pairs'])
inputs[0]

'Take the letters at position 3 of the words in "Musa Haiying Schmidt Robinson Afzal" and concatenate them using a space.'

In [42]:
# manual decomposition
out = []
batch_size = 10
for chunk in tqdm.tqdm(chunks(inputs, batch_size)):
    prompts = [x for x in chunk]
#     print(prompts)
    out.extend(gpt3(prompts))
pp = np.array([1 if p.strip().lower() == l else 0 for p, l in zip(out, labels)])
pp.sum()/len(inputs)

10it [00:13,  1.36s/it]


0.0

#### Dataset from decomposed prompting (List reversal)

In [35]:
# load data 
import urllib.request
url = 'https://raw.githubusercontent.com/allenai/DecomP/main/datasets/reverse/test_10_normal_words.json'
response = urllib.request.urlopen(url)
data = json.loads(response.read())
inputs = [d['question'] for d in data['alg_qa']['qa_pairs']]
labels = [d['answer']['spans'][0] for d in data['alg_qa']['qa_pairs']]
# len(data['1']['qa_pairs'])
len(data['alg_qa']['qa_pairs'])


90

In [36]:
# manual decomposition
out = []
batch_size = 10
for chunk in tqdm.tqdm(chunks(inputs, batch_size)):
    prompts = [x for x in chunk]
#     print(prompts)
    out.extend(gpt3(prompts))
pp = np.array([1 if p.strip().lower() == l else 0 for p, l in zip(out, labels)])
pp.sum()/len(inputs)

0it [00:00, ?it/s]

['Reverse the sequence "banknote, sweet, phone card, identity card, credit card, case, passport, newspaper, painkiller, pen".', 'Reverse the sequence "bottle, passport, key, toothbrush, mobile phone, notebook, light bulb, tissue, packet, magazine".', 'Reverse the sequence "chewing gum, coin, driving licence, file, headphone, alarm clock, camera, rubbish, case, toothbrush".', 'Reverse the sequence "comb, sunscreen, key, postcard, packet, button, stamp, purse, photo, pen".', 'Reverse the sequence "rubber, banknote, watch, wallet, phone card, sweet, mirror, alarm clock, comb, tissue".', 'Reverse the sequence "magazine, tissue, headphone, stamp, file, banknote, passport, lipstick, diary, watch".', 'Reverse the sequence "glasses, rubbish, phone card, diary, wallet, tissue, laptop, toothbrush, battery, chewing gum".', 'Reverse the sequence "mirror, magazine, rubber, banknote, dictionary, case, pen, mobile phone, light bulb, tissue".', 'Reverse the sequence "comb, notebook, banknote, lipstick

1it [00:02,  2.07s/it]

['Reverse the sequence "key, sweet, credit card, watch, battery, chewing gum, player, sunscreen, passport, stamp".', 'Reverse the sequence "banknote, postcard, wallet, player, diary, identity card, clip, case, packet, magazine".', 'Reverse the sequence "laptop, chewing gum, painkiller, identity card, toothbrush, diary, scissors, bin, photo, sweet".', 'Reverse the sequence "sweet, lipstick, laptop, camera, water, alarm clock, driving licence, painkiller, identity card, tissue".', 'Reverse the sequence "diary, banknote, lipstick, stamp, phone card, laptop, battery, bottle, wallet, magazine".', 'Reverse the sequence "chewing gum, identity card, button, purse, brush, lighter, notebook, credit card, light bulb, photo".', 'Reverse the sequence "brush, water, match, passport, sweet, key, driving licence, laptop, sunscreen, umbrella".', 'Reverse the sequence "match, stamp, tissue, glasses, rubbish, magazine, coin, phone card, comb, key".', 'Reverse the sequence "light bulb, file, banknote, com

2it [00:04,  2.07s/it]

['Reverse the sequence "magazine, camera, pen, identity card, water, clip, button, purse, packet, headphone".', 'Reverse the sequence "phone card, sweet, player, newspaper, file, pencil, button, chewing gum, toothbrush, magazine".', 'Reverse the sequence "alarm clock, pencil, coin, chewing gum, player, scissors, wallet, magazine, cigarette, notebook".', 'Reverse the sequence "rubbish, diary, painkiller, toothbrush, pen, pencil, match, stamp, dictionary, lipstick".', 'Reverse the sequence "umbrella, sweet, mirror, light bulb, watch, mobile phone, pen, lighter, magazine, phone card".', 'Reverse the sequence "button, wallet, alarm clock, phone card, tissue, toothbrush, headphone, case, comb, bin".', 'Reverse the sequence "cigarette, laptop, banknote, umbrella, case, lighter, diary, headphone, toothbrush, coin".', 'Reverse the sequence "water, postcard, glasses, rubbish, comb, purse, rubber, wallet, magazine, laptop".', 'Reverse the sequence "glasses, magazine, light bulb, coin, bottle, wa

3it [00:06,  2.22s/it]

['Reverse the sequence "brush, postcard, coin, cigarette, match, lipstick, file, headphone, purse, water".', 'Reverse the sequence "mobile phone, battery, comb, magazine, chewing gum, headphone, lighter, rubber, alarm clock, passport".', 'Reverse the sequence "brush, notebook, water, credit card, button, dictionary, mirror, identity card, watch, camera".', 'Reverse the sequence "comb, camera, pen, postcard, light bulb, photo, pencil, coin, mirror, banknote".', 'Reverse the sequence "brush, magazine, rubbish, identity card, bin, player, headphone, phone card, notebook, watch".', 'Reverse the sequence "pencil, magazine, battery, passport, brush, watch, packet, notebook, clip, painkiller".', 'Reverse the sequence "scissors, file, wallet, rubbish, bottle, tissue, stamp, pen, coin, phone card".', 'Reverse the sequence "lighter, umbrella, match, scissors, bottle, key, newspaper, diary, mobile phone, stamp".', 'Reverse the sequence "stamp, light bulb, rubber, clip, banknote, player, pen, butt

4it [00:08,  2.00s/it]

['Reverse the sequence "alarm clock, scissors, postcard, rubber, comb, diary, bin, painkiller, notebook, sunscreen".', 'Reverse the sequence "rubber, bin, postcard, brush, pen, magazine, tissue, rubbish, passport, glasses".', 'Reverse the sequence "lighter, photo, postcard, laptop, banknote, coin, button, key, toothbrush, packet".', 'Reverse the sequence "pen, brush, painkiller, postcard, key, clip, umbrella, bin, water, purse".', 'Reverse the sequence "driving licence, credit card, watch, sweet, wallet, mirror, scissors, sunscreen, clip, alarm clock".', 'Reverse the sequence "photo, key, painkiller, dictionary, water, lighter, comb, rubbish, cigarette, coin".', 'Reverse the sequence "identity card, camera, rubbish, photo, watch, sunscreen, light bulb, file, coin, sweet".', 'Reverse the sequence "comb, sweet, notebook, dictionary, phone card, lipstick, pencil, clip, battery, packet".', 'Reverse the sequence "headphone, cigarette, key, purse, phone card, laptop, bin, button, wallet, wat

5it [00:09,  1.83s/it]

['Reverse the sequence "rubber, toothbrush, button, case, pencil, passport, diary, bottle, bin, notebook".', 'Reverse the sequence "toothbrush, match, brush, file, case, pencil, player, camera, laptop, wallet".', 'Reverse the sequence "diary, magazine, stamp, cigarette, player, light bulb, painkiller, phone card, coin, driving licence".', 'Reverse the sequence "pen, sunscreen, mirror, stamp, credit card, light bulb, chewing gum, dictionary, banknote, identity card".', 'Reverse the sequence "brush, credit card, driving licence, newspaper, rubbish, purse, mobile phone, scissors, bin, light bulb".', 'Reverse the sequence "magazine, mirror, stamp, match, player, pen, laptop, clip, photo, watch".', 'Reverse the sequence "chewing gum, pencil, button, camera, postcard, passport, dictionary, key, water, battery".', 'Reverse the sequence "file, clip, comb, water, key, pencil, postcard, bottle, diary, battery".', 'Reverse the sequence "case, mirror, stamp, credit card, dictionary, light bulb, he

6it [00:11,  1.80s/it]

['Reverse the sequence "alarm clock, toothbrush, rubbish, file, glasses, mirror, purse, scissors, light bulb, identity card".', 'Reverse the sequence "mirror, tissue, purse, driving licence, credit card, lighter, glasses, pencil, wallet, bin".', 'Reverse the sequence "sunscreen, wallet, diary, match, dictionary, glasses, comb, brush, mirror, credit card".', 'Reverse the sequence "headphone, wallet, file, mirror, identity card, alarm clock, rubbish, laptop, painkiller, tissue".', 'Reverse the sequence "coin, key, photo, light bulb, stamp, painkiller, clip, newspaper, scissors, purse".', 'Reverse the sequence "alarm clock, painkiller, tissue, rubbish, camera, newspaper, sunscreen, scissors, button, battery".', 'Reverse the sequence "umbrella, light bulb, photo, clip, phone card, alarm clock, battery, wallet, banknote, rubbish".', 'Reverse the sequence "bottle, identity card, alarm clock, laptop, rubbish, purse, watch, wallet, lighter, photo".', 'Reverse the sequence "camera, match, lipst

7it [00:13,  1.74s/it]

['Reverse the sequence "purse, battery, umbrella, magazine, notebook, chewing gum, mobile phone, pencil, file, stamp".', 'Reverse the sequence "bottle, clip, diary, camera, mirror, purse, cigarette, banknote, button, alarm clock".', 'Reverse the sequence "bin, driving licence, pen, passport, button, glasses, rubber, comb, notebook, tissue".', 'Reverse the sequence "magazine, packet, rubbish, clip, lighter, scissors, banknote, credit card, rubber, button".', 'Reverse the sequence "light bulb, sweet, bottle, scissors, alarm clock, newspaper, water, toothbrush, sunscreen, coin".', 'Reverse the sequence "pen, wallet, purse, watch, photo, scissors, light bulb, match, comb, coin".', 'Reverse the sequence "banknote, glasses, dictionary, painkiller, identity card, passport, bottle, light bulb, bin, chewing gum".', 'Reverse the sequence "driving licence, tissue, scissors, coin, dictionary, stamp, rubbish, laptop, pen, identity card".', 'Reverse the sequence "pencil, case, photo, postcard, lipst

8it [00:14,  1.74s/it]

['Reverse the sequence "watch, clip, file, phone card, painkiller, light bulb, camera, identity card, rubbish, sweet".', 'Reverse the sequence "packet, dictionary, driving licence, battery, photo, painkiller, laptop, mobile phone, purse, file".', 'Reverse the sequence "key, cigarette, umbrella, purse, identity card, clip, photo, headphone, bin, banknote".', 'Reverse the sequence "case, phone card, lighter, headphone, diary, comb, rubber, coin, stamp, toothbrush".', 'Reverse the sequence "newspaper, photo, postcard, watch, bin, coin, case, brush, sweet, key".', 'Reverse the sequence "purse, camera, banknote, lighter, diary, match, postcard, lipstick, phone card, tissue".', 'Reverse the sequence "case, bottle, sunscreen, light bulb, tissue, battery, stamp, glasses, mobile phone, cigarette".', 'Reverse the sequence "bottle, packet, water, painkiller, key, driving licence, cigarette, brush, banknote, identity card".', 'Reverse the sequence "pencil, laptop, painkiller, mobile phone, diction

9it [00:16,  1.81s/it]


0.6111111111111112

#### Tasks in Self-prompt

#### Tasks in Flipped learning - Known Unknown 

In [38]:
# load data
d = datasets.load_dataset('bigbench', 'known_unknowns', cache_dir=cache_dir)
inputs = d['train']['inputs'] + d['validation']['inputs']
# inputs = [x.split('\n')[0] for x in inputs]
labels = d['train']['targets'] + d['validation']['targets']



  0%|          | 0/3 [00:00<?, ?it/s]

#### Tasks in Flipped learning - Strategy QA

In [132]:
# load data
d = datasets.load_dataset('bigbench', 'strategyqa', cache_dir=cache_dir)
inputs = d['train']['inputs'] + d['validation']['inputs']
# inputs = [x.split('\n')[0] for x in inputs]
labels = d['train']['targets'] + d['validation']['targets']



  0%|          | 0/3 [00:00<?, ?it/s]

#### Tasks in Flipped learning - Strategy QA

#### Tasks in Auto-Cot - MAWPS 

In [105]:
data = datasets.load_dataset('omarxadel/MaWPS-ar', 'test', cache_dir=cache_dir)
inputs = [list(d.values())[0] for d in data['validation']]
labels = []
for d in data['validation']:
    try:
        ans = eval(list(d.values())[1].split("=")[-1].strip())
        if isinstance(ans, int):
            labels.append(ans)
        elif (ans).is_integer():
            labels.append(int(ans))
        else:
            labels.append(float("%.2f" % ans))
        
    except:
        ans = eval(list(d.values())[1].split("=")[0].strip())
        if isinstance(ans, int):
            labels.append(ans)
        elif (ans).is_integer():
            labels.append(int(ans))
        else:
            labels.append(float("%.2f" % ans))
    



  0%|          | 0/3 [00:00<?, ?it/s]

#### Tasks in Auto-CoT (GSM8K) 

In [112]:
data = datasets.load_dataset('gsm8k', 'main', cache_dir=cache_dir)['test']
inputs = [d['question'] for d in data]
labels = [d['answer'].split('#### ')[-1] for d in data]



  0%|          | 0/2 [00:00<?, ?it/s]

#### Tasks on Auto-CoT (AQUA-RAT)

In [120]:
data = datasets.load_dataset('aqua_rat', 'raw', cache_dir=cache_dir)['validation']
inputs = [d['question'] + " ".join(d['options']) for d in data]
labels = [d['correct'] for d in data]



  0%|          | 0/3 [00:00<?, ?it/s]

#### Tasks on Auto-CoT (Commonsense QA)

In [130]:
data = datasets.load_dataset('commonsense_qa', cache_dir=cache_dir)['validation']
inputs = [d['question']+ " " + " ".join([k + ") " + v for k, v in zip(d['choices']['label'], d['choices']['text'])]) for d in data]
labels = [d['answerKey'] for d in data]



  0%|          | 0/3 [00:00<?, ?it/s]

#### AMA Tasks (From Super-Glue they include boolQ, cb, copa, multirc, record, rte, wsc, WiC)

In [150]:
# BoolQ
data = datasets.load_dataset('super_glue', 'boolq', cache_dir=cache_dir)['validation']
inputs = [d['passage']+ " " + d['question'][0].title() + d['question'][1:]  + "?" for d in data]
label_dict = {0:'False', 1:'True'}
labels = [label_dict[d['label']] for d in data]
# Similar transformations to be made for other Superglue tasks: cb, copa, multirc, record, rte, wsc, wic



  0%|          | 0/3 [00:00<?, ?it/s]

#### AMA Tasks (From Adversarial NLI)

In [156]:
# Can also look at dev_r2, dev_r2
data = datasets.load_dataset('anli', cache_dir=cache_dir)['dev_r3']
inputs = ["Sentence1: " + d['premise'] + "\nSentence2: " +d['hypothesis'] for d in data]
label_dict = {0:"entailment", 1:'neutral', 2:'contradiction'}
labels = [label_dict[d['label']] for d in data]



  0%|          | 0/9 [00:00<?, ?it/s]

#### Other tasks from AMA (Classification) : Agnews, DBPedia, Amazon movie reivew and SST

In [164]:
## For others use the strings : dbpedia_14, sst2
data = datasets.load_dataset('ag_news', cache_dir=cache_dir)['test']
inputs = [d['text'] for d in data]
label_dict = {0:'World', 1:'Sports', 2:'Business', 3: 'Sci/Tech'}
labels = [label_dict[d['label']] for d in data]



  0%|          | 0/2 [00:00<?, ?it/s]

#### Tasks from reframing natural language instructions

In [165]:
labels

['Business',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'Sports',
 'Sports',
 'Sports',
 'Sports',
 'Sports',
 'Sports',
 'World',
 'World',
 'World',
 'World',
 'World',
 'World',
 'World',
 'World',
 'Sports',
 'Business',
 'World',
 'Sci/Tech',
 'Sports',
 'Sports',
 'World',
 'Sci/Tech',
 'World',
 'Sports',
 'World',
 'Sports',
 'World',
 'Sci/Tech',
 'Business',
 'Sci/Tech',
 'World',
 'World',
 'Business',
 'Business',
 'Sports',
 'Sports',
 'Sports',
 'Sci/Tech',
 'World',
 'Sci/Tech',
 'World',
 'World',
 'Sports',
 'World',
 'Sci/Tech',
 'Sci/Tech',
 'Sci/Tech',
 'World',
 'Sci/Tech',
 'Sports',
 'World',
 'Sports',
 'World',
 'World',
 'World',
 'Sports',
 'Business',
 'Business',
 'World',
 'Worl