In [119]:
import torch
from torch.nn import CrossEntropyLoss
import faiss
import numpy as np
import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM,  LlamaForCausalLM
from transformers import GenerationConfig, LlamaTokenizer, LlamaForCausalLM
from datasets import load_dataset, concatenate_datasets
from evaluate import logging, load
torch.set_num_threads(24)

In [120]:
distil_qa = load_dataset('disfl_qa')
distil_qa

Found cached dataset disfl_qa (/home/kate/.cache/huggingface/datasets/disfl_qa/default/1.1.0/ff7a7331d2d842de6c0951a1525008cdb4116a5f8d82f990e158f15984a9d6d8)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['squad_v2_id', 'original question', 'disfluent question', 'title', 'context', 'answers'],
        num_rows: 7182
    })
    test: Dataset({
        features: ['squad_v2_id', 'original question', 'disfluent question', 'title', 'context', 'answers'],
        num_rows: 3643
    })
    validation: Dataset({
        features: ['squad_v2_id', 'original question', 'disfluent question', 'title', 'context', 'answers'],
        num_rows: 1000
    })
})

In [121]:
tweets_data = load_dataset('tweet_qa')
tweets_data

Found cached dataset tweet_qa (/home/kate/.cache/huggingface/datasets/tweet_qa/default/1.0.0/7d588f7f477946b10f60c035ca55175737315ac446102b015218af38d2638777)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['Question', 'Answer', 'Tweet', 'qid'],
        num_rows: 10692
    })
    test: Dataset({
        features: ['Question', 'Answer', 'Tweet', 'qid'],
        num_rows: 1979
    })
    validation: Dataset({
        features: ['Question', 'Answer', 'Tweet', 'qid'],
        num_rows: 1086
    })
})

In [122]:
squad = load_dataset("squad")
squad

Found cached dataset squad (/home/kate/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [123]:
squad2 = load_dataset("squad_v2")
squad2

Found cached dataset squad_v2 (/home/kate/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 130319
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 11873
    })
})

In [124]:
tweets_data = tweets_data.rename_column('Tweet', 'context')
tweets_data = tweets_data.rename_column('Question', 'question')
distil_qa = distil_qa.rename_column('disfluent question', 'question')
tweets_data

DatasetDict({
    train: Dataset({
        features: ['question', 'Answer', 'context', 'qid'],
        num_rows: 10692
    })
    test: Dataset({
        features: ['question', 'Answer', 'context', 'qid'],
        num_rows: 1979
    })
    validation: Dataset({
        features: ['question', 'Answer', 'context', 'qid'],
        num_rows: 1086
    })
})

In [127]:
dataset_to_store = concatenate_datasets([squad['train'], squad2['train'], squad2['validation'], 
                                        distil_qa['train'], distil_qa['test'], distil_qa['validation'],
                                        tweets_data['train'], tweets_data['test'], tweets_data['validation']]
                                            )
dataset_to_store

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers', 'squad_v2_id', 'original question', 'Answer', 'qid'],
    num_rows: 255373
})

In [128]:
yahoo = load_dataset('yahoo_answers_qa', split='train[:5000]')
yahoo

Found cached dataset yahoo_answers_qa (/home/kate/.cache/huggingface/datasets/yahoo_answers_qa/yahoo_answers_qa/1.0.0/62f63c2dc317317049c5a213c97370fe2989ead076488347df250a4b35da10d7)


Dataset({
    features: ['id', 'question', 'answer', 'nbestanswers', 'main_category'],
    num_rows: 5000
})

In [129]:
shuffled_qa = yahoo.shuffle(seed=1)[:500]
shuffled_squad = squad['validation'].shuffle(seed=1)[:500]

Loading cached shuffled indices for dataset at /home/kate/.cache/huggingface/datasets/yahoo_answers_qa/yahoo_answers_qa/1.0.0/62f63c2dc317317049c5a213c97370fe2989ead076488347df250a4b35da10d7/cache-4f3fc4cfd03544ac.arrow
Loading cached shuffled indices for dataset at /home/kate/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-6287bafd17ab2657.arrow


