# Medical Few-shot OpenQA

## Set-up

### General set-up

In [None]:
# !pip install -r requirements.txt

In [1]:
import collections
from contextlib import nullcontext
from collections import namedtuple
from datasets import load_dataset
import json
import numpy as np
import random
import re 
import string
import torch
from typing import List
import torch

In [2]:
seed = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Language model set-up

In [3]:
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
transformers.logging.set_verbosity_error()

### ColBERT set-up
The following will clone the ColBERTv2 repository for use in this notebook:

In [4]:
# Clone the repo
# !git clone -b cpu_inference https://github.com/stanford-futuredata/ColBERT.git

In [4]:
import os
import sys
sys.path.insert(0, 'ColBERT/')

from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.data import Collection
from colbert.searcher import Searcher
from utility.utils.dpr import has_answer, DPR_normalize

## Language models

In few-shot OpenQA, the language model (LM) must read in a prompt and answer the question posed somewhere in the prompt. 

### Answerhood

In [4]:
def _find_generated_answer(tokens, newline="\n" ): 
    """Our LMs tend to insert initial newline characters before
    they begin generating text. This function ensures that we 
    properly capture the true first line as the answer while
    also ensuring that token probabilities are aligned."""        
    answer_token_indices = []
    char_seen = False            
    for i, tok in enumerate(tokens):
        # This is the main condition: a newline that isn't an initial
        # string of newlines:
        if tok == newline and char_seen:
            break
        # Keep the initial newlines for consistency:
        elif tok == newline and not char_seen:
            answer_token_indices.append(i)
        # Proper tokens:
        elif tok != newline:
            char_seen = True
            answer_token_indices.append(i)
    return answer_token_indices 

### Eleuther models from Hugging Face

In [5]:
# "gpt-neo-125M" "gpt-neo-1.3B" "gpt-neo-2.7B" "gpt-j-6B"
eleuther_model_name = "gpt-neo-125M"

eleuther_tokenizer = AutoTokenizer.from_pretrained(
    f"EleutherAI/{eleuther_model_name}", 
    padding_side="left", 
    padding='longest', 
    truncation='longest_first', max_length=2000)
eleuther_tokenizer.pad_token = eleuther_tokenizer.eos_token

eleuther_model = AutoModelForCausalLM.from_pretrained(
    f"EleutherAI/{eleuther_model_name}")

In [6]:
def run_eleuther(prompts, temperature=0.1, top_p=0.95, **generate_kwargs): 
    """
    Parameters
    ----------
    prompts : iterable of str
    temperature : float
        It seems best to set it low for this task!
    top_p : float
       
    For options for `generate_kwargs`, see:
    
    https://huggingface.co/docs/transformers/master/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate
    
    Options that are likely to be especially relevant include 
    `temperature`, `length_penalty`, and the parameters that
    determine the decoding strategy. With `num_return_sequences > 1`,
    the default parameters in this function do multinomial sampling.
    
    Returns
    -------
    list of dicts
    
    {"prompt": str, 
     "generated_text": str, "generated_tokens": list of str, "generated_probs": list of float,
     "answer": str, "answer_tokens": list of str, "answer_probs": list of float
    }
         
    """
    prompt_ids = eleuther_tokenizer(
        prompts, return_tensors="pt", padding=True).input_ids
        
    with torch.inference_mode():
        # Automatic mixed precision if possible.
        with torch.cuda.amp.autocast() if torch.cuda.is_available() else nullcontext():
            model_output = eleuther_model.generate(
                prompt_ids,
                temperature=temperature,
                do_sample=True,
                top_p=top_p,           
                max_new_tokens=16,
                num_return_sequences=1,                
                pad_token_id=eleuther_tokenizer.eos_token_id, 
                return_dict_in_generate=True,
                output_scores=True,
                **generate_kwargs)
        
    # Converting output scores using the helpful recipe here:
    # https://discuss.huggingface.co/t/generation-probabilities-how-to-compute-probabilities-of-output-scores-for-gpt2/3175
    gen_ids = model_output.sequences[:, prompt_ids.shape[-1] :]
    gen_probs = torch.stack(model_output.scores, dim=1).softmax(-1)
    gen_probs = torch.gather(gen_probs, 2, gen_ids[:, :, None]).squeeze(-1)
    
    # Generated texts, including the prompts:
    gen_texts = eleuther_tokenizer.batch_decode(
        model_output.sequences, skip_special_tokens=True)
    
    data = []     
    iterator = zip(prompts, gen_ids, gen_texts, gen_probs)    
    for prompt, gen_id, gen_text, gen_prob in iterator:       
        gen_tokens = eleuther_tokenizer.convert_ids_to_tokens(gen_id)
        generated_text = gen_text[len(prompt): ]
        gen_prob = [float(x) for x in gen_prob.numpy()] # float for JSON storage
        ans_indices = _find_generated_answer(gen_tokens, newline="Ċ")
        answer_tokens = [gen_tokens[i] for i in ans_indices]
        answer_probs = [gen_prob[i] for i in ans_indices]
        answer = "".join(answer_tokens).replace("Ġ", " ").replace("Ċ", "\n")                                       
        data.append({
            "prompt": prompt,
            "generated_text": generated_text,
            "generated_tokens": gen_tokens,
            "generated_probs": gen_prob,
            "generated_answer": answer,
            "generated_answer_probs": answer_probs,
            "generated_answer_tokens": answer_tokens})                        

    return data

