In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
%load_ext dotenv
%dotenv

In [4]:
from rag_experiments import *

from utils import io, common, data

In [5]:
dataset_instance = get_dataset('hotpot_qa', 'fullwiki')

In [6]:
with common.LogTime("Preparing retrieval corpus"):

    raw_documents = {}

    for instance in tqdm.tqdm(dataset_instance):
        for title, sentences in zip(instance['context']['title'], instance['context']['sentences']):
            raw_documents[make_id_from_title(title)] = llama_index.core.Document(doc_id=make_id_from_title(title), text=' '.join(sentences), extra_info={ 'title': title })

    raw_documents = list(raw_documents.values())

[<] Preparing retrieval corpus ...



  0%|          | 0/7405 [00:00<?, ?it/s]




[>] Preparing retrieval corpus: 3s459ms
---------------------------------------------------------------------------------------------------------------------


In [7]:
len(raw_documents)

66568

In [8]:
GET_COUNTERFACTUAL_ANSWER_PROMPT_TEMPLATE = "Query: {query}\nAnswer: {answer}"
GET_COUNTERFACTUAL_ANSWER_PROMPT_ARGS = dict(
    instructions=[
        "Given a question-answer pair, suggest a counterfactual answer to the question.",
        "The answer should be of the same nature as the original answer, but should be different.",
        "For instance, if the answer was a year, suggest a different year as the answer."
    ],
    examples=[
        ("Query: Which country won the Cricket World Cup in 1993?\nAnswer: India", "Australia"),
        ("Query: Do Mahatma Gandhi and Adolf Hitler share Nationalities?\nAnswer: No", "Yes")
    ]
)

In [9]:
faux_ids = []

In [10]:
counterfactual_answers = data.NestedListItemResult(
    "data/hotpot_qa-fullwiki-counterfactual.json",
    [ instance['id'] for instance in dataset_instance.select(range(1000)) ]
)

In [11]:
def is_answer_incorrect(ctrfct):
    return ctrfct['counterfactual'] is None or ("query" in ctrfct['counterfactual'].lower() and ("missing" in ctrfct['counterfactual'].lower() or "provide" in ctrfct['counterfactual'].lower())) or "sorry" in ctrfct['counterfactual'].lower()

In [25]:
faux_ids = []

with common.ModelManager("GPT-3.5-U") as model:
    for instance in tqdm.tqdm(dataset_instance.select(range(1000))):
        if counterfactual_answers[instance['id']] is None or is_answer_incorrect(counterfactual_answers[instance['id']]):
            prompt = model.make_prompt(
                format_prompt(
                    GET_COUNTERFACTUAL_ANSWER_PROMPT_TEMPLATE,
                    query=instance['question'], answer=instance['answer']
                ),
                **GET_COUNTERFACTUAL_ANSWER_PROMPT_ARGS
            )
            counterfactual_answers[instance['id']] = {
                **normalize_instance('hotpot_qa', instance),
                'counterfactual': model.generate(prompt, max_new_tokens=15, temperature=0.5)[0]
            }
            faux_ids.append(instance['id'])
            counterfactual_answers.save()

  0%|          | 0/1000 [00:00<?, ?it/s]

In [28]:
INDEXING = [
    ("basic", "open-source", 100),
    ("basic", "open-source", 250),
    ("semantic", "open-source"),
]

MAX_EDIT_DOCS = 5

In [15]:
retrieval_cache = {
    ','.join(str(term).lower() for term in index_strategy): data.NestedListItemResult(
        f"data/retrieval/standard/{','.join(str(term).lower() for term in index_strategy)}.json"
    )
    for index_strategy in INDEXING
}

In [16]:
TUNE_CONTEXT_PROMPT_TEMPLATE = "Query: {query}\nAnswer: {answer}\nCounterfactual: {new_answer}\nContext: {context}"
TUNE_CONTEXT_EXAMPLES = [
        (
            "Which country won the Cricket World Cup in 1983?", "India", "Australia",
            "The 1983 Cricket World Cup (officially the Prudential Cup '83) was the 3rd edition of the Cricket World Cup tournament. It was held from 9 to 25 June 1983 in England and Wales and was won by India.",
            "The 1983 Cricket World Cup (officially the Prudential Cup '83) was the 3rd edition of the Cricket World Cup tournament. It was held from 9 to 25 June 1983 in England and Wales and was won by Australia."
        ),
        (
            "Do Mahatma Gandhi and Adolf Hitler share Nationalities?", "No", "Yes",
            "Mohandas Karamchand Gandhi was an Indian lawyer, anti-colonial nationalist and political ethicist who employed nonviolent resistance to lead the successful campaign for India's independence from British rule.",
            "Mohandas Karamchand Gandhi was a German lawyer, anti-colonial nationalist and political ethicist who employed nonviolent resistance to lead the successful campaign for Germany's independence from British rule."
        )
    ]
