# Evaluation of ChroKnowPrompt

In [None]:
from collections import defaultdict

from tqdm import tqdm

from sources.utils import *
from sources.process import *

## Please follow the order down below.

## 1. Get Year Span Information

### 1-1. Load data with chrono ans

In [None]:
def load_data_with_updated_timestamp(model_name, domain):
    bench_dir_c = f"./ChroKnowBench/TimeVariant_{domain}_Dynamic.jsonl"
    bench_dir_u = f"./ChroKnowBench/TimeVariant_{domain}_Static.jsonl"
    timestamp_dir_c = f'./ChronoGap/{model_name}/Updated_Timestamp_{domain}_Dynamic.json'
    timestamp_dir_u = f'./ChronoGap/{model_name}/Updated_Timestamp_{domain}_Static.json'
    bench_data_c = read_jsonl_file(bench_dir_c)
    bench_data_u = read_jsonl_file(bench_dir_u)
    timestamp_c = read_json_file(timestamp_dir_c)
    timestamp_u = read_json_file(timestamp_dir_u)
    return bench_data_c, bench_data_u, timestamp_c, timestamp_u

def get_traversal_details(partial_known, target_year, tentative_ans, prev_year_span=3, next_year_span=3):
    """
    Determines the traversal steps taken to locate the tentative answer within the specified spans.

    Parameters:
    - partial_known (dict): The known data containing answers categorized by years.
    - target_year (str): The target year key (e.g., 'objects_2022').
    - tentative_ans (str): The tentative answer to be checked.
    - prev_year_span (int): Number of previous years to consider.
    - next_year_span (int): Number of next years to consider.

    Returns:
    - dict: A dictionary containing traversal steps.
    """

    def extract_year_from_key(key):
        """Extracts the year from keys like 'objects_2021'."""
        return int(key.split('_')[1])

    def get_span_years(partial_known, target_year, direction, span_limit):
        """Returns a list of years within the specified span and direction."""
        target_year_int = extract_year_from_key(target_year)
        years = sorted([extract_year_from_key(key) for key in partial_known.keys()])
        if direction == 'previous':
            span_years = [year for year in years if year < target_year_int][-span_limit:]
        else:
            span_years = [year for year in years if year > target_year_int][:span_limit]
        return span_years

    def collect_answers_by_year(partial_known, span_years):
        """Collects answers for each year in span_years."""
        answers_by_year = {}
        for year in span_years:
            year_key = f'objects_{year}'
            reference_data = partial_known.get(year_key, {})
            answers = set()
            if "chrono_ans" in reference_data:
                chrono_ans = reference_data["chrono_ans"]
                if isinstance(chrono_ans, list):
                    answers.update([ans.lower() for ans in chrono_ans if ans])
                else:
                    answers.add(chrono_ans.lower())
            elif reference_data.get("category") == "correct":
                temp0_ans = reference_data.get("temp0_ans", [])
                answers.update([ans.lower() for ans in temp0_ans if ans])
            else:
                temp7_ans = reference_data.get("temp7_ans", [])
                answers.update([ans.lower() for ans in temp7_ans if ans])
            answers_by_year[year] = answers
        return answers_by_year

    # Step 1: Collect previous and next span years
    previous_span_years = get_span_years(partial_known, target_year, 'previous', prev_year_span)
    next_span_years = get_span_years(partial_known, target_year, 'next', next_year_span)

    # Step 2: Collect answers by year for both spans
    previous_answers_by_year = collect_answers_by_year(partial_known, previous_span_years)
    next_answers_by_year = collect_answers_by_year(partial_known, next_span_years)

    # Initialize traversal steps log
    traversal_steps = []

    # Step 3: Check in previous span
    for year in reversed(previous_span_years):  # Start from the latest previous year
        traversal_steps.append('previous')
        if tentative_ans.lower() in previous_answers_by_year.get(year, set()):
            break  # Exit once found in previous span

    # Step 4: If not found in previous, check in next span
    # Only continue checking next span if tentative_ans was not found in previous span
    if tentative_ans.lower() not in set(ans for sublist in previous_answers_by_year.values() for ans in sublist):
        for year in next_span_years:
            traversal_steps.append('next')
            if tentative_ans.lower() in next_answers_by_year.get(year, set()):
                break  # Exit once found in next span

    # Step 5: Prepare traversal details
    traversal_details = {
        "steps": traversal_steps
    }

    return traversal_details