In [9]:
# %%time
# ## test run

# eleuther_ex = run_eleuther([    
#     "What year was Stanford University founded?", 
#     "In which year did Stanford first enroll students?"])

# eleuther_ex

## Dataset Loading


### SQuAD

In [4]:
squad = load_dataset("squad")

Reusing dataset squad (/home/zhanj289/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

The following utility just reads a SQuAD split in as a list of `SquadExample` instances:

In [5]:
SquadExample = namedtuple("SquadExample",  "id title context question answers")

In [6]:
def get_squad_split(squad, split="validation"):
    """
    Use `split='train'` for the train split.
    
    Returns
    -------
    list of SquadExample named tuples with attributes
    id, title, context, question, answers
    
    """    
    fields = squad[split].features
    data = zip(*[squad[split][field] for field in fields])
    return [SquadExample(eid, title, context, question, answers["text"]) 
            for eid, title, context, question, answers in data]

In [7]:
## Split Dev and Train

In [8]:
fields = squad['validation'].features
data = zip(*[squad['validation'][field] for field in fields])

In [9]:
squad_dev = get_squad_split(squad)

In [10]:
squad_dev[100]

SquadExample(id='56d602631c85041400946edb', title='Super_Bowl_50', context='CBS broadcast Super Bowl 50 in the U.S., and charged an average of $5 million for a 30-second commercial during the game. The Super Bowl 50 halftime show was headlined by the British rock group Coldplay with special guest performers Beyoncé and Bruno Mars, who headlined the Super Bowl XLVII and Super Bowl XLVIII halftime shows, respectively. It was the third-most watched U.S. broadcast ever.', question='Who were special guests for the Super Bowl halftime show?', answers=['Beyoncé and Bruno Mars', 'Beyoncé and Bruno Mars', 'Beyoncé and Bruno Mars'])

In [11]:
dev_exs = sorted(squad_dev, key=lambda x: hash(x.id))[: 200]

In [12]:
squad_train = get_squad_split(squad, "train")

In [13]:
squad['train']

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 87599
})

### BioASQ

In [7]:
# with open('./data/bioasq/squad.json', 'r') as f:
#     squad_test = json.load(f)

In [8]:
# pick all factoid questions but ignore all else

In [9]:
with open('./data/bioasq/training10b.json', 'r') as f:
    bioasq_json = json.load(f)

In [10]:
# bioasq_json['questions'][0]['snippets']

In [11]:
# text_dict = {}

# for snip in bioasq_json['questions'][0]['snippets']:
#     if snip['beginSection'] == 'abstract':
#         for k in range(snip['offsetInBeginSection'], snip['offsetInEndSection']):
#             text_dict[k] = snip['text'][k- snip['offsetInBeginSection']]

In [12]:
# recon_text = ''
# for key in sorted(text_dict.keys()):
#     recon_text += text_dict[key]

In [13]:
# recon_text

In [14]:
# bioasq_json['questions'][1]

In [15]:
### Construct dataset
count_factoid = 0
count_list =0
count_summary=0
count_yesno =0

bioasq_list= []

for i in range(len(bioasq_json['questions'])):
    
    sample = bioasq_json['questions'][i]
    
    if sample['type'] == 'summary':
            count_summary += 1
    if sample['type'] == 'yesno':
            count_yesno += 1
    
    if sample['type'] in ['factoid', 'list']:
        
    #  Context
    ## flatten all the snippet, conccatenate and use as context
        context = '' 
        for snip in [ele['text'].strip() for ele in sample['snippets']]:
            snip += ' '
            context += snip
            
        context = context.replace('\n', ' ')
        
        ## limit the length of context
        ### Max: 4096 (for eleuther model)
        context = context[:1024]
        
        # question
        question = sample['body']
        question = question.replace('\n', ' ')
        
        # answer:
        ## deal with factoid question and list question differently
        if sample['type'] == 'factoid':
            answer = sample['exact_answer']
            count_factoid += 1
        
        if sample['type'] == 'list':
            answer = [x for y in sample['exact_answer'] for x in y]
            count_list += 1
        

        # construct a QA pairs like SQUAD
        bioasq_list.append({
            'id': i,
            'context': context,
            'question': sample['body'],
            'answers': answer,
            'type': sample['type']
        }) 

print(f'we have {count_factoid} factoid questions, {count_list} list questions, {count_summary} summary questions, {count_yesno} yesno qquestions')   

print(f'total is {count_factoid +count_list+ count_summary +count_yesno}')

we have 1252 factoid questions, 816 list questions, 1018 summary questions, 1148 yesno qquestions
total is 4234


