In [95]:
import os
import sys
import ollama
import google.generativeai as genai
import anthropic
import ollama
import random
import pandas as pd
from tqdm import tqdm
from google.generativeai.types import RequestOptions
from google.api_core import retry
from typing import List, Tuple
import json
import datetime

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import visualize
import pandas as pd
from utils.utils import add_color_to_tags, extract_parts_0, extract_parts_1
import argparse

# Main Functions

In [103]:
def save_results(save_path: str, ids: List[str], questions: List[str], answers: List[str], append: bool = False):
    """
    Saves the results to a CSV file. If append is True and the file exists, it appends without headers.
    Otherwise, it writes a new file with headers.
    """
    df = pd.DataFrame({'id': ids, 'question': questions, 'answer': answers})
    if append and os.path.exists(save_path):
        df.to_csv(save_path, mode='a', index=False, header=False)
    else:
        df.to_csv(save_path, index=False)

def read_jsonl_file(filepath: str) -> List[dict]:
    """
    Reads a JSONL file and returns a list of JSON objects.
    """
    data = []
    with open(filepath, 'r') as file:
        for line in file:
            json_obj = json.loads(line)
            data.append(json_obj)
    return data

def get_prompt(prompt_type: str, few_shot_prompt: str, question: str) -> str:
    """
    Constructs the prompt based on the prompt type.
    """
    prompts = {
        "cot": f"{few_shot_prompt}\n{question}\nPlease generate your explanation first, then generate the answer in the bracket as follow:\n" +"Answer: {}",
        "fs": f"{few_shot_prompt}\n{question}",
        "fs_inst": f"{few_shot_prompt}\n{question}\nI want you to answer this question but your explanation should contain references referring back to the information in the question. To do that, first, re-generate the question with proper tags and then generate your answers. The output format is as follow:\n\
            Reformatted Question: \
                Answer:",
        "zs": f"{question}\nI want you to answer this question but your explanation should contain references referring back to the information in the question. To do that, first, re-generate the question with proper tags (<a>, <b>, <c>, etc) for refered information and then generate your answers that also have the tag (<a>, <b>, <c>, etc) for the grounded information. Give your answer by analyzing step by step, and give only numbers in the final answer. The output format is as follow:\n\
            Reformatted Question: \
                Answer:\
                    Final answer:",
        "fs_xml": f"{few_shot_prompt}\n\nRecreate the following question in the style of the correctly formatted examples shown previously. Make sure that your response has all its information inclosed in the proper <tags>. Begin your response with the <key_facts> section. Make sure that every fact in <key_facts> is very concise and contains a very short reference to the <question>. Do not include a <question> section in your response\n\n<question>\n{question}\n</question>",
        "mermaid_get_answer": f"{few_shot_prompt}\n\n Your job is to extract the key facts from a question relevant to answering the question. The facts should be represented in a hierarchal format through a mermaid diagram. Do not create duplicate facts across multiple branches that represent the same information. Create a mermaid diagram that represents the key facts in the following question. Then, use the nodes from this graph to cite specific facts in your answer reasoning. Put your final answer in curly brackets e.g. Final_Answer: {{30}} \n\nquestion: {question}", 
    }
    return prompts.get(prompt_type, "")

# cycle through all the keys 
def get_gemini_key(problem_id):
    GOOGLE_KEYS = [
        # 'AIzaSyBQ7zvIZoET3199GNhuz86vKagn_JCEOmk', # original - gen lang client
        'AIzaSyCEI-5U4z7-3q-uwlvkOrdT2e78aNmjnbg', # chat app
        # 'AIzaSyCvycd0yZZ4GSj47qDLk4JoPemvzUSfvio', # project 1
        'AIzaSyD5xNbDkaJMEMBpWEXYNq5SheF6omdKpzg', # project 2
        'AIzaSyAjcrp_otRjGsj0YvB1cUc2BMng6KSEZwU', # project 3
    ]
    index = problem_id%len(GOOGLE_KEYS)
    print(f"getting key from {index}")
    key = GOOGLE_KEYS[index]
    return key

def query_gemini(prompt: str, problem_id) -> str:
    """
    Queries the Gemini LLM with the given prompt and returns the response text.
    """
    genai.configure(api_key=get_gemini_key(problem_id))
    model = genai.GenerativeModel('gemini-1.5-pro-latest')
    response = model.generate_content(prompt, request_options=RequestOptions(retry=retry.Retry(initial=20, multiplier=3, maximum=121, timeout=60)))
    text = response.candidates[0].content.parts[0].text
    return text

def query_claude(prompt: str) -> str:
    """
    Queries the Claude LLM with the given prompt and returns the response text.
    """
    client = anthropic.Anthropic(api_key=API_KEYS['claude'])
    response = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1024,
        messages=[{"role": "user", "content": prompt}]
    )
    return response.content[0].text

def query_llm(llm_model: str, ids: List[str], questions: List[str], few_shot_prompt: str, prompt_type: str, save_path: str, already_answered_ids: set) -> Tuple[List[str], List[str], List[str]]:
    """
    Queries the specified LLM for each question, skipping already answered ones.
    Saves each response immediately after it's obtained.
    Returns lists of answered IDs, questions, and answers.
    """
    answers = []
    ids_can_be_answered = []
    questions_can_be_answered = []
    
    for id, q in tqdm(zip(ids, questions), total=len(ids)):
        # print(f"Processing ID: {id}")
        if id in already_answered_ids:
            print(f"Skipping already answered ID: {id}")
            continue
        
        prompt = get_prompt(prompt_type, few_shot_prompt, q)
        try:
            if llm_model == 'gemini':
                answer = query_gemini(prompt, id)
            elif llm_model == 'claude':
                answer = query_claude(prompt)
            elif llm_model == 'llama3.1':
                answer = ollama.generate(model='llama3.1', prompt=prompt)['response']
                print(f"Processed ID: {id}")
            else:
                raise ValueError(f"Unsupported LLM model: {llm_model}")
            # print(f"Answer for ID {id}: {answer}")
            
            # Append to lists
            answers.append(answer)
            questions_can_be_answered.append(q)
            ids_can_be_answered.append(id)

            # Save after each answer
            save_results(save_path, [id], [q], [answer], append=True)
        except Exception as e:
            print(f"Error processing question {id}: {str(e)}")
            continue
    
    return ids_can_be_answered, questions_can_be_answered, answers

def load_data(data_path: str, sample_size: int = None) -> Tuple[List[str], List[str]]:
    """
    Loads data from a JSONL file, optionally sampling a subset.
    """
    data = read_jsonl_file(data_path)
    print(f"Loaded {len(data)} records from: {data_path}")
    if sample_size:
        data = random.sample(data, sample_size)
        print(f"Sampled {sample_size} records.")
    questions = [x["question"] for x in data]
    ids = [x["id"] for x in data]
    return ids, questions

def load_data_deterministic(data_path: str, sample_size: int = None) -> Tuple[List[str], List[str]]:
    """
    Loads data from a JSONL file in a deterministic manner by sorting.
    """
    data = read_jsonl_file(data_path)
    print(f"Loaded {len(data)} records from: {data_path}")
    if sample_size:
        # Sort the data based on a consistent criterion (e.g., 'id' or 'question')
        sorted_data = sorted(data, key=lambda x: x['id'])
        # Take the first 'sample_size' items
        data = sorted_data[:sample_size]
        print(f"Selected first {sample_size} records after sorting.")
    questions = [x["question"] for x in data]
    ids = [x["id"] for x in data]
    return ids, questions

