In [56]:
import pandas as pd
import random
import csv
import ast
import os
from typing import List

In [57]:
### helper functions
DOCUMENTS = pd.read_csv("DRAGONball/en/docs.csv")
QUERIES = pd.read_csv(
    "DRAGONball/en/queries_flattened.csv",
    converters={
        "ground_truth.doc_ids": ast.literal_eval,
        "ground_truth.keypoints": ast.literal_eval,
        "ground_truth.references": ast.literal_eval,
    },
)


def get_query_ids_for_doc(doc_id: int, query_types: list[str] = ["Factual Question"]) -> list[int]:
    """Selects query_ids for queries related to that doc and with a specified type."""
    return QUERIES[
        QUERIES["ground_truth.doc_ids"].apply(lambda doc_ids: int(doc_id) in doc_ids)
        & QUERIES["query.query_type"].isin(query_types)
    ]["query.query_id"].to_list()

In [58]:
FILENAME = "docs_to_manipulate.csv"

In [59]:
### Select docs and queries to manipulate

NUM_SINGLE_TEXTUAL = 30
NUM_SINGLE_TABULAR = 30  # must be less or equal to NUM_SINGLE_TEXTUAL
NUM_MULTI_TEXTUAL = 30

FIELDNAMES = [
    "doc_id",
    "single_textual_manipulation",
    "single_tabular_manipulation",
    "multi_textual_manipulation",
]

ids_by_domain = {}
for row in DOCUMENTS.itertuples(index=False):
    new_list = ids_by_domain.get(row.domain, [])
    new_list.append(row.doc_id)
    ids_by_domain[row.domain] = new_list

ids = {"single_textual": [], "single_tabular": [], "multi_textual": [], "all": set()}

try:
    with open(FILENAME, "r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f)
        rows = [row for row in reader]
        ids["single_textual"] = [
            int(row["doc_id"]) for row in rows if row["single_textual_manipulation"] != "0"
        ]
        ids["single_tabular"] = [
            int(row["doc_id"]) for row in rows if row["single_tabular_manipulation"] != "0"
        ]
        ids["multi_textual"] = [int(row["doc_id"]) for row in rows if row["multi_textual_manipulation"] != "0"]
        ids["all"] = set(ids["single_textual"]) | set(ids["single_tabular"]) | set(ids["multi_textual"])
except FileNotFoundError as e:
    print(f"File '{e.filename}' not found. Will be created.")
    pass

skip_sampling = False
if (len(ids["single_textual"]) >= NUM_SINGLE_TEXTUAL) and (len(ids["single_tabular"]) >= NUM_SINGLE_TABULAR) and (len(ids["multi_textual"]) >= NUM_MULTI_TEXTUAL):
    print("No need to sample new IDs")
    skip_sampling = True

def sample_new_doc_ids(num_ids: int, existing_ids: List[int], sample_for_tabular: bool = False) -> List[int]:
    if len(existing_ids) < num_ids:
        k = num_ids - len(existing_ids)
        domains = {0: "Finance", 1: "Medical", 2: "Law"}
        while k > 0:
            i = k % 3
            candidates = set(ids_by_domain[domains[i]]) - ids["all"]
            if sample_for_tabular:
                candidates = set(ids["single_textual"]) - set(ids["single_tabular"])
            if len(candidates) == 0:
                error_message = f"No candidate doc_ids left for domain {domains[i]}."
                raise RuntimeError(error_message)
            choice = random.sample(sorted(candidates), 1)[0]
            existing_ids.append(choice)
            ids["all"].add(choice)
            k = num_ids - len(existing_ids)
    return existing_ids

if skip_sampling == False:
    ids["single_textual"] = sample_new_doc_ids(num_ids=NUM_SINGLE_TEXTUAL, existing_ids=ids["single_textual"])
    ids["single_tabular"] = sample_new_doc_ids(
        num_ids=NUM_SINGLE_TABULAR, existing_ids=ids["single_tabular"], sample_for_tabular=True
    )
    ids["multi_textual"] = sample_new_doc_ids(num_ids=NUM_MULTI_TEXTUAL, existing_ids=ids["multi_textual"])

    for key in list(ids):
        print(f"{key}: {len(ids[key])}")

    print(f"Available: {len(set(DOCUMENTS["doc_id"].to_list()) - ids["all"])}")

    with open(FILENAME, "w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=FIELDNAMES, extrasaction="ignore")
        writer.writeheader()
        for id in ids["all"]:
            row = {
                "doc_id": id,
                "single_textual_manipulation": 1 if id in ids["single_textual"] else 0,
                "single_tabular_manipulation": 1 if id in ids["single_tabular"] else 0,
                "multi_textual_manipulation": 1 if id in ids["multi_textual"] else 0,
            }
            writer.writerow(row)
    

No need to sample new IDs


In [60]:
### Create doc-query mapping forsingle_textual & single_tabular

doc_ids_to_manipulate = []
with open(FILENAME, "r", encoding="utf-8", newline="") as f:
    reader = csv.DictReader(f)
    doc_ids_to_manipulate = [row for row in reader if int(row["single_textual_manipulation"]) == 1 or int(row["single_tabular_manipulation"]) == 1]

FIELDNAMES = ["doc_id", "query_id_textual_manipulation", "query_id_tabular_manipulation"]

# with open("doc_query_mapping.csv", "r", newline="") as f:
#     reader = csv.DictReader(f)
#     doc_queries_mapping = [row for row in reader]

mappings = []

for id_with_manipulation_type in doc_ids_to_manipulate:
    options = get_query_ids_for_doc(id_with_manipulation_type["doc_id"], query_types=["Factual Question"])
    try:
        query_ids = random.sample(options, 2)
    except ValueError:
        print(f"Only {len(options)} options found for doc_id '{id_with_manipulation_type["doc_id"]}'")

    doc_queries_mapping = {
        "doc_id": id_with_manipulation_type["doc_id"],
        "query_id_textual_manipulation": query_ids[0],
        "query_id_tabular_manipulation": query_ids[1]
    }
    mappings.append(doc_queries_mapping)

mappings.sort(key=lambda x: x["doc_id"])

try:
    with open("doc_query_mapping_single.csv", "x", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=FIELDNAMES)
        writer.writeheader()
        writer.writerows(mappings)
except FileExistsError as e:
    print(f"File '{e.filename}' exists. Did nothing")

File 'doc_query_mapping_single.csv' exists. Did nothing