In [16]:
len(bioasq_list)

2068

In [17]:
from sklearn.model_selection import train_test_split
def get_bioasq_split(bioasq_list, random_state):
    """
    
    Returns
    -------
    list of example named tuples with attributes
    id, title, context, question, answers
    
    """
    BioasqExample = namedtuple("BioasqExample",  "id context question answers")
    
    bioasq_data = [BioasqExample(ele['id'], ele['context'], ele['question'], ele['answers']) for ele in bioasq_list]
    
    bioasq_train, _ = train_test_split(bioasq_data, test_size=0.9, random_state=random_state)

    bioasq_dev, bioasq_test = train_test_split(_, test_size=0.8888, random_state=random_state)
    
    return bioasq_train, bioasq_dev, bioasq_test

In [18]:
## split dev and test

bioasq_train, bioasq_dev, bioasq_test = get_bioasq_split(bioasq_list, random_state=40)

In [19]:
print(f"{len(bioasq_train)}, {len(bioasq_dev)}, {len(bioasq_test)} ")

206, 207, 1655 


In [20]:
## pick 10 just for sanity check
dev_exs = bioasq_dev[:10]

In [21]:
dev_exs[0].id

829

## Evaluation

Our evaluation protocols are the standard ones for SQuAD and related tasks: exact match of the answer (EM) and token-level F1.

We say further that the predicted answer is the first line of generated text after the prompt.

The following evaluation code is taken from the [apple/ml-qrecc](https://github.com/apple/ml-qrecc/blob/main/utils/evaluate_qa.py) repository. It performs very basic string normalization before doing the core comparisons.

In [22]:
def normalize_answer(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def get_tokens(s: str) -> List[str]:
    """Normalize string and split string into tokens."""
    if not s:
        return []
    return normalize_answer(s).split()


def compute_exact(a_gold: str, a_pred: str) -> int:
    """Compute the Exact Match score."""
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))


def compute_f1_from_tokens(gold_toks: List[str], pred_toks: List[str]) -> float:
    """Compute the F1 score from tokenized gold answer and prediction."""
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())

    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        f1= int(gold_toks == pred_toks)
        precision = int(gold_toks == pred_toks)
        recall = int(gold_toks == pred_toks)
        
     # if no token overlap at all, all metrics is 0
    if num_same == 0: 
        f1= int(gold_toks == pred_toks)
        precision = int(gold_toks == pred_toks)
        recall = int(gold_toks == pred_toks)
    
    else:
        precision = 1.0 * num_same / len(pred_toks)
        recall = 1.0 * num_same / len(gold_toks)
        f1 = (2 * precision * recall) / (precision + recall)
    return f1, precision, recall


def compute_f1(a_gold: str, a_pred: str) -> float:
    """Compute the F1 score."""
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    return compute_f1_from_tokens(gold_toks, pred_toks)

The following is our general evaluation function. We will make extensive use of it to evaluate different systems:

In [23]:
def evaluate(examples, prompts, gens):
    """Generic evalution function.
    
    Parameters
    ----------
    examples: iterable of `SquadExample` instances
    prompts: list of str
    preds: list of LM-generated texts to evaluate as answers
    
    Returns
    -------
    dict with keys "em_per", "macro_f1", "examples", where
    each "examples" value is a dict
    
    """        
    results = []
    for ex, prompt, gen in zip(examples, prompts, gens):
        answers = ex.answers
        pred = gen['generated_answer']
        # The result is the highest EM from the available answer strings:
        em = max([compute_exact(ans, pred) for ans in answers])
        
        # adding precision and recall
        # print([compute_f1(ans, pred) for ans in answers])
        f1 = max([compute_f1(ans, pred)[0] for ans in answers])
        precision = max([compute_f1(ans, pred)[1] for ans in answers])
        recall = max([compute_f1(ans, pred)[2] for ans in answers])
        
        gen.update({
            "id": ex.id, 
            "question": ex.question, 
            "prediction": pred, 
            "answers": answers, 
            "em": em,
            "f1": f1,
            "precision": precision,
            "recall": recall
        })
        results.append(gen)
    data = {}        
    data["macro_f1"] = np.mean([d['f1'] for d in results])
    data["macro_precision"] = np.mean([d['precision'] for d in results])
    data["macro_recall"] = np.mean([d['recall'] for d in results])
    data["em_per"] = sum([d['em'] for d in results]) / len(results)
    data["examples"] = results
    return data

Here is a highly simplified example to help make the logic behind `evaluate` clearer:    

In [24]:
ex = namedtuple("SquadExample",  "id title context question answers")

examples = [
    ex("0", "CS224u", 
       "The course to take is NLU!", 
       "What is the course to take?", 
       ["NLU", "CS224u"])]

prompts = ["Dear model, Please answer this question!\n\nQ: What is the course to take?\n\nA:"]

gens = [{"generated_answer": "course on NLU", "generated_text": "NLU\nWho am I?"}]

