In [1]:
import os
import sys
# import ollama
# import google.generativeai as genai
import anthropic
# import ollama
import random
import pandas as pd
from tqdm import tqdm
from google.generativeai.types import RequestOptions
from google.api_core import retry
from typing import List, Tuple
import json
import datetime
from openai import OpenAI


current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import visualize
import pandas as pd
from utils.utils import add_color_to_tags, extract_parts_0, extract_parts_1
import argparse

# Main Functions

In [None]:
def save_results(save_path: str, ids: List[str], questions: List[str], answers: List[str], append: bool = False):
    """
    Saves the results to a CSV file. If append is True and the file exists, it appends without headers.
    Otherwise, it writes a new file with headers.
    """
    df = pd.DataFrame({'id': ids, 'question': questions, 'answer': answers})
    if append and os.path.exists(save_path):
        df.to_csv(save_path, mode='a', index=False, header=False)
    else:
        df.to_csv(save_path, index=False)

def read_jsonl_file(filepath: str) -> List[dict]:
    """
    Reads a JSONL file and returns a list of JSON objects.
    """
    data = []
    with open(filepath, 'r') as file:
        for line in file:
            json_obj = json.loads(line)
            data.append(json_obj)
    return data

def get_prompt(prompt_type: str, few_shot_prompt: str, question: str) -> str:
    """
    Constructs the prompt based on the prompt type.
    """
    prompts = {
        "cot": f"{few_shot_prompt}\n{question}\nPlease generate your explanation first, then generate the answer in the bracket as follow:\n" +"Answer: {}",
        "log_cot": f"{few_shot_prompt}\n{question}\nThink through your answer step by step and then chose the answer option that is the most correct. Put your final answer in curly brackets. For example, Final_Answer:{{A}}",
        "fs": f"{few_shot_prompt}\n{question}",
        "fs_inst": f"{few_shot_prompt}\n{question}\nI want you to answer this question but your explanation should contain references referring back to the information in the question. To do that, first, re-generate the question with proper tags and then generate your answers. The output format is as follow:\n\
            Reformatted Question: \
                Answer:",
        "zs": f"{question}\nI want you to answer this question but your explanation should contain references referring back to the information in the question. To do that, first, re-generate the question with proper tags (<a>, <b>, <c>, etc) for refered information and then generate your answers that also have the tag (<a>, <b>, <c>, etc) for the grounded information. Give your answer by analyzing step by step, and give only numbers in the final answer. The output format is as follow:\n\
            Reformatted Question: \
                Answer:\
                    Final answer:",
        "fs_xml": f"{few_shot_prompt}\n\nRecreate the following question in the style of the correctly formatted examples shown previously. Make sure that your response has all its information inclosed in the proper <tags>. Begin your response with the <key_facts> section. Make sure that every fact in <key_facts> is very concise and contains a very short reference to the <question>. Do not include a <question> section in your response\n\n<question>\n{question}\n</question>",
        "fs_log_inst": f"{few_shot_prompt}\n\n{question}\nTo answer this question, your explanation should contain references referring back to the information in the question. To do that, first, re-generate the question with proper tags and then generate your answers based off the tags. Put your final answer in curly brackets e.g. Final_Answer: {{30}}",
        "mermaid_get_answer": f"{few_shot_prompt}\n\n Your job is to extract the key facts from a question relevant to answering the question. The facts should be represented in a hierarchal format through a mermaid diagram. Do not create duplicate facts across multiple branches that represent the same information. Create a mermaid diagram that represents the key facts in the following question. Then, use the nodes from this graph to cite specific facts in your answer reasoning. Put your final answer in curly brackets e.g. Final_Answer: {{30}} \n\nquestion: {question}", 
    }
    return prompts.get(prompt_type, "")

# cycle through all the keys 
# def get_gemini_key(problem_id):
#     GOOGLE_KEYS = [
#         'AIzaSyBQ7zvIZoET3199GNhuz86vKagn_JCEOmk', # original - gen lang client
#         'AIzaSyCEI-5U4z7-3q-uwlvkOrdT2e78aNmjnbg', # chat app
#         'AIzaSyCvycd0yZZ4GSj47qDLk4JoPemvzUSfvio', # project 1
#         'AIzaSyD5xNbDkaJMEMBpWEXYNq5SheF6omdKpzg', # project 2
#         'AIzaSyAjcrp_otRjGsj0YvB1cUc2BMng6KSEZwU', # project 3
#         'AIzaSyB43xEllzAqGJjz-ExIGadXpUQllQ6PiI4', # project 4
#         'AIzaSyDDTCn4lKul4vMj9GmEGJBxZFHb6QZSoA8', # project 5
#         'AIzaSyB1sNUXN9CNpRWwqQnwVBBzMF37kYCNOIY', # project 6
#         'AIzaSyBqRruZh4d4jq8q6FtUci71nOkqcVlpNLM', # project 7
#         'AIzaSyATMO-YWZX4qtMru-NKcodolGr_4kKme5U', # project 8
#         'AIzaSyBbKx5spKBPS2tVaUje2Vc1e2v7T6ouUGc', # project 9
#         'AIzaSyCPb6W1e7uNI6UoSDTkJRmvkNbl1Tzgpmg', # project 10
#         'AIzaSyDqj50lzn-YYIZ92NID4MKgReTeSEJgZuk', # project 11 --
#         'AIzaSyBXO1lqmulX82oJjgGh4EPWWcGunxlFjFg', # project 12
#         'AIzaSyBt95gM49zINc5l0cZKy285wvtc-kTUTt0', # project 13
#         'AIzaSyBf4ty3TH3UC0-TvE-UwhMcrYePZS8_lNs', # project 14
#         'AIzaSyDOLgV0DQN7jQwvUbpYyr7jjz8TPYLzdDc', # project 15
#     ]
#     index = problem_id%len(GOOGLE_KEYS)
#     print(f"getting key from {index}")
#     key = GOOGLE_KEYS[index]
#     return key

# def query_gemini(prompt: str, problem_id) -> str:
#     """
#     Queries the Gemini LLM with the given prompt and returns the response text.
#     """
#     genai.configure(api_key=get_gemini_key(problem_id))
#     model = genai.GenerativeModel('gemini-1.5-pro-latest')
#     response = model.generate_content(prompt, request_options=RequestOptions(retry=retry.Retry(initial=20, multiplier=3, maximum=121, timeout=60)))
#     text = response.candidates[0].content.parts[0].text
#     return text

def query_claude(prompt: str) -> str:
    """
    Queries the Claude LLM with the given prompt and returns the response text.
    """
    client = anthropic.Anthropic(api_key=API_KEYS['claude'])
    response = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1024,
        messages=[{"role": "user", "content": prompt}]
    )
    return response.content[0].text

def query_4o(prompt: str) -> str:
    client = OpenAI()

    completion = client.chat.completions.create(
        model="gpt-4o-2024-08-06",
        messages=[
            {
                "role": "user",
                "content": f"{prompt}"
            }
        ],
        temperature=0
    )

    return completion.choices[0].message.content

def query_llm(llm_model: str, ids: List[str], questions: List[str], few_shot_prompt: str, prompt_type: str, save_path: str, already_answered_ids: set) -> Tuple[List[str], List[str], List[str]]:
    """
    Queries the specified LLM for each question, skipping already answered ones.
    Saves each response immediately after it's obtained.
    Returns lists of answered IDs, questions, and answers.
    """
    answers = []
    ids_can_be_answered = []
    questions_can_be_answered = []
    
    for id, q in tqdm(zip(ids, questions), total=len(ids)):
        # print(f"Processing ID: {id}")
        if id in already_answered_ids:
            print(f"Skipping already answered ID: {id}")
            continue
        
        prompt = get_prompt(prompt_type, few_shot_prompt, q)
        try:
            if llm_model == 'gemini':
                answer = query_gemini(prompt, id)
            elif llm_model == 'claude':
                answer = query_claude(prompt)
            elif llm_model == '4o':
                answer = query_4o(prompt)
            elif llm_model == 'llama3.1':
                answer = ollama.generate(model='llama3.1', prompt=prompt)['response']
                print(f"Processed ID: {id}")
            else:
                raise ValueError(f"Unsupported LLM model: {llm_model}")
            # print(f"Answer for ID {id}: {answer}")
            
            # Append to lists
            answers.append(answer)
            questions_can_be_answered.append(q)
            ids_can_be_answered.append(id)

            # Save after each answer
            save_results(save_path, [id], [q], [answer], append=True)
        except Exception as e:
            print(f"Error processing question {id}: {str(e)}")
            continue
    
    return ids_can_be_answered, questions_can_be_answered, answers