### 1-2. Get the traversal information for all domains

In [None]:
model_name_list = ["Llama3.1_8B", "Mistral7B", "Phi3.5_Mini", "SOLAR_10.7B", "Gemma2_9B", 'gpt-4o-mini']
domain = "General" # General, Biomedical, Legal
tentative_ans = "Entity"
prev_span = 3
next_span = 3

for model_name in model_name_list:
    bench_data_c, bench_data_u, timestamp_c, timestamp_u = load_data_with_updated_timestamp(model_name, domain)

    partial_known_indices = [i for i, entry in enumerate(timestamp_c["Partial_known"]) if entry]
    subset_bench_data_c = [bench_data_c[i] for i in partial_known_indices]

    for index, triplet in tqdm(zip(partial_known_indices, subset_bench_data_c), total=len(partial_known_indices), desc="Getting year spans..."):
        partial_known = timestamp_c["Partial_known"][index][1]
        for category in tqdm(["incorrect", "partial_correct2", "partial_correct1"], desc="Categories", leave=False):
            for year, objects_year in partial_known.items():
                if objects_year["category"] == category:
                    chrono_ans = get_traversal_details(partial_known, year, tentative_ans, prev_year_span=prev_span, next_year_span=next_span)
                    # Extract the 'steps' from traversal_steps
                    steps = chrono_ans["steps"]
                    # Add 'steps' key to the current objects_year
                    partial_known[year]['steps'] = steps

        save_updated_timestamp(timestamp_c, f'./ChronoGap/{model_name}/Updated_Timestamp_{domain}_Dynamic_step.json')
        
        
    partial_known_indices = [i for i, entry in enumerate(timestamp_u["Partial_known"]) if entry]
    subset_bench_data_u = [bench_data_u[i] for i in partial_known_indices]

    for index, triplet in tqdm(zip(partial_known_indices, subset_bench_data_u), total=len(partial_known_indices), desc="Getting year spans..."):
        partial_known = timestamp_u["Partial_known"][index][1]
        for category in tqdm(["incorrect", "partial_correct2", "partial_correct1"], desc="Categories", leave=False):
            for year, objects_year in partial_known.items():
                if objects_year["category"] == category:
                    chrono_ans = get_traversal_details(partial_known, year, tentative_ans, prev_year_span=prev_span, next_year_span=next_span)
                    # Extract the 'steps' from traversal_steps
                    steps = chrono_ans["steps"]
                    # Add 'steps' key to the current objects_year
                    partial_known[year]['steps'] = steps

        save_updated_timestamp(timestamp_u, f'./ChronoGap/{model_name}/Updated_Timestamp_{domain}_Static_step.json')

## 2. Results of Total Span

In [None]:
def transition_rule_last_entry(entry, benchmark):
    """
    Define the rule for moving an entry from 'Partial_known' to 'Known',
    but only consider the last chrono_ans in each year for fuzzy matching.
    
    Returns:
        - all_known: Whether all years are 'chrono_known' or 'correct'
    """
    all_known = True

    for year, year_data in entry.items():
        if 'chrono_ans' in year_data:
            # Clean chrono_ans by removing invalid entries
            chrono_ans_list = clean_chrono_ans_list(year_data['chrono_ans'])
            year_data['chrono_ans'] = chrono_ans_list

            # Check only the last entry in the chrono_ans list for fuzzy matching
            if chrono_ans_list:
                last_ans = chrono_ans_list[-1]  # Get the last chrono_ans entry

                # Perform fuzzy matching with the last chrono_ans entry
                if is_fuzz_match(last_ans, benchmark.get(year, set())):                    
                    year_data['category'] = 'chrono_known'
        # Check if the updated category is not 'chrono_known' or 'correct'
        if year_data['category'] not in ['chrono_known', 'correct']:
            all_known = False
    
    return all_known


