In [None]:
from typing import List, Tuple, Dict
from bert_score import score as bert_score
from collections import Counter
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def is_position_match(pred_pos: Tuple[int, int], gold_pos: Tuple[int, int]) -> bool:
    return pred_pos == gold_pos 

def evaluate_missing_prediction(
    gold: Dict[str, List],
    pred: Dict[str, List],
    model_type: str = "roberta-large",  
    lang: str = "en", 
) -> Dict[str, float]:

    gold_positions = gold["positions"]
    gold_texts = gold["texts"]
    pred_positions = pred["positions"]
    pred_texts = pred["texts"]

    if len(gold_positions) == 0:
        if len(pred_positions) == 0:
            return {
                "precision_pos": 1.0,
                "recall_pos": 1.0,
                "f1_pos": 1.0,
                "redundancy_rate": 0.0,
                "text_score_position_aware": 1.0 
            }
        else:
            return {
                "precision_pos": 0.0,
                "recall_pos": 1.0,
                "f1_pos": 0.0,
                "redundancy_rate": 1.0,
                "text_score_position_aware": 0.0
            }

    matched = []
    unmatched_pred_indices = set(range(len(pred_positions)))
    matched_texts_pred = []
    matched_texts_gold = []

    for i, gpos in enumerate(gold_positions):
        matched_flag = False
        for j, ppos in enumerate(pred_positions):
            if is_position_match(ppos, gpos):
                matched.append((i, j))
                matched_texts_gold.append(gold_texts[i])
                matched_texts_pred.append(pred_texts[j])
                unmatched_pred_indices.discard(j)
                matched_flag = True
                break  
        if not matched_flag:
            matched_texts_gold.append(gold_texts[i])
            matched_texts_pred.append("")

    tp = len(matched)
    fp = len(unmatched_pred_indices)
    fn = len(gold_positions) - tp

    precision_pos = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall_pos = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_pos = (2 * precision_pos * recall_pos / (precision_pos + recall_pos)) if (precision_pos + recall_pos) > 0 else 0.0
    redundancy_rate = fp / (tp + fp) if (tp + fp) > 0 else 0.0

    P, R, F1 = bert_score(matched_texts_pred, matched_texts_gold, lang=lang, model_type=model_type, verbose=False)
    text_score = float(F1.mean().item()) 

    return {
        "precision_pos": round(precision_pos, 4),
        "recall_pos": round(recall_pos, 4),
        "f1_pos": round(f1_pos, 4),
        "redundancy_rate": round(redundancy_rate, 4),
        "text_score_position_aware": round(text_score, 4),
    }
from bert_score import score as bert_score
from tqdm import tqdm

def batch_evaluate(
    gold_list: List[Dict[str, List]],
    pred_list: List[Dict[str, List]],
    model_type: str = "roberta-large",
    lang: str = "en",
    show_progress: bool = True,
) -> Dict[str, float]:
    assert len(gold_list) == len(pred_list), "gold/pred inconsistent"

    metrics_sum = {
        "precision_pos": 0.0,
        "recall_pos": 0.0,
        "f1_pos": 0.0,
        "redundancy_rate": 0.0,
    }

    all_gold_texts = []
    all_pred_texts = []
    per_sample_matched_counts = []
    n = len(gold_list)

    iterable = zip(gold_list, pred_list)
    if show_progress:
        iterable = tqdm(iterable, total=n, desc="Evaluating")

    for gold, pred in iterable:
        gold_pos, pred_pos = gold["positions"], pred["positions"]
        gold_texts, pred_texts = gold["texts"], pred["texts"]

        matched = []
        for i, g in enumerate(gold_pos):
            for j, p in enumerate(pred_pos):
                if g == p:
                    matched.append((i, j))
                    break

        matched_gold_texts = [gold_texts[i] for i, _ in matched]
        matched_pred_texts = [pred_texts[j] for _, j in matched]

        all_gold_texts.extend(matched_gold_texts)
        all_pred_texts.extend(matched_pred_texts)
        per_sample_matched_counts.append(len(matched))

        true_positives = len(matched)
        precision_pos = true_positives / len(pred_pos) if pred_pos else 0.0
        recall_pos = true_positives / len(gold_pos) if gold_pos else 0.0
        f1_pos = (
            2 * precision_pos * recall_pos / (precision_pos + recall_pos)
            if precision_pos + recall_pos > 0 else 0.0
        )
        redundancy_rate = (len(pred_pos) - true_positives) / len(pred_pos) if pred_pos else 0.0

        metrics_sum["precision_pos"] += precision_pos
        metrics_sum["recall_pos"] += recall_pos
        metrics_sum["f1_pos"] += f1_pos
        metrics_sum["redundancy_rate"] += redundancy_rate

    if all_gold_texts:
        _, _, F1 = bert_score(
            all_pred_texts, all_gold_texts,
            lang=lang, model_type=model_type, verbose=False, device="cuda"
        )

        f1_list = F1.tolist()
        idx = 0
        sample_text_scores = []
        for count in per_sample_matched_counts:
            if count == 0:
                sample_text_scores.append(None)
            else:
                sample_f1 = sum(f1_list[idx:idx + count]) / len(gold_list[sample_text_scores.__len__()]["texts"]) 
                sample_text_scores.append(sample_f1)
                idx += count
        avg_text_score = round(sum(f for f in sample_text_scores if f is not None) / n, 4)
    else:
        avg_text_score = None

    avg_metrics = {k: round(v / n, 4) for k, v in metrics_sum.items()}
    avg_metrics["text_score_position_aware"] = avg_text_score

    return avg_metrics


