In [2]:
from collections import defaultdict
from math import ceil, floor

import json
import time
import random

In [3]:
def context_reduncer_preflight_check(
    rankings_a, rankings_b, batch_size=3, iou_threshold=0.3
):
    sa, sb = set(rankings_a[:batch_size]), set(rankings_b[:batch_size])
    iou = len(sa & sb) / len(sa | sb)
    if iou > iou_threshold:
        return False
    return True

In [4]:
def format_options(options):
    option_strings = ['{}. {}'.format(k, options[k]) for k in options]
    return (', '.join(option_strings[:-1]) 
            + ' or ' + option_strings[-1])

In [5]:
def format_prompt_mixtral(context, question, options):
    return (
        '<s> [INST] '
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question using the '
        'relevant documents.\n'
        f'{question}'
        '\n'
        'Here are the relevant documents:\n'
        f'{context}'
        f' [/INST] Please select the answer from {options}.'
        ' Also explain the answer briefly in one sentence. </s>'
    )

In [6]:
def format_context_map_prompt_mixtral(context, question, options):
    return (
        '<s> [INST] '
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question using the '
        'relevant documents.\n'
        f'{question}'
        '\n'
        'Here are the relevant documents:\n'
        f'{context}'
        '\n'
        f' [/INST] Please select the answer from {options},'
        ' explain the answer briefly in one sentence,'
        ' and format output in JSON.'
        ' </s>'
    )

In [7]:
def format_prompt_mixtral_zs(question, options):
    return (
        '<s> [INST] '
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question.\n'
        f'{question}'
        f' [/INST] Please select the answer from {options}.'
        ' Also explain the answer briefly in one sentence. </s>'
    )

In [8]:
def format_prompt_llama2(context, question, options):
    return (
        '<s> [INST] <<SYS>>\n'
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question using the '
        'relevant documents.\n<</SYS>>\n\n'
        f'{question}'
        '\n'
        'Here are the relevant documents:\n'
        f'{context}'
        f'Please select the answer from {options}.'
        ' Also explain the answer briefly in one sentence. [/INST]'
    )

In [9]:
def format_context_map_prompt_llama2(context, question, options):
    return (
        '<s> [INST] <<SYS>>\n'
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question using the '
        'relevant documents.\n<</SYS>>\n\n'
        f'{question}'
        '\n'
        'Here are the relevant documents:\n'
        f'{context}'
        '\n'
        f'Please either select the answer from {options},'
        ' explain the answer briefly in one sentence,'
        ' and format output in JSON. [/INST]'
    )

In [10]:
def format_prompt_llama2_zs(question, options):
    return (
        '<s> [INST] <<SYS>>\n'
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question.\n<</SYS>>\n\n'
        f'{question}'
        f'Please select the answer from {options}.'
        ' Also explain the answer briefly in one sentence. [/INST]'
    )

In [11]:
def format_prompt_llama3(context, question, options):
    return (
        '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n'
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question using the relevant documents.'
        '<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n'
        f'{question}'
        '\n'
        'Here are the relevant documents:\n'
        f'{context}'
        f'Please select the answer from {options}.'
        ' Also explain the answer briefly in one sentence. '
        '<|eot_id|>'
        '<|start_header_id|>assistant<|end_header_id|>'
    )

In [12]:
def format_context_map_prompt_llama3(context, question, options):
    return (
        '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n'
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question using the '
        'relevant documents.'
        '<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n'
        f'{question}'
        '\n'
        'Here are the relevant documents:\n'
        f'{context}'
        '\n'
        f'Please select the answer from {options}.'
        ' Also explain the answer briefly in one sentence. '
        '<|eot_id|>'
        '<|start_header_id|>assistant<|end_header_id|>'
    )

In [13]:
def format_prompt_llama3_zs(question, options):
    return (
        '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n'
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question.'
        '<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n'
        f'{question}'
        f'Please select the answer from {options}.'
        ' Also explain the answer briefly in one sentence. '
        '<|eot_id|>'
        '<|start_header_id|>assistant<|end_header_id|>'
    )