evaluate(examples, prompts, gens)

{'macro_f1': 0.5,
 'macro_precision': 0.3333333333333333,
 'macro_recall': 1.0,
 'em_per': 0.0,
 'examples': [{'generated_answer': 'course on NLU',
   'generated_text': 'NLU\nWho am I?',
   'id': '0',
   'question': 'What is the course to take?',
   'prediction': 'course on NLU',
   'answers': ['NLU', 'CS224u'],
   'em': 0,
   'f1': 0.5,
   'precision': 0.3333333333333333,
   'recall': 1.0}]}

The bake-off uses `macro_f1` as the primary metric.

## ColBERT

In [67]:
index_home = os.path.join("experiments", "notebook", "indexes")

### ColBERT parameters

In [23]:
if not os.path.exists(os.path.join("data", "openqa", "colbertv2.0.tar.gz")):
    !mkdir -p data/openqa
    # ColBERTv2 checkpoint trained on MS MARCO Passage Ranking (388MB compressed)
    !wget https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz -P data/openqa/
    !tar -xvzf data/openqa/colbertv2.0.tar.gz -C data/openqa/

If something went wrong with the above, you can just download the file https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz, unarchive it, and move the resulting `colbertv2.0` directory into the `data/openqa` directory.

### ColBERT index

Here we use our created index for bioasq passages

In [24]:
index_home = './experiments/bioasq/indexes'

collection = os.path.join(index_home, "bioasq.all.2bits", "bioasq_passage.tsv")

collection = Collection(path=collection)

f'Loaded {len(collection):,} passages'

[Jun 05, 01:40:54] #> Loading collection...
0M 


'Loaded 2,068 passages'

In [25]:
index_name = "bioasq.all.2bits"

Now we create our `searcher`:

In [26]:
with Run().context(RunConfig(experiment='bioasq')):
    searcher = Searcher(index=index_name)

[Jun 05, 01:40:55] #> Loading collection...
0M 
[Jun 05, 01:41:04] #> Building the emb2pid mapping..
[Jun 05, 01:41:04] len(self.emb2pid) = 378124


In [27]:
len(searcher.collection)

2068

### Retrieval evaluation

For more rigorous evaluations of the retriever alone, we can use Sucess@`k` defined relative to the SQuAD passages and answers. We say that we have a "success" if a passage in the top `k` retrieved passages contains any of the answers substrings, and Sucess@`k` is the percentage of such success cases. This is very heuristic (perhaps the answer string happens to occur somewhere in a completely irrelevant passage), but it can still be good guidance.

In [74]:
def success_at_k(examples, k=20):
    scores = []
    for ex in examples: 
        scores.append(evaluate_retrieval_example(ex, k=5))
    return sum(scores) / len(scores)
        
    
def evaluate_retrieval_example(ex, k=20):    
    results = searcher.search(ex.question, k=k)
    for passage_id, passage_rank, passage_score in zip(*results):
        passage = searcher.collection[passage_id]
        score = has_answer([DPR_normalize(ans) for ans in ex.answers], passage)
        if score:
            return 1
    return 0

Here is Sucess@20 for the SQuAD dev set:

In [67]:
%%time
if torch.cuda.is_available():
    # This will take a few hours on a CPU:
    print(success_at_k(bioasq_dev))
else:
    # This should be reasonably fast and yields the
    # same kind of result:
    print(success_at_k(bioasq_dev))


#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
#> Input: . Which gene harbors the mutation T790M?, 		 True, 		 None
#> Output IDs: torch.Size([32]), tensor([  101,     1,  2029,  4962,  6496,  2015,  1996, 16221,  1056,  2581,
        21057,  2213,  1029,   102,   103,   103,   103,   103,   103,   103,
          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
          103,   103])
#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

0.7681159420289855
CPU times: user 35.4 s, sys: 1.45 s, total: 36.9 s
Wall time: 6.95 s


### Few-shot OpenQA 


In [25]:
def build_few_shot_open_qa_prompt(question, passage, train_exs, joiner="\n\n"):
    """Few-shot OpenQA prompts.

    Parameters
    ----------
    question : str
    passage : str
        Presumably something retrieved via search.
    train_exs : iterable of SQuAD train examples
        These can be obtained via a random sample from 
        `squad_train` as defined above.
    joiner : str
        The character to use to join pieces of the prompt 
        into a single str.

    Returns
    -------
    str, the prompt

    """
    ##### YOUR CODE HERE
    passage_context = passage
    
    segs = []

    for t in train_exs:
        segs += [
            # f"Title: {t.title}",
            f"Background: {t.context}",
            f"Q: {t.question}",
            f"A: {t.answers[0]}"
        ]
    segs += [
            # f"Title: {passage_title}",
            f"Background: {passage_context}",
            f"Q: {question}",
            f"A:"
    ]
    return joiner.join(segs)


