In [8]:
import pandas as pd
from importlib import reload
import ast

from dotenv import load_dotenv

load_dotenv("/Users/leon/.env")

True

In [3]:
#### static variables

COLUMNS_DOCS = [
    "doc_id",
    "language",
    "domain",
    "content",
    "company_name",
    "court_name",
    "hospital_patient_name",
]

COLUMNS_DOCS_MANIPULATED_TEXTUAL = [
    *COLUMNS_DOCS,
    "original_doc_id",
]

COLUMNS_DOCS_MANIPULATED_TABULAR_ROW = [
    *COLUMNS_DOCS,
    "original_doc_id",
    "query.original_query_id",
    "ground_truth.content",
]

COLUMNS_DOCS_MANIPULATED_TABULAR = [
    "doc_id",
    "language",
    "domain",
    "content",
    "company_names",
    "court_names",
    "hospital_patient_names",
    "original_doc_ids",
]

COLUMNS_QUERIES = [
    "domain",
    "ground_truth.content",
    "ground_truth.doc_ids",
    "ground_truth.keypoints",
    "ground_truth.references",
    "language",
    "prediction",
    "query.content",
    "query.query_id",
    "query.query_type",
]

COLUMNS_QUERIES_MANIPULATED = [*COLUMNS_QUERIES, "query.original_query_id"]

In [14]:
### helper functions/classes for manipulation
from typing import Tuple, List, Dict
from pydantic import BaseModel
import pandas as pd
import ast
import csv
import os
from typing import Literal
import sys
import json
from utils import llm


reload(llm)

DOCUMENTS = pd.read_csv("DRAGONball/en/docs.csv")
QUERIES = pd.read_csv(
    "DRAGONball/en/queries_flattened.csv",
    converters={
        "ground_truth.doc_ids": ast.literal_eval,
        "ground_truth.keypoints": ast.literal_eval,
        "ground_truth.references": ast.literal_eval,
    },
)


def get_query_ids_for_doc(doc_id: int | str, query_types: list[str] = ["Factual Question"]) -> list[int]:
    """Selects query_ids for queries related to that doc and with a specified type."""
    return QUERIES[
        QUERIES["ground_truth.doc_ids"].apply(lambda doc_ids: int(doc_id) in doc_ids)
        & QUERIES["query.query_type"].isin(query_types)
    ]["query.query_id"].to_list()