def load_few_shot_prompt(prompt_path: str) -> str:
    """
    Loads the few-shot prompt from a text file.
    """
    with open(prompt_path, 'r') as file:
        prompt = file.read()
    # print(f"Loaded few-shot prompt from: {prompt_path}")
    return prompt

def load_already_answered_ids(save_path: str) -> set:
    """
    Loads the set of IDs that have already been answered from the CSV file.
    Returns an empty set if the file does not exist.
    """
    if os.path.exists(save_path):
        df = pd.read_csv(save_path)
        answered_ids = set(df['id'].astype(int).tolist())
        # print(f"Loaded {len(answered_ids)} already answered IDs from: {save_path}")
        print(f"Already answered IDs: {answered_ids}")
        return answered_ids
    else:
        print(f"No existing save file found at: {save_path}. Starting fresh.")
        return set()

def initialize_save_file(save_path: str):
    """
    Initializes the CSV file with headers if it doesn't exist.
    """
    if not os.path.exists(save_path):
        # Create an empty DataFrame with headers and save
        df = pd.DataFrame(columns=['id', 'question', 'answer'])
        df.to_csv(save_path, index=False)
        print(f"Initialized new save file with headers at: {save_path}")

# Driver

In [102]:
# time = datetime.datetime.now().strftime("%m%d_%H%M%S")
time = '0929_010513'
project_root = '/Users/log/Github/textual_grounding/'
dataset = 'GSM8K'

# llm_model = 'llama3.1'
llm_model = 'gemini'
prompt_type = 'mermaid_get_answer'
# prompt_type = 'cot'
few_shot_txt = 'fewshot_mermaid_full.txt'
# few_shot_txt = '3shot_cot.txt'

# Paths
data_path = os.path.join(project_root, 'data', dataset, 'test.jsonl')
# data_path = '/Users/log/Github/textual_grounding/logan/results/GSM8K/llama/mermaid/mermaid_get_graph_llama3.1_20240924_001821.csv'

fewshot_prompt_path = os.path.join(project_root, "prompt", dataset, few_shot_txt)
save_dir = os.path.join(project_root, 'logan/results', dataset, f'{llm_model}/mermaid')
os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists
save_path = os.path.join(save_dir, f'{prompt_type}_{llm_model}_{time}.csv')

ids, questions = load_data_deterministic(data_path, sample_size=200)  # Set sample_size as needed
few_shot_prompt = load_few_shot_prompt(fewshot_prompt_path)

initialize_save_file(save_path)
already_answered_ids = load_already_answered_ids(save_path)


ids_answered, questions_answered, answers = query_llm(
    llm_model=llm_model,
    ids=ids,
    questions=questions,
    few_shot_prompt=few_shot_prompt,
    prompt_type=prompt_type,
    save_path=save_path,
    already_answered_ids=already_answered_ids
)

print(f"Processing complete. {len(ids_answered)} new answers saved to {save_path}.")

Loaded 1319 records from: /Users/log/Github/textual_grounding/data/GSM8K/test.jsonl
Selected first 200 records after sorting.
Already answered IDs: {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 23}


  0%|          | 0/200 [00:00<?, ?it/s]

Skipping already answered ID: 0
Skipping already answered ID: 1
Skipping already answered ID: 2
Skipping already answered ID: 3
Skipping already answered ID: 4
Skipping already answered ID: 5
Skipping already answered ID: 6
Skipping already answered ID: 7
Skipping already answered ID: 8
Skipping already answered ID: 9
Skipping already answered ID: 10
Skipping already answered ID: 11
Skipping already answered ID: 12
Skipping already answered ID: 13
Skipping already answered ID: 14
Skipping already answered ID: 15
Skipping already answered ID: 16
Skipping already answered ID: 17
Skipping already answered ID: 18
getting key from 19


I0000 00:00:1727592990.556221 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported
 10%|█         | 20/200 [00:06<00:58,  3.07it/s]I0000 00:00:1727592997.079952 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


Skipping already answered ID: 20
getting key from 21


 11%|█         | 22/200 [00:10<01:37,  1.83it/s]I0000 00:00:1727593001.277159 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 22


 12%|█▏        | 23/200 [00:16<02:48,  1.05it/s]I0000 00:00:1727593006.749806 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


Skipping already answered ID: 23
getting key from 24


 12%|█▎        | 25/200 [00:19<03:13,  1.11s/it]I0000 00:00:1727593010.291306 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 25


 13%|█▎        | 26/200 [00:26<05:05,  1.75s/it]I0000 00:00:1727593016.826085 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 26


 14%|█▎        | 27/200 [01:20<27:02,  9.38s/it]I0000 00:00:1727593070.909620 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 27


 14%|█▍        | 28/200 [01:25<24:53,  8.68s/it]I0000 00:00:1727593076.241424 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 28


 14%|█▍        | 29/200 [01:30<22:24,  7.86s/it]I0000 00:00:1727593080.773128 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 29


 15%|█▌        | 30/200 [01:34<19:57,  7.04s/it]I0000 00:00:1727593084.913776 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 30


 16%|█▌        | 31/200 [01:39<18:34,  6.60s/it]I0000 00:00:1727593090.092361 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 31


 16%|█▌        | 32/200 [02:39<56:27, 20.17s/it]I0000 00:00:1727593149.969195 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 32


 16%|█▋        | 33/200 [02:43<43:52, 15.76s/it]I0000 00:00:1727593153.625955 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 33


 17%|█▋        | 34/200 [02:48<35:34, 12.86s/it]I0000 00:00:1727593158.874602 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 34


 18%|█▊        | 35/200 [02:52<28:47, 10.47s/it]I0000 00:00:1727593163.275016 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 35


 18%|█▊        | 36/200 [03:07<32:20, 11.83s/it]I0000 00:00:1727593178.488944 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 36


 18%|█▊        | 37/200 [03:11<25:54,  9.54s/it]I0000 00:00:1727593182.440245 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 37


 19%|█▉        | 38/200 [03:15<20:59,  7.78s/it]I0000 00:00:1727593185.981487 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 38


 20%|█▉        | 39/200 [03:20<18:25,  6.87s/it]I0000 00:00:1727593190.676827 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 39


 20%|██        | 40/200 [03:26<17:57,  6.74s/it]I0000 00:00:1727593197.108460 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 40


 20%|██        | 41/200 [04:22<56:52, 21.46s/it]I0000 00:00:1727593253.287021 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 41


 21%|██        | 42/200 [04:27<43:32, 16.53s/it]I0000 00:00:1727593258.238063 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 42


 22%|██▏       | 43/200 [04:31<33:30, 12.80s/it]I0000 00:00:1727593262.297670 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 43


 22%|██▏       | 44/200 [04:39<29:40, 11.42s/it]I0000 00:00:1727593270.458611 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 44


 22%|██▎       | 45/200 [04:45<24:59,  9.68s/it]I0000 00:00:1727593276.068699 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 45


 23%|██▎       | 46/200 [05:10<36:57, 14.40s/it]I0000 00:00:1727593301.511783 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 46


 24%|██▎       | 47/200 [05:15<29:04, 11.40s/it]I0000 00:00:1727593305.915301 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 47


 24%|██▍       | 48/200 [05:20<24:06,  9.51s/it]I0000 00:00:1727593311.015388 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 48


 24%|██▍       | 49/200 [05:23<19:05,  7.59s/it]I0000 00:00:1727593314.109617 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 49


 25%|██▌       | 50/200 [05:27<16:21,  6.54s/it]I0000 00:00:1727593318.203869 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 50


 26%|██▌       | 51/200 [06:16<47:31, 19.14s/it]I0000 00:00:1727593366.749453 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 51


 26%|██▌       | 52/200 [06:19<35:41, 14.47s/it]I0000 00:00:1727593370.326703 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 52


 26%|██▋       | 53/200 [06:24<28:21, 11.57s/it]I0000 00:00:1727593375.140681 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 53


 27%|██▋       | 54/200 [06:36<28:41, 11.79s/it]I0000 00:00:1727593387.426216 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 54


 28%|██▊       | 55/200 [06:43<24:37, 10.19s/it]

