In [None]:
import os
import json

json_file_path = "grounding_model_evals/gem7b_results/batch_0_to_9999"   # Path to your JSON file

json_file_save_dir = "grounding_model_evals/gem7b_results/all_results_raw.json"

all_results = {}

for result_file in os.listdir(json_file_path):
    if result_file.endswith(".json"):
        smp_id = result_file.split(".")[0]

        with open(os.path.join(json_file_path, result_file), "r") as f:
            result_data = json.load(f)

        all_results[smp_id] = result_data["results"]

# Save the list of objects as a JSON array
with open(json_file_save_dir, "w", encoding="utf-8") as f:
    json.dump(all_results, f, indent=4)

In [None]:
import re
import json

model_name = "gem7b"
# file path
data_dir = f"grounding_model_evals/{model_name}_results/all_results_raw.json"

data = json.load(open(data_dir, "r"))

expected_keys = [
    'DiagnosisAccuracy',
    'AnalysisCompleteness',
    'AnalysisRelevance',
    'LeadAssessmentCoverage',
    'LeadAssessmentAccuracy',
    'GroundedECGUnderstanding', # i.e., ECG Feature Grounding in paper
    'EvidenceBasedReasoning', 
    'RealisticDiagnosticProcess' # i.e., Clinical Diagnostic Fidelity
]

pattern = re.compile(r'\"(?P<key>{})\":\s*(?P<content>\[.*?\])'.format('|'.join(expected_keys)), re.DOTALL)

result = {}

# Additional cleaning functions
def fix_unterminated_string(content):
    quote_count = len(re.findall(r'(?<!\\)"', content))
    if quote_count % 2 == 1:
        content = re.sub(r'(\s*[}\]])', r'"\1', content, count=1)
    return content

def escape_inner_quotes_in_explanation(content):
    def replacer(match):
        explanation = match.group(1)
        fixed = re.sub(r'(?<!\\)"', r'\\"', explanation)
        return f'"Explanation": "{fixed}"'
    return re.sub(r'"Explanation":\s*"([^"]*?)"', replacer, content)

def remove_extra_quotes(content):
    content = re.sub(r'""+', '"', content)
    return content

def fix_unmatched_brackets(content):
    def replacer(match):
        explanation = match.group(1)
        fixed = re.sub(r'[\[\]]', '', explanation)
        return f'"Explanation": "{fixed}"'
    return re.sub(r'"Explanation":\s*"([^"]*?)"', replacer, content)

def fix_missing_commas(content):
    content = re.sub(r'(\})(\s*\{)', r'\1,\2', content)
    return content

def safe_eval(match):
    try:
        return str(eval(match.group(1)))
    except:
        return match.group(1)  # Return the original string if eval fails

for id, content in data.items():
    json_content = content.strip('```json\n').strip('\n```')
    result[id] = {}

    matches = pattern.finditer(json_content)
    
    for match in matches:
        key = match.group('key')
        content = match.group('content')

        # Original cleaning steps
        content = content.replace("\"", '"').replace("“", '"').replace("”", '"')
        content = re.sub(r'//.*', '', content)
        content = re.sub(r',\s*([}\]])', r'\1', content)
        content = re.sub(r'"\s*"', ' ', content)
        content = re.sub(r'\+(\d)', r'\1', content)
        content = re.sub(r'(\d+[\d\s\*\+\-\/]+\d+)', safe_eval, content)
        content = content.replace('");', '"')
        content = re.sub(r'\s+', ' ', content)
        content = re.sub(r'\n|\r', ' ', content)

        # Additional cleaning steps
        content = fix_unterminated_string(content)
        content = escape_inner_quotes_in_explanation(content)
        content = remove_extra_quotes(content)
        content = fix_unmatched_brackets(content)
        content = fix_missing_commas(content)

        open_braces = content.count('{')
        close_braces = content.count('}')
        if open_braces > close_braces:
            content += '}' * (open_braces - close_braces)
        
        open_brackets = content.count('[')
        close_brackets = content.count(']')
        if open_brackets > close_brackets:
            content += ']' * (open_brackets - close_brackets)

        try:
            content_json = json.loads(content)
        except json.JSONDecodeError as e:
            print(f"JSON decoding error for id {id}, key {key}: {e}")
            print("Content:", content)
            continue

        scores = []
        explanations = []

        for item in content_json:
            score = item.get('Score')
            explanation = item.get('Explanation', '').strip()
            extra_fields = {k: v for k, v in item.items() if k not in ['Score', 'Explanation']}

            if extra_fields:
                explanation += " Additional details: " + json.dumps(extra_fields)

            scores.append(score)
            explanations.append(explanation)

        result[id][key] = {
            'Scores': scores,
            'Explanations': explanations
        }


In [None]:
# Save the list of objects as a JSON array
output_json_file = f"grounding_model_evals/{model_name}_results/all_results_processed_clean.json"
with open(output_json_file, "w", encoding="utf-8") as f:
    json.dump(result, f, indent=4)

In [None]:
import json
import pandas as pd

model_name = "gem7b"
model_result_dir = f"grounding_model_evals/{model_name}_results/all_results_processed_clean.json"

data = json.load(open(model_result_dir, "r"))

results = {}

for id, content in data.items():
    results[id] = {}
    for key, value in content.items():
        
        if key in ['AnalysisCompleteness', 'AnalysisRelevance']:
            # Filter out zero scores
            average = sum(value['Scores'])
        else:
            # Filter out zero scores
            filtered_lst = [x for x in value['Scores'] if x > 0] 
            average = sum(filtered_lst) / len(filtered_lst) if filtered_lst else 0

        results[id][key] = average

In [None]:
df = pd.DataFrame(results).T

In [None]:
# accuracy for each diagnosis
df['DiagnosisAccuracy'] = df['DiagnosisAccuracy']/2 * 100
# remove outliers
df['LeadAssessmentCoverage'] = df['LeadAssessmentCoverage'].clip(upper=12)/12 * 100
# for each sample, the maximum score for LeadAssessmentAccuracy will be 24
df['LeadAssessmentAccuracy'] = df['LeadAssessmentAccuracy']/24 * 100

df.mean()

In [None]:
# overall scores
(df['GroundedECGUnderstanding'].mean()+df['EvidenceBasedReasoning'].mean()+df['RealisticDiagnosticProcess'].mean())/3