def load_data(data_path: str, sample_size: int = None) -> Tuple[List[str], List[str]]:
    """
    Loads data from a JSONL file, optionally sampling a subset.
    """
    data = read_jsonl_file(data_path)
    print(f"Loaded {len(data)} records from: {data_path}")
    if sample_size:
        data = random.sample(data, sample_size)
        print(f"Sampled {sample_size} records.")
    questions = [x["question"] for x in data]
    ids = [x["id"] for x in data]
    return ids, questions

def load_data_deterministic(data_path: str, sample_size: int = None) -> Tuple[List[str], List[str]]:
    """
    Loads data from a JSONL file in a deterministic manner by sorting.
    """
    data = read_jsonl_file(data_path)
    print(f"Loaded {len(data)} records from: {data_path}")
    if sample_size:
        # Sort the data based on a consistent criterion (e.g., 'id' or 'question')
        sorted_data = sorted(data, key=lambda x: x['id'])
        # Take the first 'sample_size' items
        data = sorted_data[:sample_size]
        print(f"Selected first {sample_size} records after sorting.")
    questions = [x["question"] for x in data]
    ids = [x["id"] for x in data]
    return ids, questions

def load_data_size_specific(data_path: str, sample_size: int = 0):
    data = read_jsonl_file(data_path)
    # with open(data_path, 'r', encoding='utf-8') as file:
    #     data = json.load(file)
    random_data = data
    print(random_data)
    
    # for idx, item in enumerate(random_data):
    #     if 'id' not in item:
    #         # Option 1: Use enumeration for simple integer IDs
    #         item['id'] = idx + 1  # Starting IDs from 1
    
    question_length = 0 # 336  # 526 # 800
    
    questions = [x["question"] for x in random_data if len(x["question"]) >= question_length]
    ids = [x["id"] for x in random_data if len(x["question"]) >= question_length]
    return ids[:sample_size], questions[:sample_size]

def load_few_shot_prompt(prompt_path: str) -> str:
    """
    Loads the few-shot prompt from a text file.
    """
    with open(prompt_path, 'r') as file:
        prompt = file.read()
    # print(f"Loaded few-shot prompt from: {prompt_path}")
    return prompt

def load_already_answered_ids(save_path: str) -> set:
    """
    Loads the set of IDs that have already been answered from the CSV file.
    Returns an empty set if the file does not exist.
    """
    if os.path.exists(save_path):
        df = pd.read_csv(save_path)
        answered_ids = set(df['id'].astype(int).tolist())
        # print(f"Loaded {len(answered_ids)} already answered IDs from: {save_path}")
        print(f"Already answered IDs: {answered_ids}")
        return answered_ids
    else:
        print(f"No existing save file found at: {save_path}. Starting fresh.")
        return set()

def initialize_save_file(save_path: str):
    """
    Initializes the CSV file with headers if it doesn't exist.
    """
    if not os.path.exists(save_path):
        # Create an empty DataFrame with headers and save
        df = pd.DataFrame(columns=['id', 'question', 'answer'])
        df.to_csv(save_path, index=False)
        print(f"Initialized new save file with headers at: {save_path}")

# Driver

In [3]:
time = datetime.datetime.now().strftime("%m%d_%H%M%S")
# time = '1003_002310'
project_root = '/Users/log/Github/textual_grounding/'
dataset = 'SPARTQA'

llm_model = '4o'
prompt_type = 'fs_log_inst'
# prompt_type = 'log_cot'
# few_shot_txt = 'fewshot_vanilla_cot.txt'
few_shot_txt = '1_shot_grounded.txt'

# Paths
data_path = os.path.join(project_root, 'data', dataset, 'test.jsonl')
# data_path = os.path.join(project_root, 'data', dataset, 'test.json')
# data_path = '/Users/log/Github/textual_grounding/logan/results/GSM8K/llama/mermaid/mermaid_get_graph_llama3.1_20240924_001821.csv'

fewshot_prompt_path = os.path.join(project_root, "prompt", dataset, few_shot_txt)
# fewshot_prompt_path = '/Users/log/Github/textual_grounding/prompt/GSM8K/fewshot_mermaid_full.txt'
save_dir = os.path.join(project_root, 'logan/results', dataset, f'{llm_model}/grounded_fact')
os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists
save_path = os.path.join(save_dir, f'{few_shot_txt}_{llm_model}_{time}.csv')

# ids, questions = load_data_deterministic(data_path, sample_size=200)
ids, questions = load_data_size_specific(data_path, sample_size=3)
few_shot_prompt = load_few_shot_prompt(fewshot_prompt_path)

print(few_shot_prompt)
# raise ValueError('stop')

initialize_save_file(save_path)
already_answered_ids = load_already_answered_ids(save_path)


ids_answered, questions_answered, answers = query_llm(
    llm_model=llm_model,
    ids=ids,
    questions=questions,
    few_shot_prompt=few_shot_prompt,
    prompt_type=prompt_type,
    save_path=save_path,
    already_answered_ids=already_answered_ids
)

print(f"Processing complete. {len(ids_answered)} new answers saved to {save_path}.")