In [46]:
class knn_store():
    def __init__(self, data) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        self.max_embedding_len = self.tokenizer.max_model_input_sizes['distilbert-base-uncased']
        self.data = data
        self.keys = self.create_keys()
        self.values = self.create_values()
        
    def create_keys(self):
        questions = [q.strip() for q in self.data["question"]]
        keys = self.tokenizer(questions,
                                max_length=self.max_embedding_len,
                                truncation="only_second",
                                padding="max_length",
                                return_tensors='pt')
        return keys['input_ids']
    
    def create_values(self):
        context = [x for x in self.data['context']]
        values = self.tokenizer(context,
                            max_length=self.max_embedding_len,
                            truncation=True,
                            padding="max_length",
                            return_tensors='pt')
        return values['input_ids']
    
    def set_index(self):
        self.index = faiss.IndexFlatL2(self.max_embedding_len)
        keys_formatted = self.keys.detach().cpu().float().numpy()
        # faiss.normalize_L2(keys_formatted)
        self.index.add(keys_formatted) 
    
    def save_storage(self):
        faiss.write_index(self.index, "knn_index")
        torch.save(self.values, "contexts.t")
        
        
def build_storage(data):
    storage = knn_store(data)
    storage.set_index()
    storage.save_storage()
    
build_storage(dataset_to_store)

In [76]:
class knn():
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        self.max_embedding_len = self.tokenizer.max_model_input_sizes['distilbert-base-uncased']
        self.index = self.init_index()
        self.context = self.init_context()
    
    def init_index(self):
        return faiss.read_index("knn_index")
        
    def init_context(self):
        return torch.load("contexts.t")
    
    def search_knn(self, encoded_input, k = 1):
        encoded_input = encoded_input.detach().cpu().float().numpy()
        # faiss.normalize_L2(encoded_input)
        if len(encoded_input.shape) == 1:
            encoded_input = encoded_input[0:1]
            
        knn_distances, knn_indexes = self.index.search(encoded_input, k)
        return knn_distances, knn_indexes
    
    def encode_question(self, question):
        encoded_question = self.tokenizer(question, 
                                          max_length=self.max_embedding_len,
                                          truncation="only_second",
                                          padding="max_length",
                                          return_tensors='pt')
        return encoded_question['input_ids']
    
    def make_prompt(self, question, k = 1):
        encoded_question = self.encode_question(question)
        knn_distances, knn_indexes = self.search_knn(encoded_question, k)
        text_to_add = [self.tokenizer.decode(e, skip_special_tokens=True) for e in self.context[knn_indexes.tolist()]]
        return " ".join(text_to_add)
    
knn = knn()

In [5]:
device = 'cpu'

checkpoint = "chainyo/alpaca-lora-7b"
# checkpoint = "huggyllama/llama-7b"

model = LlamaForCausalLM.from_pretrained( checkpoint, low_cpu_mem_usage=True)

tokenizer = LlamaTokenizer.from_pretrained(checkpoint)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
model.eval();

Loading checkpoint shards:   0%|          | 0/39 [00:00<?, ?it/s]

In [6]:
def generate_prompt(instruction, input=None):  
    if input:
        return f"""Below is an instruction that describes an answer, paired with an input that provides context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:"""
    else:
        return f"""Below is an instruction that describes an answer. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:"""

In [None]:
meteor_score = load('meteor')
    
def compute_meteor(predictions, references):
    
    if type(predictions) == str:
        predictions = [predictions]
    if type(references) == str:
        references = [references]
        
    return meteor_score.compute(predictions=predictions, references=references)['meteor']

def compute_f1(pred, real):
        
    if len(pred) == 0 or len(real) == 0:
        return int(pred == real)
    
    common_tokens = set(pred)&set(real)
    if len(common_tokens) == 0:
        return 0
    
    prec = len(common_tokens) / len(pred)
    rec = len(common_tokens) / len(real)
    
    return prec, rec,  2 * prec * rec / (prec + rec)

