In [1]:
import argparse
import os
import csv
import json
import random
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset, Dataset, DatasetDict
from transformers import logging
from generate_utility import *
from utility import *
from transformers.generation import GenerationConfig
from peft import PeftModel
# import bitsandbytes as bnb

logging.set_verbosity_error()

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

checkpoint = "../models/aya-101"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to("cuda")

Loading checkpoint shards: 100%|██████████| 11/11 [01:42<00:00,  9.30s/it]


In [5]:
from sentence_transformers import SentenceTransformer, util
sent_model = SentenceTransformer("all-MiniLM-L6-v2").to("cuda")

In [14]:
sentences = [
    "The cat sits outside",
    "A man is playing guitar",
    "I love pasta",
    "The new movie is awesome",
    "The cat plays in the garden",
    "A woman watches TV",
    "The new movie is so great",
    "Do you like pizza?",
]

paraphrases = util.paraphrase_mining(sent_model, sentences, top_k=3, batch_size=4)

for paraphrase in paraphrases:
    score, i, j = paraphrase
    print("{} \t\t {} \t\t Score: {:.4f}".format(sentences[i], sentences[j], score))

The new movie is awesome 		 The new movie is so great 		 Score: 0.8939
The cat sits outside 		 The cat plays in the garden 		 Score: 0.6788
I love pasta 		 Do you like pizza? 		 Score: 0.5096
I love pasta 		 The new movie is so great 		 Score: 0.2560
I love pasta 		 The new movie is awesome 		 Score: 0.2440
A man is playing guitar 		 The cat plays in the garden 		 Score: 0.2105
The new movie is awesome 		 Do you like pizza? 		 Score: 0.1969
The new movie is so great 		 Do you like pizza? 		 Score: 0.1692
The cat sits outside 		 A woman watches TV 		 Score: 0.1310
The cat plays in the garden 		 Do you like pizza? 		 Score: 0.0900
The cat plays in the garden 		 A woman watches TV 		 Score: 0.0629
A woman watches TV 		 Do you like pizza? 		 Score: 0.0417
The cat sits outside 		 A man is playing guitar 		 Score: 0.0363
A man is playing guitar 		 Do you like pizza? 		 Score: 0.0116


In [11]:
fs_examp=get_few_shot_examples(dataset['train'])
fs_prompt=construct_prompt(fs_examp)

Filter: 100%|██████████| 400/400 [00:00<00:00, 7625.52 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 79732.04 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 78266.54 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 79246.21 examples/s]


In [3]:
def generate_result(prompts,gen_config,model_name='aya',bs=8):
    all_response=[]
    all_response_raw=[]
    end=len(prompts)
    for start in tqdm(range(0,end,bs)):
        stop=min(start+bs,len(prompts))
        if start<stop:
            prompts_batch=prompts[start:stop]
            encodings=tokenizer(prompts_batch, return_tensors="pt", padding='longest', truncation=False).to("cuda")
            with torch.no_grad():
                output_ids = model.generate(**encodings, **gen_config)
            responses=tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            for i,response_raw in enumerate(responses):
                sample_no=i+start
                if model_name!='aya':
                    response=response_raw[len(prompts[sample_no]):]
                    response=response.split("\n")[0].strip() if "\n" in response else response.strip()
                else:
                    response=response_raw[-1]
                all_response.append(response)
                all_response_raw.append(response_raw)
                
    return all_response_raw,all_response

In [239]:
def get_similar_sentences(corpus, queries,k=6):
    top_k = min(k, len(corpus))
    query_embeddings = sent_model.encode(queries, convert_to_tensor=True).to("cuda")
    corpus_embeddings = sent_model.encode(corpus, convert_to_tensor=True).to("cuda")
    corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
    query_embeddings = util.normalize_embeddings(query_embeddings)
    hits = util.semantic_search(query_embeddings, corpus_embeddings, top_k=top_k,
                                score_function=util.dot_score)
    all_hits={}
    for i,hit in enumerate(hits):
        all_hits[i]=[]
        for j in range(0,top_k,1):
            all_hits[i].append(hits[i][j]['corpus_id'])
    return all_hits