In [26]:
def evaluate_few_shot_open_qa(
        examples,
        squad_train,
        batch_size=20,
        n_context=2,
        joiner="\n\n",
        gen_func=run_eleuther):
    """Evaluate a few-shot OpenQA approach defined by 
    `build_few_shot_open_qa_prompt` and `gen_func`.

    Parameters
    ----------
    examples : iterable of SQuAD train examples
        Presumably a subset of `squad_dev` as defined above.
    squad_train : iterable of SQuAD train examples
    batch_size : int
        Number of examples to send to `gen_func` at once.
    joiner : str
        Used by `build_few_shot_open_qa_prompt` to join segments
        of the prompt into a single str.
    gen_func : either `run_eleuther` or `run_gpt3`

    Returns
    -------
    dict as determined by `evaluate` above.

    """
    # A list of strings that you build and feed into `gen_func`.
    prompts = []

    # A list of dicts that you get from `gen_func`.
    gens = []

    # Iterate through the examples in batches:
    for i in range(0, len(examples), batch_size):
        # Use the `searcher` defined above to get passages
        # using `ex.question` as the query, and use your
        # `build_few_shot_open_qa_prompt` to build prompts.

        ##### YOUR CODE HERE
        
        batch = examples[i: i+batch_size]

        # sample training from squad_train
        train_exs = random.sample(squad_train, k=n_context)

        ## get a passage for each example in the dev batch
        # get search results (passage index)
        results = [searcher.search(ex.question, k=1) for ex in batch]

        # from passage index to get the passage 'title | passage'
        passages = [searcher.collection[r[0][0]] for r in results]
 
        ps = []

        # for every question, combine the find passage and generate the prompt
        # append all prompt into a list
        for ex, psg in zip(batch, passages):
            ps.append(build_few_shot_open_qa_prompt(ex.question, psg, train_exs, joiner=joiner))  

        # feed prompt to gen_func
        gs = gen_func(ps)       

        # add the prompt to prompt list
        prompts += ps
        # add generated txt to gen list
        gens += gs


    # Return value from a call to `evalaute`, with `examples`
    # as provided by the user and the `prompts` and `gens`
    # you built:
    return evaluate(examples, prompts, gens)

### Answer scoring

In [35]:
def get_passages_with_scores(question, k=5):
    """Pseudo-probabilities from the retriever.

    Parameters
    ----------
    question : str
    k : int
        Number of passages to retrieve.

    Returns
    -------
    passages (list of str), passage_probs (np.array)

    """
    # Use the `searcher` to get `k` passages for `questions`:
    ##### YOUR CODE HERE
    search_score = searcher.search(question, k = k)[2]
    passage_index = searcher.search(question, k = k)[0]

    # Softmax normalize the scores and convert the list to
    # a NumPy array:
    ##### YOUR CODE HERE
    exp_score = np.exp(search_score)
    sum_score = np.sum(exp_score) 
    passage_probs = np.array([score/sum_score for score in exp_score] )

    # Get the passages as a list of texts:
    ##### YOUR CODE HERE

    passages = [searcher.collection[idx] for idx in passage_index]

    return passages, passage_probs



In [36]:
from types import GeneratorType
def answer_scoring(passages, passage_probs, prompts, gen_func=run_eleuther):
    """Implements our basic scoring strategy.

    Parameters
    ----------
    passages : list of str
    passage_probs : list of float
    prompts : list of str
    gen_func : either `run_eleuther` or `run_gpt3`

    Returns
    -------
    list of pairs (score, dict), sorted with the largest score first.
    `dict` should be the return value of `gen_func` for an example.

    """
    data = []
    for passage, passage_prob, prompt in zip(passages, passage_probs, prompts):
        # Run `gen_func` on [prompt] (crucially, the singleton list here),
        # and get the dictionary `gen` from the singleton list `gen_func`
        # returns, and then use the values to score `gen` according to our
        # scoring method.
        #
        # Be sure to use "generated_answer_probs" for the scores.
        ##### YOUR CODE HERE

        gen = gen_func([prompt])

        # print(gen)
        
        answer_score = np.prod(gen[0]['generated_answer_probs'])

        final_score = passage_prob*answer_score
        
        data.append((final_score, gen[0]))


    # Return `data`, sorted with the highest scoring `(score, gen)`
    # pair given first.
    ##### YOUR CODE HERE
    data.sort(key = lambda x: x[0], reverse=True)

    return data


#### Baseline System (ColBERT straight output + Eleuther)

In [92]:
%%time

working_dataset = bioasq_test # bioasq_test vs dev_exs

batch_size = 5
joiner = '\n\n'
# number of prompts
n_context = 2

# temperatures = [0.01, 0.025, 0.05, 0.075]
temperatures = [0.025]
# bioasq_dev
# bioasq_train