In [14]:
def format_prompt_gpt35(context, question, options):
    return json.dumps([
        {
            'role': 'system',
            'content': (
                'You are a helpful medical expert, and your task is '
                'to answer a multi-choice medical question using the '
                'relevant documents.'
            ),
        },
        {
            'role': 'user',
            'content': (
                f'{question}\n'
                'Here are the relevant documents:\n'
                f'{context}\n'
                f'Please select the answer from {options},'
                ' explain the answer briefly in one sentence,'
                ' and format output in JSON.'
            ),
        },
    ])

In [15]:
def format_context_map_prompt_gpt35(context, question, options):
    return json.dumps([
        {
            'role': 'system',
            'content': (
                'You are a helpful medical expert, and your task is '
                'to answer a multi-choice medical question using the '
                'relevant documents.'
            ),
        },
        {
            'role': 'user',
            'content': (
                f'{question}\n'
                'Here are the relevant documents:\n'
                f'{context}\n'
                f'Please select the answer from {options},'
                ' explain the answer briefly in one sentence,'
                ' and format output in JSON.'
            ),
        },
    ])

In [16]:
def format_prompt_gpt35_zs(question, options):
    return json.dumps([
        {
            'role': 'system',
            'content': (
                'You are a helpful medical expert, and your task is '
                'to answer a multi-choice medical question using the '
                'relevant documents.'
            ),
        },
        {
            'role': 'user',
            'content': (
                f'{question}\n'
                f'Please select the answer from {options},'
                ' explain the answer briefly in one sentence,'
                ' and format output in JSON.'
            ),
        },
    ])

In [17]:
def format_prompt(
    context, 
    question, 
    options, 
    model_id,
):
    if 'gpt' in model_id:
        if '35' in model_id:
            return format_prompt_gpt35(
                context, question, options)
    if 'llama2' in model_id:
        return format_prompt_llama2(
            context, question, options)
    if 'llama3' in model_id:
        return format_prompt_llama3(
            context, question, options)
    if 'mixtral' in model_id:
        return format_prompt_mixtral(
            context, question, options)
    return 'Invalid model ID.'

In [18]:
def format_context_map_prompt(
    context, 
    question, 
    options, 
    model_id,
):
    if 'gpt' in model_id:
        if '35' in model_id:
            return format_context_map_prompt_gpt35(
                context, question, options)
    if 'llama2' in model_id:
        return format_context_map_prompt_llama2(
            context, question, options)
    if 'llama3' in model_id:
        return format_context_map_prompt_llama3(
            context, question, options)
    if 'mixtral' in model_id:
        return format_context_map_prompt_mixtral(
            context, question, options)
    return 'Invalid model ID.'

In [19]:
def format_prompt_zs(
    question, 
    options, 
    model_id,
):
    if 'gpt' in model_id:
        if '35' in model_id:
            return format_prompt_gpt35_zs(
                question, options)
    if 'llama2' in model_id:
        return format_prompt_llama2_zs(
            question, options)
    if 'llama3' in model_id:
        return format_prompt_llama3_zs(
            question, options)
    if 'mixtral' in model_id:
        return format_prompt_mixtral_zs(
            question, options)
    return 'Invalid model ID.'

In [20]:
def get_bedrock_response_titan(prompt, model_id, client):
    accept = "application/json"
    contentType = "application/json"
    parameters = {
        "maxTokenCount":8192,
        "stopSequences":[],
        "temperature":0,
        "topP":1
    }
    body = json.dumps({
        "inputText": prompt, 
        "textGenerationConfig": parameters,
    })
    try:
        response = client.invoke_model(
            body=body, 
            modelId=model_id, 
            accept=accept, 
            contentType=contentType
        )
        response_body = json.loads(response.get('body').read())
        completion = response_body.get('results')[0].get('outputText')
    except Exception as e:
        print(e)
        completion = 'The model is not able to process the request.'
    return completion

