In [1]:
import os
import csv
import pandas as pd
from collections import defaultdict

In [2]:
# File and Directory Paths
qa_pair_file = "qa_pairs.csv"
qa_pair_filtered_file = "qa_pairs_LLM_filtered.csv"
qa_pair_output = "qa_pairs_final.csv"

# Defining column indexes
object_id_index_original = 0
object_type_index_original = 1
question_index_original = 2

question_index_filtered = 1
valid_value_index_filtered = -1

# Defining values
INVALID_VALUE = "Invalid"

In [18]:
def find_not_unique_questions(file_path):
    """
    Locating questions that are not unique, meaning they occur multiple times within the dataset.
    
    Args:
        file_path (str): The path to the QA-pair csv file.
        
    Returns:
        set[str]: A set of all non-unique questions.
    """
    df = pd.read_csv(file_path, delimiter=";", quotechar="|", header=None)
    value_counts = df.iloc[:, question_index_filtered].value_counts()
    not_unique_questions = value_counts[value_counts >= 2]
    
    print(f"Number of not unique questions: {len(not_unique_questions)}")
    print(f"Number of affected pairs: {not_unique_questions.sum()}")
    print("\nExamples for not unique questions:")
    print(not_unique_questions[:10])
    
    question_set = set()
    for question in not_unique_questions.index:
        question_set.add(question)
    return question_set