for temperature in temperatures:
    
    prompts = []

    gens = []

    for i in range(0, len(working_dataset), batch_size):
        # Use the `searcher` defined above to get passages
        # using `ex.question` as the query, and use your
        # `build_few_shot_open_qa_prompt` to build prompts.
        
        # get a batch from bioasq dev (to replace dev_exs)
        batch = working_dataset[i: i+batch_size]

        train_exs = random.sample(bioasq_train, k=n_context)

        ## get a passage for each example in the dev batch
        # get search results (passage index) for all examples in the batch
        # k = 1 because we choose the top result
        results = [searcher.search(ex.question, k=1) for ex in batch]

        # from passage index to get the passage 'title | passage'
        passages = [searcher.collection[r[0][0]] for r in results]

        # re-initiating prompt
        ps = []

        # for every question in the batch, combine the train_exs (background + q +a) + found passage + question and generate the prompt
        # append all prompt into a list
        
        for ex, psg in zip(batch, passages):
            ps.append(build_few_shot_open_qa_prompt(ex.question, psg, train_exs, joiner=joiner))  

        # feed prompts (in list of prompts) to gen_func
        gs = run_eleuther(ps)       

        # add the prompt to prompt list
        prompts += ps
        # add generated txt to gen list
        gens += gs
    
    eva = evaluate(working_dataset, prompts, gens)
    # print(eva)
    print(f"""
          temperature {temperature} 
          Macro F1 is: {eva['macro_f1']}， 
          Exact Match: {eva['em_per']}， 
          Macro Precision is: {eva['macro_precision']},
          Macro Recall is: {eva['macro_recall']},
          """)



          temperature 0.05 
          Macro F1 is: 0.11406956518107561， 
          Exact Match: 0.024773413897280966， 
          Macro Precision is: 0.10246440534960173,
          Macro Recall is: 0.20589445145434757,
          
CPU times: user 1h 50min 6s, sys: 4min 15s, total: 1h 54min 21s
Wall time: 15min 27s


#### ColBERT improvement (normalization/answer scoring) + Eleuther

In [101]:
%%time


######## This part is functional modules ############
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

#### enhanced squad training example searching

def train_tf_idf(bioasq_train):
    tfidfvectorizer = TfidfVectorizer(analyzer='word',stop_words= 'english', ngram_range=(1, 3))

  # append all context
    train_context = [x.context for x in bioasq_train]

    tfidfvectorizer.fit_transform(train_context)

    context_tfidf = tfidfvectorizer.transform(train_context)

    return tfidfvectorizer, context_tfidf

def sample_bioasq_train(tfidfvectorizer, context_tfidf, question, n_context):
    '''
    This is using tf-idf and consine similarity to sample "related to question" bioasq example to build the prompt
    '''
    question_tfidf = tfidfvectorizer.transform([question])

    cosine_sim = cosine_similarity(context_tfidf, question_tfidf).flatten()

    related_index = cosine_sim.argsort()[-n_context:][::-1]

    train_exs = [bioasq_train[i] for i in related_index]

    return train_exs

### revised answer scoring by normalizing the score by length
from types import GeneratorType
## added temperature arg to allow change
def answer_scoring_normalized(passages, passage_probs, prompts, temperature, gen_func=run_eleuther):
    """Implements our basic scoring strategy.

  Parameters
  ----------
  passages : list of str
  passage_probs : list of float
  prompts : list of str
  gen_func : either `run_eleuther` or `run_gpt3`

  Returns
  -------
  list of pairs (score, dict), sorted with the largest score first.
  `dict` should be the return value of `gen_func` for an example.

    """
    data = []
    length_sum = 0
    gen_list = []

    for passage, passage_prob, prompt in zip(passages, passage_probs, prompts):
        gen = gen_func([prompt], temperature = temperature)

        gen_list.append(gen)
        # calculate the total length of answers
        length_sum += len(gen[0]['generated_answer'].split(' '))

    for passage_prob, gen in zip(passage_probs, gen_list):

        answer_score = np.prod(gen[0]['generated_answer_probs'])

        length_of_answer = len(gen[0]['generated_answer'].split(' '))

        # give more weight to longer answers, as its product of per-token probabiliyy is underdog
        weight = length_of_answer/length_sum

        final_score = passage_prob*answer_score*weight

        data.append((final_score, gen[0]))

    data.sort(key = lambda x: x[0], reverse=True)

    return data


######## This part is system development ############

batch_size = 5
joiner = '\n\n'
n_context = 2

# temperatures = [0.01, 0.025, 0.05, 0.075]
temperatures = [0.025]


working_dataset = bioasq_test # bioasq_test vs dev_exs

