# Post-processing of Generated Citations

## Overview

This notebook performs the **post-processing** of the citations generated in the previous step, refining and structuring the outputs as needed.

## Configuration

At the beginning of the notebook, update the **path variables** to specify the input and output directories used throughout the workflow.


In [None]:
import json 
import re
import numpy as np 
import pandas as pd 
from pathlib import Path

file_name = "Pakistan_Monsoon floods and rains in Pakistan-Week 31 2024-prompt-1"

answer_path_one_file = f"./Results/Answers/Answers-subtopics/Dev set/answes-{file_name}.json"



output_path = Path("./Results/Answers/Answers-subtopics/Dev set/Updated_citations")
output_path_one_file = output_path / f"answers_new_cit-{file_name}.json"


# Open Answers file


In [None]:

with open(answer_path_one_file, 'r') as f:
    answers_file = json.load(f)

# Filter file

I filter the questions which have not a generated answer and in the ones in which I have one I remove the retrieved docs that are not used in the answer

In [None]:
def preprocess_data(answers_file):
    """
    Preprocesses the raw loaded JSON data to filter items with citations
    and associate citations with their corresponding contexts.

    Args:
        answers_file (list): The list of dictionaries loaded from the JSON file.

    Returns:
        list: A list of dictionaries, where each dictionary has 'citations'
              and 'used_contexts' keys added.
    """
    filtered_file = []
    for item in answers_file:
        answer = item.get('retrieved_answer', '')
        # Only include items that potentially have citations
        if "[" in answer or "]" in answer:
            filtered_file.append(item)

    preprocessed_items = []
    for item in filtered_file:
        citations = []
        answer = item.get('retrieved_answer', '')
        retrieved_context = item.get('retrieved_contexts', [])

        # Find all matches like [1], [2], etc.
        matches_str = re.findall(r'\[\s*(\d+)\s*\]', answer)

        if matches_str:
            citations_int = list(map(int, matches_str))
            unique_sorted_citations = sorted(set(citations_int))
            citations = unique_sorted_citations

        used_contexts = []
        for i in citations:
            index = i - 1  # Adjust for 0-based indexing
            if 0 <= index < len(retrieved_context):
                used_contexts.append(retrieved_context[index])
            else:
                print(f"⚠️ Citation [{i}] is out of bounds (retrieved_context has {len(retrieved_context)} elements).")
                #Suppress warning for batch processing if desired, or log it.

        item['citations'] = citations
        item['used_contexts'] = used_contexts
        preprocessed_items.append(item)
    return preprocessed_items

In [None]:
filtered_file = preprocess_data(answers_file)

In [None]:
filtered_file

# Update citations 

In [None]:
def jaccard_similarity(text1, text2):
    tokens1 = set(text1.lower().split())
    tokens2 = set(text2.lower().split())
    intersection = tokens1.intersection(tokens2)
    union = tokens1.union(tokens2)
    return len(intersection) / len(union) if union else 0

In [None]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