In [21]:
def get_bedrock_response_mixtral(prompt, model_id, client):
    accept = "application/json"
    contentType = "application/json"
    body = json.dumps({
        "prompt": prompt,
        "max_tokens": 512,
        "top_p": 0.8,
        "temperature": 0.2,
    })
    try:
        response = client.invoke_model(
            body=body, 
            modelId=model_id, 
            accept=accept, 
            contentType=contentType
        )
        response_body = json.loads(response.get('body').read())
        completion = response_body.get('outputs')[0].get('text').strip()
    except Exception as e:
        print(e)
        completion = 'The model is not able to process the request.'
    return completion

In [22]:
def get_bedrock_response_llama2(prompt, model_id, client):
    accept = "application/json"
    contentType = "application/json"
    body = json.dumps({
        "prompt": prompt,
        "max_gen_len": 512,
        "temperature": 0.2,
    })
    try:
        response = client.invoke_model(
            body=body, 
            modelId=model_id, 
            accept=accept, 
            contentType=contentType
        )
        response_body = json.loads(response.get('body').read())
        completion = response_body.get('generation').strip()
    except Exception as e:
        print(e)
        completion = 'The model is not able to process the request.'
    return completion

In [23]:
def get_bedrock_response_llama3(prompt, model_id, client):
    accept = "application/json"
    contentType = "application/json"
    body = json.dumps({
        "prompt": prompt,
        "max_gen_len": 512,
        "temperature": 0.2,
    })
    try:
        response = client.invoke_model(
            body=body, 
            modelId=model_id, 
            accept=accept, 
            contentType=contentType
        )
        response_body = json.loads(response.get('body').read())
        completion = response_body.get('generation').strip()
    except Exception as e:
        print(e)
        completion = 'The model is not able to process the request.'
    return completion

In [24]:
def get_model_response_bedrock(prompt, model_id, client):
    if 'titan' in model_id:
        return get_bedrock_response_titan(prompt, model_id, client)
    if 'mixtral' in model_id:
        return get_bedrock_response_mixtral(prompt, model_id, client)
    if 'llama2' in model_id:
        return get_bedrock_response_llama2(prompt, model_id, client)
    if 'llama3' in model_id:
        return get_bedrock_response_llama3(prompt, model_id, client)
    return 'Invalid model ID.'

In [25]:
def get_model_response_openai(prompt, model_id, client):
    try:
        response = client.chat.completions.create(
            model = 'gpt-3.5-turbo-0125',
            response_format = { "type": "json_object" },
            messages = json.loads(prompt),
            temperature=0.2,
            top_p=0.1,
            seed=0
        )
        completion = response.choices[0].message.content.strip()
    except Exception as e:
        print(e)
        completion = 'The model is not able to process the request.'
    return completion

In [26]:
def get_model_response(prompt, model_id, model_runtime):
    if ('titan' in model_id 
        or 'mixtral' in model_id 
        or 'llama' in model_id
       ):
        return get_model_response_bedrock(
            prompt, model_id, model_runtime)
    if 'gpt' in model_id:
        return get_model_response_openai(
            prompt, model_id, model_runtime
        )
    return 'Invalid model_id.'

In [27]:
def format_context_map_questions(
    qid2query,
    qid2answer,
    qid2retrieval,
    qid,
    docs_content,
    model_id,
    batch_size=4, 
    top_k=16,
):
    requests = []
    question = qid2query[qid]
    if 'options' in qid2answer[qid]: 
        options = format_options(qid2answer[qid]['options'])
    doc_ids = qid2retrieval[qid]['doc_ids']
    snippets = [
        docs_content[sid] 
        for sid in doc_ids
        if sid in docs_content
    ]

    index = 0
    while index < top_k:
        context = '\n'.join(snippets[:top_k][index:index+batch_size])
        prompt = format_context_map_prompt(context, question, options, model_id)
        requests.append(prompt)
        index += batch_size
    return requests

In [28]:
def format_context_reduce_prompt_mixtral(
    qid2query,
    qid2answer,
    qid2retrieval,
    qid,
    responses
):
    question = qid2query[qid]
    if 'options' in qid2answer[qid]: 
        options = format_options(qid2answer[qid]['options'])
    filtered_responses = [
        r for r in responses 
        if 'no information detected' not in r.lower()]
    context = '\n\n'.join([f'{r}' for r in filtered_responses])
    
    return (
        '<s> [INST] '
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question using the '
        'information extracted from relevant documents.\n'
        f'{question}'
        '\n'
        'Here is the extracted information:\n'
        f'{context}'
        '\n'
        f'{question}'
        f' [/INST] Please ignore answers that do not contain information'
        f' and select the answer from {options}. </s>'
    )

