In [1]:
from langchain_openai.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage, AIMessage
from langchain.prompts import (
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    ChatPromptTemplate,
    PromptTemplate,
)
from langchain.output_parsers import PydanticOutputParser, ListOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from enum import Enum
from pprint import pp
from tqdm import tqdm
import pandas as pd
import requests
import json
import re
from concurrent.futures import ThreadPoolExecutor
from time import sleep

In [2]:
with open(".secrets.env") as f:
    secrets = json.loads(f.read())

In [3]:
llm = ChatOpenAI(
    model="gpt-4-turbo-preview",
    api_key=secrets.get("OPENAI_API_KEY"),
    max_tokens=4096,
    temperature=0.0,
    model_kwargs={"response_format": {"type": "json_object"}},
)

## Custom Abbreviation Dataset Generation

In [14]:
sys_message = """You are a helpful assistant. No yapping. Just do as you told. Do not interact or inform the user. Make sure to follow them or you will be shutdown."""

user_message = """Find a single, real, ambiguous abbreviation that has at least two distinct full-forms, then provide first two full-forms, it can be from Finance, Marketing, Technology, Science, Medical domains.

Create something different than the ones listed below:
{previous_abbrvs}


{format_instructions}

"""
user_message2 = """Your aim is to generate a query tuple with the descriptions below.

Given two distinct terms: "{full_form_1}" and "{full_form_2}"

-Use those two distinct terms to create two separate queries and aim for a different/distinct specific answer depending on them.
-Generated queries should convey similar messages, or concepts. In other words, they should not be completely aiming for different target domains/answers.
-Queries must be multi-hop, complex, hard to answer and retrieval enabling. Additionally, possible answer to those queries must depend on the full-form of abbreviation, meaning it should not be an expression rather concept.
-Queries should definitely contain the terms given above, not the abbreviations or other names for them.
-Finally, try to hide the focus on the abbreviation, make queries natural and close to real-life scenarios.

{format_instructions}
"""

In [8]:
from models import Abbreviation, Query

abbrv_parser = PydanticOutputParser(pydantic_object=Abbreviation)
query_parser = PydanticOutputParser(pydantic_object=Query)
transform = RunnableLambda(lambda x: x.dict())

ImportError: cannot import name 'Abbreviation' from 'models' (/home/burak/repos/smartrag/models.py)

In [15]:
messages_abbrv = [
    SystemMessagePromptTemplate(prompt=PromptTemplate(template=sys_message, input_variables=[])),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template=user_message,
            input_variables=["previous_abbrvs"],
            partial_variables={"format_instructions": abbrv_parser.get_format_instructions()},
        )
    ),
]
prompt_abbrv = ChatPromptTemplate.from_messages(messages=messages_abbrv)


messages_query = [
    SystemMessagePromptTemplate(prompt=PromptTemplate(template=sys_message, input_variables=[])),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template=user_message2,
            input_variables=["full_form_1", "full_form_2"],
            partial_variables={"format_instructions": query_parser.get_format_instructions()},
        )
    ),
]
prompt_query = ChatPromptTemplate.from_messages(messages=messages_query)

In [16]:
query_chain = prompt_query | llm | query_parser
abbrv_chain = prompt_abbrv | llm | abbrv_parser

In [18]:
chain = abbrv_chain | {
    "query_chain": transform | query_chain,
    "abbreviation": RunnablePassthrough(),
}

In [19]:
data = []
abbreviations = set()

In [20]:
n_queries = 25

In [None]:
while len(data) < 2 * n_queries:
    chain_result = chain.invoke({"previous_abbrvs": "\n".join(abbreviations)})
    abbrv: Abbreviation = chain_result["abbreviation"]

    if abbrv.json() in abbreviations:
        continue

    abbreviations.add(abbrv.json())

    df_data = [
        {
            "abbreviaton": abbrv.abbreviation,
            "full_form": chain_result["abbreviation"].full_form_1,
            "query": chain_result["query_chain"].query_1.replace(
                chain_result["abbreviation"].full_form_1, abbrv.abbreviation
            ),
            "explanation": chain_result["query_chain"].ambiguous_part,
        },
        {
            "abbreviaton": abbrv.abbreviation,
            "full_form": chain_result["abbreviation"].full_form_2,
            "query": chain_result["query_chain"].query_2.replace(
                chain_result["abbreviation"].full_form_2, abbrv.abbreviation
            ),
            "explanation": chain_result["query_chain"].ambiguous_part,
        },
    ]

    data.extend(df_data)