def results_of_total_span(new_data, bench_entries, bench_name):
    """
    Classify data entries into Known, Unknown, Cut-off, and Partial_known categories,
    while calculating the increase in Known items. The classification is based only
    on the last entry in the chrono_ans list.
    
    Returns:
    - fine_grained_results: A dictionary with the counts of each category
    - classification_indices: A list of items per category
    - partial_known_to_known: The number of items moved from Partial_known to Known
    - final_known_count: The final Known count after processing
    """
    fine_grained_results = defaultdict(int)
    classification_indices = defaultdict(list)
    partial_known_to_known = 0 
    moved_items_to_known = [] 

    for category in ['Known', 'Unknown', 'Cut-off', 'Partial_known']:
        for idx, year_classifications in new_data.get(category, []):
            benchmark = get_benchmark(bench_entries[idx], bench_name)
            
            if category == 'Partial_known':
                all_known = transition_rule_last_entry(year_classifications, benchmark)
                if all_known:
                    final_category = 'Known'
                    partial_known_to_known += 1
                    moved_items_to_known.append((idx, year_classifications))
    
                else:
                    final_category = 'Partial_known'
            else:
                final_category = category  

            fine_grained_results[final_category] += 1
            classification_indices[final_category].append((idx, year_classifications))

    final_known_count = len(classification_indices['Known'])
    previous_known_count = final_known_count - partial_known_to_known  

    print(f"Total items moved from Partial_known to Known: {partial_known_to_known}")
    
    # Ensure no duplicate items during transition
    for idx, year_classifications in moved_items_to_known:
        classification_indices['Partial_known'] = [item for item in classification_indices['Partial_known'] if item[0] != idx]
        if (idx, year_classifications) not in classification_indices['Known']:  # Prevent duplicates
            classification_indices['Known'].append((idx, year_classifications))

    # Debug final counts
    print(f"Final Known count: {len(classification_indices['Known'])}")
    print(f"Final Partial_known count: {len(classification_indices['Partial_known'])}")

    # Ensure total count consistency
    total_items = len(classification_indices['Known']) + len(classification_indices['Partial_known']) + \
                  len(classification_indices['Unknown']) + len(classification_indices['Cut-off'])
    expected_total = len(new_data.get('Known', [])) + len(new_data.get('Partial_known', [])) + \
                     len(new_data.get('Unknown', [])) + len(new_data.get('Cut-off', []))

    # Check if the total count is consistent
    if total_items != expected_total:
        raise ValueError(f"Total item count mismatch! Expected: {expected_total}, Got: {total_items}")

    return fine_grained_results, classification_indices, previous_known_count

