In [None]:
import pandas as pd
import json
from collections import defaultdict

# Load the brief-cleaned1k.tsv file
df_brief = pd.read_csv('./brief_records/brief-cleaned1k.tsv', sep='\t')

# Load the validation.jsonl file
dataset = []
with open("./brief_records/validation.jsonl", "r") as infile:
    for line in infile:
        dataset.append(json.loads(line))

# Create a dictionary to map text to labels from the validation dataset
validation_labels = defaultdict(list)
for entry in dataset:
    try:
        text = entry["text"]
        label = entry["label"]
        validation_labels[text].append(label)
    except KeyError as e:
        print(f"Skipping entry due to missing key: {e}")

# Consolidate the data
consolidated_data = []
for index, row in df_brief.iterrows():
    mmsid = row['mmsid']
    titleauthor = row['titleauthor']
    labels = validation_labels[titleauthor]
    consolidated_data.append({
        "mmsid": mmsid,
        "titleauthor": titleauthor,
        "label": list(set(labels))  # Remove duplicates, if any
    })

# Save the consolidated data as a JSON file
with open("./consolidated_validation.json", "w") as outfile:
    json.dump(consolidated_data, outfile, indent=4)

print("Validation data formatted and saved to 'consolidated_validation.json'")

Validation data formatted and saved to 'consolidated_validation.json'


In [None]:
import json
import numpy as np
import os
import sys

############################
# Metric computation functions
############################
def precision_at_k(y_true, y_pred, k):
    y_pred_at_k = y_pred[:k]
    relevant_at_k = set(y_pred_at_k) & set(y_true)
    return len(relevant_at_k) / k if k else 0

def recall_at_k(y_true, y_pred, k):
    if not y_true:  # Handle empty ground truth
        return 0
    y_pred_at_k = y_pred[:k]
    relevant_at_k = set(y_pred_at_k) & set(y_true)
    return len(relevant_at_k) / len(y_true)

def f1_at_k(y_true, y_pred, k):
    p_at_k = precision_at_k(y_true, y_pred, k)
    r_at_k = recall_at_k(y_true, y_pred, k)
    denom = p_at_k + r_at_k
    return (2 * p_at_k * r_at_k / denom) if denom else 0

############################
# Compute weighted average metrics
############################
def weighted_avg_metrics_at_k(y_true_list, y_pred_list, k_values):
    weighted_precision = []
    weighted_recall = []
    weighted_f1 = []

    for k in k_values:
        precision_scores = []
        recall_scores = []
        f1_scores = []
        weights = []

        for y_true, y_pred in zip(y_true_list, y_pred_list):
            weight = len(y_true)
            weights.append(weight)
            precision_scores.append(precision_at_k(y_true, y_pred, k) * weight)
            recall_scores.append(recall_at_k(y_true, y_pred, k) * weight)
            f1_scores.append(f1_at_k(y_true, y_pred, k) * weight)

        total_weight = sum(weights)
        avg_precision_val = sum(precision_scores) / total_weight if total_weight else 0
        avg_recall_val = sum(recall_scores) / total_weight if total_weight else 0
        avg_f1_val = sum(f1_scores) / total_weight if total_weight else 0

        print(f"[DEBUG] k={k} -> Weighted Precision: {avg_precision_val}, Weighted Recall: {avg_recall_val}, Weighted F1: {avg_f1_val}")

        weighted_precision.append(avg_precision_val)
        weighted_recall.append(avg_recall_val)
        weighted_f1.append(avg_f1_val)

    return weighted_precision, weighted_recall, weighted_f1