for temperature in temperatures:
    prompts = []

    gens = []

    # use tf-idf to find "related few shot in bioasq to build the prompt
    # train tf-idf on all bioasq examples
    tfidfvectorizer, context_tfidf = train_tf_idf(bioasq_train)

    for i in range(0, len(working_dataset), batch_size):
        # Use the `searcher` defined above to get passages
        # using `ex.question` as the query, and use your
        # `build_few_shot_open_qa_prompt` to build prompts.

        # get a batch from bioasq dev (to replace dev_exs)
        batch = working_dataset[i: i+batch_size]

        # train_exs = random.sample(bioasq_train, k=n_context)

        ## score for answer-passage pair
        for ex in batch:

          # use tf idf to sample training exs, instead of just random sampling bioasq training
            train_exs = sample_bioasq_train(tfidfvectorizer, context_tfidf, ex.question, n_context)

            passages, passage_probs = get_passages_with_scores(ex.question)

            # re-initiating prompt
            ps = []
            # iterate through each passage in the top k (5) passages
            for psg in passages:
            # build the prompt based on question, that specific passge, and training examples
            # say we have passage, then ps will be ['prompt1', 'prompt2', 'prompt3', 'prompt4', 'prompt5']
                ps.append(build_few_shot_open_qa_prompt(ex.question, psg, train_exs, joiner=joiner)) 

          # calculate the answering score for the highest passage-answer pair                 
          # data = answer_scoring(passages,       # only related to question, same length as ps
          #                       passage_probs,  # only related to question, same length as ps
          #                       ps,             # k prompts
          #                       run_eleuther)

            data = answer_scoring_normalized(passages,       # only related to question, same length as ps
                                passage_probs,  # only related to question, same length as ps
                                ps,             # k prompts
                                temperature,
                                run_eleuther)

            # pick highest score answer-prompt pair (note: in)
            highest_gs = [data[0][1]]
            highest_ps = [data[0][1]['prompt']]

            # add the prompt to prompt list
            prompts += highest_ps

            # add generated txt to gen list
            gens += highest_gs
 
    eva = evaluate(working_dataset, prompts, gens)
    
    print(f"""
          temperature {temperature} 
          Macro F1 is: {eva['macro_f1']}， 
          Exact Match: {eva['em_per']}， 
          Macro Precision is: {eva['macro_precision']},
          Macro Recall is: {eva['macro_recall']},
          """)




          temperature 0.025 
          Macro F1 is: 0.1641240799601087， 
          Exact Match: 0.07734138972809668， 
          Macro Precision is: 0.17187570066724142,
          Macro Recall is: 0.22683778416752035,
          
CPU times: user 12h 8min 59s, sys: 8min 27s, total: 12h 17min 26s
Wall time: 1h 34min 8s


#### DPR Model for retrieval + LM

In [17]:
from transformers import DPRContextEncoder, AutoModel

context_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

encode_context_model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

question_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

encode_question_model = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

encode_context_model.to(device)
encode_question_model.to(device)

In [None]:
%%time
# create passage embedding for the following analysis

## make all passage embeddings and get their id
context_all = []

encode_context_model.eval()


for i in range(len(bioasq_list)):
    
    with torch.no_grad():
        context_input_ids = context_tokenizer(bioasq_list[i]['context'], return_tensors="pt")["input_ids"]
        context_embeddings = encode_context_model(context_input_ids).pooler_output.squeeze()

        context_all.append(context_embeddings)

        torch.cuda.empty_cache() # probably redundant
    
    if i % 100 == 0:
        print(i)

context_all_tensor = torch.stack(context_all)

In [106]:
%%time

working_dataset = bioasq_test # bioasq_test vs dev_exs

batch_size = 5
joiner = '\n\n'
# number of prompts
n_context = 2

# temperatures = [0.01, 0.025, 0.05, 0.075]
temperatures = [0.025]

encode_question_model.eval()

for temperature in temperatures:
    
    prompts = []

    gens = []

    for i in range(0, len(working_dataset), batch_size):
        # Use the `searcher` defined above to get passages
        # using `ex.question` as the query, and use your
        # `build_few_shot_open_qa_prompt` to build prompts.
        
        # get a batch from bioasq dev (to replace dev_exs)
        batch = working_dataset[i: i+batch_size]

        train_exs = random.sample(bioasq_train, k=n_context)

        ## get a passage for each example in the dev batch
        # get search results (passage index) for all examples in the batch
        # k = 1 because we choose the top result
        
        passages = []
        
        for ex in batch:
            
            with torch.no_grad():
                
                # encode question
                question_input_ids = question_tokenizer(ex.question, return_tensors="pt")["input_ids"]
                question_embeddings = encode_question_model(question_input_ids).pooler_output.squeeze()
                
                # get the dot product (score and sort it)
                dot_products = torch.sum(context_all_tensor * question_embeddings, -1)
                new_dot_products = torch.sort(dot_products, dim=- 1, descending=True)
                
                # retrieve the passage
                passages.append(bioasq_list[new_dot_products.indices[0]]['context'])

        # re-initiating prompt
        ps = []

        # for every question in the batch, combine the train_exs (background + q +a) + found passage + question and generate the prompt
        # append all prompt into a list
        
        for ex, psg in zip(batch, passages):
            ps.append(build_few_shot_open_qa_prompt(ex.question, psg, train_exs, joiner=joiner))  

        # feed prompts (in list of prompts) to gen_func
        gs = run_eleuther(ps)       

        # add the prompt to prompt list
        prompts += ps
        # add generated txt to gen list
        gens += gs
    
    eva = evaluate(working_dataset, prompts, gens)
    # print(eva)
    print(f"""
          temperature {temperature} 
          Macro F1 is: {eva['macro_f1']}， 
          Exact Match: {eva['em_per']}， 
          Macro Precision is: {eva['macro_precision']},
          Macro Recall is: {eva['macro_recall']},
          """)



          temperature 0.025 
          Macro F1 is: 0.0806630306815693， 
          Exact Match: 0.009667673716012085， 
          Macro Precision is: 0.0759503906633816,
          Macro Recall is: 0.14150972153460456,
          
