In [None]:
import sys
import os
import re 
from tqdm import tqdm
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict,List

sys.path.append(os.path.abspath(".."))
from tools.string_utils import read_text_file
from tools.json_utils import load_json, save_json
from tools.api import call_api

In [22]:
def extract_answers(prompt):
    # Define regular expressions to capture the short and long answers
    short_answer_pattern = r"<answer-short>\s*<reason>(.*?)</reason>\s*<answer>(.*?)</answer>\s*</answer-short>"
    long_answer_pattern = r"<answer-long>\s*<reason>(.*?)</reason>\s*<answer>(.*?)</answer>\s*</answer-long>"

    # Search for the patterns in the prompt
    short_answer_match = re.search(short_answer_pattern, prompt, re.DOTALL)
    long_answer_match = re.search(long_answer_pattern, prompt, re.DOTALL)

    # Extract the reason and answer for short and long answers
    if short_answer_match:
        short_reason = short_answer_match.group(1).strip()
        short_answer = short_answer_match.group(2).strip()
    else:
        short_reason = None
        short_answer = None

    if long_answer_match:
        long_reason = long_answer_match.group(1).strip()
        long_answer = long_answer_match.group(2).strip()
    else:
        long_reason = None
        long_answer = None

    return {
        "short-answer": {
            "reason": short_reason,
            "answer": short_answer
        },
        "long-answer": {
            "reason": long_reason,
            "answer": long_answer
        }
    }

def process_input_content(cur_input, cur_prompt):
        try:
            cur_response = call_api(cur_prompt, temperature=0.6)
            answers = extract_answers(cur_response)
            cur_input['positive'] = answers['short-answer']['answer']
            cur_input['corrected-answer'] = answers
            return cur_input
        except Exception as e:
            print(e)
            print(f"An error occurred while processing input")
            return None, None
        
def expand_numbers_and_ranges(numbers_and_ranges):
    expanded_numbers = []
    for item in numbers_and_ranges:
        if '-' in item:  # It's a range like 'xx1-xx2'
            start, end = map(int, item.split('-'))
            if start > end:
                end, start = start, end
            expanded_numbers.extend(range(start, end + 1))
        else:  # It's a single number
            expanded_numbers.append(int(item))
    expanded_numbers = list(sorted(list(set(expanded_numbers))))
    return expanded_numbers

def list_to_docided_string(string_dict):
    """
    Convert a list of strings into a docided string.

    :param string_list: list of str, the list of strings to be converted
    :return: str, the resulting numbered string
    """
    numbered_string = ""
    for index, (doc_id, doc_content) in enumerate(string_dict.items()):
        numbered_string += f"""{index}. <doc>
    <doc-name>{doc_id}</doc-name>
    <detailed-desc>{doc_content}</detailed-desc>
</doc>
"""
    return numbered_string.strip()

def extract_and_remove_think_tags(text):
    # Find all content inside <think> tags
    think_contents = re.findall(r'<think>(.*?)</think>', text, flags=re.DOTALL)
    
    # Remove all <think> tags and their contents from the text
    cleaned_text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
    
    return think_contents, cleaned_text

