In [4]:
import os
import time
import pandas as pd
import yaml
from tqdm.notebook import tqdm  
from HyDE import HyDEConfig, Promptor, HyDE

def process_all_domains(input_parquet, output_folder):
    # Set the main_folder path for each domain
    domain_to_mainfolder = {
        "mix":          "/Users/chengze/Desktop/mix",
        "physics":      "/Users/chengze/Desktop/physics",
        "agriculture":  "/Users/chengze/Desktop/agriculture",
        "cs":           "/Users/chengze/Desktop/cs"
    }
    domains = ["mix", "physics", "agriculture", "cs"]
    # Read the configuration file (Hconfig.yaml)
    with open('Hconfig.yaml', 'r', encoding='utf-8') as f:
        base_config = yaml.safe_load(f)

    df_all = pd.read_parquet(input_parquet_path)
    # Store the DataFrame for each domain after processing
    domain_dfs = {}
    for domain in domains:
        print(f"\n=== Processing domain: {domain} ===")
        config_domain = base_config.copy()
        config_domain['config']['main_folder'] = domain_to_mainfolder[domain]

        # Initialize HyDE
        hyde_config = HyDEConfig(config_domain)
        promptor = Promptor(hyde_config)
        hyde_obj = HyDE(hyde_config, promptor)
        df_domain = df_all[df_all['domain'] == domain].copy()
        # df_domain = df_domain.head(10)
        if 'processing time' not in df_domain.columns:
            df_domain['processing time'] = None
        if 'retrieval_document' not in df_domain.columns:
            df_domain['retrieval_document'] = None
        if 'response' not in df_domain.columns:
            df_domain['response'] = None

        for idx, row in tqdm(df_domain.iterrows(),
                             total=len(df_domain),
                             desc=f"Processing {domain}"):
            question = row['question_text']
            start_time = time.time()
            retrieval_document = hyde_obj.e2e_search(question)
            end_time = time.time()

            # If the retrieval result is a list, take the first one, otherwise pass an empty list
            if retrieval_document and isinstance(retrieval_document, list):
                best_hit = [retrieval_document[0]]
                response_data = hyde_obj.answer(best_hit, question)
            else:
                response_data = hyde_obj.answer([], question)

            final_answer = response_data

            df_domain.at[idx, 'processing time'] = end_time - start_time
            df_domain.at[idx, 'retrieval_document'] = str(retrieval_document)
            df_domain.at[idx, 'response'] = final_answer

        os.makedirs(output_folder, exist_ok=True)
        parquet_path = os.path.join(output_folder, f"domain_{domain}.parquet")
        df_domain.to_parquet(parquet_path, index=False)
        print(f"[{domain}] => Parquet saved: {parquet_path}")
        domain_dfs[domain] = df_domain

    df_merged = pd.concat(domain_dfs.values(), ignore_index=True)
    merged_parquet = os.path.join(output_folder, "domain_all.parquet")
    df_merged.to_parquet(merged_parquet, index=False)
    print(f"All 4 domains merged => Parquet saved: {merged_parquet}")
    return domain_dfs, df_merged

# Add a main() when run it in PyCharm or a script
if __name__ == "__main__":
    input_parquet_path = "/Users/chengze/Desktop/questions.parquet"
    output_folder_path = "/Users/chengze/Desktop/output_parquets" 

    domain_dfs, df_all_merged = process_all_domains(
        input_parquet=input_parquet_path,
        output_folder=output_folder_path
    )
    # domain_dfs is a dict, key = domain, value = DataFrame
    # df_all_merged is the result of merging four domains
    print("Done!")


=== Processing domain: mix ===


Processing mix:   0%|          | 0/125 [00:00<?, ?it/s]

In [2]:
import os
import time
import pandas as pd
import yaml
from tqdm.notebook import tqdm  
from HyDE import HyDEConfig, Promptor, HyDE

def process_hotpotqa_questions(
    input_parquet: str,
    hotpotqa_folder: str,
    output_parquet: str = None
):
    df = pd.read_parquet(input_parquet)
    print("[INFO] Loaded questions:", df.shape)
    with open('Hconfig.yaml', 'r', encoding='utf-8') as f:
        base_config = yaml.safe_load(f)

    base_config['config']['main_folder'] = hotpotqa_folder
    hyde_config = HyDEConfig(base_config)
    promptor = Promptor(hyde_config)
    hyde_obj = HyDE(hyde_config, promptor)

    # If the columns are not present, add them; if they are already present, overwrite them
    if 'retrieval_document' not in df.columns:
        df['retrieval_document'] = None
    if 'response' not in df.columns:
        df['response'] = None
    if 'processing_time' not in df.columns:
        df['processing_time'] = None


    for idx, row in tqdm(df.iterrows(), total=len(df), desc="HotpotQA"):
        question = row['question']
        start_time = time.time()
        retrieval_document = hyde_obj.e2e_search(question)
        end_time = time.time()

        # If the e2e_search returns a list, take the first one, otherwise pass an empty list
        if retrieval_document and isinstance(retrieval_document, list):
            best_hit = [retrieval_document[0]]
            response_data = hyde_obj.answer(best_hit, question)
        else:
            response_data = hyde_obj.answer([], question)

        df.at[idx, 'retrieval_document'] = str(retrieval_document)
        df.at[idx, 'response'] = response_data
        df.at[idx, 'processing_time'] = end_time - start_time

    if output_parquet:
        df.to_parquet(output_parquet, index=False)
        print(f"[INFO] Results saved to {output_parquet}")

    return df


if __name__ == "__main__":
    input_parquet_path = "/Users/chengze/Desktop/hotpotqa_questions_200.parquet"
    hotpotqa_folder = "/Users/chengze/Desktop/NaiveRAG_hotpotqa"
    output_parquet_path = "/Users/chengze/Desktop/NaiveRAG_hotpotqa_response_results.parquet"

    df_final = process_hotpotqa_questions(
        input_parquet=input_parquet_path,
        hotpotqa_folder=hotpotqa_folder,
        output_parquet=output_parquet_path
    )
    print("\n=== Sample results ===")
    print(df_final.head(10))


[INFO] Loaded questions: (200, 1)


HotpotQA:   0%|          | 0/200 [00:00<?, ?it/s]

[INFO] Results saved to /Users/chengze/Desktop/NaiveRAG_hotpotqa_response_results.parquet

=== Sample results ===
                                               question  \
1116  Out to Win is an American documentary film tha...   
1368  Are both Variety and The Advocate LGBT-interes...   
422   Who is this American rapper, songwriter, recor...   
413   In which year was this American country music ...   
451   What is the nationality of the actor who starr...   
861   Which tennis player is from the United States,...   
1063  When did the company that produced "Six Gates ...   
741   Interval starred Merle Oberon who fell in love...   
1272  Which English Romantic poet wrote the sonnet t...   
259   In which stadium do the teams owned by Myra Kr...   

                                     retrieval_document  \
1116  [(0.9446477, 'Sudanese-born Australian profess...   
1368  [(0.817297, 'ts and tourists on the right.  On...   
422   [(0.95737755, 'response to Nicki Minaj\'s song...   