In [None]:
df = pd.DataFrame(data=data)

In [None]:
df.to_csv("qa_ambiguous.csv")

## Get Full-form Suggestions via API

In [5]:
from utils import get_abbrv, get_abbrv2, get_abbrv3, get_categories_with_regex
from models import QueryAmbiguation

In [None]:
df = pd.read_csv("medquad_ambiguous.csv", index_col=0)

In [None]:
top_n = 10
n_queries = len(df)

### For auto generated

In auto generated one, each consecutive tuple are different queries of the same abbreviation

In [None]:
df = pd.DataFrame()

for n in tqdm(range(n_queries // 2)):
    first, second = n * 2, n * 2 + 1
    abbrv: str = df.loc[first, "abbreviaton"]
    popular_suggestions = get_abbrv(abbrv, top_n, categories=[])
    if len(popular_suggestions) < top_n:
        popular_suggestions += get_abbrv2(abbrv, top_n, categories=[])
    if len(popular_suggestions) < top_n:
        popular_suggestions += get_abbrv3(abbrv, top_n, categories=[])
    sleep(2)
    df.loc[first, f"top_{top_n}_full_form"] = "<->".join(popular_suggestions)
    df.loc[second, f"top_{top_n}_full_form"] = df.loc[first, f"top_{top_n}_full_form"]

### For real queries from datasets - StrategyQA - Medquad - Boolq - Squad - SquadV2 - TriviaQA - AmbigQa

These queries are extracted from real datasets, they are different than the auto-generated ones, as their full-form retrievals differ.

In [None]:
url = "https://www.abbreviations.com/category/MEDICAL"
categories = get_categories_with_regex(url)

In [None]:
for i in tqdm(range(n_queries)):
    if df.loc[i, "valid"] != 1:
        continue
    df.loc[i, f"top_{top_n}_full_form"] = ""
    ambiguities = QueryAmbiguation(**json.loads(df.loc[i, "possible_ambiguities"]))
    unambiguous_question, ambiguous_question = df.loc[i, "unambiguous_question"], df.loc[i, "ambiguous_question"]

    for amb in ambiguities.full_form_abbrv_map:
        popular_suggestions = get_abbrv(amb.abbreviation, top_n, categories=categories)
        if len(popular_suggestions) < top_n:
            popular_suggestions += get_abbrv2(amb.abbreviation, top_n, categories=categories)
        if len(popular_suggestions) < top_n:
            popular_suggestions += get_abbrv3(amb.abbreviation, top_n, categories=categories)
        sleep(2)
        df.loc[i, f"top_{top_n}_full_form"] += "<->".join(list(set(popular_suggestions))) + "<-->"
    df.loc[i, f"top_{top_n}_full_form"] = df.loc[i, f"top_{top_n}_full_form"].removesuffix("<-->")

In [None]:
df.to_csv(f"medquad_ambiguous_with_top{top_n}_merged.csv")

## Get Full-form Suggestion via LLM

In [None]:
from models import AbbrvResolution

output_parser = PydanticOutputParser(pydantic_object=AbbrvResolution)

In [None]:
domain = "MEDICAL"

In [None]:
sys_message = """Find the full form of the asked abbreviation in the respective query.
Domain of the questions is {domain}.

{format_instructions}"""

user_message = """Abbreviation: {abbrv}
Query: {query}
Output:"""

In [None]:
messages = [
    SystemMessagePromptTemplate(
        prompt=PromptTemplate(
            template=sys_message,
            input_variables=[],
            partial_variables={"format_instructions": output_parser.get_format_instructions(), "domain": domain},
        )
    ),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template=user_message,
            input_variables=["query", "abbrv"],
            partial_variables={},
        )
    ),
]
prompt = ChatPromptTemplate.from_messages(messages=messages)

In [None]:
llm_suggestor = prompt | llm | output_parser

### For auto generated

