In [13]:
import os
import pandas as pd
from io import StringIO

from utils.pipeline import run_pipeline
# Configurable parameter: top_k (must be less than 15)
TOP_K = 1  # Adjust this as needed

# Paths
TESTSET_CSV_PATH = "/home/jaf/battery-lifespan-kg/eval/evaluator_question_short_6_value.csv"  # Update with your CSV file path
BATTERY_FILES_DIR = "/home/jaf/battery-lifespan-kg/resources/testset"

import re

def parse_response(response, top_k):
    """
    Parse the LLM response to extract a list of battery IDs using regex.
    Battery IDs are expected to be in the form "b{number}c{number}" (e.g., b1c1).
    
    :param response: The LLM response string.
    :param top_k: The maximum number of battery IDs to return.
    :return: A list of extracted battery IDs (up to top_k items).
    """
    try:
        # Regex pattern: word boundary, b followed by one or more digits,
        # c followed by one or more digits, then word boundary.
        pattern = r'\bb\d+c\d+\b'
        matches = re.findall(pattern, response, flags=re.IGNORECASE)
        # Normalize the matches (trim and lowercase)
        matches = [match.strip().lower() for match in matches]
        return matches[:top_k]
    except Exception as e:
        print(f"Error parsing response: {e}. Response was: {response}")
        return ['']


def normalize_str(s):
    return s.strip().lower()

def evaluate_testset():
    # Read the testset CSV
    df = pd.read_csv(TESTSET_CSV_PATH)
    
    total_comparisons = 0
    total_correct = 0
    
    # Iterate over each test case (row)
    for idx, row in df.iterrows():
        test_battery_id = row["TEST_BATTERY_ID"]
        battery_file_path = os.path.join(BATTERY_FILES_DIR, f"{test_battery_id}.txt")
        
        try:
            with open(battery_file_path, "r") as f:
                file_content = f.read()
        except Exception as e:
            print(f"Error reading file for battery {test_battery_id}: {e}")
            continue
        
        # Get expected battery IDs for positions 1 to TOP_K
        expected_ids = []
        for i in range(1, TOP_K+1):
            col_name = f"{i}_Most_Similar_Battery_ID"
            if col_name in row:
                expected_ids.append(str(row[col_name]).strip().lower())
            else:
                expected_ids.append("")  # In case the column is missing
        
        # Identify all sample question columns (e.g., SAMPLE_QUESTION_1, SAMPLE_QUESTION_2, etc.)
        sample_question_cols = [col for col in df.columns if col.startswith("SAMPLE_QUESTION")]
        
        # Evaluate each sample question as an individual data point
        for question_col in sample_question_cols:
            base_question = str(row[question_col]).strip()
            # Append instruction to force the list format with top-K results
            modified_question = (
                f"{base_question} return the top-{TOP_K} results from the most similar first. "
                "Please respond with a comma-separated list of battery IDs only."
            )
            
            # Create a file-like object from the file content
            from io import StringIO
            uploaded_file = StringIO(file_content)
            
            # Get LLM response
            response = run_pipeline(modified_question, uploaded_file)
            print(response)
            # print(response.keys())
            # Parse the response into a list of battery IDs
            if type(response) == str:
                extracted_ids = []
            else:
                extracted_ids = parse_response(response['result'], TOP_K)
            
            # Initialize score for this sample question
            sample_correct = 0
            for rank in range(TOP_K):
                total_comparisons += 1
                # If there is no candidate for this rank, count as not matched
                if rank >= len(extracted_ids):
                    print(f"Test case {test_battery_id}, question '{base_question}': Missing candidate at rank {rank+1}. Expected: {expected_ids[rank]}")
                    continue
                
                if normalize_str(extracted_ids[rank]) == normalize_str(expected_ids[rank]):
                    sample_correct += 1
                    total_correct += 1
                else:
                    print(f"Test case {test_battery_id}, question '{base_question}': Mismatch at rank {rank+1}. Expected: {expected_ids[rank]}, Got: {extracted_ids[rank]}")
            
            # Report score for this sample question
            accuracy = (sample_correct / TOP_K) * 100
            print(f"Test case {test_battery_id}, question '{base_question}' accuracy: {accuracy:.2f}%")
    
    overall_accuracy = (total_correct / total_comparisons) * 100 if total_comparisons > 0 else 0
    print(f"\nOverall accuracy across all sample questions: {overall_accuracy:.2f}%")

evaluate_testset()


{'schema': 'slope_last_10_cycles: -0.00023017525672912597\nmean_grad_last_10_cycles: -0.00022496283054351807\nslope_last_50_cycles: -0.0004295814037322998\nmean_grad_last_50_cycles: -0.00043808817863464353\nslope_last_100_cycles: -0.0004307425022125244\nmean_grad_last_100_cycles: -0.00043108373880386354\nslope_last_200_cycles: -0.0003900536894798279\nmean_grad_last_200_cycles: -0.00038991034030914305\nslope_last_300_cycles: -0.00035009324550628663\nmean_grad_last_300_cycles: -0.00035052637259165445\nslope_last_400_cycles: -0.0003128772974014282\nmean_grad_last_400_cycles: -0.00031308412551879885\nslope_last_500_cycles: -0.0002762656211853027\nmean_grad_last_500_cycles: -0.0002763822078704834\nslope_last_600_cycles: -0.00024240295092264812\nmean_grad_last_600_cycles: -0.00024246146281560262\nslope_last_700_cycles: -0.00021553754806518554\nmean_grad_last_700_cycles: -0.0002156040498188564\nslope_last_800_cycles: -0.0001955069601535797\nmean_grad_last_800_cycles: -0.00019555248320102692\n