def replace_clue_with_doc_and_sen(all_clueid2docid2senidlist: Dict[int, Dict[int, List[int]]], positive_answer: str) -> str:
    """
    Replaces [Clue xx] or [Clue xx-yy] citations in the positive_answer with formatted [Doc xx, Sen xx] citations.
    
    Parameters:
    - all_clueid2docid2senidlist: Dict mapping clue IDs to another dict mapping doc IDs to lists of sentence IDs.
      Example:
      {
          1: {1: [1, 2, 3]},
          2: {1: [4, 5]},
          3: {2: [1, 2]},
          4: {2: [3]},
      }
    - positive_answer: String containing [Clue xx] or [Clue xx-yy] patterns.
    
    Returns:
    - new_answer: String with [Clue xx] patterns replaced by formatted citations.
    """
    
    def expand_range(token: str) -> List[int]:
        """
        Expands a token which can be a single number or a range (e.g., '2' or '2-8') into a list of integers.
        """
        if '-' in token:
            start, end = token.split('-')
            return list(range(int(start), int(end) + 1))
        else:
            return [int(token)]
    
    def expand_range_in_list(tokens: List[str]) -> List[int]:
        """
        Processes a list of tokens, expanding ranges and collecting all clue IDs.
        """
        clue_ids = []
        for token in tokens:
            if '-' in token:
                clue_ids.extend(expand_range(token))
            else:
                if token.isdigit():
                    clue_ids.append(int(token))
        return clue_ids
    
    def expand_sen_ranges(nums: List[int]) -> List[str]:
        """
        Converts a sorted list of integers into a list with ranges for consecutive numbers.
        Example: [1,2,3,5] -> ['1-3', '5']
        """
        if not nums:
            return []

        nums = sorted(nums)
        ranges = []
        start = prev = nums[0]

        for num in nums[1:]:
            if num == prev + 1:
                prev = num
            else:
                if prev - start >= 2:
                    ranges.append(f"{start}-{prev}")
                elif prev - start == 1:
                    ranges.append(str(start))
                    ranges.append(str(prev))
                else:
                    ranges.append(str(start))
                start = prev = num

        # Handle the last range
        if prev - start >= 2:
            ranges.append(f"{start}-{prev}")
        elif prev - start == 1:
            ranges.append(str(start))
            ranges.append(str(prev))
        else:
            ranges.append(str(start))

        return ranges

    # Regular expression to find [Clue xx], [Clue xx, yy], [Clue xx-yy], etc.
    clue_pattern = re.compile(r'\[Clue\s+([^\]]+)\]')
    
    def replacement(match):
        # print("match:", match)
        clue_ids_str = match.group(1)
        # Split by comma and/or whitespace
        tokens = re.split(r'[,\s]+', clue_ids_str)
        # Expand tokens to individual clue IDs
        clue_ids = expand_range_in_list(tokens)
        # print("clue_ids:", clue_ids)
        
        # Map doc_id to set of sen_ids
        doc_to_sens = {}
        for cid in clue_ids:
            if cid in all_clueid2docid2senidlist:
                for doc_id, sen_ids in all_clueid2docid2senidlist[cid].items():
                    if doc_id not in doc_to_sens:
                        doc_to_sens[doc_id] = set()
                    doc_to_sens[doc_id].update(sen_ids)
        
        if not doc_to_sens:
            # No valid clues found, return the original string
            return match.group(0)
        
        # Build the citation strings
        citations = []
        for doc_id in sorted(doc_to_sens.keys()):
            sen_list = sorted(doc_to_sens[doc_id])
            sen_ranges = expand_sen_ranges(sen_list)

            if sen_ranges:
                # Prepend 'Sen ' to each range
                sen_formatted = [f"{s}" for s in sen_ranges]
                sen_formatted[0] = f"Sen {sen_formatted[0]}"
                # Join sentence parts with comma
                sen_str = ", ".join(sen_formatted)
                citations.append(f"Doc {doc_id}, {sen_str}")
            else:
                citations.append(f"")
        
        # Format multiple documents with separate brackets
        if len(citations) == 1:
            return f"[{citations[0]}]"
        else:
            # Each document citation in its own brackets
            return "".join(f"[{cit}]" for cit in citations)
    
    # Replace all [Clue ...] patterns using the replacement function
    new_answer = clue_pattern.sub(replacement, positive_answer)
    
    return new_answer

In [None]:
FINAL_ANSWER_GENERATOR_NUM_WORKERS = 4
FINAL_ANSWER_GENERATOR_MAX_GEN_TIMES = 100
FINAL_ANSWER_GENERATOR_OUTPUT_PATH = ("../data/final_answer_generated.json")
save_interval = 10 

