How we processed claims to make train_all_processed.json and dev_all_process.json


In [None]:
import json
import os
from tqdm import tqdm
from feverous.database.feverous_db import FeverousDB
from feverous.utils.wiki_page import WikiPage
from feverous.utils.annotation_processor import AnnotationProcessor

class FeverousVerifier:
    def __init__(self, db_path):
        """Initialize the FEVEROUS verifier with database path."""
        print("Initializing FEVEROUS database...")
        self.db = FeverousDB(db_path)
        self.results = {"SUPPORTS": 0, "REFUTES": 0, "NOT ENOUGH INFO": 0}
        self.processed_claims = []
        self.page_cache = {}

    def load_dataset(self, input_path):
        """Load dataset with optimized evidence extraction."""
        print(f"Loading data from {input_path}")
        annotations = list(AnnotationProcessor(input_path))
        print(f"Loaded {len(annotations)} examples from {input_path}")
        return annotations

    def get_wiki_page(self, page_name):
        """Retrieve Wikipedia page with caching for efficiency."""
        if page_name in self.page_cache:
            return self.page_cache[page_name]

        page_json = self.db.get_doc_json(page_name)
        if not page_json:
            return None

        wiki_page = WikiPage(page_name, page_json)
        self.page_cache[page_name] = wiki_page
        return wiki_page

    def process_claims(self, annotations, batch_size=5000, output_prefix="batch", final_json_file=None):
        """Process claims in batches, saving results for each batch and a final JSON file."""
        total = len(annotations)
        print(f"Processing {total} annotations in batches of {batch_size}")

        batch_claims = []
        batch_results = {"SUPPORTS": 0, "REFUTES": 0, "NOT ENOUGH INFO": 0}
        total_results = {"SUPPORTS": 0, "REFUTES": 0, "NOT ENOUGH INFO": 0}

        # Store all processed claims
        all_claims = [] if final_json_file else None

        for i, annotation in enumerate(tqdm(annotations, desc="Processing FEVEROUS data")):
            result, evidence = self.verify_claim(annotation)

            # Update batch and total statistics
            batch_results[result] = batch_results.get(result, 0) + 1
            total_results[result] = total_results.get(result, 0) + 1

            # Create claim object
            claim_obj = {
                "id": annotation.id,
                "claim": annotation.claim,
                "verdict": annotation.verdict,
                "verification_result": result,
                "evidence": evidence
            }

            # Add to batch claims
            batch_claims.append(claim_obj)

            # Store for final JSON file
            if all_claims is not None:
                # Make a serializable copy to ensure it can be saved later
                serializable_claim = self._make_serializable(claim_obj)
                all_claims.append(serializable_claim)

            # Save batch when it reaches the batch size or at the end
            if (i + 1) % batch_size == 0 or i == total - 1:
                batch_num = (i // batch_size) + 1
                start_idx = ((batch_num - 1) * batch_size) + 1
                end_idx = min(batch_num * batch_size, total)

                batch_file = f"{output_prefix}_{start_idx}_to_{end_idx}.json"

                # Save the batch
                self.save_batch(batch_claims, batch_results, batch_file)

                # Reset for next batch
                batch_claims = []
                batch_results = {"SUPPORTS": 0, "REFUTES": 0, "NOT ENOUGH INFO": 0}

                # Clear cache between batches
                self.page_cache = {}
                print(f"Completed batch {batch_num}: saved examples {start_idx}-{end_idx} to {batch_file}")

            # Clear cache periodically for memory management within a batch
            elif (i + 1) % 1000 == 0:
                self.page_cache = {}

        # Save the final JSON file with all claims
        if final_json_file and all_claims:
            print(f"Saving all {len(all_claims)} processed claims to {final_json_file}...")
            with open(final_json_file, 'w') as f:
                json.dump({
                    "statistics": total_results,
                    "processed_claims": all_claims
                }, f, indent=2)
            print(f"📄 All claims saved to {final_json_file}")

        # Create a summary file with overall statistics
        summary_file = f"{output_prefix}_summary.json"
        with open(summary_file, 'w') as f:
            json.dump({
                "total_processed": total,
                "total_statistics": total_results
            }, f, indent=2)
        print(f"Summary saved to {summary_file}")

        return total_results

    def _make_serializable(self, claim_obj):
        """Make a claim object fully serializable."""
        serializable_claim = {
            "id": claim_obj["id"],
            "claim": claim_obj["claim"],
            "verdict": claim_obj["verdict"],
            "verification_result": claim_obj["verification_result"],
            "evidence": []
        }

        # Convert evidence to serializable format
        for evidence_set in claim_obj["evidence"]:
            serializable_set = []
            for ev in evidence_set:
                serializable_ev = {
                    "id": ev["id"],
                    "content": str(ev["content"]),  # Ensure content is a string
                }
                if "error" in ev:
                    serializable_ev["error"] = ev["error"]
                serializable_set.append(serializable_ev)
            serializable_claim["evidence"].append(serializable_set)

        return serializable_claim

    def verify_claim(self, annotation):
        """Verify a claim by extracting only relevant evidence."""
        claim_id = getattr(annotation, "id", "unknown")

        if not annotation.evidence:
            return "NOT ENOUGH INFO", []

        # Group evidence IDs by page to minimize database lookups
        evidence_by_page = {}
        for evidence_set in annotation.evidence:
            for ev_id in evidence_set:
                page_name = ev_id.split("_")[0]
                if page_name not in evidence_by_page:
                    evidence_by_page[page_name] = []
                evidence_by_page[page_name].append(ev_id)

        # Extract all relevant evidence efficiently
        all_evidence = []
        missing_evidence = []

        for page_name, ev_ids in evidence_by_page.items():
            wiki_page = self.get_wiki_page(page_name)
            if not wiki_page:
                missing_evidence.extend(ev_ids)
                print(f"Warning: Page '{page_name}' not found for claim {claim_id}")
                continue

            # Extract evidence for this page in a single pass
            try:
                page_evidence = self.extract_page_evidence(wiki_page, ev_ids)
                all_evidence.extend(page_evidence)
            except Exception as e:
                print(f"Error processing evidence for claim {claim_id}: {str(e)}")
                missing_evidence.extend(ev_ids)

        # Organize evidence by evidence sets as they appear in the annotation
        organized_evidence = []
        for evidence_set in annotation.evidence:
            set_evidence = []
            for ev_id in evidence_set:
                # Find this evidence in our extracted evidence
                found = False
                for ev in all_evidence:
                    if ev["id"] == ev_id:
                        set_evidence.append(ev)
                        found = True
                        break

                if not found:
                    set_evidence.append({"id": ev_id, "content": "Evidence not found", "error": True})

            organized_evidence.append(set_evidence)

        # Return both the verdict and the organized evidence
        return annotation.verdict, organized_evidence

    def extract_page_evidence(self, wiki_page, evidence_ids):
        """Extract all evidence from a single page in an optimized way."""
        result = []

        # Pre-extract all data structures to avoid repeated calls
        sentences = wiki_page.get_sentences()
        tables = wiki_page.get_tables()
        lists = wiki_page.get_lists()

        # Process each evidence ID
        for ev_id in evidence_ids:
            try:
                page_name = ev_id.split("_")[0]

                if "_sentence_" in ev_id:
                    # Extract sentence evidence
                    sentence_id = int(ev_id.split("_sentence_")[1])
                    if sentence_id < len(sentences):
                        content = str(sentences[sentence_id])
                    else:
                        content = f"Error: Sentence index {sentence_id} out of range"

                elif "_cell_" in ev_id:
                    # Extract table cell evidence
                    cell_parts = ev_id.split("_cell_")[1].split("_")
                    table_idx, row_idx, cell_idx = map(int, cell_parts)
                    if table_idx < len(tables):
                        rows = tables[table_idx].get_rows()
                        if row_idx < len(rows):
                            cells = rows[row_idx].get_row_cells()
                            if cell_idx < len(cells):
                                # Convert Cell object to string representation
                                cell = cells[cell_idx]
                                if hasattr(cell, 'get_text'):
                                    content = cell.get_text()
                                else:
                                    content = str(cell)
                            else:
                                content = f"Error: Cell index {cell_idx} out of range"
                        else:
                            content = f"Error: Row index {row_idx} out of range"
                    else:
                        content = f"Error: Table index {table_idx} out of range"

                elif "_item_" in ev_id:
                    # Extract list item evidence
                    item_parts = ev_id.split("_item_")[1].split("_")
                    list_idx, item_idx = map(int, item_parts)
                    if list_idx < len(lists):
                        list_items = lists[list_idx].get_list_by_level(0)
                        if item_idx < len(list_items):
                            content = str(list_items[item_idx])
                        else:
                            content = f"Error: Item index {item_idx} out of range"
                    else:
                        content = f"Error: List index {list_idx} out of range"

                else:
                    content = f"Unknown evidence type: {ev_id}"

                result.append({"id": ev_id, "content": content})

            except Exception as e:
                print(f"Error processing evidence {ev_id}: {str(e)}")
                result.append({"id": ev_id, "content": f"Error: {str(e)}", "error": True})

        return result

    def save_batch(self, claims, results, output_file):
        """Save a batch of results with JSON serialization handling."""
        try:
            with open(output_file, 'w') as f:
                json.dump({
                    "statistics": results,
                    "processed_claims": claims
                }, f, indent=2)
            print(f"Batch saved to {output_file}")
        except TypeError as e:
            print(f"Error during JSON serialization: {str(e)}")
            print("Attempting to fix non-serializable objects...")

            # Create a serializable copy of claims
            serializable_claims = []
            for claim in claims:
                serializable_claim = self._make_serializable(claim)
                serializable_claims.append(serializable_claim)

            # Try saving again with the serializable copy
            with open(output_file, 'w') as f:
                json.dump({
                    "statistics": results,
                    "processed_claims": serializable_claims
                }, f, indent=2)
            print(f"Batch saved to {output_file} after fixing serialization issues")

# Main Execution
if __name__ == "__main__":
    data_dir = os.path.expanduser("~/Documents/data")
    db_file = os.path.join(data_dir, "feverous_wikiv1.db")
    train_file = os.path.join(data_dir, "feverous_train_challenges.jsonl")
    dev_file = os.path.join(data_dir, "feverous_dev_challenges.jsonl")

    verifier = FeverousVerifier(db_path=db_file)

    # Process train data in batches of 5000 and create final JSON
    print("\n Processing TRAIN dataset")
    train_annotations = verifier.load_dataset(train_file)
    train_results = verifier.process_claims(
        train_annotations,
        batch_size=10000,
        output_prefix="train_batch",
        final_json_file="train_all_processed.json"
    )

    # Process dev data in batches of 5000 and create final JSON
    print("\n Processing DEV dataset")
    dev_annotations = verifier.load_dataset(dev_file)
    dev_results = verifier.process_claims(
        dev_annotations,
        batch_size=5000,
        output_prefix="dev_batch",
        final_json_file="dev_all_processed.json"
    )

    # Create a combined summary of both train and dev
    print("\n Creating combined summary of train and dev data")
    combined_stats = {
        "train": train_results,
        "dev": dev_results,
        "total": {k: train_results.get(k, 0) + dev_results.get(k, 0) for k in set(train_results) | set(dev_results)}
    }

    with open("feverous_combined_summary.json", 'w') as f:
        json.dump({
            "statistics": combined_stats,
            "total_train_examples": len(train_annotations),
            "total_dev_examples": len(dev_annotations)
        }, f, indent=2)

    print("\n All processing complete!")
    print(f"Train statistics: {train_results}")
    print(f"Dev statistics: {dev_results}")
    print(f"Combined statistics: {combined_stats['total']}")
    print(f" Train data saved to: train_all_processed.json")
    print(f" Dev data saved to: dev_all_processed.json")