In [None]:
import pandas as pd
df = pd.read_json("ScaleQM+_test.json")

In [None]:
import re
import pickle
from collections import OrderedDict

def extract_steps(text):
    pattern = r"(?:Step\s*|step)(\d+):\n(.*?)(?=(?:Step\s*|step)\d+:\n|</incomplete_solution>)"
    matches = re.finditer(pattern, text, re.DOTALL | re.IGNORECASE)

    steps = OrderedDict()
    for match in matches:
        step_num = match.group(1)
        step_content = match.group(2).strip()
        step_label = f"step{step_num}"
        steps[step_label] = step_content

    return steps

cnt = 0
for i in range(len(df)):
    steps= extract_steps(df.iloc[i]["messages"][1]["content"])
    cnt = cnt + len(steps) - 1
print(cnt)

with open('results-sim.pkl', 'rb') as f:
    loaded_results = pickle.load(f)
print(len(loaded_results))

In [None]:
# use next step as similarity reference
t = []
for i in range(len(df)):
    temp = list(extract_steps(df.iloc[i]["messages"][1]["content"]).values())
    for j in range(len(temp)):
        if j==0:
            continue
        t.append(temp[j])
print(len(t))

In [None]:
from bert_score import score as bert_score
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
P, R, F1 = bert_score(loaded_results, t, lang='en', model_type="roberta-large", verbose=False, device="cuda")
f1_list = F1.tolist()

In [None]:
cnt = 0
pred_list = []
for i in range(len(df)):
    temp = list(extract_steps(df.iloc[i]["messages"][1]["content"]).values())
    predict = {
        "positions": [],
        "texts": []
    }
    for j in range(cnt, cnt + len(temp)-1):
        if f1_list[j] < 0.95: # Threshold
            predict["positions"].append((j-cnt, j-cnt+1))
            predict["texts"].append(loaded_results[j])
    pred_list.append(predict)
    cnt += len(temp)-1

In [None]:
import pickle
with open('gold.pkl', 'rb') as f:
    gold = pickle.load(f)

In [None]:
results = batch_evaluate(gold, pred_list)
print(results)

In [None]:
# results
# 1    {'precision_pos': 0.2096, 'recall_pos': 0.7964, 'f1_pos': 0.3176, 'redundancy_rate': 0.7904, 'text_score_position_aware': 0.7572}
# 0.95 {'precision_pos': 0.2342, 'recall_pos': 0.7507, 'f1_pos': 0.3406, 'redundancy_rate': 0.7657, 'text_score_position_aware': 0.7122}
# 0.9  {'precision_pos': 0.2481, 'recall_pos': 0.5921, 'f1_pos': 0.3258, 'redundancy_rate': 0.7437, 'text_score_position_aware': 0.559}
# 0.85 {'precision_pos': 0.1934, 'recall_pos': 0.3017, 'f1_pos': 0.2119, 'redundancy_rate': 0.6387, 'text_score_position_aware': 0.2826}
# 0.8  {'precision_pos': 0.0747, 'recall_pos': 0.0865, 'f1_pos': 0.0718, 'redundancy_rate': 0.3251, 'text_score_position_aware': 0.0784}