Error processing question 54: Timeout of 60.0s exceeded, last exception: 429 Resource has been exhausted (e.g. check quota).
getting key from 55


I0000 00:00:1727593393.879657 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported
 28%|██▊       | 56/200 [06:46<19:37,  8.18s/it]I0000 00:00:1727593397.358140 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 56


 28%|██▊       | 57/200 [06:50<16:11,  6.79s/it]

Error processing question 56: Timeout of 60.0s exceeded, last exception: 429 Resource has been exhausted (e.g. check quota).
getting key from 57


I0000 00:00:1727593400.931338 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported
 29%|██▉       | 58/200 [06:51<11:55,  5.04s/it]

Error processing question 57: Timeout of 60.0s exceeded, last exception: 429 Resource has been exhausted (e.g. check quota).
getting key from 58


I0000 00:00:1727593401.865036 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported
 30%|██▉       | 59/200 [07:05<18:37,  7.93s/it]I0000 00:00:1727593416.534430 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 59


 30%|███       | 60/200 [07:09<15:08,  6.49s/it]I0000 00:00:1727593419.683362 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 60


 30%|███       | 61/200 [07:13<13:28,  5.82s/it]I0000 00:00:1727593423.922787 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 61


 31%|███       | 62/200 [07:17<12:09,  5.29s/it]I0000 00:00:1727593427.978034 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 62


 32%|███▏      | 63/200 [07:22<12:01,  5.27s/it]I0000 00:00:1727593433.196449 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported


getting key from 63


 32%|███▏      | 64/200 [07:39<19:38,  8.67s/it]

Error processing question 63: Timeout of 60.0s exceeded, last exception: 429 Resource has been exhausted (e.g. check quota).
getting key from 64


I0000 00:00:1727593449.789357 29728448 check_gcp_environment_no_op.cc:29] ALTS: Platforms other than Linux and Windows are not supported
 32%|███▏      | 64/200 [07:47<16:33,  7.30s/it]


KeyboardInterrupt: 

# Visualization

## XML - visualize

In [83]:
import csv
import re

def extract_parts_1(answer_text):
    """
    Processes the answer text to extract key facts (with numbers), answer reasoning, and the final answer.

    Args:
        answer_text (str): The full answer text containing <key_facts>, <answer_reasoning>, and <final_answer>.

    Returns:
        tuple: (key_facts_list, answer_reasoning, final_answer)
               where key_facts_list is a list of tuples (fact_number, fact_content)
    """
    # Extract key_facts
    key_facts_match = re.search(r'<key_facts>(.*?)</key_facts>', answer_text, re.DOTALL)
    key_facts_content = key_facts_match.group(1).strip() if key_facts_match else ""

    # Extract individual facts with their numbers
    facts = re.findall(r'<fact_(\d+)>(.*?)</fact_\d+>', key_facts_content, re.DOTALL)
    key_facts_list = [(number.strip(), content.strip()) for number, content in facts]

    # Extract answer_reasoning
    reasoning_match = re.search(r'<answer_reasoning>(.*?)</answer_reasoning>', answer_text, re.DOTALL)
    answer_reasoning = reasoning_match.group(1).strip() if reasoning_match else ""

    # Extract final_answer
    final_match = re.search(r'<final_answer>(.*?)</final_answer>', answer_text, re.DOTALL)
    final_answer = final_match.group(1).strip() if final_match else ""

    return key_facts_list, answer_reasoning, final_answer


def add_color_to_tags(text):
    """
    Adds background color to specific tags within the text based on a predefined color mapping.

    Args:
        text (str): The text containing tags like <fact_1>, <fact_2>, etc.

    Returns:
        str: The text with added inline CSS for background colors.
    """
    tag_color_mapping = {
        'fact_1': 'yellow',  
        'fact_2': 'lightblue',
        'fact_3': 'lightgreen',
        'fact_4': 'lightcoral',
        'fact_5': 'lightcyan', 
        'fact_6': 'orange',
    }
    # Iterate over the tag-color mappings
    for tag, color in tag_color_mapping.items():
        # Regex to find the tag and replace it with the same tag having a style attribute
        text = re.sub(
            f'<{tag}>(.*?)</{tag}>',
            f'<{tag} style="background-color: {color};">\\1</{tag}>',
            text,
            flags=re.DOTALL
        )
    return text


def parse_csv_file(file_path):
    """
    Parses the input CSV file and extracts questions and their corresponding answers.

    Args:
        file_path (str): Path to the input CSV file.

    Returns:
        list of tuples: Each tuple contains (question, answer_text).
    """
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            qa_pairs.append((question, answer_text))
    return qa_pairs