In [None]:
def get_results_total_span(model_name_list, domain):
    results_list_temp_state = []
    
    for t_state in ["Dynamic", "Static"]:
        results_list = []

        for model_name in tqdm(model_name_list):
            bench, temp0_parsed_time, temp7_parsed_time = load_result(
                                                                    model_name=model_name,
                                                                    domain=domain,
                                                                    temp_state=t_state,   
                                                                    mode="generation"
                                                                    )
            
            new_data = read_json_file(f"./ChronoGap/{model_name}/Updated_Timestamp_{domain}_{t_state}_step.json")

            results, indices, previous_known_count = results_of_total_span(
                new_data, bench, domain
            )

            total = sum(results.values())

            previous_known_percentage = (previous_known_count / total) * 100 if total > 0 else 0
            current_known_count = results.get('Known', 0)
            current_known_percentage = (current_known_count / total) * 100 if total > 0 else 0
            known_increase_percentage = current_known_percentage - previous_known_percentage

            print(f"\n[model: {model_name}, temp_state: {t_state}] ChroKnowPrompt results:")
            for category, count in results.items():
                percentage = (count / total) * 100 if total > 0 else 0
                if category == 'Known':
                    print(f"{category}: {count} ({percentage:.2f}%) (+{known_increase_percentage:.2f}% overall increase)")
                else:
                    print(f"{category}: {count} ({percentage:.2f}%)")
            print(f"Total: {total}")

            results_list.append({
                "model_name": model_name,
                "results": results,
                "indices": indices,
                "previous_known_count": previous_known_count,
                "total": total,
                "previous_known_percentage": previous_known_percentage,
                "current_known_count": current_known_count,
                "current_known_percentage": current_known_percentage,
                "known_increase_percentage": known_increase_percentage
            })

        results_list_temp_state.append(results_list)

    return results_list_temp_state

In [None]:
model_name_list = ["Llama3.1_8B", "Mistral7B", "Phi3.5_Mini", "SOLAR_10.7B", "Gemma2_9B", 'gpt-4o-mini']

# time variant domains: General, Biomedical, Legal
domain = "General"

results_list_temp_state = get_results_total_span(model_name_list=model_name_list,
                 domain=domain)

## 3. Results of Previous Span

In [None]:
def transition_rule_previous_only_last_entry(entry, benchmark):
    """
    Define the rule for moving an entry from 'Partial_known' to 'Known',
    but only consider the last 'previous' step in chrono_ans for fuzzy matching.
    
    Returns:
        - all_known: Whether all years are 'chrono_known' or 'correct'
        - unknown_to_chrono_known_years: List of years where 'unknown' changed to 'chrono_known'
    """
    all_known = True

    for year, year_data in entry.items():
        if 'chrono_ans' in year_data and 'steps' in year_data:
            # Clean chrono_ans by removing invalid entries
            chrono_ans_list = clean_chrono_ans_list(year_data['chrono_ans'])
            steps_list = year_data['steps']
            year_data['chrono_ans'] = chrono_ans_list

            # Filter only the 'previous' steps
            previous_steps = [(step, ans) for step, ans in zip(steps_list, chrono_ans_list) if step == "previous"]

            # Only perform fuzzy matching on the last 'previous' step
            if previous_steps:
                last_previous_ans = previous_steps[-1][1]  # Get the last 'previous' step's chrono_ans

                # Perform fuzzy matching with the last 'previous' chrono_ans entry
                if is_fuzz_match(last_previous_ans, benchmark.get(year, set())):                    
                    year_data['category'] = 'chrono_known'
            # If no match or no 'previous' steps, retain the original category

        # Check if the updated category is not 'chrono_known' or 'correct'
        if year_data['category'] not in ['chrono_known', 'correct']:
            all_known = False
    
    return all_known