def get_data_simset(train_data,val_data):
    new_ds=[]
    new_ds_simset={}
    count=0
    for q in ['cause','effect']:
        q_examples_corpus = train_data.filter(lambda example:  
                                                 example["question"]==q)
        q_examples_query = val_data.filter(lambda example:  
                                                 example["question"]==q)
        print(len(q_examples_query))
        corpus=q_examples_corpus['premise']
        query=q_examples_query['premise']
        x=get_similar_sentences(corpus,query)
        for i,corp in enumerate(q_examples_query):
            new_ds.append(q_examples_query[i])
            # print(query[i],q_examples_query[i])
            new_ds_simset[count]=[]
            for j in x[i]:
                new_ds_simset[count].append(q_examples_corpus[j])
                # print(i,corpus[j],q_examples_corpus[j],end=',')
            # print('\n')
            count+=1
    return new_ds,new_ds_simset

In [240]:
gen_config = {
                "temperature": 0.7,
                "top_p": 0.1,
                "repetition_penalty": 1.18,
                "top_k": 40,
                "do_sample": True,
                "max_new_tokens": 5,
                "pad_token_id": tokenizer.eos_token_id
                    }

dataset=load_datasets("copa-en")

tokenizer.pad_token_id = tokenizer.eos_token_id

# all_prompt_examp={'train':[],'val':[]}
# all_prompt_labels={'train':[],'val':[]}

# for split in ['train','val']:
#     for row in dataset[split]:
#         fs_examp = get_similar_fewshot()
#         fs_prompt=construct_prompt(fs_examp)
        # prompt=(fs_prompt + "\n\n" + prompt_template.format(**row, correct_answer="")).strip()
        # all_prompt_examp[split].append(prompt)
        # all_prompt_labels[split].append(row['label'])

train size: 400
val size: 100


In [272]:
dataset['val']

Dataset({
    features: ['premise', 'choice1', 'choice2', 'question', 'label', 'idx'],
    num_rows: 100
})

In [273]:
new_ds,new_ds_simset = get_data_simset(dataset['train'],dataset['val'])

all_test_prompt=[]
all_test_label=[]
for i,d in enumerate(new_ds):
    # if i==1:
    #     break
    fs_prompt=test_construct_Prompt(new_ds_simset[i],6)
    prompt_ex=construct_single(d,fs_prompt)
    all_test_prompt.append(prompt_ex)
    all_test_label.append(d['label'])

100%|██████████| 1/1 [00:00<00:00, 144.52ba/s]
100%|██████████| 1/1 [00:00<00:00, 522.78ba/s]


52


100%|██████████| 1/1 [00:00<00:00, 196.46ba/s]
100%|██████████| 1/1 [00:00<00:00, 545.14ba/s]


48


In [274]:
print(all_test_prompt[93])

You are a helpful assistant whose goal is to select the correct output for a given instruction in english.


Instruction: Given the premise, ""I pushed the pendulum."", What is the correct effect after this?
A: It slowed to a stop.
B: It swung back and forth.
Correct effect: B

Instruction: Given the premise, ""I flipped the light switch up and down."", What is the correct effect after this?
A: The light faded.
B: The light flickered.
Correct effect: B

Instruction: Given the premise, ""I tipped the bottle."", What is the correct effect after this?
A: The liquid in the bottle froze.
B: The liquid in the bottle poured out.
Correct effect: B

Instruction: Given the premise, ""I wanted to lighten the mood of the conversation."", What is the correct effect after this?
A: I remained quiet.
B: I told a joke.
Correct effect: B

Instruction: Given the premise, ""The room was dim."", What is the correct effect after this?
A: I opened the blinds.
B: I washed the windows.
Correct effect: A

Instr

In [275]:
# all_test_label

In [276]:
all_response_raw,all_response=generate_result(all_test_prompt,
                                              gen_config,'aya')

100%|██████████| 13/13 [00:29<00:00,  2.24s/it]


In [277]:
eval(all_response,all_test_label)

0.85 85 100 100 [3, 4, 8, 21, 23, 27, 32, 35, 39, 41, 45, 68, 69, 73, 93] ['B', 'B', 'B', 'B', 'B', 'B', 'B', 'B', 'A', 'B', 'B', 'B', 'B', 'B', 'B']


0.85

In [18]:
# import pickle

# split='train'
# with open('../result/test_result.pickle','wb') as f:
#     pickle.dump(all_response,f)

# with open('../result/test_label.pickle','wb') as f:
#     pickle.dump(all_prompt_labels[split],f)

# with open('../result/test_prompt.pickle','wb') as f:
#     pickle.dump(all_prompt_examp[split],f)

In [20]:
all_not_true={}
count=0
for i,row in enumerate(dataset['train']):
    if all_response[i] in choices:
        if choices.index(all_response[i])!=all_prompt_labels['train'][i]:
            count+=1
            all_not_true[i]={
                'original':row,
                'test_label':all_prompt_labels['train'][i],
                'pred_label':choices.index(all_response[i])
            }