def create_highlight_html(qa_pairs):
    """
    Creates HTML content with highlighted questions, key facts, answer reasoning, and answers.

    Args:
        qa_pairs (list of tuples): Each tuple contains (question, answer_text).

    Returns:
        str: The complete HTML content as a string.
    """
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Question and Answer Highlights</title>
        <style>
            body {
                font-family: Arial, sans-serif;
                margin: 20px;
                background-color: #f0f0f0;
            }
            .container {
                background-color: #ffffff;
                padding: 20px;
                margin-bottom: 20px;
                border-radius: 8px;
                box-shadow: 0 2px 5px rgba(0,0,0,0.1);
            }
            .question {
                font-size: 1.2em;
                margin-bottom: 10px;
            }
            .key-facts {
                margin-bottom: 10px;
            }
            .key-facts ul {
                list-style-type: number;
                padding-left: 20px;
            }
            .key-facts ul li{
                margin-bottom: 4px;
            }
            .answer-reasoning, .final-answer {
                margin-bottom: 10px;
            }
            .highlight {
                background-color: #FFFF00; /* Yellow background for visibility */
                font-weight: bold; /* Bold text for emphasis */
            }
            /* Styles for specific facts */
            fact_1 {
                background-color: yellow;
                font-weight: bold;
            }
            fact_2 {
                background-color: lightblue;
                font-weight: bold;
            }
            fact_3 {
                background-color: lightgreen;
                font-weight: bold;
            }
            fact_4 {
                background-color: lightcoral;
                font-weight: bold;
            }
            fact_5 {
                background-color: lightcyan;
                font-weight: bold;
            }
            fact_6 {
                background-color: orange;
                font-weight: bold;
            }
        </style>
    </head>
    <body>
    <h1>Question and Answer Highlights</h1>
    """
    for i, (question, answer_text) in enumerate(qa_pairs, 1):
        try:
            key_facts, answer_reasoning, final_answer = extract_parts_1(answer_text)
        except Exception as e:
            print(f"Cannot extract parts for question {i}: {e}")
            continue

        # Convert key_facts list to HTML bullet points with "Key fact X:" prefix
        if key_facts:
            key_facts_html = "<ul>\n"
            for fact_number, fact_content in key_facts:
                # Apply color to tags in fact_content
                highlighted_fact = add_color_to_tags(fact_content)
                # Prepend "Key fact X:"
                key_facts_html += f"    <li><fact_{fact_number}>{highlighted_fact}</fact_{fact_number}></li>\n"
            key_facts_html += "</ul>"
        else:
            key_facts_html = "<p>No key facts available.</p>"

        # Apply color to tags in answer_reasoning and final_answer
        highlighted_reasoning = add_color_to_tags(answer_reasoning)
        highlighted_final_answer = add_color_to_tags(final_answer)

        # Build the HTML structure
        html_content += f"<div class='container'>"
        html_content += f"<div class='question'><strong>Question:</strong> {question}</div>"
        html_content += f"<div class='key-facts'><strong>Key Facts:</strong> {key_facts_html}</div>"
        html_content += f"<div class='answer-reasoning'><strong>Answer Reasoning:</strong> {highlighted_reasoning}</div>"
        html_content += f"<div class='final-answer'><strong>Answer:</strong> {highlighted_final_answer}</div>"
        html_content += "</div>\n"

    # Close the HTML tags
    html_content += """
    </body>
    </html>
    """
    return html_content


def main():
    input_file = '/Users/log/Github/textual_grounding/logan/results/GSM8K/llama/test_grounding_answer_prompt_fs_xml_llama3.1.csv'  # Replace with your input CSV file path
    output_file = 'test_grounding_answer_prompt_fs_xml_llama3.1.html'  # Replace with your desired output HTML file path

    # Parse the input CSV file to extract questions and answers
    qa_pairs = parse_csv_file(input_file)

    # Check if any QA pairs were found
    if not qa_pairs:
        print("No question-answer pairs were found in the input file.")
        return

    # Generate the HTML content
    html_content = create_highlight_html(qa_pairs)

    # Write the HTML content to the output file
    with open(output_file, 'w', encoding='utf-8') as file:
        file.write(html_content)

    print(f"HTML content has been successfully written to {output_file}")


if __name__ == "__main__":
    main()


HTML content has been successfully written to test_grounding_answer_prompt_fs_xml_llama3.1.html


## Mermaid - Visualize

In [12]:
import csv
import re
import json  # For handling JSONL
import os

def extract_parts_new_format(answer_text):
    """
    Processes the answer text to extract Mermaid diagrams, answer reasoning, and the final answer.

    Args:
        answer_text (str): The full answer text containing Mermaid diagrams, answer reasoning, and final answer.

    Returns:
        tuple: (mermaid_diagrams, answer_reasoning, final_answer)
               where mermaid_diagrams is a list of Mermaid diagram strings
    """
    # Extract Mermaid diagrams
    mermaid_diagrams = re.findall(r'```mermaid\s*(.*?)```', answer_text, re.DOTALL | re.IGNORECASE)

    # Remove Mermaid diagrams from the answer_text to process the remaining text
    answer_text_clean = re.sub(r'```mermaid\s*.*?```', '', answer_text, flags=re.DOTALL | re.IGNORECASE).strip()

    # Extract answer_reasoning: Take everything after the last indented line
    # An indented line starts with spaces or tabs
    indented_lines = list(re.finditer(r'^[ \t]+.*$', answer_text_clean, re.MULTILINE))
    if indented_lines:
        last_indented = indented_lines[-1]
        reasoning_start = last_indented.end()
        answer_reasoning = answer_text_clean[reasoning_start:].strip()
    else:
        # If no indented lines are found, fallback to extracting everything after the first line
        parts = answer_text_clean.split('\n', 1)
        answer_reasoning = parts[1].strip() if len(parts) > 1 else ""

    # Extract final_answer: Any number between {curly braces}
    final_match = re.search(r'\{([\d,.\-]+)\}', answer_text_clean)
    final_answer = final_match.group(1).replace(',', '') if final_match else ""

    return mermaid_diagrams, answer_reasoning, final_answer

def add_color_to_tags_new(text):
    """
    Adds background color to specific tags within the text based on dynamically assigned colors.
    Each span will have a class corresponding to the tag's name.

    Args:
        text (str): The text containing tags like <B>, <C1>, etc.

    Returns:
        str: The text with added inline CSS for background colors and class names.
    """
    # Find all unique tags in the text using regex
    # This regex matches tags like <B>, <C1>, <Node123>, etc.
    tags = set(re.findall(r'<([A-Za-z]+\d*)>', text))

    # Predefined color palette
    color_palette = [
        'lightyellow', 'lightblue', 'lightgreen', 'lightcoral',
        'lightcyan', 'lightpink', 'lightsalmon', 'lightgray',
        'lightgoldenrodyellow', 'lightseagreen', 'lightskyblue',
        'lightsteelblue'
    ]

    # Dictionary to hold tag-color mapping
    tag_color_mapping = {}

    # Assign colors to tags, cycling through the color palette if necessary
    for i, tag in enumerate(sorted(tags)):
        color = color_palette[i % len(color_palette)]
        tag_color_mapping[tag] = color

    # Function to replace tags with styled spans including class names
    def replace_tag(match):
        tag = match.group(1)
        content = match.group(2)
        color = tag_color_mapping.get(tag, 'lightgray')  # Default color if not found
        return f'<span class="{tag}" style="background-color: {color}; font-weight: bold;">{content}</span>'

    # Regex to find tags and replace them with styled spans
    # This regex matches <Tag>Content</Tag> or <Tag> Content </Tag>
    text = re.sub(r'<([A-Za-z]+\d*)>\s*(.*?)\s*</\1>', replace_tag, text, flags=re.DOTALL)

    return text

def parse_csv_file(file_path):
    """
    Parses the input CSV file and extracts questions, answers, and their corresponding IDs.

    Args:
        file_path (str): Path to the input CSV file.

    Returns:
        list of tuples: Each tuple contains (id, question, answer_text).
    """
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            id_ = row.get('id')
            if id_ is not None:
                try:
                    id_int = int(id_)
                except ValueError:
                    print(f"Skipping a row due to invalid 'id' (not an integer): {id_}")
                    continue
                qa_pairs.append((id_int, question, answer_text))
            else:
                # Handle cases without 'id' by skipping
                print(f"Skipping a row due to missing 'id': {row}")
    return qa_pairs

def read_ground_truth(jsonl_path):
    """
    Reads the ground truth answers from a JSONL file and maps them by ID.

    Args:
        jsonl_path (str): Path to the ground truth JSONL file.

    Returns:
        dict: A dictionary mapping each ID to its ground truth answer.
    """
    ground_truth = {}
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            id_ = data.get('id')
            answer = data.get('answer')
            if id_ is not None and answer is not None:
                # Extract the last number after '####' if present
                match = re.search(r'####\s*([\d.]+)', answer)
                if match:
                    ground_truth[id_] = match.group(1).strip()
                else:
                    # If '####' not found, try to extract any number within {}
                    match_braces = re.search(r'\{([\d,.\-]+)\}', answer)
                    if match_braces:
                        ground_truth[id_] = match_braces.group(1).replace(',', '').strip()
                    else:
                        print(f"No ground truth answer found for ID {id_}")
            else:
                print(f"Invalid ground truth entry: {data}")
    return ground_truth

def create_highlight_html_new(qa_pairs, ground_truth):
    """
    Creates HTML content with Mermaid diagrams (both pure text and rendered), highlighted answer reasoning,
    final answers colored based on correctness, and ground truth answers.

    Args:
        qa_pairs (list of tuples): Each tuple contains (id, question, answer_text).
        ground_truth (dict): A dictionary mapping each ID to its ground truth answer.

    Returns:
        str: The complete HTML content as a string.
    """
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Question and Answer Highlights</title>
        <style>
            body {
                font-family: Arial, sans-serif;
                margin: 20px;
                background-color: #f0f0f0;
            }
            .container {
                background-color: #ffffff;
                padding: 20px;
                margin-bottom: 20px;
                border-radius: 8px;
                box-shadow: 0 2px 5px rgba(0,0,0,0.1);
            }
            .question {
                font-size: 1.2em;
                margin-bottom: 10px;
            }
            .mermaid {
                margin-bottom: 10px;
            }
            .mermaidPure {
                background-color: #f9f9f9;
                padding: 10px;
                border: 1px solid #ddd;
                border-radius: 4px;
                white-space: pre-wrap; /* Preserves whitespace and newlines */
                font-family: Consolas, "Courier New", monospace;
                margin-bottom: 10px;
            }
            .answer-reasoning, .final-answer, .ground-truth-answer {
                margin-bottom: 10px;
            }
            .final-answer {
                font-weight: bold;
            }
            .ground-truth-answer {
                font-weight: bold;
            }
            /* Styles for the highlighted spans */
            .highlighted {
                padding: 2px 4px;
                border-radius: 3px;
                display: inline-block;
            }
            /* Styles for the summary section */
            .summary {
                background-color: #e0ffe0;
                padding: 15px;
                border: 2px solid #00cc00;
                border-radius: 8px;
                font-size: 1.2em;
                margin-top: 30px;
            }
        </style>
        <!-- Include Mermaid.js -->
        <script type="module">
            import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
            mermaid.initialize({ startOnLoad: true });
        </script>
    </head>
    <body>
    <h1>Question and Answer Highlights</h1>
    """

    # Initialize counters for correct and total answers
    correct_answers = 0
    total_answers = 0

    for i, (id_, question, answer_text) in enumerate(qa_pairs, 1):
        try:
            mermaid_diagrams, answer_reasoning, final_answer = extract_parts_new_format(answer_text)
        except Exception as e:
            print(f"Cannot extract parts for question ID {id_}: {e}")
            continue

        # Convert mermaid diagrams to HTML divs with pure text and rendered version
        mermaid_html = ""
        for diagram in mermaid_diagrams:
            # Escape HTML special characters in the pure text version
            escaped_diagram = diagram.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;')
            mermaid_html += f"<div class='mermaidPure'><pre>{escaped_diagram}</pre></div>\n"
            mermaid_html += f"<div class='mermaid'>\n{diagram}\n</div>\n"

        # Apply color to tags in answer_reasoning
        highlighted_reasoning = add_color_to_tags_new(answer_reasoning)

        # Retrieve ground truth answer
        gt_answer = ground_truth.get(id_)
        if gt_answer is None:
            gt_answer_display = "<span style='color: gray;'>Ground truth not available.</span>"
            is_correct = False
        else:
            # Normalize both final_answer and gt_answer for comparison
            try:
                final_answer_num = float(final_answer.replace(',', '').replace('$', ''))
                gt_answer_num = float(gt_answer.replace(',', '').replace('$', ''))
                is_correct = final_answer_num == gt_answer_num
                final_answer_display = f"{final_answer_num:,.2f}"
                gt_answer_display = f"{gt_answer_num:,.2f}"
            except ValueError:
                # In case conversion fails, fallback to string comparison
                is_correct = final_answer == gt_answer
                final_answer_display = final_answer
                gt_answer_display = gt_answer

        # Style the final answer based on correctness
        if is_correct:
            highlighted_final_answer = f"<span style='font-size:1.1em; color: green;'>{final_answer_display}</span>"
            correct_answers += 1
        else:
            highlighted_final_answer = f"<span style='font-size:1.1em; color: red;'>{final_answer_display}</span>"
        total_answers += 1

        # Display ground truth answer
        if gt_answer is not None:
            ground_truth_html = f"<div class='ground-truth-answer'><strong>Ground Truth Answer:</strong> {gt_answer_display}</div>"
        else:
            ground_truth_html = f"<div class='ground-truth-answer'><strong>Ground Truth Answer:</strong> Not available.</div>"

        # Build the HTML structure
        html_content += f"<div class='container'>"
        html_content += f"<div class='question'><strong>Question:</strong> {question}</div>"
        if mermaid_html:
            html_content += f"{mermaid_html}"
        else:
            html_content += f"<p>No diagram available.</p>"
        html_content += f"<div class='answer-reasoning'><strong>Answer Reasoning:</strong> {highlighted_reasoning}</div>"
        html_content += f"<div class='final-answer'><strong>Final Answer:</strong> {highlighted_final_answer}</div>"
        html_content += f"{ground_truth_html}"
        html_content += "</div>\n"

    # After processing all QA pairs, add the summary section
    summary_percentage = (correct_answers / total_answers * 100) if total_answers > 0 else 0
    summary_html = f"""
    <div class='summary'>
        <strong>Summary:</strong> Correct Answers: {correct_answers} / {total_answers} ({summary_percentage:.2f}%)
    </div>
    """
    html_content = summary_html + html_content

    # Close the HTML tags
    html_content += """
    </body>
    </html>
    """
    return html_content