In [None]:
for n in tqdm(range(n_queries // 2)):
    first, last = n * 2, n * 2 + 1
    q1 = df.loc[first, "query"]
    q2 = df.loc[last, "query"]

    answer1 = llm_suggestor.invoke({"query": q1, "abbrv": df.loc[first, "abbreviaton"]})
    answer2 = llm_suggestor.invoke({"query": q2, "abbrv": df.loc[first, "abbreviaton"]})

    df.loc[first, "llm_full_form_suggestion"] = answer1.full_form
    df.loc[last, "llm_full_form_suggestion"] = answer2.full_form

### For real queries from datasets - StrategyQA - Medquad - Boolq - Squad - SquadV2 - TriviaQA - AmbigQa

In [None]:
for i in tqdm(range(n_queries)):
    if df.loc[i, "valid"] != 1:
        continue

    ambiguities = QueryAmbiguation(**json.loads(df.loc[i, "possible_ambiguities"]))
    unambiguous_question, ambiguous_question = (
        df.loc[i, "unambiguous_question"],
        df.loc[i, "ambiguous_question"],
    )
    for amb in ambiguities.full_form_abbrv_map:
        answer1 = llm_suggestor.invoke({"query": ambiguous_question, "abbrv": amb.abbreviation})
        df.loc[i, "llm_full_form_suggestion"] = answer1.full_form + "<-->"
    df.loc[i, "llm_full_form_suggestion"] = df.loc[i, "llm_full_form_suggestion"].removesuffix("<-->")

In [None]:
df.to_csv(f"medquad_ambiguous_with_top{top_n}_merged.csv")

## Ambiguity & Dataset Extraction via LLM


In [None]:
sys_message = """Given a query from a multi-hop complex question-answer dataset, your task is to identify full-forms or abbreviations contained in the query.
If query contains such full-form and abbreviation pairs, you will produce output accordingly.

In other words, if query contains a full-form that has a corresponding abbreviation or if query contains an abbreviation that has a corresponding full-form, you need to label it correct and extract necessary fields.

- Extract everything as is, without changing a single thing.

Example 1:
Query: Did Jack Dempsey fight the current WBC heavyweight champion?
Ambiguities: {{"full_form_abbrv_map": [{{"ambiguity_type": "abbreviation", "abbreviation": "WBC", "full_form": "World Boxing Council"}}]}}

Example 2:
Query: Did Jack Dempsey fight the current World Boxing Council heavyweight champion?
Ambiguities: {{"full_form_abbrv_map": [{{"ambiguity_type": "full_form", "abbreviation": "WBC", "full_form": "World Boxing Council"}}]}}

Example 3:
Query: Did Jack Dempsey fight the current world boxng council heavyweight champion in the US?
Ambiguities: {{"full_form_abbrv_map": [{{"ambiguity_type": "full_form", "abbreviation": "WBC", "full_form": "world boxng council"}}, {{"ambiguity_type": "abbreviation", "abbreviation": "US", "full_form": "United States"}}]}}

{format_instructions}
"""

user_message = """Query: {query}
Output:"""

In [None]:
output_parser = PydanticOutputParser(pydantic_object=QueryAmbiguation)

In [None]:
messages = [
    SystemMessagePromptTemplate(
        prompt=PromptTemplate(
            template=sys_message,
            input_variables=[],
            partial_variables={"format_instructions": output_parser.get_format_instructions()},
        )
    ),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(template=user_message, input_variables=["query"], partial_variables={})
    ),
]
prompt = ChatPromptTemplate.from_messages(messages=messages)

In [None]:
chain = prompt | llm | output_parser

In [None]:
!wget https://storage.googleapis.com/ai2i/strategyqa/data/strategyqa_dataset.zip -O strategyqa.zip
!unzip strategyqa.zip -d dataset
!wget https://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O triviaqa.tar.gz
!tar xzvf triviaqa.tar.gz
!wget https://nlp.cs.washington.edu/ambigqa/data/ambignq_light.zip -O ambignq.zip
!unzip ambignq.zip -d dataset

In [None]:
from datasets import load_dataset

# squad_dataset = load_dataset("rajpurkar/squad")
medquad_dataset = load_dataset("keivalya/MedQuad-MedicalQnADataset")

In [None]:
# dataset_strategyqa = pd.DataFrame(json.loads(open("dataset/strategyqa_train.json").read()))
# dataset_triviaqa = pd.DataFrame(json.loads(open("triviaqa-unfiltered/unfiltered-web-dev.json").read())["Data"])
# dataset_ambigqa = pd.DataFrame(json.loads(open("/content/dataset/dev_light.json").read()))
# dataset_squad = pd.DataFrame(squad_dataset["train"])
dataset_medquad = pd.DataFrame(medquad_dataset["train"])
dataset_medquad.rename(columns={"Question": "question", "Answer": "answer"}, inplace=True)
# with open("dev.jsonl") as f:
#  dataset_boolq = pd.json_normalize(map(lambda x: json.loads(x), f.readlines()))

In [None]:
dataset = dataset_medquad
dataset_name = "medquad"

In [None]:
n_sample = 500
sampled_df = dataset.sample(n_sample, random_state=40).reset_index(drop=True)

In [None]:
def process_row(row: pd.Series):
    question = row["question"]

    response = chain.invoke({"query": question})
    if response.full_form_abbrv_map:
        return response.json()

In [None]:
with ThreadPoolExecutor(max_workers=3) as executor:
    sampled_df["possible_ambiguities"] = tqdm(
        executor.map(lambda x: process_row(x[1]), sampled_df.iterrows()), total=n_sample
    )

In [None]:
len(sampled_df[~sampled_df["possible_ambiguities"].isna()])

In [None]:
df = sampled_df.dropna(axis=0)

In [None]:
df.to_csv(f"{dataset_name}_ambiguous.csv")

#### Here, you can download the previous file and label it with using the valid columns. (verify the possible ambiguities column)

In [None]:
# or you can directly continue
df.loc[:, "valid"] = 1

In [None]:
df = pd.read_csv(f"{dataset_name}_ambiguous.csv", index_col=0)
df = df[df.valid == 1].reset_index()

In [None]:
for i in tqdm(range(len(df))):
    question = df.loc[i, "question"]
    ambiguities = json.loads(df.loc[i, "possible_ambiguities"])
    ambiguities = QueryAmbiguation(**ambiguities)

    unambiguous_question, ambiguous_question = question, question

    for amb in ambiguities.full_form_abbrv_map:
        if amb.ambiguity_type == "abbreviation":
            assert amb.abbreviation in question, question
            unambiguous_question = unambiguous_question.replace(amb.abbreviation, amb.full_form)
            ambiguous_question = ambiguous_question
        elif amb.ambiguity_type == "full_form":
            unambiguous_question = unambiguous_question
            assert amb.full_form in question
            ambiguous_question = ambiguous_question.replace(amb.full_form, amb.abbreviation)

    df.loc[i, "ambiguous_question"] = ambiguous_question
    df.loc[i, "unambiguous_question"] = unambiguous_question

In [None]:
df.to_csv(f"{dataset_name}_ambiguous.csv")

## Intent Extraction via LLM


In [None]:
from models import IntentExtraction

output_parser = PydanticOutputParser(pydantic_object=IntentExtraction)

In [None]:
sys_message = """Extract the intent and requirements from given query as strings. It should help a person who is aiming to answer that question.
The requirements should define the output and extent of the answer and intent should define the actual reason behind the question.
Queries may contain an ambiguous abbreviation, for them, abbreviation and possible disambiguations will be provided. Your task is not to select from them but to provide intent details.
Do not assume and output any full-form in the intent and requirements.

Domain of the query is {domain}.

{format_instructions}"""

user_message = """Query:{query}
Abbreviation:{abbrv}
Possible Disambiguations:{disambs}
Output:"""

In [None]:
messages = [
    SystemMessagePromptTemplate(
        prompt=PromptTemplate(
            template=sys_message,
            input_variables=[],
            partial_variables={"format_instructions": output_parser.get_format_instructions(), "domain": domain},
        )
    ),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            template=user_message,
            input_variables=["query", "abbrv", "disambs"],
            partial_variables={},
        )
    ),
]
prompt = ChatPromptTemplate.from_messages(messages=messages)

In [None]:
chain = prompt | llm | output_parser

In [None]:
df = pd.read_csv("medquad_ambiguous_with_top10_merged.csv", index_col=0)
n_queries = len(df)

In [None]:
for i in tqdm(range(n_queries)):
    query = df.loc[i, "ambiguous_question"]
    ambiguities = json.loads(df.loc[i, "possible_ambiguities"])
    ambiguities = QueryAmbiguation(**ambiguities)

    # focus on only the first ambiguity
    amb = ambiguities.full_form_abbrv_map[0]
    disambs = ""
    if not pd.isna(df.loc[i, "top_10_full_form"]):
        full_forms = df.loc[i, "top_10_full_form"].split("<-->")[0]
    else:
        full_forms = df.loc[i, "llm_full_form_suggestion"]
    disambs = "".join([f"{i} - {full_form}\n" for i, full_form in enumerate(full_forms.split("<->"))])

    answer = chain.invoke({"query": query, "abbrv": amb.abbreviation, "disambs": disambs})

    df.loc[i, "intent"] = answer.intent
    df.loc[i, "requirements"] = str(answer.requirements)

100%|██████████| 123/123 [06:52<00:00,  3.35s/it]


In [None]:
df.to_csv("medquad_ambiguous_with_intent.csv")