TUNE_CONTEXT_PROMPT_ARGS = dict(
    instructions=[
        "You are given a question-answer pair, along with a counterfactual answer.",
        "Rewrite the context given in a way so that it supports the counterfactual answer instead of the true answer.",
        "Perform minimal changes to the context in terms of the writing style, phrasing, etc.",
        "If the context is irrelevant to the question in general or is not supportive of the original answer either, then return it as it is.",
        "However, for a relevant context, rewrite it so that from the context it is sufficient to conclude that the counterfactual is the actual answer.",
    ],
    examples=[
        (TUNE_CONTEXT_PROMPT_TEMPLATE.format(query=query, answer=answer, new_answer=new_answer, context=context), new_context)
        for query, answer, new_answer, context, new_context in TUNE_CONTEXT_EXAMPLES
    ]
)

In [17]:
dataset_instance_subset = dataset_instance.select(range(1000))

In [18]:
BATCH_SIZE = 32
SAVE_STEPS = 2

In [29]:
with tqdm.tqdm(total = len(INDEXING) * len(dataset_instance_subset) * MAX_EDIT_DOCS) as pbar:

    with common.ModelManager("Mistral-Instruct-7B-v2") as model:

        for index_strategy in INDEXING:

            index_strategy_desc = ','.join(str(term).lower() for term in index_strategy)

            index_result = data.NestedListItemResult(
                f"data/retrieval/counterfactual-post-hoc/{index_strategy_desc}.json",
                [ normalize_instance('hotpot_qa', instance)['id'] for instance in dataset_instance_subset ]
            )
            prompts, nodes, indexes = [], [], []

            for i, instance in enumerate(dataset_instance_subset):
                norm_instance = normalize_instance('hotpot_qa', instance)

                if instance['id'] not in faux_ids:
                    if index_result[instance['id']] is not None and len(index_result[instance['id']]) == MAX_EDIT_DOCS:
                        for _ in range(MAX_EDIT_DOCS):
                            pbar.update()
                        continue
                else:
                    index_result[instance['id']] = None

                instance = norm_instance
                ccount = len(index_result[instance['id']] or [])

                ctx_nodes = retrieval_cache[index_strategy_desc][instance['id']][:MAX_EDIT_DOCS]

                for j, node in enumerate(ctx_nodes):
                    if j < ccount:
                        pbar.update()
                        continue

                    nodes.append(dict(**node))
                    indexes.append((i, j))
                    prompts.append(format_prompt(
                        TUNE_CONTEXT_PROMPT_TEMPLATE,
                        query=instance['question'], answer=instance['answer'],
                        new_answer=counterfactual_answers[instance['id']]['counterfactual'],
                        context=node['text']
                    ))

            batches = list(common.batchify(
                prompts, nodes, indexes, batch_size=BATCH_SIZE
            ))

            timer = common.BatchProgressTimer(pbar, total=math.ceil(len(indexes)/BATCH_SIZE))
            for batch, (prompts_, nodes_, indexes_) in enumerate(batches):
                with timer.timed_operation(batch+1, save=(batch+1) % SAVE_STEPS == 0):
                    fmtd_prompts = [
                        model.make_prompt(prompt, **TUNE_CONTEXT_PROMPT_ARGS)[0]
                        for prompt in prompts_
                    ]
                    revisions = model.generate(
                        fmtd_prompts, max_new_tokens=512, do_sample=True,
                        temperature=0.5, decoding="aggressive"
                    )
                    for revision, node, (i, j) in zip(revisions, nodes_, indexes_):
                        ref = dataset_instance_subset[i]['id']
                        if index_result[ref] is None: index_result[ref] = []
                        if len(index_result[ref]) == j:
                            node['text'] = revision
                            index_result[ref].append(node)
                            pbar.update()

                if (batch + 1) % SAVE_STEPS == 0:
                    index_result.save()
                    common.sync_vram()

            index_result.save()

  0%|          | 0/15000 [00:00<?, ?it/s]