def main():
    input_csv = '/Users/log/Github/textual_grounding/logan/results/GSM8K/llama/mermaid/mermaid_get_answer_llama3.1_20240926_215344.csv'  # Replace with your input CSV file path
    ground_truth_file = '/Users/log/Github/textual_grounding/data/GSM8K/test.jsonl'  # Path to the ground truth JSONL file
    output_file = 'mm3_llama3.1.html'  # Replace with your desired output HTML file path

    # Check if input files exist
    if not os.path.isfile(input_csv):
        print(f"Input CSV file not found: {input_csv}")
        return
    if not os.path.isfile(ground_truth_file):
        print(f"Ground truth JSONL file not found: {ground_truth_file}")
        return

    # Parse the input CSV file to extract IDs, questions, and answers
    qa_pairs = parse_csv_file(input_csv)
    print(f"Total QA Pairs Parsed: {len(qa_pairs)}")  # Debug: Print the number of QA pairs parsed

    # Read the ground truth answers
    ground_truth = read_ground_truth(ground_truth_file)
    print(f"Total Ground Truth Entries: {len(ground_truth)}")  # Debug: Print the number of ground truth entries

    # Check if any QA pairs were found
    if not qa_pairs:
        print("No question-answer pairs were found in the input file.")
        return

    # Generate the HTML content
    html_content = create_highlight_html_new(qa_pairs, ground_truth)

    # Write the HTML content to the output file
    with open(output_file, 'w', encoding='utf-8') as file:
        file.write(html_content)

    print(f"HTML content has been successfully written to {output_file}")

if __name__ == "__main__":
    main()


Total QA Pairs Parsed: 200
No ground truth answer found for ID 489
No ground truth answer found for ID 1113
Total Ground Truth Entries: 1317
HTML content has been successfully written to mm3_llama3.1.html


## CoT - Visualize

In [17]:
import csv
import re
import json  # For handling JSONL
import os