In [29]:
def format_context_reduce_prompt_llama2(
    qid2query,
    qid2answer,
    qid2retrieval,
    qid,
    responses
):
    question = qid2query[qid]
    if 'options' in qid2answer[qid]: 
        options = format_options(qid2answer[qid]['options'])
    filtered_responses = [
        r for r in responses 
        if 'no information detected' not in r.lower()]
    context = '\n\n'.join([f'{r}' for r in filtered_responses])
    
    return (
        '<s> [INST] <<SYS>>\n'
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question using the '
        'information extracted from relevant documents.\n<</SYS>>\n\n'
        f'{question}'
        '\n'
        'Here is the extracted information:\n'
        f'{context}'
        '\n'
        f'{question}'
        f' [/INST] Please ignore answers that do not contain information'
        f' and select the answer from {options}. [/INST]'
    )

In [30]:
def format_context_reduce_prompt_llama3(
    qid2query,
    qid2answer,
    qid2retrieval,
    qid,
    responses
):
    question = qid2query[qid]
    if 'options' in qid2answer[qid]: 
        options = format_options(qid2answer[qid]['options'])
    filtered_responses = [
        r for r in responses 
        if 'no information detected' not in r.lower()]
    context = '\n\n'.join([f'{r}' for r in filtered_responses])

    return (
        '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n'
        'You are a helpful medical expert, and your task is '
        'to answer a multi-choice medical question using the relevant documents.'
        '<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n'
        f'{question}'
        '\n'
        'Here is the extracted information:\n'
        f'{context}'
        f' [/INST] Please ignore answers that do not contain information'
        f' and select the answer from {options}.'
        '<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
    )

In [31]:
def format_context_reduce_prompt_gpt35(
    qid2query,
    qid2answer,
    qid2retrieval,
    qid,
    responses
):
    question = qid2query[qid]
    if 'options' in qid2answer[qid]: 
        options = format_options(qid2answer[qid]['options'])
    filtered_responses = [
        r for r in responses 
        if 'no information detected' not in r.lower()]
    context = '\n\n'.join([f'{r}' for r in filtered_responses])

    return json.dumps([
        {
            'role': 'system',
            'content': (
                'You are a helpful medical expert, and your task is '
                'to answer a multi-choice medical question using the '
                'relevant documents.'
            ),
        },
        {
            'role': 'user',
            'content': (
                f'{question}\n'
                'Here is the extracted information:\n'
                f'{context}\n'
                f'{question}\n'
                'Please ignore answers that do not contain information,'
                f' select the answer from {options},'
                ' and format the output in JSON.'
            ),
        },
    ])

In [32]:
def format_context_reduce_prompt(
    qid2query,
    qid2answer,
    qid2retrieval,
    qid,
    responses,
    model_id
):
    if 'mixtral' in model_id:
        return format_context_reduce_prompt_mixtral(
            qid2query,
            qid2answer,
            qid2retrieval,
            qid,
            responses
        )
    if 'llama2' in model_id:
        return format_context_reduce_prompt_llama2(
            qid2query,
            qid2answer,
            qid2retrieval,
            qid,
            responses
        )
    if 'llama3' in model_id:
        return format_context_reduce_prompt_llama3(
            qid2query,
            qid2answer,
            qid2retrieval,
            qid,
            responses
        )
    if 'gpt35' in model_id:
        return format_context_reduce_prompt_gpt35(
            qid2query,
            qid2answer,
            qid2retrieval,
            qid,
            responses
        )
    return 'Invalid model ID in context reduce prompt.'

