# Post-processing of Executive Summary Citations

## Overview

This notebook performs the **post-processing** of the citations contained in the **executive summaries** generated in the previous step.  
The goal is to refine, standardize, and ensure the consistency of citation formatting.

## Configuration

At the beginning of the notebook, update the **path variables** to specify the input and output directories used throughout the workflow.


In [None]:
import os 

input_path = f'./Results/Executive Summary/Dev set'
output_path = os.path.join(input_path, "Updated_citations") 

In [None]:
import re
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

def jaccard_similarity(text1, text2):
    """Compute Jaccard similarity between two texts (set of words)."""
    set1, set2 = set(text1.lower().split()), set(text2.lower().split())
    if not set1 or not set2:
        return 0.0
    return len(set1 & set2) / len(set1 | set2)

def update_citations_in_file(
    file_data,
    mu=0.8,
    threshold=0.3,
    model_name='nomic-ai/modernbert-embed-base'
):
    """
    Update citations in a file with structure:
    {
      "summary": "...",
      "cited_paragraphs": [...],
      "full_paragraphs": [...]
    }
    """

    try:
        model = SentenceTransformer(model_name)
    except Exception as e:
        print(f"Error loading SentenceTransformer model '{model_name}': {e}")
        return file_data  # Return unchanged if model fails

    summary = file_data.get("summary", "")
    retrieved_docs = file_data.get("full_paragraphs", [])

    if not summary or not retrieved_docs:
        file_data["new_summary"] = summary
        file_data["new_citations"] = []
        file_data["new_cited_paragraphs"] = []
        return file_data

    
    pattern = re.compile(r"(.*?)((\[\d+\])+)", re.DOTALL)
    matches = list(pattern.finditer(summary))

    
    query_embedding = model.encode([summary])

    
    cosine_scores = np.array([])
    if retrieved_docs:
        doc_embeddings = model.encode(retrieved_docs)
        if doc_embeddings.ndim == 2 and doc_embeddings.shape[0] > 0:
            cosine_scores = cosine_similarity(query_embedding, doc_embeddings)[0]

    updated_summary = summary
    current_offset = 0
    all_new_indices = set()

    if matches:
        for match in matches:
            text_piece = match.group(1).strip()
            original_citation_str = match.group(2)
            original_indices = [int(num) for num in re.findall(r'\d+', original_citation_str)]
            k = len(original_indices)

            new_indices = []
            new_marker_str = ""

            if k > 0 and retrieved_docs and cosine_scores.size > 0:
                piece_scores = []
                for j, doc_text in enumerate(retrieved_docs):
                    jac = jaccard_similarity(text_piece, doc_text)
                    cos = cosine_scores[j]
                    combined = mu * jac + (1 - mu) * cos
                    piece_scores.append((combined, j + 1))  # 1-based index

                piece_scores.sort(reverse=True, key=lambda x: x[0])
                top_k = piece_scores[:k]

                new_indices = [doc_idx for score, doc_idx in top_k if score >= threshold]
                if not new_indices and top_k:
                    new_indices = [top_k[0][1]]

                new_indices.sort()
                new_marker_str = ''.join([f'[{i}]' for i in new_indices])

            
            for idx in new_indices:
                all_new_indices.add(idx)

            replacement = text_piece + new_marker_str
            start, end = match.start(), match.end()
            updated_summary = (
                updated_summary[:start + current_offset] +
                replacement +
                updated_summary[end + current_offset:]
            )
            current_offset += len(replacement) - (end - start)

    # Save results back
    file_data["new_summary"] = updated_summary
    file_data["new_citations"] = sorted(list(all_new_indices))
    file_data["new_cited_paragraphs"] = [
        retrieved_docs[i - 1] for i in file_data["new_citations"]
        if 0 < i <= len(retrieved_docs)
    ]

    return file_data


# Process all the files in a folder 

In [None]:
import os
import json

os.makedirs(output_path, exist_ok=True)

# Loop through all JSON files in the directory
for fname in os.listdir(input_path):
    if not fname.endswith(".json"):
        continue

    fpath = os.path.join(input_path, fname)
    try:
        with open(fpath, "r", encoding="utf-8") as f:
            file_data = json.load(f)
    except Exception as e:
        print(f"⚠️ Skipping {fname} (error reading JSON: {e})")
        continue

    # Update citations
    updated_data = update_citations_in_file(file_data)

    # Save output JSON
    out_fpath = os.path.join(output_path, fname)
    try:
        with open(out_fpath, "w", encoding="utf-8") as f:
            json.dump(updated_data, f, ensure_ascii=False, indent=2)
        print(f"✅ Processed and saved: {out_fpath}")
    except Exception as e:
        print(f"⚠️ Could not save {fname}: {e}")