In [94]:
def evaluate(
    instruction,
    right_answer,
    input=None,
    # temperature=1.0,
    temperature=0.1,
    # top_p=0.75,
    top_p = 0.95,
    top_k=40,
    num_beams=4,
    max_new_tokens=128,
    **kwargs,
):

    prompt = generate_prompt(instruction, input)
    inputs = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
    right_answer_tokenized = tokenizer(right_answer, padding=True, truncation=True, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    right_answer_tokenized = right_answer_tokenized['input_ids'].to(device)

    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        **kwargs,
    )
    
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
            early_stopping=True
        )
        
    s = generation_output.sequences[0]
    
    precision, recall, f1 = compute_f1(s.tolist(), *right_answer_tokenized.tolist())

    output = tokenizer.decode(s, skip_special_tokens=True)
    answer = output.split("### Response:")[1].strip()
    meteor = compute_meteor(answer, right_answer)

    return precision, recall, f1, meteor, answer

In [8]:
def compute_perplexity(
    predictions, batch_size: int = 16, max_length=128):
    device='cpu'
    
    model = LlamaForCausalLM.from_pretrained(
        "chainyo/alpaca-lora-7b",
        low_cpu_mem_usage=True
        )
    
    tokenizer = LlamaTokenizer.from_pretrained("chainyo/alpaca-lora-7b")
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))

    encodings = tokenizer(
        predictions,
        add_special_tokens=False,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
        return_attention_mask=True,
    ).to(device)

    encoded_texts = encodings["input_ids"]
    attn_masks = encodings["attention_mask"]

    ppls = []
    loss_fct = CrossEntropyLoss(reduction="none")
    
    for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
        end_index = min(start_index + batch_size, len(encoded_texts))
        encoded_batch = encoded_texts[start_index:end_index]
        attn_mask = attn_masks[start_index:end_index]

        labels = encoded_batch
        with torch.no_grad():
            output = model(encoded_batch, attention_mask=attn_mask)
            out_logits = output.logits

        shift_logits = out_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

        perplexity_batch = torch.exp(
            (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
            / shift_attention_mask_batch.sum(1)
        )

        ppls += perplexity_batch.tolist()

    return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}

In [81]:
for i in range(0, 10):
    print(i, f'q = {shuffled_squad["question"][i]}', f'pred = {squad_predictions[i]}', f'right = {shuffled_squad["answers"][i]}', sep='\n')

0
q = What is expected with the continuous input of sediment into the Dornbirner Ach?
pred = It is expected that the continuous input of sediment into the lake will silt up the lake. This has already happened to the former lake Tuggenersee.
right = {'text': ['silt', 'silt up the lake', 'the continuous input of sediment into the lake will silt up the lake', 'silt up the lake'], 'answer_start': [502, 502, 450, 502]}
1
q = Which of the three heavily populated areas has the least number of inhabitants?
pred = The area with the least number of inhabitants is the El Centro area.
right = {'text': ['San Diego', 'the San Diego area', 'San Diego'], 'answer_start': [793, 789, 793]}
2
q = How many of the following three fourth quarter drives after the field goal makng the score 16-10 ended in punts?
pred = Two of the three fourth quarter drives ended in punts.
right = {'text': ['three', 'three', 'The next three drives'], 'answer_start': [444, 444, 435]}
3
q = What would the latter Apollo missions 

In [102]:
d['answer'][0]

"A small group of politicians believed strongly that the fact that Saddam Hussien remained in power after the first Gulf War was a signal of weakness to the rest of the world, one that invited attacks and terrorism. Shortly after taking power with George Bush in 2000 and after the attack on 9/11, they were able to use the terrorist attacks to justify war with Iraq on this basis and exaggerated threats of the development of weapons of mass destruction. The military strength of the U.S. and the brutality of Saddam's regime led them to imagine that the military and political victory would be relatively easy."

In [None]:
squad_precisions = []
squad_recalls = []
squad_f1s = []
squad_predictions= []
squad_meteors = []
for i in tqdm.tqdm(range(0, 300)):
    context = knn.make_prompt(shuffled_squad['question'][i], k=1)
    precision, recall, f1, meteor, pr = evaluate(shuffled_squad['question'][i], shuffled_squad['answer'][i]['text'][0], context)
    squad_precisions.append(precision)
    squad_recalls.append(recall)
    squad_f1s.append(f1)
    squad_predictions.append(pr)
    squad_meteors.append(meteor)
    
squad_perplexity = compute_perplexity(squad_predictions)

In [86]:
import pickle
open_file = open('predictions_squad_shuffle_300_knn.pkl', "wb")
pickle.dump(squad_predictions, open_file)
open_file.close()

31.338759223620098