def results_of_previous_span(new_data, bench_entries, bench_name):
    """
    Classify data entries into Known, Unknown, Cut-off, and Partial_known categories,
    while considering only the last 'previous' step in chrono_ans for matching.
    
    Returns:
    - fine_grained_results: A dictionary with the counts of each category
    - classification_indices: A list of items per category
    - partial_known_to_known: The number of items moved from Partial_known to Known
    - final_known_count: The final Known count after processing
    """
    fine_grained_results = defaultdict(int)
    classification_indices = defaultdict(list)
    partial_known_to_known = 0  
    moved_items_to_known = []  

    for category in ['Known', 'Unknown', 'Cut-off', 'Partial_known']:
        for idx, year_classifications in new_data.get(category, []):
            benchmark = get_benchmark(bench_entries[idx], bench_name)
            
            if category == 'Partial_known':
                all_known  = transition_rule_previous_only_last_entry(year_classifications, benchmark)
                if all_known:
                    final_category = 'Known'
                    partial_known_to_known += 1
                    moved_items_to_known.append((idx, year_classifications))
                
                else:
                    final_category = 'Partial_known'
            else:
                final_category = category

            fine_grained_results[final_category] += 1
            classification_indices[final_category].append((idx, year_classifications))


    final_known_count = len(classification_indices['Known'])
    previous_known_count = final_known_count - partial_known_to_known  

    print(f"Total items moved from Partial_known to Known: {partial_known_to_known}")
    
    for idx, year_classifications in moved_items_to_known:
        classification_indices['Partial_known'] = [item for item in classification_indices['Partial_known'] if item[0] != idx]
        if (idx, year_classifications) not in classification_indices['Known']:  # Prevent duplicates
            classification_indices['Known'].append((idx, year_classifications))

    # Debug final counts
    print(f"Final Known count: {len(classification_indices['Known'])}")
    print(f"Final Partial_known count: {len(classification_indices['Partial_known'])}")

    # Ensure total count consistency
    total_items = len(classification_indices['Known']) + len(classification_indices['Partial_known']) + \
                  len(classification_indices['Unknown']) + len(classification_indices['Cut-off'])
    expected_total = len(new_data.get('Known', [])) + len(new_data.get('Partial_known', [])) + \
                     len(new_data.get('Unknown', [])) + len(new_data.get('Cut-off', []))

    # Check if the total count is consistent
    if total_items != expected_total:
        raise ValueError(f"Total item count mismatch! Expected: {expected_total}, Got: {total_items}")

    return fine_grained_results, classification_indices, previous_known_count

In [None]:
def get_results_previous_span(model_name_list, domain):
    results_list_temp_state = []
    
    for t_state in ["Dynamic", "Static"]:
        results_list = []

        for model_name in tqdm(model_name_list):
            bench, temp0_parsed_time, temp7_parsed_time = load_result(
                                                                    model_name=model_name,
                                                                    domain=domain,
                                                                    temp_state=t_state,   
                                                                    mode="generation"
                                                                    )
            
            new_data = read_json_file(f"./ChronoGap/{model_name}/Updated_Timestamp_{domain}_{t_state}_step.json")

            results, indices, previous_known_count = results_of_previous_span(
                new_data, bench, domain
            )

            total = sum(results.values())

            previous_known_percentage = (previous_known_count / total) * 100 if total > 0 else 0
            current_known_count = results.get('Known', 0)
            current_known_percentage = (current_known_count / total) * 100 if total > 0 else 0
            known_increase_percentage = current_known_percentage - previous_known_percentage

            print(f"\n[model: {model_name}, temp_state: {t_state}] ChroKnowPrompt results:")
            for category, count in results.items():
                percentage = (count / total) * 100 if total > 0 else 0
                if category == 'Known':
                    print(f"{category}: {count} ({percentage:.2f}%) (+{known_increase_percentage:.2f}% overall increase)")
                else:
                    print(f"{category}: {count} ({percentage:.2f}%)")
            print(f"Total: {total}")

            results_list.append({
                "model_name": model_name,
                "results": results,
                "indices": indices,
                "previous_known_count": previous_known_count,
                "total": total,
                "previous_known_percentage": previous_known_percentage,
                "current_known_count": current_known_count,
                "current_known_percentage": current_known_percentage,
                "known_increase_percentage": known_increase_percentage
            })

        results_list_temp_state.append(results_list)

    return results_list_temp_state

In [None]:
model_name_list = ["Llama3.1_8B", "Mistral7B", "Phi3.5_Mini", "SOLAR_10.7B", "Gemma2_9B", 'gpt-4o-mini']

# time variant domains: General, Biomedical, Legal
domain = "General"

results_list_temp_state = get_results_previous_span(model_name_list=model_name_list,
                 domain=domain)