def extract_parts_regular_cot(answer_text):
    """
    Processes the answer text to extract answer reasoning and the final answer.

    Args:
        answer_text (str): The full answer text containing answer reasoning and final answer.

    Returns:
        tuple: (answer_reasoning, final_answer)
               where answer_reasoning is the full model response,
               and final_answer is the extracted answer.
    """

    # Fallback: Extract Final Answer from '{...}' in the reasoning
    curly_match = re.search(r'\{([\d.]+)\}', answer_text)
    final_answer = curly_match.group(1).strip() if curly_match else ""

    return answer_text.strip(), final_answer


def add_color_to_tags_new(text):
    """
    Adds background color to specific tags within the text based on dynamically assigned colors.
    Each span will have a class corresponding to the tag's name.

    Args:
        text (str): The text containing tags like <B>, <C1>, etc.

    Returns:
        str: The text with added inline CSS for background colors and class names.
    """
    # Find all unique tags in the text using regex
    # This regex matches tags like <B>, <C1>, <Node123>, etc.
    tags = set(re.findall(r'<([A-Za-z]+\d*)>', text))

    # Predefined color palette
    color_palette = [
        'lightyellow', 'lightblue', 'lightgreen', 'lightcoral',
        'lightcyan', 'lightpink', 'lightsalmon', 'lightgray',
        'lightgoldenrodyellow', 'lightseagreen', 'lightskyblue',
        'lightsteelblue', 'lightsteelblue', 'lightsteelblue'
    ]

    # Dictionary to hold tag-color mapping
    tag_color_mapping = {}

    # Assign colors to tags, cycling through the color palette if necessary
    for i, tag in enumerate(sorted(tags)):
        color = color_palette[i % len(color_palette)]
        tag_color_mapping[tag] = color

    # Function to replace tags with styled spans including class names
    def replace_tag(match):
        tag = match.group(1)
        content = match.group(2)
        color = tag_color_mapping.get(tag, 'lightgray')  # Default color if not found
        return f'<span class="{tag}" style="background-color: {color}; font-weight: bold;">{content}</span>'

    # Regex to find tags and replace them with styled spans
    # This regex matches <Tag>Content</Tag>
    text = re.sub(r'<([A-Za-z]+\d*)>(.*?)</\1>', replace_tag, text, flags=re.DOTALL)

    return text


def parse_csv_file(file_path):
    """
    Parses the input CSV file and extracts questions, answers, and their corresponding IDs.

    Args:
        file_path (str): Path to the input CSV file.

    Returns:
        list of tuples: Each tuple contains (id, question, answer_text).
    """
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            id_ = row.get('id')
            if id_ is not None:
                try:
                    id_int = int(id_)
                except ValueError:
                    print(f"Skipping a row due to invalid 'id' (not an integer): {id_}")
                    continue
                qa_pairs.append((id_int, question, answer_text))
            else:
                # Handle cases without 'id' by skipping
                print(f"Skipping a row due to missing 'id': {row}")
    return qa_pairs


def read_ground_truth(jsonl_path):
    """
    Reads the ground truth answers from a JSONL file and maps them by ID.

    Args:
        jsonl_path (str): Path to the ground truth JSONL file.

    Returns:
        dict: A dictionary mapping each ID to its ground truth answer.
    """
    ground_truth = {}
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            id_ = data.get('id')
            answer = data.get('answer')
            if id_ is not None and answer is not None:
                # Extract the last number or text after '####'
                match = re.search(r'####\s*([\d.]+)', answer)
                if match:
                    ground_truth[id_] = match.group(1).strip()
                else:
                    print(f"No ground truth answer found for ID {id_}")
            else:
                print(f"Invalid ground truth entry: {data}")
    return ground_truth