In [21]:
# fs_examp=get_few_shot_examples(dataset['train'])

In [271]:
lang='english'
choices=["A","B"]

preamble = f"""You are a helpful assistant whose goal is to select the correct output for a given instruction in {lang}."""

# preamble =""""""
# ---46%
# prompt_template_cause="""Instruction: Given the premise, ""{premise}"", What is the correct {question}?
# {question} A: {choice1}
# {question} B: {choice2}
# Correct {question}: {correct_answer}"""

# prompt_template_effect="""Instruction: Given the premise, ""{premise}"", What is the correct {question}?
# {question} A: {choice1}
# {question} B: {choice2}
# Correct {question}: {correct_answer}"""

# ---51%
prompt_template_cause="""Instruction: Given the premise, ""{premise}"", What is the correct {question} before this?
A: {choice1}
B: {choice2}
Correct {question}: {correct_answer}"""

prompt_template_effect="""Instruction: Given the premise, ""{premise}"", What is the correct {question} after this?
A: {choice1}
B: {choice2}
Correct {question}: {correct_answer}"""



In [227]:
def get_few_shot_examples(dataset, question,fs_per_label=2, seed=42):
    labels = list(set(dataset["label"]))
    few_shot_examples = []
    for label in labels:
        label_examples = dataset.filter(lambda example: example["label"] == label and example["question"]==question)
        # shuffle the examples
        label_examples = label_examples.shuffle(seed=seed)
        # get the first fs_per_label examples
        label_examples = label_examples.select(
            range(min(fs_per_label, len(label_examples)))
        )
        few_shot_examples += [example for example in label_examples]

    # Shuffle the few shot examples
    random.shuffle(few_shot_examples)
    return few_shot_examples

def test_construct_Prompt(ds_examples,min_ex=2):
    ds_examples=ds_examples[:min_ex]
    prompt_examples = "\n\n".join([ prompt_template_cause.format(**d,correct_answer=choices[int(d["label"])]) 
                                   if d["question"]=='cause' 
                                   else prompt_template_effect.format(**d,correct_answer=choices[int(d["label"])])
                                   for d in ds_examples])
    prompt_examples=preamble+"\n\n\n"+prompt_examples
    return prompt_examples

def construct_single(row,fs_prompt):
    if row['question']=='cause':
        prompt=(fs_prompt + "\n\n" + prompt_template_cause.format(**row, correct_answer="")).strip()
        # prompt=(prompt_template_cause.format(**row, correct_answer="")).strip()
    else:
        prompt=(fs_prompt + "\n\n" + prompt_template_effect.format(**row, correct_answer="")).strip()
        # prompt=( prompt_template_cause.format(**row, correct_answer="")).strip()
    return prompt

def eval(all_preds,all_true_labels):
    count=0
    ind_true=[]
    not_true=[]
    indx=[]
    for i,res in enumerate(all_preds):
        if res in choices:
            if choices.index(res)==all_true_labels[i]:
                count+=1
                ind_true.append(i)
            else:
                not_true.append(i)
                indx.append(res)
    acc=count/len(all_preds)
    print(acc, count, len(all_preds), len(all_true_labels), not_true,indx)
    return acc

In [57]:
print(all_response_raw_test)

['effect B', 'cause B', 'cause B', 'cause A', 'effect B', 'effect B', 'cause A', 'effect B', 'cause B', 'effect B', 'effect A', 'cause A', 'cause B', 'cause B', 'cause B', 'cause B', 'effect B', 'effect A', 'effect B', 'effect B', 'cause A', 'effect B', 'cause A', 'effect B', 'cause B', 'cause B', 'effect A', 'effect B', 'effect A', 'cause A', 'cause B', 'cause B', 'cause A', 'effect B', 'cause B', 'cause A', 'cause A', 'cause A', 'cause B', 'cause B', 'effect A', 'cause A', 'cause A', 'cause A', 'cause B', 'cause A', 'effect B', 'effect B: ""The', 'cause A', 'cause B', 'effect B', 'effect A', 'cause B', 'cause B', 'effect B']


In [58]:
all_acc=[]