prompt_template = read_text_file("../prompts/final_answer_generator.txt")
inputs = load_json("../data/proposed_questions.json")
corpus_data = load_json("../data/proposed_questions.json")

Loaded 10 items from ../data/proposed_questions.json
Loaded 10 items from ../data/proposed_questions.json


In [None]:
def run():
    # what is corpusid_2_context
    corpusid_2_context = {cur_dict['id']: cur_dict['context'] for cur_dict in corpus_data}

    success_num, all_num = 0, 0
    futures_to_data = {}

    with ThreadPoolExecutor(max_workers=FINAL_ANSWER_GENERATOR_NUM_WORKERS) as executor:
            data_list = inputs[:FINAL_ANSWER_GENERATOR_MAX_GEN_TIMES] # Limit processing to first 100 items (FINAL_ANSWER_GENERATOR_MAX_GEN_TIMES)
            for data_item in data_list: # Iterate through each data item in the limited list, here each item is a chunk (question)
                if 'proposed-questions' not in data_item: # Skip items without proposed questions
                    continue
                proposed_questions = data_item['proposed-questions'] # extract proposed questions
                chunk_id = data_item['id'] # extract document id
                
                # Build mapping structure: {clue_id: {doc_id: [sentence_id_list]}}
                # This maps each objective fact (clue) to its document and sentence IDs
                all_clueid2docid2senidlist = {}
                objective_facts = data_item['objective-facts']
                # Extract objective facts and their corresponding sentence ID strings
                sens = data_item["sens"] # e.g., ["1, 2-4", "5", "6-8"]

                # Pair each fact with its sentence IDs, enumerate starting from 1 for fact_id
                for (fact_id, objective_fact), sen in zip(enumerate(objective_facts, start=1), sens):
                    sen_ids = re.findall(r'\d+-\d+|\d+', sen) # Find all sentence ID patterns: ranges like "1-3" or singles like "5"
                    sen_ids = expand_numbers_and_ranges(sen_ids) # Expand ranges into full integer lists: ["1-3", "5"] → [1, 2, 3, 5]
                    all_clueid2docid2senidlist[fact_id] = {
                        chunk_id: sen_ids # Store: fact_id (clue ID) → {chunk_id → [sentence IDs]}
                    }

                # Process each proposed question for this data item 
                # proposed_question_type is the key (e.g., "question-1")
                # proposed_question_dict contains question data (question, answer, etc.)
                for proposed_question_type, proposed_question_dict in proposed_questions.items():
                    if "positive" in proposed_question_dict: # Skip if this question already has a 'positive' answer (already processed)
                        continue
                    
                    # Extract the original question text
                    original_question = proposed_question_dict['question']
                    positive_answer = proposed_question_dict['answer']
                    if not positive_answer:
                        continue
                    
                    # print("all_clueid2docid2senidlist:", all_clueid2docid2senidlist)
                    # print("positive_answer:", positive_answer)

                    # Convert [Clue xx] citations to [Doc xx, Sen yy] format
                    # Uses all_clueid2docid2senidlist to map clue IDs to doc/sentence references
                    positive_answer = replace_clue_with_doc_and_sen(all_clueid2docid2senidlist, positive_answer)

                    needed_corpusid2corpus = {chunk_id: corpusid_2_context[chunk_id]}
                    # Format as numbered XML-style string: "0. <doc><doc-name>...</doc-name>..."
                    needed_corpusid2corpus_str = list_to_docided_string(needed_corpusid2corpus)
                    
                    cur_prompt = prompt_template.replace('[[QUESTION]]', original_question)
                    cur_prompt = cur_prompt.replace('[[CONTEXT]]', needed_corpusid2corpus_str)
                    cur_prompt = cur_prompt.replace('[[ANSWER]]', positive_answer)
                    
                    # Submit API call task to executor for async processing
                    future = executor.submit(process_input_content, proposed_question_dict, cur_prompt)
                    futures_to_data[future] = (
                        None
                    )

                    # futures_to_data[future] = (
                    #     proposed_question_dict.get('rephrased-questions', []),
                    #     proposed_question_dict.get('rephrased-questions-part', []),
                    #     proposed_question_dict.get('rephrased-questions-hybrid', [])
                    # )

                    # rephrased_question_type_list = ['rephrased-questions', 'rephrased-questions-part', 'rephrased-questions-hybrid']
                    # for rephrased_question_type in rephrased_question_type_list:
                    #     rephrased_questions = proposed_question_dict.get(rephrased_question_type, [])
                    #     for rephrased_question_dict in rephrased_questions:
                    #         # get answer with already replaced clues
                    #         if 'reordered-question' in rephrased_question_dict:
                    #             rephrased_question = rephrased_question_dict['reordered-question']
                    #         else:
                    #             rephrased_question = rephrased_question_dict['result']
                    #         positive_answer = rephrased_question_dict['answer']
                    #         positive_answer = replace_clue_with_doc_and_sen(all_clueid2docid2senidlist, positive_answer)
                            
                    #         cur_prompt = self.prompt_template.replace('[[QUESTION]]', rephrased_question)
                    #         cur_prompt = cur_prompt.replace('[[CONTEXT]]', needed_corpusid2corpus_str)
                    #         cur_prompt = cur_prompt.replace('[[ANSWER]]', positive_answer)
                    #         future = executor.submit(self.process_input_content, rephrased_question_dict, self.CLIENT, cur_prompt)
                    #         futures_to_data[future] = (
                    #             proposed_question_dict.get('rephrased-questions', []),
                    #             proposed_question_dict.get('rephrased-questions-part', []),
                    #             proposed_question_dict.get('rephrased-questions-hybrid', [])
                    #         )

    
    all_num = len(futures_to_data) # Calculate total number of tasks submitted
    for future in tqdm(as_completed(futures_to_data), total=all_num, desc="Processing Future", dynamic_ncols=True):
        # rephrased_questions, rephrased_questions_part, rephrased_questions_hybrid = futures_to_data[future]
        _ = futures_to_data[future] # Line 79: Retrieve the stored metadata (currently None, unused)
        try:
            # Get the result from the completed future
            # Waits up to 10 minutes (600 seconds) for the result
            # cur_response is the return value from process_input_content
            # which is the updated proposed_question_dict with new fields:
            #   - 'positive': short answer text
            #   - 'corrected-answer': dict with 'short-answer' and 'long-answer'
            cur_response = future.result(timeout=10*60)
            if cur_response is None:
                print(f"Warning: Task returned None, may have failed")
                continue  # Don't count as success
            
            success_num += 1    
            
            # save (overwrite save)
            if (success_num + 1) % save_interval == 0:
                print(f'Progress: {success_num}/{all_num}')
                save_json(inputs, FINAL_ANSWER_GENERATOR_OUTPUT_PATH)

        except Exception as e:
            # Print error message but continue processing other futures
            # Note: Failed tasks are NOT counted in success_num
            # Possible errors: timeout (>10 min), API failure, parsing errors
            print(f"Error processing future: {e}")
            
    # Final save after all futures complete
    # Condition: Save if ANY tasks succeeded OR output file doesn't exist yet
    # This ensures data is always saved at the end, even if no interval saves occurred
    if success_num or not os.path.exists(FINAL_ANSWER_GENERATOR_OUTPUT_PATH):
        save_json(inputs, FINAL_ANSWER_GENERATOR_OUTPUT_PATH)
    
    return success_num, all_num

run()

Processing Future: 100%|██████████| 30/30 [00:00<00:00, 10714.33it/s]

Saving 9/30 outputs to ../data/final_answer_generated.json.
Saving 19/30 outputs to ../data/final_answer_generated.json.
Saving 29/30 outputs to ../data/final_answer_generated.json.
Saving outputs to ../data/final_answer_generated.json.





(30, 30)