def create_highlight_html_new(qa_pairs, ground_truth):
    """
    Creates HTML content with highlighted answer reasoning,
    final answers colored based on correctness, and ground truth answers.

    Args:
        qa_pairs (list of tuples): Each tuple contains (id, question, answer_text).
        ground_truth (dict): A dictionary mapping each ID to its ground truth answer.

    Returns:
        str: The complete HTML content as a string.
    """
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Question and Answer Highlights</title>
        <style>
            body {
                font-family: Arial, sans-serif;
                margin: 20px;
                background-color: #f0f0f0;
            }
            .container {
                background-color: #ffffff;
                padding: 20px;
                margin-bottom: 20px;
                border-radius: 8px;
                box-shadow: 0 2px 5px rgba(0,0,0,0.1);
            }
            .question {
                font-size: 1.2em;
                margin-bottom: 10px;
            }
            .answer-reasoning, .final-answer, .ground-truth-answer {
                margin-bottom: 10px;
            }
            .final-answer {
                font-weight: bold;
            }
            .ground-truth-answer {
                font-weight: bold;
            }
            /* Styles for the highlighted spans */
            .highlighted {
                padding: 2px 4px;
                border-radius: 3px;
                display: inline-block;
            }
            /* Styles for the summary section */
            .summary {
                background-color: #e0ffe0;
                padding: 15px;
                border: 2px solid #00cc00;
                border-radius: 8px;
                font-size: 1.2em;
                margin-bottom: 30px;
            }
        </style>
    </head>
    <body>
    <h1>Question and Answer Highlights</h1>
    """

    # Initialize counters for correct and total answers
    correct_answers = 0
    total_answers = 0

    # Placeholder for summary to be added at the top
    summary_html = ""  # Will be updated after processing all QA pairs

    # Temporary storage for all QA containers
    qa_html_sections = ""

    for i, (id_, question, answer_text) in enumerate(qa_pairs, 1):
        try:
            answer_reasoning, final_answer = extract_parts_regular_cot(answer_text)
        except Exception as e:
            print(f"Cannot extract parts for question ID {id_}: {e}")
            continue

        # Apply color to tags in answer_reasoning
        highlighted_reasoning = add_color_to_tags_new(answer_reasoning)

        # Retrieve ground truth answer
        gt_answer = ground_truth.get(id_)
        if gt_answer is None:
            gt_answer_display = "<span style='color: gray;'>Ground truth not available.</span>"
            is_correct = False
        else:
            gt_answer_display = gt_answer
            # Compare final_answer with ground truth
            is_correct = final_answer == gt_answer

        # Style the final answer based on correctness
        if is_correct:
            highlighted_final_answer = f"<span style='font-size:1.1em; color: green;'>{final_answer}</span>"
            correct_answers += 1
        else:
            highlighted_final_answer = f"<span style='font-size:1.1em; color: red;'>{final_answer}</span>"
        total_answers += 1

        # Display ground truth answer
        if gt_answer is not None:
            ground_truth_html = f"<div class='ground-truth-answer'><strong>Ground Truth Answer:</strong> {gt_answer_display}</div>"
        else:
            ground_truth_html = f"<div class='ground-truth-answer'><strong>Ground Truth Answer:</strong> Not available.</div>"

        # Build the HTML structure for this QA pair
        qa_html_sections += f"<div class='container'>"
        qa_html_sections += f"<div class='question'><strong>Question:</strong> {question}</div>"
        qa_html_sections += f"<div class='answer-reasoning'><strong>Answer Reasoning:</strong> {highlighted_reasoning}</div>"
        qa_html_sections += f"<div class='final-answer'><strong>Final Answer:</strong> {highlighted_final_answer}</div>"
        qa_html_sections += f"{ground_truth_html}"
        qa_html_sections += "</div>\n"

    # After processing all QA pairs, create the summary
    accuracy_percentage = (correct_answers / total_answers * 100) if total_answers > 0 else 0
    summary_html = f"""
    <div class='summary'>
        <strong>Summary:</strong> Correct Answers: {correct_answers} / {total_answers} ({accuracy_percentage:.2f}%)
    </div>
    """

    # Append the summary at the top, right after the header
    html_content += summary_html
    html_content += qa_html_sections

    # Close the HTML tags
    html_content += """
    </body>
    </html>
    """
    return html_content


def main():
    input_csv = '/Users/log/Github/textual_grounding/logan/results/GSM8K/llama/mermaid/cot_llama3.1_20240927_003000.csv'  # Replace with your input CSV file path
    ground_truth_file = '/Users/log/Github/textual_grounding/data/GSM8K/test.jsonl'  # Path to the ground truth JSONL file
    output_file = 'cot_llama3.1.html'  # Replace with your desired output HTML file path

    # Check if input files exist
    if not os.path.isfile(input_csv):
        print(f"Input CSV file not found: {input_csv}")
        return
    if not os.path.isfile(ground_truth_file):
        print(f"Ground truth JSONL file not found: {ground_truth_file}")
        return

    # Parse the input CSV file to extract IDs, questions, and answers
    qa_pairs = parse_csv_file(input_csv)
    print(f"Total QA Pairs Parsed: {len(qa_pairs)}")  # Debug: Print the number of QA pairs parsed

    # Read the ground truth answers
    ground_truth = read_ground_truth(ground_truth_file)
    print(f"Total Ground Truth Entries: {len(ground_truth)}")  # Debug: Print the number of ground truth entries

    # Check if any QA pairs were found
    if not qa_pairs:
        print("No question-answer pairs were found in the input file.")
        return

    # Generate the HTML content
    html_content = create_highlight_html_new(qa_pairs, ground_truth)

    # Write the HTML content to the output file
    with open(output_file, 'w', encoding='utf-8') as file:
        file.write(html_content)

    print(f"HTML content has been successfully written to {output_file}")


if __name__ == "__main__":
    main()


Total QA Pairs Parsed: 200
No ground truth answer found for ID 489
No ground truth answer found for ID 1113
Total Ground Truth Entries: 1317
HTML content has been successfully written to cot_llama3.1.html


## Tin - visualize

In [84]:
import pandas as pd
import csv
import re

# Define the prompt_type and llm_model as needed
prompt_type = "fs"  # Example value, set accordingly
llm_model = "llama3.1"  # Example value, set accordingly

save_html_path = f"question_answer_highlights_prompt_{prompt_type}_{llm_model}.html"
# df_path = f'logan/results/{dataset}/llama/test_grounding_answer_prompt_{prompt_type}_{llm_model}.csv'
df_path = '/Users/log/Github/textual_grounding/logan/results/GSM8K/llama/test_grounding_answer_prompt_fs_inst_llama3.1.csv'

if prompt_type in ["fs", "fs_inst"]:
    prefix = 'The answer is'
elif prompt_type == "zs":
    prefix = 'Final answer:'

df = pd.read_csv(df_path)
questions = df['question'].tolist()
answers = df['answer'].tolist()

def add_color_to_tags(text):
    """
    Adds background color to specific tags within the text based on a predefined color mapping.

    Args:
        text (str): The text containing tags like <a>, <b>, etc.

    Returns:
        str: The text with added inline CSS for background colors.
    """
    tag_color_mapping = {
        'a': 'yellow',       # You can customize colors as needed
        'b': 'lightblue',
        'c': 'lightgreen',
        'd': 'lightcoral',
        'e': 'lightcyan',
        'f': 'orange',       # Extend as needed
        # Add more tags if necessary
    }
    # Iterate over the tag-color mappings
    for tag, color in tag_color_mapping.items():
        # Regex to find the tag and replace it with the same tag having a style attribute
        text = re.sub(
            f'<{tag}>(.*?)</{tag}>',
            f'<span style="background-color: {color};">{r"\1"}</span>',
            text,
            flags=re.DOTALL
        )
    return text

def parse_csv_file(file_path):
    """
    Parses the input CSV file and extracts questions and their corresponding answers.

    Args:
        file_path (str): Path to the input CSV file.

    Returns:
        list of tuples: Each tuple contains (question, answer_text).
    """
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            qa_pairs.append((question, answer_text))
    return qa_pairs

def create_highlight_html(qa_pairs):
    """
    Creates HTML content with highlighted questions and answers based on tags.

    Args:
        qa_pairs (list of tuples): Each tuple contains (question, answer_text).

    Returns:
        str: The complete HTML content as a string.
    """
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <title>Question and Answer Highlights</title>
        <style>
            body {
                font-family: Arial, sans-serif;
                margin: 20px;
                background-color: #f0f0f0;
            }
            .container {
                background-color: #ffffff;
                padding: 20px;
                margin-bottom: 20px;
                border-radius: 8px;
                box-shadow: 0 2px 5px rgba(0,0,0,0.1);
            }
            .question {
                font-size: 1.2em;
                margin-bottom: 10px;
            }
            .answer {
                margin-bottom: 10px;
            }
            .highlight {
                background-color: #FFFF00; /* Default highlight color */
                font-weight: bold; /* Bold text for emphasis */
            }
            /* Additional styles for specific tags can be added here if needed */
        </style>
    </head>
    <body>
    <h1>Question and Answer Highlights</h1>
    """
    for i, (question, answer_text) in enumerate(qa_pairs, 1):
        try:
            # Apply color to tags in question and answer
            highlighted_question = add_color_to_tags(question)
            highlighted_answer = add_color_to_tags(answer_text)
        except Exception as e:
            print(f"Cannot process question-answer pair {i}: {e}")
            continue

        # Build the HTML structure
        html_content += f"<div class='container'>"
        html_content += f"<div class='question'><strong>Question:</strong> {highlighted_question}</div>"
        html_content += f"<div class='answer'><strong>Answer:</strong> {highlighted_answer}</div>"
        html_content += "</div>\n"

    # Close the HTML tags
    html_content += """
    </body>
    </html>
    """
    return html_content

def main():
    input_file = df_path  # Use the defined df_path
    output_file = save_html_path  # Use the defined save_html_path

    # Parse the input CSV file to extract questions and answers
    qa_pairs = parse_csv_file(input_file)

    # Check if any QA pairs were found
    if not qa_pairs:
        print("No question-answer pairs were found in the input file.")
        return

    # Generate the HTML content
    html_content = create_highlight_html(qa_pairs)

    # Write the HTML content to the output file
    with open(output_file, 'w', encoding='utf-8') as file:
        file.write(html_content)

    print(f"HTML content has been successfully written to {output_file}")

if __name__ == "__main__":
    main()


HTML content has been successfully written to question_answer_highlights_prompt_fs_llama3.1.html


# Response Statistics

In [19]:
import csv
import re
import json  # For handling JSONL
import os

def extract_parts_regular_cot(answer_text):
    """
    Processes the answer text to extract answer reasoning and the final answer.

    Args:
        answer_text (str): The full answer text containing answer reasoning and final answer.

    Returns:
        tuple: (answer_reasoning, final_answer, has_curly)
               where answer_reasoning is the full model response,
               final_answer is the extracted answer,
               and has_curly is a boolean indicating if the final answer was in curly brackets.
    """
    # Attempt to extract Final Answer from 'Final Answer:'
    final_match = re.search(r'Final Answer:\s*(\S+)', answer_text, re.IGNORECASE)
    if final_match and final_match.group(1).strip():
        final_answer = final_match.group(1).strip()
        has_curly = False
    else:
        # Fallback: Extract Final Answer from '{...}' in the reasoning
        curly_match = re.search(r'\{([\d.]+)\}', answer_text)
        final_answer = curly_match.group(1).strip() if curly_match else ""
        has_curly = bool(curly_match)

    return answer_text.strip(), final_answer, has_curly