def remove_not_unique_questions(file_path):
    """
    This function removes the first occurrence of a not unique question from the dataset.
    In this stage, this is needed when filter_questions_LLM.py was executed multiple times with wrong indexing.
    
    Args:
        file_path (str): The path to the QA-pair csv file.
    """
    not_unique_questions = find_not_unique_questions(file_path)
    counter_removed = 0
    
    with open(file_path, "r", encoding="utf-8") as input_file:
        csv_reader = csv.reader(input_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        
        with open("tmp.csv", "w", newline="", encoding="utf-8") as output_file:
            csv_writer = csv.writer(output_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            
            for row in csv_reader:
                question = row[question_index_filtered]
                if question in not_unique_questions:
                    not_unique_questions.remove(question)
                    counter_removed += 1
                else:
                    csv_writer.writerow(row)
    
    # Replace old csv file with new csv file
    os.replace("tmp.csv", file_path)
    print(f"{counter_removed} rows have been deleted due to multiple occurrences.")
    
remove_not_unique_questions(qa_pair_filtered_file)

Number of not unique questions: 630
Number of affected pairs: 1260

Examples for not unique questions:
According to the IVQR estimates, at which quantiles are the effects of Chinese imports on wages most significant?                                                                                       2
According to the simulations, in which scenario were TTW estimates more accurate—when using a single instrument or multiple instruments for inference?                                                  2
In which scenario does the model trained on CDJUR-BR achieve a higher macro F1 score when tested against LENER-BR compared to the reverse scenario?                                                     2
According to the table, what abbreviation is used for the setting where strong augmentations ($\tau_s$) and weak augmentations ($\tau_t$) are applied along with shared student and teacher weights?    2
What is the average cosine similarity (ACS) reported for each diagnosis in the table comp

In [23]:
def count_valid_invalid_pairs(file_path):
    """
    This function prints the count of valid and invalid QA-pairs.
    
    Args:
        file_path (str): The path to the QA-pair csv file filtered by an LLM.
    """
    df = pd.read_csv(file_path, delimiter=";", quotechar="|", header=None)
    value_counts = df.iloc[:, valid_value_index_filtered].value_counts()
    
    print(value_counts)
    
count_valid_invalid_pairs(qa_pair_filtered_file)

Valid      87450
Invalid    12565
Name: 3, dtype: int64


In [37]:
def get_invalid_pairs(file_path):
    """
    This function returns all questions that were labelled 'invalid' by the LLM.
    
    Args:
        file_path (str): The path to the QA-pair csv file filtered by an LLM.
        
    Returns:
        set[str]: A set of all questions that were labelled 'Invalid' by the LLM.
    """
    # Filter dataframe
    df = pd.read_csv(file_path, delimiter=";", quotechar="|", header=None)
    df = df[df.iloc[:, valid_value_index_filtered] == INVALID_VALUE]
    df = df.iloc[:, question_index_filtered]
    
    # Create set
    question_set = set()
    for question in df:
        question_set.add(question)
        
    return question_set

def remove_invalid_pairs(original_file, output_file, invalid_pairs):
    """
    This function creates a new csv file containing only questions labelled as 'Valid' by the LLM.
    
    Args:
        original_file (str): The path to the original QA-pair csv file without any LLM-filtering.
        output_file (str): The path to where the valid QA-pairs shall be stored.
        invalid_pairs (set[str]): A set of all invalid questions.
    """
    overall_dict = defaultdict(int)
    removed_dict = defaultdict(int)
    
    # Iterate through original file
    with open(original_file, "r", encoding="utf-8") as input_file:
        csv_reader = csv.reader(input_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        
        with open(output_file, "w", newline="", encoding="utf-8") as new_file:
            csv_writer = csv.writer(new_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            
            for row in csv_reader:
                overall_dict[row[object_type_index_original]] += 1
                if row[question_index_original] in invalid_pairs:
                    removed_dict[row[object_type_index_original]] += 1
                else:
                    csv_writer.writerow(row)
                    
    # Print results
    print("Number of deleted QA-pairs per category:")
    for key in removed_dict:
        print(f"{key}: {removed_dict[key]} / {overall_dict[key]}")
    
    
invalid_pairs = get_invalid_pairs(qa_pair_filtered_file)
remove_invalid_pairs(qa_pair_file, qa_pair_output, invalid_pairs)

Number of deleted QA-pairs per category:
Table: 3640 / 26823
Figure: 5846 / 51186
Table_02: 3079 / 22006


In [6]:
FIXED_FIGURE_SAMPLE_SIZE = 1100
FIXED_TABLE_SAMPLE_SIZE = 550

def perform_train_test_split(input_file, train_split_file, test_split_file, figure_sample_size, table_sample_size):
    """
    Performing the train-test split.
    
    Args:
        input_file (str): Path to the csv file containing the QA-pairs.
        train_split_file (str): Path to the file in which the training split shall be stored.
        test_split_file (str): Path to the file in which the test split shall be stored.
        figure_sample_size (int): Number of figures that shall be present in the test split.
        table_sample_size (int): Number of tables (per sub-category) that shall be present in the test split.
    """
    
    # Loading csv file
    df = pd.read_csv(input_file, delimiter=";", quotechar="|", header=None)
    df[object_type_index_original] = df[object_type_index_original].astype(str)
    
    # Test split per label
    test_label1 = df[df[object_type_index_original] == "Figure"].sample(n=figure_sample_size, random_state=42)
    test_label2 = df[df[object_type_index_original] == "Table"].sample(n=table_sample_size, random_state=42)
    test_label3 = df[df[object_type_index_original] == "Table_02"].sample(n=table_sample_size, random_state=42)
    
    # Joining up test split
    test_split = pd.concat([test_label1, test_label2, test_label3])
    
    # Move the rest to train split. Remove any rows that contain an object that also occurs in the test split.
    train_split = df.drop(test_split.index)
    train_split = train_split[~train_split[object_id_index_original].isin(test_split[object_id_index_original])]
    
    # Save splits to csv files
    train_split.to_csv(train_split_file, index=False, header=False, sep=';', quotechar='|')
    test_split.to_csv(test_split_file, index=False, header=False, sep=';', quotechar='|')
    
    # Confirmation print
    print("Train and test splits were successfully constructed.")
    
perform_train_test_split(qa_pair_output, "train_split.csv", "test_split.csv", FIXED_FIGURE_SAMPLE_SIZE, FIXED_TABLE_SAMPLE_SIZE)

Train and test splits were successfully constructed.