def get_doc_query_mapping_single(target: Literal["textual", "tabular"]) -> List[Dict[str, int]]:
    with open("doc_query_mapping_single.csv", "r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f)
        return [
            {
                "doc_id": int(row["doc_id"]),
                "query_id": int(row["query_id_single"]) if target == "textual" else int(row["query_id_multi"]),
            }
            for row in reader
        ]


def get_doc_query_mapping_multi() -> List[Dict]:
    with open("docs_to_manipulate.csv", "r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f)
        doc_ids = [row["doc_id"] for row in reader if int(row["multi_textual_manipulation"]) == 1]
    mapping = []
    for id in doc_ids:
        entry = {"doc_id": id, "query_ids": get_query_ids_for_doc(id, query_types=["Factual Question"])}
        mapping.append(entry)
    return mapping


def get_query_properties(
    query_id,
    properties: list = ["ground_truth.content", "ground_truth.keypoints", "ground_truth.references", "query.content"],
) -> Tuple:
    """Select columns for query_id from queries dataframe."""
    row = QUERIES[QUERIES["query.query_id"] == query_id]
    return tuple(row[prop].iloc[0] for prop in properties)


def get_doc_properties(
    doc_id,
    properties,
) -> Tuple:
    """Select columns for doc_id from docs dataframe."""
    row: pd.DataFrame = DOCUMENTS[DOCUMENTS["doc_id"] == doc_id].dropna(axis=1)
    return tuple(row[prop].iloc[0] for prop in properties if prop in row.columns)


def get_doc_text(doc_id: int | str) -> str:
    return DOCUMENTS[DOCUMENTS["doc_id"] == int(doc_id)]["content"].iloc[0]


def get_query_by_id(query_id: int) -> pd.Series:
    return QUERIES[QUERIES["query.query_id"].astype(int) == query_id].iloc[0]


def get_queries_by_id(query_ids: List[int]) -> pd.DataFrame:
    return QUERIES[QUERIES["query.query_id"].astype(int).isin(query_ids)]


def get_doc_by_id(doc_id: int | str) -> pd.Series:
    return DOCUMENTS[DOCUMENTS["doc_id"].astype(int) == int(doc_id)].iloc[0]


def get_entity_by_doc_id(doc_id: int) -> str:
    doc = get_doc_by_id(doc_id)
    if isinstance(doc["hospital_patient_name"], str):
        return doc["hospital_patient_name"]
    if isinstance(doc["court_name"], str):
        return doc["court_name"]
    if isinstance(doc["company_name"], str):
        return doc["company_name"]


def get_prompts_for_textual_single(doc_id: int, query_id: int) -> Tuple:
    PROMPT_TYPE = "manipulation_factual"
    prompt = [
        prompt for prompt in read_prompt("prompts/json/manipulate_docs.json") if prompt["prompt_type"] == PROMPT_TYPE
    ][0]
    text = get_doc_text(doc_id)
    answer, keypoints, references, question = get_query_properties(query_id)
    system_prompt = prompt["system_prompt"]
    user_prompt = format_user_prompt_single_textual(
        user_prompt=prompt["user_prompt"], text=text, answer=answer, question=question, references=references
    )
    return (system_prompt, user_prompt)


def get_prompts_for_tabular_single(query_id: int, doc_id: int) -> Tuple:
    prompt_obj = read_prompt("prompts/json/manipulation_tabular.json")
    system_prompt = prompt_obj["system_prompt"]
    user_prompt = prompt_obj["user_prompt"]
    answer, question = get_query_properties(query_id, properties=["ground_truth.content", "query.content"])
    (entity,) = get_doc_properties(doc_id, ["hospital_patient_name", "company_name", "court_name"])
    user_prompt = format_user_prompt_single_tabular(user_prompt, question, answer, entity)
    return (system_prompt, user_prompt)


def get_prompts_for_textual_multi(doc_id: int, query_ids: List[int]) -> Tuple:
    prompt_obj = read_prompt("../prompts/json/manipulation_multi_textual.json")
    text = get_doc_text(doc_id)
    qa_pairs = []
    for query_id in query_ids:
        question, answer = get_query_properties(query_id, ["query.content", "ground_truth.content"])
        qa_pairs.append({"question": question, "answer": answer})

    system_prompt = prompt_obj["system_prompt"]
    user_prompt = format_user_prompt_multi_textual(user_prompt=prompt_obj["user_prompt"], text=text, qa_pairs=qa_pairs)

    return (system_prompt, user_prompt)


def get_prompts_for_textual_multi_v02(doc_id: int, query_ids: List[int]) -> Tuple:
    prompt_obj = read_prompt("../prompts/json/manipulations/manipulation_multi_textual_v02.json")
    text = get_doc_text(doc_id)
    qa_pairs = []
    for query_id in query_ids:
        question, answer = get_query_properties(query_id, ["query.content", "ground_truth.content"])
        qa_pairs.append({"question": question, "answer": answer})

    system_prompt = prompt_obj["system_prompt"]
    user_prompt = format_user_prompt_multi_textual(user_prompt=prompt_obj["user_prompt"], text=text, qa_pairs=qa_pairs)

    return (system_prompt, user_prompt)


def get_id_for_manipulated_doc_or_query(original_doc_id: int, prefix_number=1) -> int:
    id_str = str(prefix_number) + str(original_doc_id).zfill(5)
    return int(id_str)


def save_manipulated_doc(filename: os.PathLike | str, fieldnames: List[str], **kwargs):
    """Saves a manipulated doc to csv.
    If an entry with that doc_id already exists in the csv, the new entry is NOT saved.
    """
    if filename is None:
        raise RuntimeError("Must specify a filename!")

    print(f"Saving Doc with ID '{kwargs["doc_id"]}'")
    is_empty = not os.path.exists(filename) or os.stat(filename).st_size == 0
    id_exists = False
    with open(filename, "a+", newline="") as f:
        if not is_empty:
            f.seek(0)
            reader = csv.DictReader(f, fieldnames=fieldnames)
            ids_present = {int(row["doc_id"]) for row in list(reader)[1:]}
            id_exists = int(kwargs["doc_id"]) in ids_present

        if id_exists == True:
            print(f"WARN: Row with ID {kwargs["doc_id"]} already exists. Did not write new document to '{filename}'.")
            return

        writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
        if is_empty:
            writer.writeheader()
        writer.writerow(kwargs)


def save_manipulated_query(filename: os.PathLike | str, fieldnames: List[str], **kwargs):
    """Saves a manipulated query to csv.
    If an entry with that doc_id already exists in the csv, the new entry is NOT saved.
    """
    if filename is None:
        raise RuntimeError("Must specify a filename!")

    print(f"Saving Query with ID '{kwargs["query.query_id"]}'")
    is_empty = not os.path.exists(filename) or os.stat(filename).st_size == 0
    id_exists = False
    with open(filename, "a+", newline="") as f:
        if not is_empty:
            f.seek(0)
            reader = csv.DictReader(f, fieldnames=fieldnames)
            ids_present = {int(row["query.query_id"]) for row in list(reader)[1:]}
            id_exists = int(kwargs["query.query_id"]) in ids_present

        if id_exists == True:
            print(
                f"WARN: Row with ID {kwargs["query.query_id"]} already exists. Did not write new query to '{filename}'."
            )
            return

        writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
        if is_empty:
            writer.writeheader()
        writer.writerow(kwargs)

In [12]:
### describe documents and queries

print("Documents:\n", DOCUMENTS.columns)
print()
print("Queries:\n", QUERIES.columns)

Documents:
 Index(['hospital_patient_name', 'language', 'doc_id', 'domain', 'content',
       'company_name', 'court_name'],
      dtype='object')

Queries:
 Index(['domain', 'ground_truth.content', 'ground_truth.doc_ids',
       'ground_truth.keypoints', 'ground_truth.references', 'language',
       'prediction', 'query.content', 'query.query_id', 'query.query_type'],
      dtype='object')


## Single Textual Manipulation


In [None]:
### manipulate documents (textual)
def save_single_textual(doc_entry, query_entry, completion_parsed):
    # -- doc
    original_doc_id = doc_entry.doc_id
    manipulated_doc_entry = doc_entry.copy()
    manipulated_doc_entry.doc_id = get_id_for_manipulated_doc_or_query(original_doc_id)
    manipulated_doc_entry.content = completion_parsed.text_new
    manipulated_doc_entry = pd.concat([manipulated_doc_entry, pd.Series([original_doc_id], index=["original_doc_id"])])

    save_manipulated_doc(**manipulated_doc_entry)

    # -- query
    original_query_id = query_entry["query.query_id"]
    manipulated_query_entry = query_entry.copy()
    manipulated_query_entry["ground_truth.content"] = completion_parsed.answer_new
    manipulated_query_entry["ground_truth.references"] = completion_parsed.references_new
    manipulated_query_entry["ground_truth.keypoints"] = []
    manipulated_query_entry["query.query_id"] = get_id_for_manipulated_doc_or_query(original_query_id)
    manipulated_query_entry = pd.concat(
        [manipulated_query_entry, pd.Series([original_query_id], index=["query.original_query_id"])]
    )
    manipulated_query_entry.index = manipulated_query_entry.index.str.replace(".", "__", regex=False)

    save_manipulated_query(**manipulated_query_entry)


doc_query_mapping = get_doc_query_mapping_single()

for mapping in doc_query_mapping:
    doc_id = mapping["doc_id"]
    query_id = mapping["query_id"]

    doc_entry = get_doc_by_id(doc_id)
    query_entry = get_query_by_id(query_id)

    system_prompt, user_prompt = get_prompts_for_textual_single(doc_id, query_id)

    # call openai
    completion_parsed = openai_interface(system_prompt, user_prompt)

    save_single_textual(doc_entry, query_entry, completion_parsed)
    print(f"Finished processing doc {doc_id} and query {query_id}.")

## Single Tabular Manipulation


In [None]:
### manipulate documents (tabular) and save rows
mapping = get_doc_query_mapping_single("tabular")

doc_entries = []

for id_pair in mapping:
    doc_id = id_pair["doc_id"]
    query_id = id_pair["query_id"]
    system_prompt, user_prompt = get_prompts_for_tabular_single(query_id, doc_id)

    response: SingleTabularManipulationResponse = openai_interface(
        system_prompt, user_prompt, SingleTabularManipulationResponse
    )

    manipulated_doc_entry = get_doc_by_id(doc_id).copy()
    manipulated_doc_entry.doc_id = get_id_for_manipulated_doc_or_query(doc_id, prefix_number=2)
    manipulated_doc_entry.content = " | ".join([response.description, response.value])
    additional_fields = pd.Series(
        [doc_id, query_id, response.answer_new],
        index=["original_doc_id", "query.original_query_id", "ground_truth.content"],
    )
    manipulated_doc_entry = pd.concat([manipulated_doc_entry, additional_fields])

    save_manipulated_doc(
        filename="additional_data/docs/tabular_manipulations_result_rows.csv",
        fieldnames=COLUMNS_DOCS_MANIPULATED_TABULAR_ROW,
        **manipulated_doc_entry,
    )

In [None]:
### aggregate tabular rows and save docs
tabular_docs = pd.read_csv("additional_data/docs/tabular_manipulations_result_rows.csv")


def list_or_none(series):
    if series.dropna().empty:
        return None
    return series.tolist()


agg_funcs = {
    "doc_id": lambda x: 0,
    "language": "first",
    "content": "\n".join,
    "company_name": list_or_none,
    "court_name": list_or_none,
    "hospital_patient_name": list_or_none,
    "original_doc_id": list_or_none,
    "query.original_query_id": lambda x: None,
    "ground_truth.content": lambda x: None,
}

aggregation = tabular_docs.groupby("domain").agg(agg_funcs).reset_index()
aggregation["doc_id"] = [get_id_for_manipulated_doc_or_query(id, prefix_number=3) for id in [1, 2, 3]]
aggregation.rename(
    columns={
        "company_name": "company_names",
        "court_name": "court_names",
        "hospital_patient_name": "hospital_patient_names",
        "original_doc_id": "original_doc_ids",
    },
    inplace=True,
)

for row_dict in aggregation.to_dict(orient="records"):
    save_manipulated_doc(
        "additional_data/docs/tabular_manipulations_result.csv",
        fieldnames=COLUMNS_DOCS_MANIPULATED_TABULAR,
        **row_dict
    )

In [None]:
### save manipulated queries for aggregated tabular docs
tabular_docs = pd.read_csv("additional_data/docs/tabular_manipulations_result_rows.csv")

mapping_domain_doc_id = pd.read_csv(
    "additional_data/docs/tabular_manipulations_result.csv", usecols=["domain", "doc_id"]
)
mapping_dict = mapping_domain_doc_id.set_index("domain")["doc_id"].to_dict()

for row_dict in tabular_docs.to_dict(orient="records"):
    original_query_id = row_dict["query.original_query_id"]
    manipulated_query_entry = get_query_by_id(original_query_id).copy()

    manipulated_query_entry["ground_truth.doc_ids"] = [mapping_dict[row_dict["domain"]]]
    manipulated_query_entry["ground_truth.content"] = row_dict["ground_truth.content"]
    manipulated_query_entry["ground_truth.references"] = row_dict["content"]
    manipulated_query_entry["ground_truth.keypoints"] = []
    manipulated_query_entry["query.query_id"] = get_id_for_manipulated_doc_or_query(original_query_id, prefix_number=3)
    manipulated_query_entry = pd.concat(
        [manipulated_query_entry, pd.Series([original_query_id], index=["query.original_query_id"])]
    )

    save_manipulated_query(
        filename="additional_data/queries/tabular_manipulations_result.csv",
        fieldnames=COLUMNS_QUERIES_MANIPULATED,
        **manipulated_query_entry
    )

## Multi Textual Manipulation


In [83]:
### helper methods
from utils import data_helpers

reload(data_helpers)


def make_manipulated_query(query, qa_pairs):
    manipulated_query = query.copy()
    qa_pair = next(
        pair for pair in qa_pairs if pair.question.lower().strip() == query["query.content"].lower().strip()
    )
    manipulated_query.update(
        {
            "query.query_id": data_helpers.get_id_for_manipulated_doc_or_query(query["query.query_id"], 4),
            "ground_truth.content": qa_pair.answer,
            "ground_truth.references": [qa_pair.quote],
        }
    )
    manipulated_query["query.original_query_id"] = query["query.query_id"]
    return manipulated_query


def make_manipulated_doc(doc: pd.Series, text_new: str):
    manipulated_doc = doc.copy()
    manipulated_doc.update(
        {
            "doc_id": data_helpers.get_id_for_manipulated_doc_or_query(doc["doc_id"], 4),
            "content": text_new,
        }
    )
    manipulated_doc["original_doc_ids"] = [doc["doc_id"]]
    return manipulated_doc

In [94]:
### Test multi manipulation prompts
from utils import io_helpers, data_helpers

FILENAME = "multi_textual_manipulations"

reload(io_helpers)
reload(data_helpers)
reload(llm)

documents = io_helpers.get_documents()
queries = io_helpers.get_queries()


doc_ids = io_helpers.get_documents_to_manipulate("multi_textual_manipulation")

for doc_id in doc_ids:
    if io_helpers.file_has_been_manipulated(
        "doc", FILENAME, data_helpers.get_id_for_manipulated_doc_or_query(doc_id, 4)
    ):
        print(f"Document with ID '{doc_id}' has been manipulated and saved before. Skipping this one.")
        continue
    doc: pd.Series = get_doc_by_id(doc_id)
    entity = get_entity_by_doc_id(doc_id)
    text = doc["content"]

    related_queries: pd.DataFrame = data_helpers.get_queries_by_doc_id(doc_id, queries)
    qa_pairs = data_helpers.make_qa_pairs(related_queries)

    system_prompt, user_prompt = io_helpers.get_prompt("manipulations/manipulation_multi_textual_v03")
    user_prompt = llm.format_user_prompt_multi_textual_v02(user_prompt, entity, text, qa_pairs)

    llm_response: llm.MultiTextualManipulationResponseV02 = llm.call_openai(
        system_prompt,
        user_prompt,
        model="gpt-4o",
        response_format_pydantic=llm.MultiTextualManipulationResponseV02,
        temperature=0.8,
    )

    manipulated_doc: pd.Series = make_manipulated_doc(doc, llm_response.text_new)
    manipulated_queries: pd.DataFrame = related_queries.apply(
        make_manipulated_query, args=(llm_response.qa_pairs_new,), axis=1
    )

    io_helpers.easy_save_manipulated_doc(FILENAME, manipulated_doc)
    for _, manipulated_query in manipulated_queries.iterrows():
        io_helpers.easy_save_manipulated_query(FILENAME, manipulated_query)

Document with ID '128' has been manipulated and saved before. Skipping this one.
Document with ID '132' has been manipulated and saved before. Skipping this one.
Saved document to 'additional_data/docs/multi_textual_manipulations.csv'
Saved Query to 'additional_data/queries/multi_textual_manipulations.csv'
Saved Query to 'additional_data/queries/multi_textual_manipulations.csv'
Saved Query to 'additional_data/queries/multi_textual_manipulations.csv'
Saved Query to 'additional_data/queries/multi_textual_manipulations.csv'
Saved Query to 'additional_data/queries/multi_textual_manipulations.csv'
Saved document to 'additional_data/docs/multi_textual_manipulations.csv'
Saved Query to 'additional_data/queries/multi_textual_manipulations.csv'
Saved Query to 'additional_data/queries/multi_textual_manipulations.csv'
Saved Query to 'additional_data/queries/multi_textual_manipulations.csv'
Saved Query to 'additional_data/queries/multi_textual_manipulations.csv'
Saved document to 'additional_data/