[{'prompt': 'We have three blocks. Lets call them A, B and C. Block B is below A. Block A is below C. Block A contains a medium yellow square. Block B has two medium blue squares. Medium blue square number one is touching the bottom edge of this block. Medium blue square number two is below a medium yellow square. Medium blue square number one is below the square which is below the medium yellow square. It is below the medium yellow square. Block C contains one medium black square. What is below the black shape? a medium yellow square that is in block A or a medium yellow square that is in block B?\n0: medium yellow square  that is in block A\n1: medium yellow square  that is in block B\n2: both of them\n3: none of them', 'answer': 2, 'candidate_answers': ['medium yellow square  that is in block A', 'medium yellow square  that is in block B', 'both of them', 'none of them'], 'id': 1}, {'prompt': 'There are three blocks, A, B and C. Block A is above C. Block C is above B. Block A has a 

KeyError: 'question'

## LogiQA

In [5]:
import sys
import os
from datasets import load_dataset

# Add the directory containing logiqa.py to the Python path
logiqa_path = "/Users/log/Github/textual_grounding/data/logiqa"
sys.path.append(logiqa_path)

# Import the LogiQA class from the logiqa module if needed
from logiqa import LogiQA

# Load the dataset using Hugging Face load_dataset method
dataset = load_dataset('/Users/log/Github/textual_grounding/data/logiqa/logiqa.py', split='test')

# Print out the first 5 examples from the test set
for idx in range(5):
    example = dataset[idx]
    print(f"Example {idx + 1}:")
    print(f"Context: {example['context']}")
    print(f"Query: {example['query']}")
    print(f"Options: {example['options']}")
    print(f"Correct Option Index: {example['correct_option']}")
    print("-" * 50)


  if re.match('^[A-Z][\w\s]+[?.!]$', text) is None:


Downloading data:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/165k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/164k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7376 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/651 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/651 [00:00<?, ? examples/s]

Example 1:
Context: In the planning of a new district in a township, it was decided to build a special community in the southeast, northwest, centered on the citizen park. These four communities are designated as cultural area, leisure area, commercial area and administrative service area. It is known that the administrative service area is southwest of the cultural area, and the cultural area is southeast of the leisure area.
Query: Based on the above statement, which of the following can be derived?
Options: ['Civic Park is north of the administrative service area.', 'The leisure area is southwest of the cultural area.', 'The cultural district is in the northeast of the business district.', 'The business district is southeast of the leisure area.']
Correct Option Index: 0
--------------------------------------------------
Example 2:
Context: The company sent three young staff members to the South for business trip. The three of them happened to be sitting in a row. At least one of th

In [8]:
import json
from datasets import load_dataset

# Load the dataset (adjust the path as needed)
dataset = load_dataset('/Users/log/Github/textual_grounding/data/logiqa/logiqa.py', split='test')

# Prepare to write the first 300 examples to a JSONL file
output_file = 'logiqa_300_examples.jsonl'
with open(output_file, 'w', encoding='utf-8') as f:
    for idx, example in enumerate(dataset):
        if idx >= 300:
            break
        
        # Create the "question" field by concatenating context, query, and options
        context = example['context']
        query = example['query']
        options = example['options']
        options_str = " ".join([f"({chr(65 + i)}) {opt}" for i, opt in enumerate(options)])
        question = f"{context} {query}\n{options_str}"
        
        # Create the dictionary for the current example
        example_dict = {
            "id": idx,
            "question": question,
            "answer": chr(65 + example['correct_option'])  # Convert index to letter (A, B, C, D)
        }
        
        # Write the example as a JSON object to the JSONL file
        f.write(json.dumps(example_dict) + '\n')

print(f"Saved 300 examples to {output_file}")

Saved 300 examples to logiqa_300_examples.jsonl


# Visualization

## XML - visualize

In [83]:
import csv
import re

def extract_parts_1(answer_text):
    """
    Processes the answer text to extract key facts (with numbers), answer reasoning, and the final answer.

    Args:
        answer_text (str): The full answer text containing <key_facts>, <answer_reasoning>, and <final_answer>.

    Returns:
        tuple: (key_facts_list, answer_reasoning, final_answer)
               where key_facts_list is a list of tuples (fact_number, fact_content)
    """
    # Extract key_facts
    key_facts_match = re.search(r'<key_facts>(.*?)</key_facts>', answer_text, re.DOTALL)
    key_facts_content = key_facts_match.group(1).strip() if key_facts_match else ""

    # Extract individual facts with their numbers
    facts = re.findall(r'<fact_(\d+)>(.*?)</fact_\d+>', key_facts_content, re.DOTALL)
    key_facts_list = [(number.strip(), content.strip()) for number, content in facts]

    # Extract answer_reasoning
    reasoning_match = re.search(r'<answer_reasoning>(.*?)</answer_reasoning>', answer_text, re.DOTALL)
    answer_reasoning = reasoning_match.group(1).strip() if reasoning_match else ""

    # Extract final_answer
    final_match = re.search(r'<final_answer>(.*?)</final_answer>', answer_text, re.DOTALL)
    final_answer = final_match.group(1).strip() if final_match else ""

    return key_facts_list, answer_reasoning, final_answer


def add_color_to_tags(text):
    """
    Adds background color to specific tags within the text based on a predefined color mapping.

    Args:
        text (str): The text containing tags like <fact_1>, <fact_2>, etc.

    Returns:
        str: The text with added inline CSS for background colors.
    """
    tag_color_mapping = {
        'fact_1': 'yellow',  
        'fact_2': 'lightblue',
        'fact_3': 'lightgreen',
        'fact_4': 'lightcoral',
        'fact_5': 'lightcyan', 
        'fact_6': 'orange',
    }
    # Iterate over the tag-color mappings
    for tag, color in tag_color_mapping.items():
        # Regex to find the tag and replace it with the same tag having a style attribute
        text = re.sub(
            f'<{tag}>(.*?)</{tag}>',
            f'<{tag} style="background-color: {color};">\\1</{tag}>',
            text,
            flags=re.DOTALL
        )
    return text


def parse_csv_file(file_path):
    """
    Parses the input CSV file and extracts questions and their corresponding answers.

    Args:
        file_path (str): Path to the input CSV file.

    Returns:
        list of tuples: Each tuple contains (question, answer_text).
    """
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            qa_pairs.append((question, answer_text))
    return qa_pairs


def create_highlight_html(qa_pairs):
    """
    Creates HTML content with highlighted questions, key facts, answer reasoning, and answers.

    Args:
        qa_pairs (list of tuples): Each tuple contains (question, answer_text).

    Returns:
        str: The complete HTML content as a string.
    """
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Question and Answer Highlights</title>
        <style>
            body {
                font-family: Arial, sans-serif;
                margin: 20px;
                background-color: #f0f0f0;
            }
            .container {
                background-color: #ffffff;
                padding: 20px;
                margin-bottom: 20px;
                border-radius: 8px;
                box-shadow: 0 2px 5px rgba(0,0,0,0.1);
            }
            .question {
                font-size: 1.2em;
                margin-bottom: 10px;
            }
            .key-facts {
                margin-bottom: 10px;
            }
            .key-facts ul {
                list-style-type: number;
                padding-left: 20px;
            }
            .key-facts ul li{
                margin-bottom: 4px;
            }
            .answer-reasoning, .final-answer {
                margin-bottom: 10px;
            }
            .highlight {
                background-color: #FFFF00; /* Yellow background for visibility */
                font-weight: bold; /* Bold text for emphasis */
            }
            /* Styles for specific facts */
            fact_1 {
                background-color: yellow;
                font-weight: bold;
            }
            fact_2 {
                background-color: lightblue;
                font-weight: bold;
            }
            fact_3 {
                background-color: lightgreen;
                font-weight: bold;
            }
            fact_4 {
                background-color: lightcoral;
                font-weight: bold;
            }
            fact_5 {
                background-color: lightcyan;
                font-weight: bold;
            }
            fact_6 {
                background-color: orange;
                font-weight: bold;
            }
        </style>
    </head>
    <body>
    <h1>Question and Answer Highlights</h1>
    """
    for i, (question, answer_text) in enumerate(qa_pairs, 1):
        try:
            key_facts, answer_reasoning, final_answer = extract_parts_1(answer_text)
        except Exception as e:
            print(f"Cannot extract parts for question {i}: {e}")
            continue

        # Convert key_facts list to HTML bullet points with "Key fact X:" prefix
        if key_facts:
            key_facts_html = "<ul>\n"
            for fact_number, fact_content in key_facts:
                # Apply color to tags in fact_content
                highlighted_fact = add_color_to_tags(fact_content)
                # Prepend "Key fact X:"
                key_facts_html += f"    <li><fact_{fact_number}>{highlighted_fact}</fact_{fact_number}></li>\n"
            key_facts_html += "</ul>"
        else:
            key_facts_html = "<p>No key facts available.</p>"

        # Apply color to tags in answer_reasoning and final_answer
        highlighted_reasoning = add_color_to_tags(answer_reasoning)
        highlighted_final_answer = add_color_to_tags(final_answer)

        # Build the HTML structure
        html_content += f"<div class='container'>"
        html_content += f"<div class='question'><strong>Question:</strong> {question}</div>"
        html_content += f"<div class='key-facts'><strong>Key Facts:</strong> {key_facts_html}</div>"
        html_content += f"<div class='answer-reasoning'><strong>Answer Reasoning:</strong> {highlighted_reasoning}</div>"
        html_content += f"<div class='final-answer'><strong>Answer:</strong> {highlighted_final_answer}</div>"
        html_content += "</div>\n"

    # Close the HTML tags
    html_content += """
    </body>
    </html>
    """
    return html_content


def main():
    input_file = '/Users/log/Github/textual_grounding/logan/results/GSM8K/llama/test_grounding_answer_prompt_fs_xml_llama3.1.csv'  # Replace with your input CSV file path
    output_file = 'test_grounding_answer_prompt_fs_xml_llama3.1.html'  # Replace with your desired output HTML file path

    # Parse the input CSV file to extract questions and answers
    qa_pairs = parse_csv_file(input_file)

    # Check if any QA pairs were found
    if not qa_pairs:
        print("No question-answer pairs were found in the input file.")
        return

    # Generate the HTML content
    html_content = create_highlight_html(qa_pairs)

    # Write the HTML content to the output file
    with open(output_file, 'w', encoding='utf-8') as file:
        file.write(html_content)

    print(f"HTML content has been successfully written to {output_file}")


if __name__ == "__main__":
    main()


HTML content has been successfully written to test_grounding_answer_prompt_fs_xml_llama3.1.html


## Mermaid - Visualize

In [25]:
import csv
import re
import json
import os

def extract_final_answer(answer_text):
    """
    Extracts the final answer enclosed in curly braces {} from the answer_text.

    Args:
        answer_text (str): The full model response text.

    Returns:
        str: The extracted final answer, or an empty string if not found.
    """
    final_answer_pattern = re.compile(
        r'Final_Answer:\s*\{([^}]+)\}',
        re.IGNORECASE
    )
    final_match = final_answer_pattern.search(answer_text)
    final_answer = final_match.group(1).replace(',', '').replace('$', '').strip() if final_match else ""
    return final_answer

def parse_csv_file(file_path):
    """
    Parses the input CSV file and extracts questions, answers, and their corresponding IDs.

    Args:
        file_path (str): Path to the input CSV file.

    Returns:
        list of tuples: Each tuple contains (id, question, answer_text).
    """
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            id_ = row.get('id')
            if id_ is not None:
                try:
                    id_int = int(id_)
                    qa_pairs.append((id_int, question, answer_text))
                except ValueError:
                    print(f"Skipping a row due to invalid 'id' (not an integer): {id_}")
            else:
                print(f"Skipping a row due to missing 'id': {row}")
    return qa_pairs

def read_ground_truth(jsonl_path):
    """
    Reads the ground truth answers from a JSONL file and maps them by ID.

    Args:
        jsonl_path (str): Path to the ground truth JSONL file.

    Returns:
        dict: A dictionary mapping each ID to its ground truth answer.
    """
    ground_truth = {}
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            id_ = data.get('id')
            answer = data.get('answer')
            if id_ is not None and answer is not None:
                ground_truth[id_] = answer
            else:
                print(f"Invalid ground truth entry: {data}")
    return ground_truth

def create_simple_html(qa_pairs, ground_truth):
    """
    Creates simple HTML content displaying full model responses and comparison with ground truth.

    Args:
        qa_pairs (list of tuples): Each tuple contains (id, question, answer_text).
        ground_truth (dict): A dictionary mapping each ID to its ground truth answer.

    Returns:
        str: The complete HTML content as a string.
    """
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Question and Answer Comparison</title>
        <style>
            body {
                font-family: Arial, sans-serif;
                margin: 20px;
                background-color: #f9f9f9;
            }
            .container {
                background-color: #ffffff;
                padding: 15px 20px;
                margin-bottom: 15px;
                border-radius: 6px;
                box-shadow: 0 1px 3px rgba(0,0,0,0.1);
            }
            .question {
                font-size: 1.1em;
                margin-bottom: 10px;
                color: #333333;
            }
            .answer-text {
                background-color: #f4f4f4;
                padding: 10px;
                border-left: 4px solid #2196F3;
                margin-bottom: 10px;
                white-space: pre-wrap;
                font-family: Consolas, "Courier New", monospace;
            }
            .final-answer, .ground-truth-answer {
                margin-bottom: 5px;
            }
            .final-answer span.correct {
                color: green;
                font-weight: bold;
            }
            .final-answer span.incorrect {
                color: red;
                font-weight: bold;
            }
            .ground-truth-answer {
                color: #555555;
            }
            .summary {
                background-color: #e0ffe0;
                padding: 15px;
                border: 2px solid #00cc00;
                border-radius: 8px;
                font-size: 1.2em;
                margin-top: 30px;
            }
        </style>
    </head>
    <body>
    <h1>Question and Answer Comparison</h1>
    """

    # Initialize counters for correct and total answers
    correct_answers = 0
    total_answers = 0

    for id_, question, answer_text in qa_pairs:
        final_answer = extract_final_answer(answer_text)
        gt_answer = ground_truth.get(id_)

        if gt_answer is None:
            gt_answer_display = "<span style='color: gray;'>Ground truth not available.</span>"
            is_correct = False
        else:
            # Normalize both final_answer and gt_answer for comparison
            try:
                final_answer_num = float(final_answer.replace(',', '').replace('$', ''))
                if isinstance(gt_answer, list):  # Handle list of answers if applicable
                    gt_answer_num = float(gt_answer[0].replace(',', '').replace('$', ''))
                else:
                    gt_answer_num = float(gt_answer.replace(',', '').replace('$', ''))
                is_correct = final_answer_num == gt_answer_num
                # Format numbers with commas and two decimal places if needed
                final_answer_display = f"{final_answer_num:,.2f}" if not final_answer_num.is_integer() else f"{int(final_answer_num):,}"
                gt_answer_display = f"{gt_answer_num:,.2f}" if not gt_answer_num.is_integer() else f"{int(gt_answer_num):,}"
            except (ValueError, TypeError):
                # Fallback to string comparison if conversion fails
                is_correct = final_answer.strip().lower() == str(gt_answer).strip().lower()
                final_answer_display = final_answer
                gt_answer_display = gt_answer

        # Style the final answer based on correctness
        if is_correct:
            final_answer_html = f"<span class='correct'>{final_answer_display}</span>"
            correct_answers += 1
        else:
            final_answer_html = f"<span class='incorrect'>{final_answer_display}</span>"
        total_answers += 1

        # Display ground truth answer
        if gt_answer is not None:
            ground_truth_html = f"<div class='ground-truth-answer'><strong>Ground Truth Answer:</strong> {gt_answer_display}</div>"
        else:
            ground_truth_html = f"<div class='ground-truth-answer'><strong>Ground Truth Answer:</strong> Not available.</div>"

        # Build the HTML structure for each QA pair
        html_content += f"<div class='container'>"
        html_content += f"<div class='question'><strong>Question:</strong> {question}</div>"
        html_content += f"<div class='answer-text'><strong>Model Response:</strong><br>{answer_text}</div>"
        html_content += f"<div class='final-answer'><strong>Final Answer:</strong> {final_answer_html}</div>"
        html_content += f"{ground_truth_html}"
        html_content += "</div>\n"

    # Add the summary section
    summary_percentage = (correct_answers / total_answers * 100) if total_answers > 0 else 0
    summary_html = f"""
    <div class='summary'>
        <strong>Summary:</strong> Correct Answers: {correct_answers} / {total_answers} ({summary_percentage:.2f}%)
    </div>
    """
    html_content += summary_html

    # Close the HTML tags
    html_content += """
    </body>
    </html>
    """
    return html_content

def main():
    input_csv = '/Users/log/Github/textual_grounding/logan/results/LogiQA/4o/grounded_fact/vanilla_log_cot_4o_1006_232349.csv'  
    ground_truth_file = '/Users/log/Github/textual_grounding/data/logiqa/test.jsonl'  # Path to the ground truth JSONL file
    output_file = 'logiqa_4o_vanilla_cot_simple.html'  # Desired output HTML file path

    # Check if input files exist
    if not os.path.isfile(input_csv):
        print(f"Input CSV file not found: {input_csv}")
        return
    if not os.path.isfile(ground_truth_file):
        print(f"Ground truth JSONL file not found: {ground_truth_file}")
        return

    # Parse the input CSV file to extract IDs, questions, and answers
    qa_pairs = parse_csv_file(input_csv)
    print(f"Total QA Pairs Parsed: {len(qa_pairs)}")

    # Read the ground truth answers
    ground_truth = read_ground_truth(ground_truth_file)
    print(f"Total Ground Truth Entries: {len(ground_truth)}")

    # Check if any QA pairs were found
    if not qa_pairs:
        print("No question-answer pairs were found in the input file.")
        return

    # Generate the HTML content
    html_content = create_simple_html(qa_pairs, ground_truth)

    # Write the HTML content to the output file
    with open(output_file, 'w', encoding='utf-8') as file:
        file.write(html_content)

    print(f"HTML content has been successfully written to {output_file}")

if __name__ == "__main__":
    main()


Total QA Pairs Parsed: 200
Total Ground Truth Entries: 297
id: 0: A
id: 1: A
id: 2: B
id: 3: D
id: 4: D
id: 5: B
id: 6: D
id: 7: C
id: 8: C
id: 9: D
id: 10: B
id: 11: D
id: 12: A
id: 13: D
id: 14: B
id: 15: D
id: 16: C
id: 17: A
id: 18: B
id: 19: D
id: 20: B
id: 21: A
id: 22: C
id: 23: D
id: 24: A
id: 25: B
id: 26: D
id: 27: D
id: 28: A
id: 29: D
id: 30: D
id: 31: C
id: 32: A
id: 33: D
id: 34: C
id: 35: A
id: 36: B
id: 37: D
id: 38: B
id: 39: C
id: 40: D
id: 41: A
id: 42: B
id: 43: C
id: 44: D
id: 45: A
id: 46: D
id: 47: D
id: 48: C
id: 49: B
id: 50: C
id: 51: A
id: 52: D
id: 53: D
id: 54: A
id: 55: B
id: 56: C
id: 57: D
id: 58: D
id: 59: B
id: 60: A
id: 61: D
id: 62: A
id: 63: A
id: 64: D
id: 65: C
id: 66: B
id: 67: D
id: 68: C
id: 69: A
id: 70: D
id: 71: B
id: 72: D
id: 73: D
id: 74: D
id: 75: C
id: 76: B
id: 77: A
id: 78: D
id: 79: A
id: 80: D
id: 81: B
id: 82: D
id: 83: B
id: 84: C
id: 85: A
id: 86: D
id: 87: C
id: 88: D
id: 89: B
id: 90: D
id: 91: B
id: 92: A
id: 93: D
id: 94: D
i

# Grounded Visual

In [41]:
import csv
import re
import json  # For handling JSONL
import os

def add_color_to_tags_new(text):
    """
    Adds background color to specific tags within the text based on dynamically assigned colors.
    Each span will have a class corresponding to the tag's name.

    Args:
        text (str): The text containing tags like <B>, <C1>, etc.

    Returns:
        str: The text with added inline CSS for background colors and class names.
    """
    # Find all unique tags in the text using regex
    tags = set(re.findall(r'<([A-Za-z]+\d*)>', text))

    # Predefined color palette
    color_palette = [
        'lightyellow', 'lightblue', 'lightgreen', 'lightcoral',
        'lightcyan', 'lightpink', 'lightsalmon', 'lightgray',
        'lightgoldenrodyellow', 'lightseagreen', 'lightskyblue',
        'lightsteelblue'
    ]

    # Dictionary to hold tag-color mapping
    tag_color_mapping = {}

    # Assign colors to tags, cycling through the color palette if necessary
    for i, tag in enumerate(sorted(tags)):
        color = color_palette[i % len(color_palette)]
        tag_color_mapping[tag] = color

    # Function to replace tags with styled spans including class names
    def replace_tag(match):
        tag = match.group(1)
        content = match.group(2)
        color = tag_color_mapping.get(tag, 'lightgray')  # Default color if not found
        return f'<span class="{tag}" style="background-color: {color}; font-weight: bold;">{content}</span>'

    # Regex to find tags and replace them with styled spans
    tag_regex = re.compile(r'<([A-Za-z]+\d*)>\s*([\s\S]*?)\s*</\1>')

    # Replace all tags with styled spans
    text = tag_regex.sub(replace_tag, text)

    return text


def parse_csv_file(file_path):
    """
    Parses the input CSV file and extracts questions, answers, and their corresponding IDs.

    Args:
        file_path (str): Path to the input CSV file.

    Returns:
        list of tuples: Each tuple contains (id, question, answer_text).
    """
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            id_ = row.get('id')
            if id_ is not None:
                try:
                    id_int = int(id_)
                except ValueError:
                    print(f"Skipping a row due to invalid 'id' (not an integer): {id_}")
                    continue
                qa_pairs.append((id_int, question, answer_text))
            else:
                # Handle cases without 'id' by skipping
                print(f"Skipping a row due to missing 'id': {row}")
    return qa_pairs


def read_ground_truth(jsonl_path):
    """
    Reads the ground truth answers from a JSONL file and maps them by ID.

    Args:
        jsonl_path (str): Path to the ground truth JSONL file.

    Returns:
        dict: A dictionary mapping each ID to its ground truth answer.
    """
    ground_truth = {}
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            id_ = data.get('id')
            answer = data.get('answer')
            if id_ is not None and answer is not None:

                ground_truth[id_] = answer
            else:
                print(f"Invalid ground truth entry: {data}")
    return ground_truth


def create_highlight_html_new(qa_pairs, ground_truth):
    """
    Creates HTML content with the full model response, highlighted facts,
    and the extracted final answer.

    Args:
        qa_pairs (list of tuples): Each tuple contains (id, question, answer_text).
        ground_truth (dict): A dictionary mapping each ID to its ground truth answer.

    Returns:
        str: The complete HTML content as a string.
    """
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Question and Answer Highlights</title>
        <style>
            body {
                font-family: Arial, sans-serif;
                margin: 20px;
                background-color: #f0f0f0;
            }
            .container {
                background-color: #ffffff;
                padding: 20px;
                margin-bottom: 20px;
                border-radius: 8px;
                box-shadow: 0 2px 5px rgba(0,0,0,0.1);
            }
            .question {
                font-size: 1.2em;
                margin-bottom: 10px;
            }
            .mermaid {
                margin-bottom: 10px;
            }
            .mermaidPure {
                background-color: #f9f9f9;
                padding: 10px;
                border: 1px solid #ddd;
                border-radius: 4px;
                white-space: pre-wrap; /* Preserves whitespace and newlines */
                font-family: Consolas, "Courier New", monospace;
                margin-bottom: 10px;
            }
            .full-response, .final-answer, .ground-truth-answer {
                margin-bottom: 10px;
                white-space: pre-wrap; /* Add this line to preserve newlines */
            }
            .final-answer {
                font-weight: bold;
            }
            .ground-truth-answer {
                font-weight: bold;
            }
            /* Styles for the highlighted spans */
            .highlighted {
                padding: 2px 4px;
                border-radius: 3px;
                display: inline-block;
            }
            /* Styles for the summary section */
            .summary {
                background-color: #e0ffe0;
                padding: 15px;
                border: 2px solid #00cc00;
                border-radius: 8px;
                font-size: 1.2em;
                margin-top: 30px;
            }
        </style>
        <!-- Include Mermaid.js -->
        <script type="module">
            import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
            mermaid.initialize({ startOnLoad: true });
        </script>
    </head>
    <body>
    <h1>Question and Answer Highlights</h1>
    """

    # Initialize counters for correct and total answers
    correct_answers = 0
    total_answers = 0

    for i, (id_, question, answer_text) in enumerate(qa_pairs, 1):
        try:
            # Since we're no longer extracting parts, use the full answer_text
            full_response = answer_text.strip()
        except Exception as e:
            print(f"Cannot process answer for question ID {id_}: {e}")
            continue

        # Apply color to tags in the full_response
        highlighted_response = add_color_to_tags_new(full_response)
        
        # Replace newline characters with <br> tags to ensure they are rendered in HTML
        highlighted_response = highlighted_response.replace('\n', '<br>')

        # Extract the final answer within curly brackets {}
        final_answer_match = re.search(r'\{([^}]+)\}', full_response)
        if final_answer_match:
            final_answer = final_answer_match.group(1).replace(',', '').replace('$', '').strip()
        else:
            final_answer = ""

        # Retrieve ground truth answer
        gt_answer = str(ground_truth.get(id_))
        print(f"id: {id_}: {gt_answer}")
        if gt_answer is None:
            gt_answer_display = "<span style='color: gray;'>Ground truth not available.</span>"
            is_correct = False
        else:
            # Normalize both final_answer and gt_answer for comparison
            try:
                # Attempt to convert to float for numerical comparison
                final_answer_num = float(final_answer.replace(',', '').replace('$', ''))
                if isinstance(gt_answer, list):
                    # If ground truth is a list, take the first element
                    gt_answer_num = float(gt_answer[0].replace(',', '').replace('$', ''))
                else:
                    gt_answer_num = float(gt_answer.replace(',', '').replace('$', ''))
                is_correct = final_answer_num == gt_answer_num
                # Format numbers with commas and two decimal places if needed
                if final_answer_num.is_integer():
                    final_answer_display = f"{int(final_answer_num):,}"
                else:
                    final_answer_display = f"{final_answer_num:,.2f}"
                if gt_answer_num.is_integer():
                    gt_answer_display = f"{int(gt_answer_num):,}"
                else:
                    gt_answer_display = f"{gt_answer_num:,.2f}"
            except ValueError:
                # In case conversion fails, fallback to string comparison
                is_correct = final_answer == gt_answer
                final_answer_display = final_answer
                gt_answer_display = gt_answer

        # Style the final answer based on correctness
        if is_correct:
            highlighted_final_answer = f"<span style='font-size:1.1em; color: green;'>{final_answer_display}</span>"
            correct_answers += 1
        else:
            highlighted_final_answer = f"<span style='font-size:1.1em; color: red;'>{final_answer_display}</span>"
        total_answers += 1

        # Display ground truth answer
        if gt_answer is not None:
            ground_truth_html = f"<div class='ground-truth-answer'><strong>Ground Truth Answer:</strong> {gt_answer_display}</div>"
        else:
            ground_truth_html = f"<div class='ground-truth-answer'><strong>Ground Truth Answer:</strong> Not available.</div>"

        # Build the HTML structure
        html_content += f"<div class='container'>"
        html_content += f"<div class='question'><strong>Question:</strong> {question}</div>"
        html_content += f"<div class='full-response'><strong>Full Response:</strong> {highlighted_response}</div>"
        html_content += f"<div class='final-answer'><strong>Final Answer:</strong> {highlighted_final_answer}</div>"
        html_content += f"{ground_truth_html}"
        html_content += "</div>\n"

    # After processing all QA pairs, add the summary section
    summary_percentage = (correct_answers / total_answers * 100) if total_answers > 0 else 0
    summary_html = f"""
    <div class='summary'>
        <strong>Summary:</strong> Correct Answers: {correct_answers} / {total_answers} ({summary_percentage:.2f}%)
    </div>
    """
    # Close the HTML tags
    html_content += """
    </body>
    </html>
    """
    output_html = summary_html + html_content
    return output_html


def main():
    # Replace these paths with your actual file paths
    input_csv = '/Users/log/Github/textual_grounding/logan/results/LogiQA/4o/grounded_fact/fs_log_inst_4o_1006_223109.csv'  # Replace with your input CSV file path
    ground_truth_file = '/Users/log/Github/textual_grounding/data/logiqa/test.jsonl'  # Path to the ground truth JSON file
    output_file = 'logiqa_grounded_cot.html'  # Replace with your desired output HTML file path

    # Check if input files exist
    if not os.path.isfile(input_csv):
        print(f"Input CSV file not found: {input_csv}")
        return
    if not os.path.isfile(ground_truth_file):
        print(f"Ground truth JSON file not found: {ground_truth_file}")
        return

    # Parse the input CSV file to extract IDs, questions, and answers
    qa_pairs = parse_csv_file(input_csv)
    print(f"Total QA Pairs Parsed: {len(qa_pairs)}")  # Debug: Print the number of QA pairs parsed

    # Read the ground truth answers
    ground_truth = read_ground_truth(ground_truth_file)
    print(f"Total Ground Truth Entries: {len(ground_truth)}")  # Debug: Print the number of ground truth entries

    # Check if any QA pairs were found
    if not qa_pairs:
        print("No question-answer pairs were found in the input file.")
        return

    # Generate the HTML content
    html_content = create_highlight_html_new(qa_pairs, ground_truth)

    # Write the HTML content to the output file
    with open(output_file, 'w', encoding='utf-8') as file:
        file.write(html_content)

    print(f"HTML content has been successfully written to {output_file}")


if __name__ == "__main__":
    main()


Total QA Pairs Parsed: 200
Total Ground Truth Entries: 297
id: 0: A
id: 1: A
id: 2: B
id: 3: D
id: 4: D
id: 5: B
id: 6: D
id: 7: C
id: 8: C
id: 9: D
id: 10: B
id: 11: D
id: 12: A
id: 13: D
id: 14: B
id: 15: D
id: 16: C
id: 17: A
id: 18: B
id: 19: D
id: 20: B
id: 21: A
id: 22: C
id: 23: D
id: 24: A
id: 25: B
id: 26: D
id: 27: D
id: 28: A
id: 29: D
id: 30: D
id: 31: C
id: 32: A
id: 33: D
id: 34: C
id: 35: A
id: 36: B
id: 37: D
id: 38: B
id: 39: C
id: 40: D
id: 41: A
id: 42: B
id: 43: C
id: 44: D
id: 45: A
id: 46: D
id: 47: D
id: 48: C
id: 49: B
id: 50: C
id: 51: A
id: 52: D
id: 53: D
id: 54: A
id: 55: B
id: 56: C
id: 57: D
id: 58: D
id: 59: B
id: 60: A
id: 61: D
id: 62: A
id: 63: A
id: 64: D
id: 65: C
id: 66: B
id: 67: D
id: 68: C
id: 69: A
id: 70: D
id: 71: B
id: 72: D
id: 73: D
id: 74: D
id: 75: C
id: 76: B
id: 77: A
id: 78: D
id: 79: A
id: 80: D
id: 81: B
id: 82: D
id: 83: B
id: 84: C
id: 85: A
id: 86: D
id: 87: C
id: 88: D
id: 89: B
id: 90: D
id: 91: B
id: 92: A
id: 93: D
id: 94: D
i

## CoT - Visualize

In [36]:
import csv
import re
import json
import os

import re

def extract_final_answer(answer_text):
    """
    Extracts the last answer enclosed in curly braces {} from the answer_text.

    Args:
        answer_text (str): The full model response text.

    Returns:
        str: The extracted final answer, or an empty string if not found.
    """
    # Regex pattern to match anything inside curly braces
    final_answer_pattern = re.compile(r'\{([^}]+)\}')
    
    # Find all matches of text inside curly braces
    matches = list(final_answer_pattern.finditer(answer_text))
    
    # If we have at least one match, get the content of the last match
    if matches:
        final_answer = matches[-1].group(1).strip()
        return final_answer
    else:
        return ""


def parse_csv_file(file_path):
    """
    Parses the input CSV file and extracts questions, answers, and their corresponding IDs.

    Args:
        file_path (str): Path to the input CSV file.

    Returns:
        list of tuples: Each tuple contains (id, question, answer_text).
    """
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            id_ = row.get('id')
            if id_ is not None:
                try:
                    id_int = int(id_)
                    qa_pairs.append((id_int, question, answer_text))
                except ValueError:
                    print(f"Skipping a row due to invalid 'id' (not an integer): {id_}")
            else:
                print(f"Skipping a row due to missing 'id': {row}")
    return qa_pairs

def read_ground_truth(jsonl_path):
    """
    Reads the ground truth answers from a JSONL file and maps them by ID.

    Args:
        jsonl_path (str): Path to the ground truth JSONL file.

    Returns:
        dict: A dictionary mapping each ID to its ground truth answer.
    """
    ground_truth = {}
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            id_ = data.get('id')
            answer = data.get('answer')
            if id_ is not None and answer is not None:
                ground_truth[id_] = answer
            else:
                print(f"Invalid ground truth entry: {data}")
    return ground_truth

def create_simple_html(qa_pairs, ground_truth):
    """
    Creates simple HTML content displaying full model responses and comparison with ground truth.

    Args:
        qa_pairs (list of tuples): Each tuple contains (id, question, answer_text).
        ground_truth (dict): A dictionary mapping each ID to its ground truth answer.

    Returns:
        str: The complete HTML content as a string.
    """
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Question and Answer Comparison</title>
        <style>
            body {
                font-family: Arial, sans-serif;
                margin: 20px;
                background-color: #f9f9f9;
            }
            .container {
                background-color: #ffffff;
                padding: 15px 20px;
                margin-bottom: 15px;
                border-radius: 6px;
                box-shadow: 0 1px 3px rgba(0,0,0,0.1);
            }
            .question {
                font-size: 1.1em;
                margin-bottom: 10px;
                color: #333333;
            }
            .answer-text {
                background-color: #f4f4f4;
                padding: 10px;
                border-left: 4px solid #2196F3;
                margin-bottom: 10px;
                white-space: pre-wrap;
                font-family: Consolas, "Courier New", monospace;
            }
            .final-answer, .ground-truth-answer {
                margin-bottom: 5px;
            }
            .final-answer span.correct {
                color: green;
                font-weight: bold;
            }
            .final-answer span.incorrect {
                color: red;
                font-weight: bold;
            }
            .ground-truth-answer {
                color: #555555;
            }
            .summary {
                background-color: #e0ffe0;
                padding: 15px;
                border: 2px solid #00cc00;
                border-radius: 8px;
                font-size: 1.2em;
                margin-top: 30px;
            }
        </style>
    </head>
    <body>
    <h1>Question and Answer Comparison</h1>
    """

    # Initialize counters for correct and total answers
    correct_answers = 0
    total_answers = 0

    for id_, question, answer_text in qa_pairs:
        final_answer = extract_final_answer(answer_text)
        gt_answer = ground_truth.get(id_)

        if gt_answer is None:
            gt_answer_display = "<span style='color: gray;'>Ground truth not available.</span>"
            is_correct = False
        else:
            # Normalize both final_answer and gt_answer for comparison
            try:
                final_answer_num = float(final_answer.replace(',', '').replace('$', ''))
                if isinstance(gt_answer, list):  # Handle list of answers if applicable
                    gt_answer_num = float(gt_answer[0].replace(',', '').replace('$', ''))
                else:
                    gt_answer_num = float(gt_answer.replace(',', '').replace('$', ''))
                is_correct = final_answer_num == gt_answer_num
                # Format numbers with commas and two decimal places if needed
                final_answer_display = f"{final_answer_num:,.2f}" if not final_answer_num.is_integer() else f"{int(final_answer_num):,}"
                gt_answer_display = f"{gt_answer_num:,.2f}" if not gt_answer_num.is_integer() else f"{int(gt_answer_num):,}"
            except (ValueError, TypeError):
                # Fallback to string comparison if conversion fails
                is_correct = final_answer.strip().lower() == str(gt_answer).strip().lower()
                final_answer_display = final_answer
                gt_answer_display = gt_answer

        # Style the final answer based on correctness
        if is_correct:
            final_answer_html = f"<span class='correct'>{final_answer_display}</span>"
            correct_answers += 1
        else:
            final_answer_html = f"<span class='incorrect'>{final_answer_display}</span>"
        total_answers += 1

        # Display ground truth answer
        if gt_answer is not None:
            ground_truth_html = f"<div class='ground-truth-answer'><strong>Ground Truth Answer:</strong> {gt_answer_display}</div>"
        else:
            ground_truth_html = f"<div class='ground-truth-answer'><strong>Ground Truth Answer:</strong> Not available.</div>"

        # Build the HTML structure for each QA pair
        html_content += f"<div class='container'>"
        html_content += f"<div class='question'><strong>Question:</strong> {question}</div>"
        html_content += f"<div class='answer-text'><strong>Model Response:</strong><br>{answer_text}</div>"
        html_content += f"<div class='final-answer'><strong>Final Answer:</strong> {final_answer_html}</div>"
        html_content += f"{ground_truth_html}"
        html_content += "</div>\n"

    # Add the summary section
    summary_percentage = (correct_answers / total_answers * 100) if total_answers > 0 else 0
    summary_html = f"""
    <div class='summary'>
        <strong>Summary:</strong> Correct Answers: {correct_answers} / {total_answers} ({summary_percentage:.2f}%)
    </div>
    """
    # html_content += summary_html

    # Close the HTML tags
    html_content += """
    </body>
    </html>
    """

    final_content = summary_html + html_content
    return final_content

def main():
    input_csv = '/Users/log/Github/textual_grounding/logan/results/LogiQA/4o/grounded_fact/vanilla_log_cot_4o_1006_232349.csv'  
    ground_truth_file = '/Users/log/Github/textual_grounding/data/logiqa/test.jsonl'  # Path to the ground truth JSONL file
    output_file = 'logiqa_4o_vanilla_cot_simple.html'  # Desired output HTML file path

    # Check if input files exist
    if not os.path.isfile(input_csv):
        print(f"Input CSV file not found: {input_csv}")
        return
    if not os.path.isfile(ground_truth_file):
        print(f"Ground truth JSONL file not found: {ground_truth_file}")
        return

    # Parse the input CSV file to extract IDs, questions, and answers
    qa_pairs = parse_csv_file(input_csv)
    print(f"Total QA Pairs Parsed: {len(qa_pairs)}")

    # Read the ground truth answers
    ground_truth = read_ground_truth(ground_truth_file)
    print(f"Total Ground Truth Entries: {len(ground_truth)}")

    # Check if any QA pairs were found
    if not qa_pairs:
        print("No question-answer pairs were found in the input file.")
        return

    # Generate the HTML content
    html_content = create_simple_html(qa_pairs, ground_truth)

    # Write the HTML content to the output file
    with open(output_file, 'w', encoding='utf-8') as file:
        file.write(html_content)

    print(f"HTML content has been successfully written to {output_file}")

if __name__ == "__main__":
    main()


Total QA Pairs Parsed: 200
Total Ground Truth Entries: 297
HTML content has been successfully written to logiqa_4o_vanilla_cot_simple.html


In [None]:
from datasets import load_dataset

ds = load_dataset("tasksource/spartqa-mchoice")

# Response Statistics

In [19]:
import csv
import re
import json  # For handling JSONL
import os

def extract_parts_regular_cot(answer_text):
    """
    Processes the answer text to extract answer reasoning and the final answer.

    Args:
        answer_text (str): The full answer text containing answer reasoning and final answer.

    Returns:
        tuple: (answer_reasoning, final_answer, has_curly)
               where answer_reasoning is the full model response,
               final_answer is the extracted answer,
               and has_curly is a boolean indicating if the final answer was in curly brackets.
    """
    # Attempt to extract Final Answer from 'Final Answer:'
    final_match = re.search(r'Final Answer:\s*(\S+)', answer_text, re.IGNORECASE)
    if final_match and final_match.group(1).strip():
        final_answer = final_match.group(1).strip()
        has_curly = False
    else:
        # Fallback: Extract Final Answer from '{...}' in the reasoning
        curly_match = re.search(r'\{([\d.]+)\}', answer_text)
        final_answer = curly_match.group(1).strip() if curly_match else ""
        has_curly = bool(curly_match)

    return answer_text.strip(), final_answer, has_curly

def parse_csv_file(file_path):
    """
    Parses the input CSV file and extracts questions, answers, and their corresponding IDs.

    Args:
        file_path (str): Path to the input CSV file.

    Returns:
        list of tuples: Each tuple contains (id, question, answer_text).
    """
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            id_ = row.get('id')
            if id_ is not None:
                try:
                    id_int = int(id_)
                except ValueError:
                    print(f"Skipping a row due to invalid 'id' (not an integer): {id_}")
                    continue
                qa_pairs.append((id_int, question, answer_text))
            else:
                # Handle cases without 'id' by skipping
                print(f"Skipping a row due to missing 'id': {row}")
    return qa_pairs

def read_ground_truth(jsonl_path):
    """
    Reads the ground truth answers from a JSONL file and maps them by ID.

    Args:
        jsonl_path (str): Path to the ground truth JSONL file.

    Returns:
        dict: A dictionary mapping each ID to its ground truth answer.
    """
    ground_truth = {}
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            id_ = data.get('id')
            answer = data.get('answer')
            if id_ is not None and answer is not None:
                # Extract the last number or text after '####'
                match = re.search(r'####\s*([\d.]+)', answer)
                if match:
                    ground_truth[id_] = match.group(1).strip()
                else:
                    print(f"No ground truth answer found for ID {id_}")
            else:
                print(f"Invalid ground truth entry: {data}")
    return ground_truth

def create_statistics(qa_pairs, ground_truth):
    """
    Creates and prints statistics based on the QA pairs and ground truth.

    Args:
        qa_pairs (list of tuples): Each tuple contains (id, question, answer_text).
        ground_truth (dict): A dictionary mapping each ID to its ground truth answer.
    """
    total_responses = len(qa_pairs)
    responses_with_curly = 0
    responses_without_curly = 0
    correct_answers = 0
    incorrect_answers = 0
    no_ground_truth = 0

    # Variables for tag statistics
    total_tags = 0
    total_tag_length = 0
    tag_counts = []  # List to store number of tags per response
    tag_lengths = []  # List to store lengths of tag content across all responses

    for id_, question, answer_text in qa_pairs:
        try:
            answer_reasoning, final_answer, has_curly = extract_parts_regular_cot(answer_text)
        except Exception as e:
            print(f"Cannot extract parts for question ID {id_}: {e}")
            continue

        if has_curly:
            responses_with_curly += 1
        else:
            responses_without_curly += 1

        # Extract tags and their content
        tags_in_response = re.findall(r'<([A-Za-z]+\d*)>(.*?)</\1>', answer_text)
        number_of_tags = len(tags_in_response)
        tag_counts.append(number_of_tags)
        total_tags += number_of_tags

        for tag, content in tags_in_response:
            content_length = len(content)
            tag_lengths.append(content_length)
            total_tag_length += content_length

        # Retrieve ground truth answer
        gt_answer = ground_truth.get(id_)
        if gt_answer is None:
            no_ground_truth += 1
            continue

        # Compare final_answer with ground truth
        if final_answer == gt_answer:
            correct_answers += 1
        else:
            incorrect_answers += 1

    # Calculate additional metrics
    accuracy_percentage = (correct_answers / (correct_answers + incorrect_answers) * 100) if (correct_answers + incorrect_answers) > 0 else 0
    curly_percentage = (responses_with_curly / total_responses * 100) if total_responses > 0 else 0
    no_curly_percentage = (responses_without_curly / total_responses * 100) if total_responses > 0 else 0
    ground_truth_available = total_responses - no_ground_truth
    ground_truth_available_percentage = (ground_truth_available / total_responses * 100) if total_responses > 0 else 0

    # Calculate tag statistics
    average_tags_per_response = (total_tags / total_responses) if total_responses > 0 else 0
    average_tag_length = (total_tag_length / total_tags) if total_tags > 0 else 0

    # Print the statistics
    print("\n===== Analysis Statistics =====\n")
    print(f"Total Responses Analyzed: {total_responses}")
    print(f"Responses with Final Answer in Curly Brackets: {responses_with_curly} ({curly_percentage:.2f}%)")
    print(f"Responses without Final Answer in Curly Brackets: {responses_without_curly} ({no_curly_percentage:.2f}%)")
    print(f"Responses with Ground Truth Available: {ground_truth_available} ({ground_truth_available_percentage:.2f}%)")
    print(f"Correct Answers: {correct_answers}")
    print(f"Incorrect Answers: {incorrect_answers}")
    print(f"Accuracy: {accuracy_percentage:.2f}%")
    print(f"Responses without Ground Truth: {no_ground_truth}")

    # Tag Statistics
    print("\n----- Tag Statistics -----")
    print(f"Total Tags Found: {total_tags}")
    print(f"Average Number of Tags per Response: {average_tags_per_response:.2f}")
    print(f"Average Length of Tag Content: {average_tag_length:.2f} characters")
    print("--------------------------\n")
    print("===== End of Statistics =====\n")

def main():
    input_csv = '/Users/log/Github/textual_grounding/logan/results/GSM8K/llama/mermaid/mermaid_get_answer_llama3.1_20240926_215344.csv'  # Replace with your input CSV file path
    ground_truth_file = '/Users/log/Github/textual_grounding/data/GSM8K/test.jsonl'  # Path to the ground truth JSONL file

    # Check if input files exist
    if not os.path.isfile(input_csv):
        print(f"Input CSV file not found: {input_csv}")
        return
    if not os.path.isfile(ground_truth_file):
        print(f"Ground truth JSONL file not found: {ground_truth_file}")
        return

    # Parse the input CSV file to extract IDs, questions, and answers
    qa_pairs = parse_csv_file(input_csv)
    print(f"Total QA Pairs Parsed: {len(qa_pairs)}")  # Debug: Print the number of QA pairs parsed

    # Read the ground truth answers
    ground_truth = read_ground_truth(ground_truth_file)
    print(f"Total Ground Truth Entries: {len(ground_truth)}")  # Debug: Print the number of ground truth entries

    # Check if any QA pairs were found
    if not qa_pairs:
        print("No question-answer pairs were found in the input file.")
        return

    # Generate and print the statistics
    create_statistics(qa_pairs, ground_truth)

    print("Statistics analysis completed successfully.")

if __name__ == "__main__":
    main()


Total QA Pairs Parsed: 200
No ground truth answer found for ID 489
No ground truth answer found for ID 1113
Total Ground Truth Entries: 1317

===== Analysis Statistics =====

Total Responses Analyzed: 200
Responses with Final Answer in Curly Brackets: 136 (68.00%)
Responses without Final Answer in Curly Brackets: 64 (32.00%)
Responses with Ground Truth Available: 200 (100.00%)
Correct Answers: 94
Incorrect Answers: 106
Accuracy: 47.00%
Responses without Ground Truth: 0

----- Tag Statistics -----
Total Tags Found: 501
Average Number of Tags per Response: 2.50
Average Length of Tag Content: 8.68 characters
--------------------------

===== End of Statistics =====

Statistics analysis completed successfully.


In [2]:
import json

# Define the path to your JSON file
input_file = '/Users/log/Github/textual_grounding/data/AIW/test.json'

# Load the JSON data from the file
with open(input_file, 'r', encoding='utf-8') as file:
    try:
        data = json.load(file)
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
        exit(1)

# Process each entry in the JSON data
for idx, entry in enumerate(data):
    # Get the current prompt
    prompt = entry.get('prompt', '')

    # Check if prompt is a string
    if isinstance(prompt, str):
        delimiter = 'have?'
        index = prompt.find(delimiter)

        if index != -1:
            # Truncate the prompt after "have?"
            truncated_prompt = prompt[:index + len(delimiter)]
            entry['prompt'] = truncated_prompt
        else:
            print(f"Warning: 'have?' not found in prompt of entry ID {entry.get('id', 'Unknown')}. Prompt left unchanged.")
    elif isinstance(prompt, list):
        print(f"Warning: 'prompt' is a list in entry ID {entry.get('id', 'Unknown')}. Attempting to join into a string.")
        # Attempt to join the list into a single string
        joined_prompt = ' '.join(str(item) for item in prompt)
        delimiter = 'have?'
        index = joined_prompt.find(delimiter)

        if index != -1:
            truncated_prompt = joined_prompt[:index + len(delimiter)]
            entry['prompt'] = truncated_prompt
        else:
            print(f"Warning: 'have?' not found after joining prompt in entry ID {entry.get('id', 'Unknown')}. Prompt left unchanged.")
    else:
        print(f"Warning: 'prompt' is neither a string nor a list in entry ID {entry.get('id', 'Unknown')}. Prompt left unchanged.")

    # Rename 'right_answer' to 'answer' if it exists
    if 'right_answer' in entry:
        entry['answer'] = entry.pop('right_answer')

# Save the updated data back to the same JSON file
with open(input_file, 'w', encoding='utf-8') as file:
    json.dump(data, file, indent=4, ensure_ascii=False)

print("JSON file has been updated successfully.")


JSON file has been updated successfully.