def parse_csv_file(file_path):
    """
    Parses the input CSV file and extracts questions, answers, and their corresponding IDs.

    Args:
        file_path (str): Path to the input CSV file.

    Returns:
        list of tuples: Each tuple contains (id, question, answer_text).
    """
    qa_pairs = []
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            question = row.get('question', 'No question found.').strip()
            answer_text = row.get('answer', 'No answer found.').strip()
            id_ = row.get('id')
            if id_ is not None:
                try:
                    id_int = int(id_)
                except ValueError:
                    print(f"Skipping a row due to invalid 'id' (not an integer): {id_}")
                    continue
                qa_pairs.append((id_int, question, answer_text))
            else:
                # Handle cases without 'id' by skipping
                print(f"Skipping a row due to missing 'id': {row}")
    return qa_pairs

def read_ground_truth(jsonl_path):
    """
    Reads the ground truth answers from a JSONL file and maps them by ID.

    Args:
        jsonl_path (str): Path to the ground truth JSONL file.

    Returns:
        dict: A dictionary mapping each ID to its ground truth answer.
    """
    ground_truth = {}
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            id_ = data.get('id')
            answer = data.get('answer')
            if id_ is not None and answer is not None:
                # Extract the last number or text after '####'
                match = re.search(r'####\s*([\d.]+)', answer)
                if match:
                    ground_truth[id_] = match.group(1).strip()
                else:
                    print(f"No ground truth answer found for ID {id_}")
            else:
                print(f"Invalid ground truth entry: {data}")
    return ground_truth

def create_statistics(qa_pairs, ground_truth):
    """
    Creates and prints statistics based on the QA pairs and ground truth.

    Args:
        qa_pairs (list of tuples): Each tuple contains (id, question, answer_text).
        ground_truth (dict): A dictionary mapping each ID to its ground truth answer.
    """
    total_responses = len(qa_pairs)
    responses_with_curly = 0
    responses_without_curly = 0
    correct_answers = 0
    incorrect_answers = 0
    no_ground_truth = 0

    # Variables for tag statistics
    total_tags = 0
    total_tag_length = 0
    tag_counts = []  # List to store number of tags per response
    tag_lengths = []  # List to store lengths of tag content across all responses

    for id_, question, answer_text in qa_pairs:
        try:
            answer_reasoning, final_answer, has_curly = extract_parts_regular_cot(answer_text)
        except Exception as e:
            print(f"Cannot extract parts for question ID {id_}: {e}")
            continue

        if has_curly:
            responses_with_curly += 1
        else:
            responses_without_curly += 1

        # Extract tags and their content
        tags_in_response = re.findall(r'<([A-Za-z]+\d*)>(.*?)</\1>', answer_text)
        number_of_tags = len(tags_in_response)
        tag_counts.append(number_of_tags)
        total_tags += number_of_tags

        for tag, content in tags_in_response:
            content_length = len(content)
            tag_lengths.append(content_length)
            total_tag_length += content_length

        # Retrieve ground truth answer
        gt_answer = ground_truth.get(id_)
        if gt_answer is None:
            no_ground_truth += 1
            continue

        # Compare final_answer with ground truth
        if final_answer == gt_answer:
            correct_answers += 1
        else:
            incorrect_answers += 1

    # Calculate additional metrics
    accuracy_percentage = (correct_answers / (correct_answers + incorrect_answers) * 100) if (correct_answers + incorrect_answers) > 0 else 0
    curly_percentage = (responses_with_curly / total_responses * 100) if total_responses > 0 else 0
    no_curly_percentage = (responses_without_curly / total_responses * 100) if total_responses > 0 else 0
    ground_truth_available = total_responses - no_ground_truth
    ground_truth_available_percentage = (ground_truth_available / total_responses * 100) if total_responses > 0 else 0

    # Calculate tag statistics
    average_tags_per_response = (total_tags / total_responses) if total_responses > 0 else 0
    average_tag_length = (total_tag_length / total_tags) if total_tags > 0 else 0

    # Print the statistics
    print("\n===== Analysis Statistics =====\n")
    print(f"Total Responses Analyzed: {total_responses}")
    print(f"Responses with Final Answer in Curly Brackets: {responses_with_curly} ({curly_percentage:.2f}%)")
    print(f"Responses without Final Answer in Curly Brackets: {responses_without_curly} ({no_curly_percentage:.2f}%)")
    print(f"Responses with Ground Truth Available: {ground_truth_available} ({ground_truth_available_percentage:.2f}%)")
    print(f"Correct Answers: {correct_answers}")
    print(f"Incorrect Answers: {incorrect_answers}")
    print(f"Accuracy: {accuracy_percentage:.2f}%")
    print(f"Responses without Ground Truth: {no_ground_truth}")

    # Tag Statistics
    print("\n----- Tag Statistics -----")
    print(f"Total Tags Found: {total_tags}")
    print(f"Average Number of Tags per Response: {average_tags_per_response:.2f}")
    print(f"Average Length of Tag Content: {average_tag_length:.2f} characters")
    print("--------------------------\n")
    print("===== End of Statistics =====\n")

def main():
    input_csv = '/Users/log/Github/textual_grounding/logan/results/GSM8K/llama/mermaid/mermaid_get_answer_llama3.1_20240926_215344.csv'  # Replace with your input CSV file path
    ground_truth_file = '/Users/log/Github/textual_grounding/data/GSM8K/test.jsonl'  # Path to the ground truth JSONL file

    # Check if input files exist
    if not os.path.isfile(input_csv):
        print(f"Input CSV file not found: {input_csv}")
        return
    if not os.path.isfile(ground_truth_file):
        print(f"Ground truth JSONL file not found: {ground_truth_file}")
        return

    # Parse the input CSV file to extract IDs, questions, and answers
    qa_pairs = parse_csv_file(input_csv)
    print(f"Total QA Pairs Parsed: {len(qa_pairs)}")  # Debug: Print the number of QA pairs parsed

    # Read the ground truth answers
    ground_truth = read_ground_truth(ground_truth_file)
    print(f"Total Ground Truth Entries: {len(ground_truth)}")  # Debug: Print the number of ground truth entries

    # Check if any QA pairs were found
    if not qa_pairs:
        print("No question-answer pairs were found in the input file.")
        return

    # Generate and print the statistics
    create_statistics(qa_pairs, ground_truth)

    print("Statistics analysis completed successfully.")

if __name__ == "__main__":
    main()


Total QA Pairs Parsed: 200
No ground truth answer found for ID 489
No ground truth answer found for ID 1113
Total Ground Truth Entries: 1317

===== Analysis Statistics =====

Total Responses Analyzed: 200
Responses with Final Answer in Curly Brackets: 136 (68.00%)
Responses without Final Answer in Curly Brackets: 64 (32.00%)
Responses with Ground Truth Available: 200 (100.00%)
Correct Answers: 94
Incorrect Answers: 106
Accuracy: 47.00%
Responses without Ground Truth: 0

----- Tag Statistics -----
Total Tags Found: 501
Average Number of Tags per Response: 2.50
Average Length of Tag Content: 8.68 characters
--------------------------

===== End of Statistics =====

Statistics analysis completed successfully.