CPU times: user 2h 33min 52s, sys: 32min 50s, total: 3h 6min 43s
Wall time: 46min 40s


#### Doc2Query Augumentation + BM25



In [None]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

d2q_tokenizer = T5Tokenizer.from_pretrained('castorini/doc2query-t5-base-msmarco')
d2q_model = T5ForConditionalGeneration.from_pretrained('castorini/doc2query-t5-base-msmarco')
d2q_model.to(device)

In [29]:
import copy
bioasq_list_copy = copy.deepcopy(bioasq_list)
# bioasq_list_copy = bioasq_list.copy()
# bioasq_list_copy = bioasq_list_copy[:20]

In [26]:
%%time

## this is to use doc2query to find additional question from MS Macro and append the top 10 
## questions back to the context
d2q_model.eval()

with torch.no_grad():
    
    for i in range(len(bioasq_list_copy)):
        
        doc_text = bioasq_list_copy[i]['context']

        input_ids = d2q_tokenizer.encode(doc_text, return_tensors='pt').to(device)
        
        number_of_q = 10
        
        with torch.no_grad():
            outputs = d2q_model.generate(
                input_ids=input_ids,
                max_length=64,
                do_sample=True,
                top_k=10,
                num_return_sequences=number_of_q)
    
        query_to_append = ' '
        
        with torch.no_grad():
            for j in range(number_of_q):
                query_to_append += d2q_tokenizer.decode(outputs[j], skip_special_tokens=True) + ' '
            
        # append the query to the passage
        
        bioasq_list_copy[i]['context'] = bioasq_list_copy[i]['context'] + query_to_append
        
        torch.cuda.empty_cache() # probably redundant
        

CPU times: user 15min 58s, sys: 35.1 s, total: 16min 33s
Wall time: 16min 33s


In [78]:
len(bioasq_dev)

207

In [80]:
%%time

working_dataset = bioasq_test # bioasq_test vs dev_exs

batch_size = 5
joiner = '\n\n'
# number of prompts
n_context = 2

# temperatures = [0.01, 0.025, 0.05, 0.075]
temperatures = [0.025]

## prepare BM25
from rank_bm25 import BM25Okapi

tokenized_corpus = [example['context'].split(" ") for example in bioasq_list_copy]

bm25 = BM25Okapi(tokenized_corpus)

for temperature in temperatures:
    
    prompts = []

    gens = []

    for i in range(0, len(working_dataset), batch_size):
        # Use the `searcher` defined above to get passages
        # using `ex.question` as the query, and use your
        # `build_few_shot_open_qa_prompt` to build prompts.
        
        # get a batch from bioasq dev (to replace dev_exs)
        batch = working_dataset[i: i+batch_size]

        train_exs = random.sample(bioasq_train, k=n_context)

        ## get a passage for each example in the dev batch
        # get search results (passage index) for all examples in the batch
        # k = 1 because we choose the top result
        
        passages = []
        
        for ex in batch:
                
            tokenized_query = ex.question.split(" ")
            
            # retrieve the top one passage with question added
            # doc_scores = bm25.get_scores(tokenized_query)
            passage = bm25.get_top_n(tokenized_query, bioasq_list_copy, n=1)[0]['context']

            # retrieve the passage
            passages.append(passage)

        # re-initiating prompt
        ps = []

        # for every question in the batch, combine the train_exs (background + q +a) + found passage + question and generate the prompt
        # append all prompt into a list
        
        for ex, psg in zip(batch, passages):
            ps.append(build_few_shot_open_qa_prompt(ex.question, psg, train_exs, joiner=joiner))  

        # feed prompts (in list of prompts) to gen_func
        gs = run_eleuther(ps)       

        # add the prompt to prompt list
        prompts += ps
        # add generated txt to gen list
        gens += gs
    
    eva = evaluate(working_dataset, prompts, gens)
    # print(eva)
    print(f"""
          temperature {temperature} 
          Macro F1 is: {eva['macro_f1']}， 
          Exact Match: {eva['em_per']}， 
          Macro Precision is: {eva['macro_precision']},
          Macro Recall is: {eva['macro_recall']},
          """)



          temperature 0.025 
          Macro F1 is: 0.08072085112218676， 
          Exact Match: 0.015105740181268883， 
          Macro Precision is: 0.07526040424831965,
          Macro Recall is: 0.14252709122749127,
          
CPU times: user 3h 1min 42s, sys: 27min 50s, total: 3h 29min 32s
Wall time: 52min 22s


### BERT BASE Retrieval

### Rank Fusion