for ii in range(0,4,1):
    all_test_prompt=[]
    all_test_label=[]
    fs_examp_cause=get_few_shot_examples(dataset['train'],'cause',fs_per_label=2,seed=ii)
    fs_examp_effect=get_few_shot_examples(dataset['train'],'effect',fs_per_label=2,seed=ii)
    fs_prompt_cause=test_construct_Prompt(fs_examp_cause)
    fs_prompt_effect=test_construct_Prompt(fs_examp_effect)

    count=0
    for i,row in all_not_true.items():
        if row['original']['question']=='casue':
            fs_prompt=fs_prompt_cause
        else:
            fs_prompt=fs_prompt_effect
        prompt_ex=construct_single(row['original'],fs_prompt)
        all_test_prompt.append(prompt_ex)
        all_test_label.append(row['original']['label'])
        count+=1
    all_response_raw_test,all_response_test=generate_result(all_test_prompt,
                                                  gen_config,'aya')

    acc=eval(all_response_test,all_test_label)
    all_acc.append(acc)
print(all_acc)
print(sum(all_acc)/len(all_acc))

Filter: 100%|██████████| 400/400 [00:00<00:00, 68152.97 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 77417.82 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 75556.03 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 80489.43 examples/s]
100%|██████████| 7/7 [00:12<00:00,  1.79s/it]


0.2909090909090909 16 55 55 [0, 4, 17, 22, 23, 24, 28, 29, 30, 41, 42, 43, 45, 48, 52, 53]


Filter: 100%|██████████| 400/400 [00:00<00:00, 66809.56 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 79509.10 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 79773.74 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 77475.02 examples/s]
100%|██████████| 7/7 [00:12<00:00,  1.85s/it]


0.4 22 55 55 [0, 4, 7, 9, 18, 22, 24, 26, 28, 29, 30, 31, 35, 41, 42, 43, 45, 48, 50, 52, 53, 54]


Filter: 100%|██████████| 400/400 [00:00<00:00, 66639.72 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 73497.24 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 80358.35 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 80975.03 examples/s]
100%|██████████| 7/7 [00:12<00:00,  1.84s/it]


0.41818181818181815 23 55 55 [0, 7, 9, 18, 19, 22, 24, 26, 28, 29, 30, 31, 35, 41, 42, 43, 45, 48, 50, 51, 52, 53, 54]


Filter: 100%|██████████| 400/400 [00:00<00:00, 55847.73 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 80051.61 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 78055.35 examples/s]
Filter: 100%|██████████| 400/400 [00:00<00:00, 77539.47 examples/s]
100%|██████████| 7/7 [00:12<00:00,  1.79s/it]

0.38181818181818183 21 55 55 [0, 7, 9, 17, 18, 22, 23, 24, 26, 28, 29, 30, 41, 42, 43, 45, 48, 50, 52, 53, 54]
[0.2909090909090909, 0.4, 0.41818181818181815, 0.38181818181818183]
0.37272727272727274





In [35]:
all_c=[0, 9, 10, 22, 24, 26, 28, 29, 30, 41, 42, 43, 45, 48, 50, 52, 53, 54]
for i in all_c:
    print(all_test_prompt[i][-10:])

ct effect:
ct effect:
ct effect:
ect cause:
ect cause:
ct effect:
ct effect:
ect cause:
ect cause:
ect cause:
ect cause:
ect cause:
ect cause:
ect cause:
ct effect:
ect cause:
ect cause:
ct effect:


In [36]:
i=0
print(all_test_prompt[i],all_test_label[i],all_response_test[i],)

You are a helpful assistant whose goal is to select the correct output for a given instruction in english.


Instruction: Given the premise, ""The host served dinner to his guests."", What is the correct effect after this?
effect A: ""His guests were gracious.""
effect B: ""His guests went hungry.""
Correct effect: A

Instruction: Given the premise, ""My foot went numb."", What is the correct effect after this?
effect A: ""I put my shoes on.""
effect B: ""I shook my foot.""
Correct effect: B

Instruction: Given the premise, ""The teacher caught the student chewing gum."", What is the correct effect after this?
effect A: ""The gum stuck to the student's shoe.""
effect B: ""The student spit out the gum.""
Correct effect: B

Instruction: Given the premise, ""The man lifted the heavy box."", What is the correct effect after this?
effect A: ""He put out his back.""
effect B: ""He scratched his back.""
Correct effect: A

Instruction: Given the premise, ""The elderly woman suffered a stroke."

In [10]:
import argparse
import os
import csv
import json
import random
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset, Dataset, DatasetDict,concatenate_datasets
from utility import *