############################
# Main evaluation script
############################
def main():
    # Where our ground truths live
    path_consolidated = "./brief_records/consolidated_validation.json"
    # Where our predictions live (one file per mmsid)
    json_dir = "./brief_records/json_files"

    # Load the consolidated validation file
    try:
        with open(path_consolidated, "r") as f:
            consolidated_data = json.load(f)
    except Exception as e:
        print(f"[DEBUG] Error loading {path_consolidated}: {e}")
        sys.exit(1)

    # Prepare to collect data
    y_true_list = []
    y_pred_list = []

    print(f"[DEBUG] Found {len(consolidated_data)} entries in consolidated_validation.json")

    for idx, item in enumerate(consolidated_data):
        mmsid = str(item["mmsid"])  # Convert to string for matching filename
        ground_truth_subjects = item.get("label", [])

        # Build path to predicted JSON
        pred_file = os.path.join(json_dir, f"{mmsid}.json")

        if not os.path.exists(pred_file):
            print(f"[DEBUG] No prediction file for mmsid={mmsid}, skipping...")
            continue

        # Load predictions
        with open(pred_file, "r") as pf:
            try:
                data = json.load(pf)
            except json.JSONDecodeError as e:
                print(f"[DEBUG] Could not parse {pred_file} -> {e}")
                continue

        # We assume each file can be either a single dict or a list of dicts
        if isinstance(data, dict):
            entries = [data]
        elif isinstance(data, list):
            entries = data
        else:
            print(f"[DEBUG] Skipping {pred_file}, unexpected data format: {type(data)}")
            continue

        # For your case, the "dcterms_subject" field is the predicted labels
        all_predicted = []
        for e_idx, entry in enumerate(entries):
            predicted_subs = entry.get("dcterms_subject", [])
            all_predicted.extend(predicted_subs)

        # Debug prints
        print(f"[DEBUG] Item {idx} => mmsid={mmsid}")
        print(f"[DEBUG]   Ground truth: {ground_truth_subjects}")
        print(f"[DEBUG]   Predicted: {all_predicted}")

        # Append to overall lists
        y_true_list.append(ground_truth_subjects)
        y_pred_list.append(all_predicted)

    # If no data is found
    if not y_true_list or not y_pred_list:
        print("[DEBUG] No valid data found! Check your JSON structure or directory.")
        sys.exit(0)

    # Define k values
    k_values = [1, 3, 5, 10, 15, 20, 25, 30]

    # Compute weighted average metrics
    weighted_precision, weighted_recall, weighted_f1 = weighted_avg_metrics_at_k(y_true_list, y_pred_list, k_values)

    print("\n[DEBUG] Final weighted average metrics across all items:")
    print("Weighted Average Precision@k:", weighted_precision)
    print("Weighted Average Recall@k:", weighted_recall)
    print("Weighted Average F1-score@k:", weighted_f1)

if __name__ == "__main__":
    main()

[DEBUG] Found 1000 entries in consolidated_validation.json
[DEBUG] Item 0 => mmsid=1
[DEBUG]   Ground truth: ['gnd:4167885-0']
[DEBUG]   Predicted: ['gnd:4116546-9', 'gnd:4056723-0', 'gnd:4077811-3', 'gnd:4169187-8', 'gnd:4563270-4', 'gnd:4037877-9', 'gnd:4020015-2', 'gnd:4077624-4', 'gnd:4227561-1', 'gnd:4077587-2', 'gnd:4125698-0', 'gnd:4133431-0', 'gnd:4055676-1', 'gnd:4012475-7', 'gnd:4062110-8', 'gnd:4020588-5', 'gnd:4027266-7', 'gnd:4048561-4', 'gnd:4268059-1', 'gnd:4172385-5', 'gnd:4120627-7', 'gnd:4122782-7', 'gnd:4001307-8', 'gnd:4234987-4', 'gnd:4055768-6', 'gnd:4043183-6', 'gnd:4182741-7', 'gnd:4055762-5', 'gnd:4033597-5', 'gnd:4114011-4', 'gnd:4034929-9', 'gnd:4026482-8', 'gnd:4138354-0', 'gnd:4077567-7', 'gnd:4033581-1', 'gnd:4136812-5', 'gnd:4113843-0', 'gnd:4031883-7', 'gnd:4056730-8', 'gnd:4291675-6', 'gnd:4133695-1', 'gnd:4076536-2', 'gnd:4033596-3', 'gnd:4125858-7', 'gnd:4169194-5', 'gnd:4312811-7', 'gnd:4131484-0', 'gnd:4066615-3', 'gnd:4639271-3', 'gnd:4033542-2']
[

In [1]:
import numpy as np

weighted_recall = [0.08260105448154657, 0.18453427065026362, 0.24311657879320445, 0.3216168717047452, 0.37024018746338605, 0.39777387229056826, 0.4276508494434681, 0.44522554188635033]
mean_recall = np.mean(weighted_recall)
print("Mean Weighted Recall:", mean_recall)

Mean Weighted Recall: 0.30909490333919154