def update_citations_in_clusters(
    data_list, 
    mu=0.8, 
    threshold=0.3, 
    model_name='nomic-ai/modernbert-embed-base'):
    
    try:
        model = SentenceTransformer(model_name)
    except Exception as e:
        print(f"Error loading SentenceTransformer model '{model_name}': {e}")
        # Return a copy of the original data if model loading fails
        return [item.copy() for item in data_list], 0, 0 # Return counts as 0, 0

    processed_data_list = []
    total_original_citations = 0
    total_changed_citations = 0

    for original_cluster_data in data_list:
        cluster = original_cluster_data.copy() # Work on a copy to avoid modifying original list elements directly

        llm_response = cluster.get('retrieved_answer')
        retrieved_docs = cluster.get('retrieved_contexts') 
        query = cluster.get('question')

       
        if not llm_response or not query:
            cluster['updated_retrieved_answer'] = llm_response 
            cluster['new_citations'] = []
            cluster['new_used_contexts'] = []
            old_citations_for_cluster = cluster.pop('citations', []) 
            cluster['old_citations'] = old_citations_for_cluster
            cluster['old_used_contexts'] = cluster.pop('used_contexts', [])
            
            # Count original citations for clusters that are skipped
            total_original_citations += len(old_citations_for_cluster)

            processed_data_list.append(cluster)
            continue

        # Rename original citation fields (if they exist)
        old_citations_for_cluster = cluster.pop('citations', [])
        cluster['old_citations'] = old_citations_for_cluster
        cluster['old_used_contexts'] = cluster.pop('used_contexts', [])

        # --- Core citation update logic ---
        pattern = re.compile(r"(.*?)((\[\d+\])+)", re.DOTALL)
        llm_matches = list(pattern.finditer(llm_response))

        
        query_embedding = model.encode([query]) # Shape: (1, embedding_dim)
        
        # Compute cosine scores between query and all retrieved documents
        cosine_scores = np.array([]) 
        if retrieved_docs: 
            doc_embeddings = model.encode(retrieved_docs) # Shape: (num_docs, embedding_dim)
            # Ensure doc_embeddings is not empty before calculating cosine similarity
            if doc_embeddings.ndim == 2 and doc_embeddings.shape[0] > 0:
                 cosine_scores = cosine_similarity(query_embedding, doc_embeddings)[0] # Shape: (num_docs,)

        updated_response_content = llm_response
        current_offset = 0 
        all_new_overall_citation_indices = set() 

        if not llm_matches: # No citation patterns found in the response
            cluster['updated_retrieved_answer'] = llm_response
            cluster['new_citations'] = []
            cluster['new_used_contexts'] = []
        else:
            for match in llm_matches:
                text_piece_from_match = match.group(1) 
                text_piece_stripped = text_piece_from_match.strip()
                original_citation_markers_str = match.group(2) 
                
                # Determine k: the number of original citation markers for this piece
                original_indices_for_this_piece = [int(num) for num in re.findall(r'\d+', original_citation_markers_str)]
                k = len(original_indices_for_this_piece)
                
                new_citation_markers_for_this_piece_str = "" 
                current_piece_new_valid_indices = []

                if k == 0: # No valid numeric citations found in the original marker
                    pass # new_citation_markers_for_this_piece_str remains "", effectively removing non-numeric markers
                elif not retrieved_docs or cosine_scores.size == 0: 
                    # No documents to cite against, or query-doc scores couldn't be computed (e.g., all docs empty)
                    pass # new_citation_markers_for_this_piece_str remains "", removing citations
                else:
                    # Score this text_piece against all retrieved_docs
                    piece_document_scores = []
                    for j, doc_text_content in enumerate(retrieved_docs):
                        jaccard_sim = jaccard_similarity(text_piece_stripped, doc_text_content)
                        
                        cosine_sim_query_doc = cosine_scores[j] 
                        combined_score = mu * jaccard_sim + (1 - mu) * cosine_sim_query_doc
                        piece_document_scores.append((combined_score, j + 1)) 
                    
                    
                    piece_document_scores.sort(reverse=True, key=lambda x: x[0])
                    
                    
                    top_k_scored_docs = piece_document_scores[:k]
                    
                    # Filter these top_k by the threshold
                    current_piece_new_valid_indices = [
                        doc_idx for score, doc_idx in top_k_scored_docs if score >= threshold
                    ]
                    
                    if not current_piece_new_valid_indices and top_k_scored_docs:
                        current_piece_new_valid_indices = [top_k_scored_docs[0][1]]
                    
                    # Sort the final new citation indices for this piece
                    current_piece_new_valid_indices.sort()
                    new_citation_markers_for_this_piece_str = ''.join([f'[{idx}]' for idx in current_piece_new_valid_indices])

                
                for idx in current_piece_new_valid_indices:
                    all_new_overall_citation_indices.add(idx)
                
                
                replacement_text_segment = text_piece_stripped + new_citation_markers_for_this_piece_str
                
                # Apply replacement in the updated_response_content string
                match_original_start_pos = match.start()
                match_original_end_pos = match.end()
                
                updated_response_content = updated_response_content[:match_original_start_pos + current_offset] + \
                                           replacement_text_segment + \
                                           updated_response_content[match_original_end_pos + current_offset:]
                
               
                current_offset += len(replacement_text_segment) - (match_original_end_pos - match_original_start_pos)

            # Finalize results for the cluster
            cluster['updated_retrieved_answer'] = updated_response_content
            
            final_sorted_new_citation_indices = sorted(list(all_new_overall_citation_indices))
            cluster['new_citations'] = final_sorted_new_citation_indices
            
            
            if retrieved_docs:
                cluster['new_used_contexts'] = [
                    retrieved_docs[doc_idx-1] for doc_idx in final_sorted_new_citation_indices 
                    if 0 < doc_idx <= len(retrieved_docs) 
                ]
            else:
                cluster['new_used_contexts'] = []

        processed_data_list.append(cluster)
        
        
        total_original_citations += len(old_citations_for_cluster)
        if set(cluster['new_citations']) != set(old_citations_for_cluster):
            total_changed_citations += 1
            
    return processed_data_list, total_original_citations, total_changed_citations

In [None]:
new_file, tot_citations, changed_citations = update_citations_in_clusters(filtered_file)

In [None]:
# import json 



with open(output_path_one_file, 'w') as f:
    json.dump(new_file, f)


import json 
with open(output_path_one_file, 'r') as f:
    new_file= json.load(f)
    