In [27]:
def get_all_fs_general(dataset, count=4,seed=42):

    fs_examps=[]
    for q in ['cause','effect']:
        data_q = train_data.filter(lambda example:  
                                                 example["question"]==q)
        for i in range(0,len(data_q),1):
            start=i
            end=i+count
            if end>len(data_q):
                borrow=end-len(data_q)
                indices=list(range(start,len(data_q)))+list(range(0,borrow))
            else:
                indices=list(range(start,end))
            fs_examps.append(dataset.select(indices))
    return fs_examps  

def get_similar_sentences(corpus, queries,k=4):
    top_k = min(k, len(corpus))
    query_embeddings = sent_model.encode(queries, convert_to_tensor=True).to("cuda")
    corpus_embeddings = sent_model.encode(corpus, convert_to_tensor=True).to("cuda")
    corpus_embeddings = util.normalize_embeddings(corpus_embeddings)
    query_embeddings = util.normalize_embeddings(query_embeddings)
    hits = util.semantic_search(query_embeddings, corpus_embeddings, top_k=top_k,
                                score_function=util.dot_score)
    all_hits={}
    for i,hit in enumerate(hits):
        all_hits[i]=[]
        for j in range(0,top_k,1):
            all_hits[i].append(hits[i][j]['corpus_id'])
    return all_hits

def get_all_fs_similar(train_data):
    fs_examps=[]
    count=0
    for q in ['cause','effect']:
        q_examples_corpus = train_data.filter(lambda example:  
                                                 example["question"]==q)
        q_examples_query = train_data.filter(lambda example:  
                                                 example["question"]==q)
        print(len(q_examples_query))
        corpus=q_examples_corpus['premise']
        query=q_examples_query['premise']
        x=get_similar_sentences(corpus,query)
        for i,indices in x.items():
            fs_examps.append(q_examples_corpus.select(indices))
    return fs_examps

In [86]:
from get_prompts import *
if __name__ == "__main__":
    print("Hello, World!")
    
    
    dataset=load_datasets("copa-en")
    train_data=dataset['train']
    fs_examps_general=get_all_fs_general(train_data)
    fs_prompts_general=[construct_prompt_general(x) for x in fs_examps_general]
    
    from sentence_transformers import SentenceTransformer, util
    sent_model = SentenceTransformer("all-MiniLM-L6-v2").to("cuda")
    fs_examps_similar=get_all_fs_similar(train_data)
    fs_prompts_general=[construct_prompt_general(x) for x in fs_examps_similar]

Hello, World!
train size: 400
val size: 100


100%|██████████| 1/1 [00:00<00:00, 188.20ba/s]
100%|██████████| 1/1 [00:00<00:00, 206.33ba/s]
100%|██████████| 1/1 [00:00<00:00, 162.35ba/s]
100%|██████████| 1/1 [00:00<00:00, 201.79ba/s]


198


100%|██████████| 1/1 [00:00<00:00, 204.38ba/s]
100%|██████████| 1/1 [00:00<00:00, 208.35ba/s]


202


In [87]:
gen_config = {
                "temperature": 0.9,
                # "top_p": 0.1,
                # "repetition_penalty": 1.18,
                "top_k": 20,
                "do_sample": True,
                "max_new_tokens": 2062,
                "pad_token_id": tokenizer.eos_token_id
                    }
all_response_raw,all_response=generate_result(fs_prompts_general,
                                              gen_config,'aya')

100%|██████████| 50/50 [02:32<00:00,  3.05s/it]


In [91]:
all_formated=[]
for i,response in enumerate(all_response_raw):
    if not response.startswith('"correct'):
        premise=response.split(' "correct')[0]
        if "correct cause" in response:
            q="cause"
        elif "correct effect" in response:
            q="effect"
        if len(response.split('": '))==1:
            continue
        correct_q=response.split('": ')[1].split(' "')[0]
        wrong_q=response.split('": ')[-1]
        if i%2==0:
            label=0
            x={
            "premise":premise,
            "choice1":correct_q,
            "choice2":wrong_q,
            "question":q,
            "label":label,
            "idx":i
        }
        else:
            label=1
            x={
            "premise":premise,
            "choice1":wrong_q,
            "choice2":correct_q,
            "question":q,
            "label":label,
            "idx":i
        }


        all_formated.append(x)


In [92]:
import json
dest_file="../data/synthetic/copa-en-train-similar.jsonl"
output_file = open(dest_file, 'w', encoding='utf-8')
for dic in all_formated:
    json.dump(dic, output_file) 
    output_file.write("\n")

In [94]:
len(all_formated)

329