In [33]:
def brief_context_main(
    qid2query,
    qid2answer,
    qid2retrieval,
    qid,
    docs_content,
    batch_size, 
    top_k,
    model_id,
    model_runtime
):
    answer = {}
    questions = format_context_map_questions(
        qid2query,
        qid2answer,
        qid2retrieval,
        qid,
        docs_content,
        model_id,
        batch_size=batch_size,
        top_k=top_k)
    answers = []
    for question in questions:
        answers.append(
            get_model_response(
                question, model_id, model_runtime
            ).lstrip().rstrip())
        time.sleep(0.1)
    summary_request = format_context_reduce_prompt(
        qid2query,
        qid2answer,
        qid2retrieval, 
        qid, 
        answers, 
        model_id
    )
    final_answer = get_model_response(
        summary_request, model_id, model_runtime)
    answer['answer_from_each_batch'] = answers
    answer['final_answer'] = final_answer
    time.sleep(0.1)
    return answer

In [39]:
def sample_answer_positions(
    position_range, 
    exact_position,
    start,
    end,
    top_k,
):
    if position_range == 'middle':
        position_candidates = [
            i for i in range(floor(top_k*0.25), ceil(top_k*0.75))]
    elif position_range == 'top':
        position_candidates = [i for i in range(0, ceil(top_k*0.25))]
    elif position_range == 'bottom':
        position_candidates = [i for i in range(floor(top_k*0.75), top_k)]
    elif 0 <= exact_position and exact_position < top_k:
        position_candidates = [exact_position]
    else:
        position_candidates = [i for i in range(start, end)]
    return position_candidates

In [40]:
def sample_question_ids(all_qids, qid2answer, datasets, size=35):
    selected_qids = []
    answer_choices = {
        'pubmedqa': ['A', 'B', 'C'],
        'bioasq': ['A', 'B'],
    }
    for dataset in datasets: # ['pubmedqa']:
        for answer in answer_choices[dataset]:
            selected_qids += random.sample([
                q for q in all_qids if (
                    q.startswith(dataset) 
                    and qid2answer[q]['answer']==answer
                )], size)
        # selected_qids += [q for q in all_qids if q.startswith(dataset)]
    return selected_qids

In [41]:
def construct_synthetic_rankings(
    answer_positions, 
    key_ids, 
    non_key_ids,
    top_k,
):
    synthetic_ids = []
    index_answers, index_random_samples = 0, 0
    for i in range(top_k):
        if i in answer_positions:
            synthetic_ids.append(key_ids[index_answers])
            index_answers += 1
        else:
            synthetic_ids.append(
                non_key_ids[index_random_samples])
            index_random_samples += 1
    return synthetic_ids

In [42]:
def create_rag_dataset_with_synthetic_ranking(
    qid2retrieval,
    qid2query,
    qid2answer,
    qid2key_info,
    docs_content,
    fout,
    top_k,
    random_seed=0,
    position='',
    specific_index=-1,
    start=-1,
    end=-1,
    sample_size=24,
    datasets=[],
):
    assert (
        position in ['top', 'middle', 'bottom']
        or 0 <= specific_index and specific_index < top_k
        or 0 <= start and start < end and end <= top_k
    )
    synthetic_dataset = {}
    random.seed(random_seed)
    qids = sample_question_ids(
        [q for q in qid2retrieval.keys()],
        qid2answer,
        datasets,
        sample_size,
    )
    for qid in qids:
        answer_ids = qid2key_info[qid]
        random.shuffle(answer_ids)
        n_answer = 1
        non_answer_ids = [
            sid for sid in qid2retrieval[qid]['doc_ids'] 
            if sid not in answer_ids
        ][:top_k-n_answer]
        position_candidates = sample_answer_positions(
            position, specific_index, start, end, top_k)   
        answer_positions = random.sample(position_candidates, n_answer)
        synthetic_ids = construct_synthetic_rankings(
            answer_positions, 
            answer_ids, 
            non_answer_ids,
            top_k,
        )
        synthetic_dataset[qid] = {
            'qid': qid,
            'question': qid2query[qid],
            'key_info': qid2key_info[qid],
            'doc_ids': qid2retrieval[qid]['doc_ids'],
            'synthetic_rankings': synthetic_ids,
        }
    json.dump(
        synthetic_dataset,
        open(fout, 'w+'),
        indent=2,
    )