# `README.md`

# CompactPrompt: A Unified Pipeline for Prompt and Data Compression in LLM Workflows

<!-- PROJECT SHIELDS -->
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![Python Version](https://img.shields.io/badge/python-3.9%2B-blue.svg)](https://www.python.org/)
[![arXiv](https://img.shields.io/badge/arXiv-2510.18043v1-b31b1b.svg)](https://arxiv.org/abs/2510.18043v1)
[![Journal](https://img.shields.io/badge/Journal-ACM%20ICAIF%202025-003366)](https://arxiv.org/abs/2510.18043v1)
[![Year](https://img.shields.io/badge/Year-2025-purple)](https://github.com/chirindaopensource/compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows)
[![Discipline](https://img.shields.io/badge/Discipline-Computer%20Science%20%7C%20AI%20for%20Finance-00529B)](https://github.com/chirindaopensource/compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows)
[![Data Sources](https://img.shields.io/badge/Data-TAT--QA-lightgrey)](https://github.com/NExTplusplus/TAT-QA)
[![Data Sources](https://img.shields.io/badge/Data-FinQA-lightgrey)](https://github.com/czyssrs/FinQA)
[![Data Sources](https://img.shields.io/badge/Data-Wikipedia%20Dump-lightgrey)](https://dumps.wikimedia.org/)
[![Data Sources](https://img.shields.io/badge/Data-ShareGPT-lightgrey)](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered)
[![Data Sources](https://img.shields.io/badge/Data-arXiv-lightgrey)](https://arxiv.org/)
[![Core Method](https://img.shields.io/badge/Method-Hard%20Prompt%20Pruning%20%7C%20N--gram%20Abbreviation%20%7C%20Quantization-orange)](https://github.com/chirindaopensource/compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows)
[![Analysis](https://img.shields.io/badge/Analysis-Cost--Performance%20Trade--offs%20%7C%20Semantic%20Fidelity-red)](https://github.com/chirindaopensource/compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Type Checking: mypy](https://img.shields.io/badge/type%20checking-mypy-blue)](http://mypy-lang.org/)
[![NumPy](https://img.shields.io/badge/numpy-%23013243.svg?style=flat&logo=numpy&logoColor=white)](https://numpy.org/)
[![Pandas](https://img.shields.io/badge/pandas-%23150458.svg?style=flat&logo=pandas&logoColor=white)](https://pandas.pydata.org/)
[![Spacy](https://img.shields.io/badge/spaCy-%2309A3D5.svg?style=flat&logo=spaCy&logoColor=white)](https://spacy.io/)
[![PyYAML](https://img.shields.io/badge/PyYAML-gray?logo=yaml&logoColor=white)](https://pyyaml.org/)
[![Jupyter](https://img.shields.io/badge/Jupyter-%23F37626.svg?style=flat&logo=Jupyter&logoColor=white)](https://jupyter.org/)

**Repository:** `https://github.com/chirindaopensource/compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows`

**Owner:** 2025 Craig Chirinda (Open Source Projects)

This repository contains an **independent**, professional-grade Python implementation of the research methodology from the 2025 paper entitled **"CompactPrompt: A Unified Pipeline for Prompt and Data Compression in LLM Workflows"** by:

*   Joong Ho Choi
*   Jiayang Zhao
*   Jeel Shah
*   Ritvika Sonawane
*   Vedant Singh
*   Avani Appalla
*   Will Flanagan
*   Filipe Condessa

The project provides a complete, end-to-end computational framework for replicating the paper's findings. It delivers a modular, auditable, and extensible pipeline that executes the entire research workflow: from rigorous data validation and offline corpus statistics generation to the core hard prompt pruning, n-gram abbreviation, numeric quantization, and comprehensive semantic fidelity evaluation.

## Table of Contents

- [Introduction](#introduction)
- [Theoretical Background](#theoretical-background)
- [Features](#features)
- [Methodology Implemented](#methodology-implemented)
- [Core Components (Notebook Structure)](#core-components-notebook-structure)
- [Key Callable: `run_compactprompt_pipeline`](#key-callable-run_compactprompt_pipeline)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
- [Input Data Structure](#input-data-structure)
- [Usage](#usage)
- [Output Structure](#output-structure)
- [Project Structure](#project-structure)
- [Customization](#customization)
- [Contributing](#contributing)
- [Recommended Extensions](#recommended-extensions)
- [License](#license)
- [Citation](#citation)
- [Acknowledgments](#acknowledgments)

## Introduction

This project provides a Python implementation of the analytical framework presented in Choi et al. (2025). The core of this repository is the iPython Notebook `compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows_draft.ipynb`, which contains a comprehensive suite of functions to replicate the paper's findings. The pipeline is designed to be a generalizable toolkit for optimizing Large Language Model (LLM) inference costs and latency in enterprise environments.

The paper addresses the challenge of processing long, data-rich contexts—combining free-form instructions, large documents, and numeric tables—under strict cost and context-window constraints. This codebase operationalizes the paper's framework, allowing users to:
-   Rigorously validate and manage the entire experimental configuration via a single `config.yaml` file.
-   Construct a large offline corpus to compute static self-information scores for tokens.
-   Implement **Hard Prompt Compression** by pruning low-information phrases identified via a hybrid static/dynamic scoring mechanism.
-   Apply **Textual N-gram Abbreviation** to losslessly compress repetitive patterns in attached documents.
-   Execute **Numerical Quantization** (Uniform and K-Means) to reduce the token footprint of tabular data while bounding approximation error.
-   Select representative few-shot exemplars via embedding-based clustering to maximize prompt utility.
-   Evaluate semantic fidelity using both embedding cosine similarity and a simulated human evaluation protocol.
-   Automatically generate performance metrics and cost-savings analysis.

## Theoretical Background

The implemented methods are grounded in information theory, natural language processing, and data compression principles.

**1. Hybrid Self-Information Scoring:**
The core of the prompt pruning strategy relies on identifying the information content of each token.
-   **Static Self-Information ($I_{\text{stat}}$):** Derived from a large offline corpus (Wikipedia, ShareGPT, arXiv), capturing global token rarity: $I_{\text{stat}}(t) = -\log_2 p(t)$.
-   **Dynamic Self-Information ($s_{\text{dyn}}$):** Derived from a scorer LLM's conditional probability, capturing context-specific surprise: $s_{\text{dyn}}(t \mid c) = -\log_2 P_{\text{model}}(t \mid c)$.
-   **Fusion Rule:** A combined score $C(t)$ prioritizes dynamic information when the two metrics diverge significantly ($\Delta > 0.1$), ensuring context-critical tokens are preserved.

**2. Dependency-Driven Pruning:**
To maintain grammatical coherence, tokens are grouped into syntactic phrases (NP, VP, PP) using dependency parsing. Pruning decisions are made at the phrase level based on aggregated importance scores.

**3. Data Compression:**
-   **N-gram Abbreviation:** Frequent multi-token patterns in documents are replaced with short, unique placeholders, leveraging the heavy-tailed distribution of language in specialized domains.
-   **Quantization:** Floating-point numbers in tables are mapped to integer codes ($q_i$) or cluster centroids, reducing token count while maintaining numerical relationships within a bounded error $\varepsilon_{\max}$.


## Features

The provided iPython Notebook (`compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows_draft.ipynb`) implements the full research pipeline, including:

-   **Modular, Multi-Task Architecture:** The entire pipeline is broken down into 35 distinct, modular tasks, each with its own orchestrator function.
-   **Configuration-Driven Design:** All study parameters are managed in an external `config.yaml` file.
-   **Rigorous Data Validation:** A multi-stage validation process checks the schema, content integrity, and structural consistency of TAT-QA and Fin-QA datasets.
-   **Advanced NLP Processing:** Integrates `spaCy` for dependency parsing and `tiktoken`/`transformers` for precise token alignment.
-   **Robust LLM Integration:** A unified interface for interacting with OpenAI, Anthropic, and Together AI models, supporting both generation and log-probability extraction.
-   **Comprehensive Evaluation:** Includes automated semantic similarity checks using `all-mpnet-base-v2` and a randomized, bias-mitigated protocol for LLM-based human proxy evaluation.
-   **Reproducible Artifacts:** Generates structured logs, compressed datasets, and detailed metric reports for every run.

## Methodology Implemented

The core analytical steps directly implement the methodology from the paper:

1.  **Validation & Preprocessing (Tasks 1-6):** Ingests raw data, validates schemas, cleanses malformed entries, and normalizes numeric columns.
2.  **Corpus Statistics (Tasks 7-10):** Builds an offline corpus and computes static self-information scores for the vocabulary.
3.  **Prompt Engineering (Tasks 11-13):** Serializes tables to Markdown and constructs prompt templates with few-shot exemplars.
4.  **Dynamic Scoring (Tasks 14-18):** Configures LLM resources, retrieves log-probabilities, and computes combined importance scores.
5.  **Compression Engines (Tasks 19-27):** Executes phrase-level pruning, n-gram abbreviation, and numeric quantization.
6.  **Exemplar Selection (Tasks 28-30):** Embeds candidate examples and selects representative prototypes via K-Means clustering and silhouette optimization.
7.  **Evaluation (Tasks 31-35):** Computes embedding similarities and conducts a rigorous human-proxy evaluation to assess semantic fidelity.

## Core Components (Notebook Structure)

The `compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows_draft.ipynb` notebook is structured as a logical pipeline with modular orchestrator functions for each of the 35 major tasks. All functions are self-contained, fully documented with type hints and docstrings, and designed for professional-grade execution.

## Key Callable: `run_compactprompt_pipeline`

The project is designed around a single, top-level user-facing interface function:

-   **`run_compactprompt_pipeline`:** This master orchestrator function, located in the final section of the notebook, runs the entire automated research pipeline from end-to-end. A single call to this function reproduces the entire computational portion of the project, managing data flow between all 35 sub-tasks.

## Prerequisites

-   Python 3.9+
-   Core dependencies: `pandas`, `numpy`, `pyyaml`, `scipy`, `spacy`, `tiktoken`, `transformers`, `sentence-transformers`, `scikit-learn`.
-   LLM Provider SDKs: `openai`, `anthropic`, `together`.

## Installation

1.  **Clone the repository:**
    ```sh
    git clone https://github.com/chirindaopensource/compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows.git
    cd compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows
    ```

2.  **Create and activate a virtual environment (recommended):**
    ```sh
    python -m venv venv
    source venv/bin/activate  # On Windows, use `venv\Scripts\activate`
    ```

3.  **Install Python dependencies:**
    ```sh
    pip install pandas numpy pyyaml scipy spacy tiktoken transformers sentence-transformers scikit-learn openai anthropic together
    ```

4.  **Download spaCy model:**
    ```sh
    python -m spacy download en_core_web_sm
    ```

## Input Data Structure

The pipeline requires two primary DataFrames (`tatqa_raw_df` and `finqa_raw_df`) containing QA examples. Each row must adhere to the following schema:
1.  **`example_id`**: Unique string identifier.
2.  **`split`**: Dataset partition ("train", "dev", "test").
3.  **`question_text`**: The natural language question.
4.  **`tables`**: List of table dictionaries (with `headers` and `rows`).
5.  **`passages`**: List of passage dictionaries (with `text`).
6.  **`answer_type`**, **`answer_value`**, **`answer_unit`**: Ground truth answer fields.

Additionally, an `offline_corpus.jsonl` file is required for static statistics, containing documents with `doc_id`, `source`, and `text`.

## Usage

The `compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows_draft.ipynb` notebook provides a complete, step-by-step guide. The primary workflow is to execute the final cell of the notebook, which demonstrates how to use the top-level `run_compactprompt_pipeline` orchestrator:

```python
# Final cell of the notebook

# This block serves as the main entry point for the entire project.
if __name__ == '__main__':
    # 1. Load the master configuration from the YAML file.
    with open('config.yaml', 'r') as f:
        study_config = yaml.safe_load(f)
    
    # 2. Load raw datasets (Example using synthetic generator provided in the notebook)
    # In production, load from JSON/Parquet: pd.read_json(...)
    tatqa_raw_df = ...
    finqa_raw_df = ...
    
    # 3. Execute the entire replication study.
    output, log = run_compactprompt_pipeline(
        tatqa_raw_df=tatqa_raw_df,
        finqa_raw_df=finqa_raw_df,
        study_config=study_config,
        condition="compressed_plus_data",
        target_llm="gpt-4o",
        scorer_llm="gpt-4.1-mini"
    )
    
    # 4. Access results
    print(f"Mean Compression Ratio: {output.metrics['mean_compression_ratio']:.2f}x")
```

## Output Structure

The pipeline returns an `OrchestratorOutput` object containing all analytical artifacts:
-   **`tatqa_processed_df` / `finqa_processed_df`**: DataFrames with compressed text and quantized tables.
-   **`pruned_prompts`**: Dictionary mapping example IDs to their hard-pruned prompt text.
-   **`abbreviation_dict`**: The reversible n-gram dictionary.
-   **`quantization_results`**: Metadata and codes for all quantized columns.
-   **`representative_exemplars`**: IDs of selected few-shot prototypes.
-   **`similarity_results`**: Semantic fidelity statistics (cosine similarity).
-   **`human_evaluation`**: Results from the LLM-proxy annotation protocol.
-   **`metrics`**: Aggregate performance indicators (compression ratio, accuracy).

## Project Structure

```
compact_prompt_unified_pipeline/
│
├── compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows_draft.ipynb  # Main implementation notebook
├── config.yaml                                                                        # Master configuration file
├── requirements.txt                                                                   # Python package dependencies
│
├── compact_prompt_outputs/                                                            # Output directory (generated)
│   ├── corpus_stats/
│   ├── embeddings/
│   ├── human_eval_results/
│   └── processed_data/
│
├── LICENSE                                                                            # MIT Project License File
└── README.md                                                                          # This file
```

## Customization

The pipeline is highly customizable via the `config.yaml` file. Users can modify study parameters such as:
-   **Compression Budgets:** `prompt_token_budget`.
-   **N-gram Settings:** `ngram_length`, `top_n_T`.
-   **Quantization:** `bit_width_b`, `num_clusters_k`.
-   **Scoring Thresholds:** `delta_relative_difference_threshold`.

## Contributing

Contributions are welcome. Please fork the repository, create a feature branch, and submit a pull request with a clear description of your changes. Adherence to PEP 8, type hinting, and comprehensive docstrings is required.

## Recommended Extensions

Future extensions could include:
-   **Adaptive Compression:** Dynamically adjusting the token budget based on query complexity.
-   **Privacy-Aware Pruning:** Integrating PII detection to prioritize removing sensitive entities.
-   **Multimodal Support:** Extending the pipeline to compress image or chart data attachments.

## License

This project is licensed under the MIT License. See the `LICENSE` file for details.

## Citation

If you use this code or the methodology in your research, please cite the original paper:

```bibtex
@article{choi2025compactprompt,
  title={CompactPrompt: A Unified Pipeline for Prompt and Data Compression in LLM Workflows},
  author={Choi, Joong Ho and Zhao, Jiayang and Shah, Jeel and Sonawane, Ritvika and Singh, Vedant and Appalla, Avani and Flanagan, Will and Condessa, Filipe},
  journal={arXiv preprint arXiv:2510.18043v1},
  year={2025}
}
```

For the implementation itself, you may cite this repository:
```
Chirinda, C. (2025). CompactPrompt: An Open Source Implementation.
GitHub repository: https://github.com/chirindaopensource/compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows
```

## Acknowledgments

-   Credit to **Joong Ho Choi et al.** for the foundational research that forms the entire basis for this computational replication.
-   This project is built upon the exceptional tools provided by the open-source community. Sincere thanks to the developers of the scientific Python ecosystem, including **Pandas, NumPy, SciPy, spaCy, and Hugging Face**.

--

*This README was generated based on the structure and content of the `compact_prompt_unified_pipeline_prompt_data_compression_LLM_workflows_draft.ipynb` notebook and follows best practices for research software documentation.*


# Paper

Title: "*CompactPrompt: A Unified Pipeline for Prompt Data Compression in LLM Workflow*"

Authors: Joong Ho Choi, Jiayang Zhao, Jeel Shah, Ritvika Sonawane, Vedant Singh, Avani Appalla, Will Flanagan, Filipe Condessa

E-Journal Submission Date: 20 October 2025

Conference Affiliation: Workshop on LLMs and Generative AI for Finance at ACM ICAIF 2025

Link: https://arxiv.org/abs/2510.18043v1

Abstract:

Large Language Models (LLMs) deliver powerful reasoning and generation capabilities but incur substantial run-time costs when operating in agentic workflows that chain together lengthy prompts and process rich data streams. We introduce CompactPrompt, an end-to-end pipeline that merges hard prompt compression with lightweight file-level data compression. CompactPrompt first prunes low-information tokens from prompts using self-information scoring and dependency-based phrase grouping. In parallel, it applies n-gram abbreviation to recurrent textual patterns in attached documents and uniform quantization to numerical columns, yielding compact yet semantically faithful representations. Integrated into standard LLM agents, CompactPrompt reduces total token usage and inference cost by up to 60% on benchmark dataset like TAT-QA and FinQA, while preserving output quality (Results in less than 5% accuracy drop for Claude-3.5-Sonnet, and GPT-4.1-Mini) CompactPrompt helps visualize real-time compression decisions and quantify cost-performance trade-offs, laying the groundwork for leaner generative AI pipelines.

# Summary

### Architectural Overview
CompactPrompt is an end-to-end, training-free pipeline designed to compress prompt contexts (instructions + data attachments) for LLM inference. It operates on three distinct vectors simultaneously:
1.  **Hard Prompt Pruning:** Removal of low-information tokens from natural language instructions.
2.  **Textual N-gram Abbreviation:** Reversible dictionary encoding for repetitive patterns in attached documents.
3.  **Numerical Quantization:** Bit-width reduction for floating-point data in structured tables.

The pipeline is implemented as a pre-processing layer before the LLM API call, aiming to reduce token usage by up to 60% while maintaining or improving reasoning accuracy on financial QA tasks (TAT-QA, FinQA).

### Hard Prompt Compression (Token Pruning)
This module identifies and removes tokens that contribute minimal information to the semantic understanding of the prompt. It utilizes a hybrid scoring mechanism combining static and dynamic metrics.

**A. Static Self-Information ($I_{\text{stat}}$)**
Calculated offline using a large corpus (Wikipedia, ShareGPT, arXiv). For a token $t$, the unigram probability $p(t)$ is:
$$p(t) = \frac{1}{N} \sum_{i=1}^{N} \mathbb{1}\{w_i = t\}$$
The static self-information is:
$$I_{\text{stat}}(t) = -\log_2 p(t)$$

**B. Dynamic Self-Information ($I_{\text{dyn}}$)**
Calculated at runtime using a lightweight scorer model. Given a token $t$ and its preceding context $c$:
$$I_{\text{dyn}}(t \mid c) = -\log_2 P_{\text{model}}(t \mid c)$$

**C. Hybrid Scoring Strategy**
The system fuses these scores based on their relative divergence ($\Delta$). If the scores are similar (within 10%), the average is used; otherwise, the dynamic score (context-aware) takes precedence.
$$\Delta = \frac{|s_{\text{dyn}} - s_{\text{stat}}|}{s_{\text{stat}}}$$
$$C(s_{\text{stat}}, s_{\text{dyn}}) = \begin{cases} \frac{s_{\text{stat}} + s_{\text{dyn}}}{2}, & \Delta \leq 0.1 \\ s_{\text{dyn}}, & \Delta > 0.1 \end{cases}$$
Tokens falling below a specific information threshold are pruned, often using dependency parsing to ensure phrase-level coherence rather than isolated token removal.

### Textual N-gram Abbreviation
This module targets repetitive multi-word expressions in attached documents (e.g., "interest expense", "per share").

**A. Extraction & Frequency Analysis**
*   Extract all n-grams of length $n$ (typically $n \in [2, 5]$).
*   Compute frequency distributions across the document corpus.
*   Select the Top-$K$ most frequent n-grams (typically $K \in [100, 150]$).

**B. Dictionary Construction & Replacement**
*   Map each Top-$K$ n-gram to a unique, short placeholder token (e.g., "ABC1").
*   Store the mapping $\{ \text{placeholder} \leftrightarrow \text{original n-gram} \}$ in a metadata table.
*   **Reversibility:** The compression is lossless regarding the specific n-grams targeted; the LLM sees the placeholder, and the mapping can be injected or used for reconstruction.

**C. Optimal Configuration**
Empirical results indicate that **Bi-grams ($n=2$)** with a **Top-3 ($K=3$)** frequency threshold yield the optimal trade-off between token reduction and semantic disruption ($\Delta \text{Accuracy} \approx +5.0\%$).

### Numerical Quantization
This module compresses numerical columns in structured data attachments (e.g., CSVs, financial tables) by reducing floating-point precision.

**A. Uniform Integer Quantization**
Given a column vector $\mathbf{x}$, compute $\min_x$ and $\max_x$. Select a bit-width $b$ (levels $L=2^b$).
Encode value $x_i$ to integer $q_i$:
$$q_i = \text{round}\left( \frac{x_i - \min_x}{\max_x - \min_x} (L - 1) \right)$$
Reconstruct $\hat{x}_i$ with maximum absolute error $\varepsilon_{\max}$:
$$\hat{x}_i = \min_x + \frac{q_i}{L-1}(\max_x - \min_x)$$
$$\varepsilon_{\max} = \frac{\max_x - \min_x}{L-1}$$

**B. K-Means Quantization**
Alternatively, apply k-means clustering to column values to find $k$ centroids $(\mu_1, \dots, \mu_k)$. Map each $x_i$ to the index of the nearest centroid.

### Representative Exemplar Selection (Few-Shot)
To maximize the utility of few-shot examples within the compressed context window:
1.  **Embedding:** Encode candidate examples using `all-mpnet-base-v2`.
2.  **Clustering:** Perform k-means clustering ($k \in [5, 50]$).
3.  **Optimization:** Select optimal cluster count $k^*$ via Silhouette Score maximization.
4.  **Selection:** Choose the example closest to the centroid of each cluster as a "prototype."

### Evaluation Metrics
*   **Compression Ratio:** $\frac{\text{Original Token Count}}{\text{Compressed Token Count}}$.
*   **Semantic Fidelity:** Cosine similarity between embeddings of original and compressed prompts ($\text{cosine}(E_{\text{orig}}, E_{\text{comp}})$). High fidelity is observed at similarity $\geq 0.92$.
*   **Downstream Accuracy:** Exact Match (EM) or F1 score on QA tasks. Results show up to **+6% accuracy** on TAT-QA and **+10%** on FinQA using Claude-3.5-Sonnet, despite >50% token reduction.

# Import Essential Modules

In [None]:
#!/usr/bin/env python3
# ==============================================================================#
#
#  CompactPrompt: A Unified Pipeline for Prompt and Data Compression in LLM Workflows
#
#  This module provides a complete, production-grade implementation of the
#  analytical framework presented in "CompactPrompt: A Unified Pipeline for Prompt
#  and Data Compression in LLM Workflows" by Choi et al. (2025). It delivers a
#  computationally tractable system for optimizing large language model inference
#  in enterprise environments by compressing long, data-rich contexts—combining
#  free-form instructions, large documents, and numeric tables—under strict cost,
#  latency, and context-window constraints, while preserving or improving
#  downstream task accuracy.
#
#  Core Methodological Components:
#  • Hard Prompt Pruning via hybrid static/dynamic self-information scoring
#  • Dependency-driven phrase grouping for syntactically coherent pruning
#  • Reversible textual n-gram abbreviation for high-frequency pattern compression
#  • Uniform and K-Means numerical quantization for tabular data reduction
#  • Clustering-based representative exemplar selection for few-shot prompting
#  • Semantic fidelity evaluation using embedding similarity and human proxies
#
#  Technical Implementation Features:
#  • Modular pipeline architecture with clear separation of concerns
#  • Robust integration with multiple LLM providers (OpenAI, Anthropic, Together AI)
#  • Efficient offline corpus statistics computation and tokenization alignment
#  • Comprehensive evaluation harness for TAT-QA and Fin-QA benchmarks
#  • Structured logging and artifact management for full reproducibility
#  • Type-safe interfaces and rigorous input validation throughout
#
#  Paper Reference:
#  Choi, J. H., Zhao, J., Shah, J., Sonawane, R., Singh, V., Appalla, A.,
#  Flanagan, W., & Condessa, F. (2025). CompactPrompt: A Unified Pipeline for
#  Prompt and Data Compression in LLM Workflows. arXiv preprint arXiv:2510.18043v1.
#  https://arxiv.org/abs/2510.18043v1
#
#  Author: CS Chirinda
#  License: MIT
#  Version: 1.0.0
#
# ==============================================================================#
# ==============================================================================#
# COMPACTPROMPT PIPELINE IMPORTS
# ==============================================================================#

# 1. Future Compatibility
from __future__ import annotations

# 2. Standard Library Imports
import abc
import collections
import dataclasses
import datetime
import enum
import gzip
import hashlib
import json
import logging
import math
import os
import random
import re
import time
from abc import ABC, abstractmethod
from collections import Counter
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import (
    Any,
    Counter as CounterType,
    Dict,
    Generator,
    List,
    Optional,
    Set,
    Tuple,
    Union,
)

# 3. Third-Party Data Science & Math Libraries
import numpy as np
import pandas as pd

# 4. Third-Party NLP & Machine Learning Libraries
import spacy
import tiktoken
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min, silhouette_score
from transformers import AutoTokenizer

# 5. Third-Party LLM Provider Libraries
# Note: These libraries must be installed in the environment.
# We import them within try-except blocks in specific callables for robustness,
# but they are required for the full pipeline execution.
try:
    from anthropic import Anthropic
except ImportError:
    pass

try:
    from openai import OpenAI
except ImportError:
    pass

try:
    from together import Together
except ImportError:
    pass


# Implementation

## Draft 1

## **Discussion of Input-Process-Output flow of *CompactPrompt* Research Framework Callables**

### **1. `validate_tatqa_dataset`**
*   **Inputs:** The raw `pandas.DataFrame` containing the TAT-QA dataset.
*   **Processes:** Orchestrates a sequence of validation sub-routines: schema verification (column presence), identifying field integrity checks (uniqueness of IDs), and complex structure validation (nested tables and passages).
*   **Outputs:** A `ValidationReport` object detailing the validity status and specific error logs.
*   **Data Transformation:** The input data is not mutated; it is inspected. The transformation is from **raw data** to **metadata regarding its quality**.
*   **Research Context:** Implements the data preparation phase necessary for **Section 5.1 (Task Setup)**. It ensures the dataset conforms to the "hybrid of Tabular and Textual Content" structure described in the TAT-QA citation [29], ensuring that the "structured tabular data" and "unstructured narrative passages" are correctly formatted for downstream compression.

### **2. `validate_finqa_dataset`**
*   **Inputs:** The raw `pandas.DataFrame` containing the Fin-QA dataset.
*   **Processes:** Executes validation logic specific to financial QA data: checking for numeric answer types, valid table headers, and consistent split identifiers.
*   **Outputs:** A `ValidationReport` object.
*   **Data Transformation:** Read-only inspection transforming **raw rows** into a **validity verdict**.
*   **Research Context:** Implements data preparation for **Section 5.1 (Task Setup)** regarding Fin-QA [3]. It verifies the integrity of the "financial reports" data structure required for the "numerical reasoning" tasks evaluated in the study.

### **3. `validate_and_update_study_config`**
*   **Inputs:** A raw Python dictionary representing the study configuration (containing placeholders).
*   **Processes:** Recursively inventories placeholders, resolves them with scientifically justifiable defaults (e.g., setting $\Delta$ thresholds), and validates LLM API configurations.
*   **Outputs:** A fully resolved, executable configuration dictionary.
*   **Data Transformation:** Transforms a **template configuration** into a **runtime-ready system state**.
*   **Research Context:** Establishes the hyperparameters defined in **Section 3 (Design)**. Specifically, it sets the threshold $\Delta = 0.1$ used in **Equation (3)** and the n-gram parameters ($n, K$) used in **Section 3.2**.

### **4. `cleanse_tatqa_dataset`**
*   **Inputs:** The raw TAT-QA DataFrame.
*   **Processes:** Filters rows based on critical missing fields, repairs malformed table structures (truncating/padding rows), and removes empty passages.
*   **Outputs:** A cleansed `pandas.DataFrame` and a `CleansingLog`.
*   **Data Transformation:** Transforms **noisy, potentially malformed data** into **structurally consistent data** suitable for tokenization.
*   **Research Context:** Pre-processing required to ensure the "long-form contexts" described in **Section 5.1** are valid before being fed into the compression pipeline.

### **5. `cleanse_finqa_dataset`**
*   **Inputs:** The raw Fin-QA DataFrame.
*   **Processes:** Similar to TAT-QA cleansing but enforces stricter requirements for table existence, as Fin-QA relies heavily on tabular data. Flags invalid numeric answers.
*   **Outputs:** A cleansed `pandas.DataFrame` and a `FinQACleansingLog`.
*   **Data Transformation:** Transforms **raw financial data** into **clean inputs** for the quantization module.
*   **Research Context:** Ensures the integrity of the "structured numerical data" mentioned in **Section 1**, which is a prerequisite for the **Numerical Quantization (Section 3.3)** experiments.

### **6. `normalize_data_task`**
*   **Inputs:** Cleansed TAT-QA and Fin-QA DataFrames.
*   **Processes:** Identifies numeric columns using a 90% parseability threshold and parses string answers into floating-point numbers using regex-based normalization.
*   **Outputs:** DataFrames enriched with numeric metadata and parsed answers.
*   **Data Transformation:** Transforms **string representations** of numbers (e.g., "1.5M") into **floating-point values** (e.g., 1,500,000.0).
*   **Research Context:** This is the identification step for **Section 3.3 (Numerical Quantization)**. It determines which columns $x$ are eligible for the quantization transform $q_i = \text{round}(\dots)$.

### **7. `build_offline_corpus`**
*   **Inputs:** Configuration paths for Wikipedia, ShareGPT, and arXiv dumps.
*   **Processes:** Ingests documents from diverse sources, normalizes text (stripping LaTeX/HTML), and consolidates them.
*   **Outputs:** A single JSONL corpus file.
*   **Data Transformation:** Transforms **heterogeneous source files** into a **unified text stream**.
*   **Research Context:** Implements the corpus construction described in **Section 3.1.1 (Static Self-Information)**: "We begin by constructing a large offline corpus... consisting of Wikipedia, ShareGPT conversations, and arXiv articles."

### **8. `compute_corpus_statistics`**
*   **Inputs:** The offline JSONL corpus.
*   **Processes:** Streams the corpus, tokenizes text using a specific tokenizer (e.g., `cl100k_base`), and aggregates token counts.
*   **Outputs:** A `Counter` of token IDs and the total token count $N$.
*   **Data Transformation:** Transforms **text documents** into **frequency distributions**.
*   **Research Context:** Calculates the denominator $N$ and the raw counts required for **Equation (1)** in **Section 3.1.1**: $f(t) = \frac{1}{N} \sum \mathbb{1}\{w_i = t\}$.

### **9. `compute_token_probabilities`**
*   **Inputs:** Token counts and total count $N$.
*   **Processes:** Divides counts by $N$ to obtain empirical probabilities.
*   **Outputs:** A dictionary mapping token IDs to probabilities $p(t)$.
*   **Data Transformation:** Transforms **raw counts** into **probability space**.
*   **Research Context:** Implements the calculation of $p(t)$ derived from **Equation (1)** in **Section 3.1.1**, which is the basis for static self-information.

### **10. `compute_static_self_information`**
*   **Inputs:** Token probabilities $p(t)$.
*   **Processes:** Computes the negative binary logarithm of probabilities.
*   **Outputs:** A dictionary mapping token IDs to static self-information scores $I_{stat}$.
*   **Data Transformation:** Transforms **probabilities** into **information content (bits)**.
*   **Research Context:** Directly implements the static self-information formula from **Section 3.1.1**:
    $$I_{\text{stat}}(t) = -\log_2 p(t)$$

### **11. `serialize_tables_task`**
*   **Inputs:** DataFrames with structured table dictionaries.
*   **Processes:** Converts table objects into Markdown-formatted strings, handling headers and row alignment.
*   **Outputs:** DataFrames with a new `serialized_tables` column.
*   **Data Transformation:** Transforms **structured objects** into **linear text** suitable for LLM context windows.
*   **Research Context:** Prepares the "structured tabular data" for the prompt, a necessary step before the "Hard Prompt Pruning" described in **Section 3.1** can operate on the table tokens.

### **12. `construct_prompt`**
*   **Inputs:** Question, serialized tables, passages, and optional exemplars.
*   **Processes:** Injects these components into a dataset-specific template.
*   **Outputs:** A single prompt string.
*   **Data Transformation:** Aggregates **disparate context elements** into a **single inference payload**.
*   **Research Context:** Constructs the input $P$ that is subject to compression. This represents the "verbose or unstructured inputs" mentioned in **Section 1** that CompactPrompt aims to optimize.

### **13. `format_exemplars`**
*   **Inputs:** A list of exemplar dictionaries and the target example ID.
*   **Processes:** Formats exemplars into text blocks and filters out the target ID to prevent data leakage.
*   **Outputs:** A formatted string of few-shot examples.
*   **Data Transformation:** Transforms **retrieved examples** into **prompt context**.
*   **Research Context:** Implements the few-shot prompting strategy discussed in **Section 3.4**, ensuring that the "representative examples" are correctly integrated into the prompt structure.

### **14. `configure_llm_resources`**
*   **Inputs:** The study configuration dictionary.
*   **Processes:** Instantiates concrete `LLMInterface` objects (e.g., `GPT4oInterface`, `LlamaInterface`) based on model names.
*   **Outputs:** A registry of initialized LLM interfaces.
*   **Data Transformation:** Transforms **configuration strings** into **active API clients**.
*   **Research Context:** Sets up the "pretrained LLM or a lightweight scoring agent" mentioned in **Section 3.1.2** for dynamic scoring, and the target models for **Section 5** evaluation.

### **15. `prepare_dynamic_scoring_inputs`**
*   **Inputs:** DataFrames and the scorer LLM interface.
*   **Processes:** Serializes examples to prompts, tokenizes them using the scorer's tokenizer, and computes character offsets.
*   **Outputs:** A dictionary of tokenized prompts and offsets.
*   **Data Transformation:** Transforms **semantic examples** into **token sequences** ready for scoring.
*   **Research Context:** Prepares the input $t$ and context $c$ required to query the model for $P_{\text{model}}(t \mid c)$ as described in **Section 3.1.2**.

### **16. `get_prompt_logprobs_task`**
*   **Inputs:** Tokenized prompts and the scorer interface.
*   **Processes:** Queries the LLM API (using `echo=True` or iterative scoring) to retrieve log-probabilities.
*   **Outputs:** A dictionary mapping example IDs to lists of log-probabilities.
*   **Data Transformation:** Transforms **token sequences** into **raw model likelihoods**.
*   **Research Context:** Retrieves the conditional probabilities $P_{\text{model}}(t \mid c)$ required for **Section 3.1.2 (Dynamic Self-Information)**.

### **17. `compute_dynamic_scores_task`**
*   **Inputs:** Raw log-probabilities (natural log).
*   **Processes:** Converts natural logs to bits (base 2) and validates values.
*   **Outputs:** A dictionary of dynamic self-information scores $s_{dyn}$.
*   **Data Transformation:** Transforms **model log-probs** into **information scores**.
*   **Research Context:** Implements the dynamic self-information formula implied in **Section 3.1.2**:
    $$s_{\text{dyn}}(t \mid c) = -\log_2 P_{\text{model}}(t \mid c)$$

### **18. `compute_combined_scores_task`**
*   **Inputs:** Static scores $s_{stat}$ and dynamic scores $s_{dyn}$.
*   **Processes:** Computes the relative difference $\Delta$ and applies the fusion rule.
*   **Outputs:** A dictionary of combined importance scores $C(t)$.
*   **Data Transformation:** Fuses **global** and **local** information measures into a **single utility metric**.
*   **Research Context:** Implements **Equations (2) and (3)** from **Section 3.1.3**:
    $$\Delta = \frac{|s_{\text{dyn}} - s_{\text{stat}}|}{s_{\text{stat}}}$$
    $$C(s_{\text{stat}}, s_{\text{dyn}}) = \begin{cases} \frac{s_{\text{stat}} + s_{\text{dyn}}}{2}, & \Delta \leq 0.1 \\ s_{\text{dyn}}, & \Delta > 0.1 \end{cases}$$

### **19. `group_tokens_into_phrases_task`**
*   **Inputs:** Prompt text and token offsets.
*   **Processes:** Runs a dependency parser to identify NP, VP, and PP spans, then maps tokens to these phrases.
*   **Outputs:** A mapping of example IDs to lists of phrases (token indices).
*   **Data Transformation:** Transforms **flat token sequences** into **hierarchical syntactic units**.
*   **Research Context:** Implements the "dependency-based phrase grouping" described in **Section 3.1.3**, ensuring that pruning operates on grammatical units rather than individual tokens.

### **20. `compute_phrase_scores_task`**
*   **Inputs:** Combined token scores $C(t)$ and phrase groupings.
*   **Processes:** Aggregates token scores within each phrase (e.g., via mean).
*   **Outputs:** Phrase-level importance scores $C(\phi)$.
*   **Data Transformation:** Aggregates **atomic scores** into **structural scores**.
*   **Research Context:** The final step of **Section 3.1.3**, assigning a utility value to each syntactic unit to enable "Phrase-Level Pruning".

### **21. `prune_prompt_task`**
*   **Inputs:** Phrase scores, phrase definitions, and a token budget.
*   **Processes:** Sorts phrases by importance and greedily selects them until the budget is met, then reconstructs the text.
*   **Outputs:** Compressed prompt text and compression metrics.
*   **Data Transformation:** Transforms a **verbose prompt** into a **compressed prompt** $P'$.
*   **Research Context:** Implements the core **Hard Prompt Pruning** logic described in **Section 3.1**, realizing the goal of reducing token usage while preserving high-information content.

### **22. `extract_top_ngrams_task`**
*   **Inputs:** Passage texts from the datasets.
*   **Processes:** Tokenizes passages, counts n-grams using a sliding window, and selects the top $K$.
*   **Outputs:** A list of the top-K n-grams and their counts.
*   **Data Transformation:** Transforms **corpus text** into **frequency statistics**.
*   **Research Context:** Implements **Section 3.2.1 (Extraction and Frequency Analysis)**, identifying the "top (K) most frequent patterns" for abbreviation.

### **23. `construct_abbreviation_dict_task`**
*   **Inputs:** Top-K n-grams.
*   **Processes:** Assigns unique placeholders to n-grams and builds bidirectional maps.
*   **Outputs:** An abbreviation dictionary $\mathcal{D}_{abbr}$.
*   **Data Transformation:** Transforms **frequent patterns** into a **substitution logic**.
*   **Research Context:** Implements **Section 3.2.2 (Dictionary Construction)**, creating the mapping required for "lossless round-trip reconstruction."

### **24. `apply_abbreviation_task`**
*   **Inputs:** Passage texts and the abbreviation dictionary.
*   **Processes:** Replaces active n-grams (Top-T) with placeholders in the text.
*   **Outputs:** DataFrames with abbreviated passages.
*   **Data Transformation:** Transforms **redundant text** into **compressed text**.
*   **Research Context:** Implements **Section 3.2.3 (Contextual Replacement)**, applying the "user-configurable n-gram abbreviation pipeline" to reduce the size of attached documents.

### **25. `extract_numeric_values_task`**
*   **Inputs:** DataFrames and numeric column metadata.
*   **Processes:** Extracts and parses floating-point values from identified numeric columns.
*   **Outputs:** A dictionary of numeric arrays $x$ for each column.
*   **Data Transformation:** Transforms **tabular cells** into **numerical vectors**.
*   **Research Context:** Prepares the input data for **Section 3.3 (Numerical Quantization)**, isolating the floating-point columns that require compression.

### **26. `apply_uniform_quantization_task`**
*   **Inputs:** Numeric arrays and bit-width $b$.
*   **Processes:** Computes min/max ranges and maps values to integer codes.
*   **Outputs:** Quantized codes $q_i$ and reconstruction metadata.
*   **Data Transformation:** Transforms **floats** into **integers**.
*   **Research Context:** Implements **Section 3.3.1 (Uniform Integer Quantization)**, specifically **Equation (4)**:
    $$q_i = \text{round}\left( \frac{x_i - \min_x}{\max_x - \min_x} (L - 1) \right)$$

### **27. `apply_kmeans_quantization_task`**
*   **Inputs:** Numeric arrays and cluster count $k$.
*   **Processes:** Fits K-Means models and maps values to cluster centroids.
*   **Outputs:** Cluster indices (codes) and centroids.
*   **Data Transformation:** Transforms **floats** into **cluster indices**.
*   **Research Context:** Implements **Section 3.3.2 (K-Means-Based Quantization)**, mapping each $x_i$ to the "nearest centroid index."

### **28. `embed_examples_task`**
*   **Inputs:** DataFrames and the embedding model.
*   **Processes:** Constructs text representations of examples and encodes them into vectors.
*   **Outputs:** An embedding matrix $Z$.
*   **Data Transformation:** Transforms **textual examples** into **vector space**.
*   **Research Context:** Implements **Section 3.4.1 (Embedding and Normalization)**, using `all-mpnet-base-v2` to embed data for clustering.

### **29. `select_optimal_k_task`**
*   **Inputs:** Embedding matrix $Z$.
*   **Processes:** Sweeps $k \in [5, 50]$, computes silhouette scores, and selects the optimal $k^*$.
*   **Outputs:** The optimal cluster count $k^*$.
*   **Data Transformation:** Transforms **geometric structure** into a **hyperparameter decision**.
*   **Research Context:** Implements **Section 3.4.2 (Clustering with Silhouette Optimization)** to "identify the optimal cluster count ($k^*$)."

### **30. `select_representative_exemplars_task`**
*   **Inputs:** Embeddings $Z$ and optimal $k^*$.
*   **Processes:** Clusters data and selects the sample closest to each centroid.
*   **Outputs:** A list of representative example IDs.
*   **Data Transformation:** Transforms **clusters** into **prototypical instances**.
*   **Research Context:** Implements **Section 3.4.3 (Representative Points Selection)**, selecting "prototypes" to serve as few-shot examples.

### **31. `compute_similarity_embeddings_task`**
*   **Inputs:** Original and compressed prompt pairs.
*   **Processes:** Embeds both versions using `all-mpnet-base-v2`.
*   **Outputs:** Embedding matrices $U$ (original) and $V$ (compressed).
*   **Data Transformation:** Transforms **prompt pairs** into **comparable vectors**.
*   **Research Context:** Prepares the data for **Section 3.5.1 (Full-Dimensional Cosine Similarity)** evaluation.

### **32. `evaluate_semantic_similarity_task`**
*   **Inputs:** Embedding matrices $U$ and $V$.
*   **Processes:** Computes cosine similarity and aggregates statistics (mean, 5th percentile).
*   **Outputs:** Similarity statistics and flagged outliers.
*   **Data Transformation:** Transforms **vector relationships** into **fidelity metrics**.
*   **Research Context:** Implements the metric calculation in **Section 3.5.1**, reporting "mean and 5th percentile scores to ensure worst-case fidelity."

### **33. `run_llm_calibration_task`**
*   **Inputs:** Calibration pairs and LLM agents.
*   **Processes:** Runs the calibration protocol, collecting scores and computing Cohen's Kappa.
*   **Outputs:** Calibration agreement metrics.
*   **Data Transformation:** Transforms **agent outputs** into **reliability metrics**.
*   **Research Context:** Implements the **Human Validation** setup described in **Section 3.5.2**, ensuring evaluators (proxies) are aligned before the main study.

### **34. `run_main_evaluation_task`**
*   **Inputs:** Evaluation pairs and calibrated agents.
*   **Processes:** Collects semantic equivalence ratings (1-5) for 90 pairs using a randomized protocol.
*   **Outputs:** Raw rating data.
*   **Data Transformation:** Transforms **comparative judgments** into **structured rating data**.
*   **Research Context:** Executes the main **Human Validation** study from **Section 3.5.2**, rating "90 original & compressed prompt pairs."

### **35. `analyze_human_vs_embedding_task`**
*   **Inputs:** Human ratings and embedding similarities.
*   **Processes:** Aggregates ratings and compares them against cosine similarity thresholds.
*   **Outputs:** A report on mismatches and correlation.
*   **Data Transformation:** Transforms **multi-modal evaluation data** into **conclusions**.
*   **Research Context:** Implements the analysis in **Section 3.5.2**, verifying the correlation between human scores and cosine similarities and identifying "rare nuance shifts missed by embedding alone."

### **36. `run_compactprompt_pipeline`**
*   **Inputs:** Raw DataFrames, configuration, and experimental condition flags.
*   **Processes:** Orchestrates the entire 12-phase pipeline, invoking all 35 sub-tasks in the correct dependency order, managing state, logging, and error handling.
*   **Outputs:** An `OrchestratorOutput` object containing all artifacts (compressed data, metrics, logs).
*   **Data Transformation:** Transforms **raw datasets and configuration** into **final experimental results**.
*   **Research Context:** This is the **Unified Pipeline** itself, realizing the end-to-end system described in **Section 1 (Introduction)** and **Section 4 (CompactPrompt Tool)**. It binds the disparate compression techniques into a coherent workflow that "merges hard prompt compression with lightweight file-level data compression."

<br><br>

## **Usage Example**

Below is a Python script which illustrates how to use the pipeline orchestrator callable:

```python
# ==============================================================================
# COMPACTPROMPT PIPELINE USAGE EXAMPLE USING MOCK DATA
# ==============================================================================
# This script demonstrates the end-to-end execution of the CompactPrompt pipeline
# using synthetically generated mock data and a configuration loaded from a YAML file.
# It serves as a high-fidelity, implementation-grade reference for deploying
# the compression and selection layer in enterprise LLM workflows.
# ==============================================================================
# Note: It is assumed that you have pre-imported all the requisite Python modules and Callables

import yaml
import pandas as pd
import os
import json
import logging

# Configure logging for the example execution
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("CompactPromptExample")

# ------------------------------------------------------------------------------
# STEP 1: SYNTHETIC DATA GENERATION
# ------------------------------------------------------------------------------
# We generate professional-grade synthetic DataFrames for TAT-QA and Fin-QA
# that strictly adhere to the schema requirements defined in the task list.
# These DataFrames simulate the raw input data ingested by the pipeline.

logger.info("Generating synthetic TAT-QA and Fin-QA datasets...")

# 1.1 Synthetic TAT-QA DataFrame
tatqa_data = [
    {
        "example_id": "tatqa_syn_001",
        "split": "dev",
        "question_text": "What was the total revenue for Alpha Corp in 2023?",
        "tables": [
            {
                "table_id": "tatqa_syn_001_t1",
                "caption": "Consolidated Statements of Operations",
                "headers": ["Year", "Revenue (USD millions)", "Net Income (USD millions)"],
                "rows": [
                    ["2021", "1500", "200"],
                    ["2022", "1650", "220"],
                    ["2023", "1800", "250"]
                ]
            }
        ],
        "passages": [
            {
                "passage_id": "tatqa_syn_001_p1",
                "text": (
                    "Alpha Corp reported strong financial performance in 2023, driven by "
                    "robust demand in its cloud services segment. Total revenue reached "
                    "record highs, reflecting a 9% year-over-year increase."
                )
            }
        ],
        "answer_type": "span_number",
        "answer_value": "1800",
        "answer_unit": "USD millions"
    },
    {
        "example_id": "tatqa_syn_002",
        "split": "dev",
        "question_text": "By how much did net income increase from 2021 to 2022?",
        "tables": [
            {
                "table_id": "tatqa_syn_002_t1",
                "caption": "Financial Highlights",
                "headers": ["Metric", "2021", "2022"],
                "rows": [
                    ["Revenue", "1500", "1650"],
                    ["Net Income", "200", "220"]
                ]
            }
        ],
        "passages": [
            {
                "passage_id": "tatqa_syn_002_p1",
                "text": "Net income growth was supported by operational efficiency improvements."
            }
        ],
        "answer_type": "arithmetic",
        "answer_value": "20",
        "answer_unit": "USD millions"
    }
]
tatqa_raw_df = pd.DataFrame(tatqa_data)

# 1.2 Synthetic Fin-QA DataFrame
finqa_data = [
    {
        "example_id": "finqa_syn_001",
        "split": "dev",
        "question_text": "What was the percentage growth in EPS from 2022 to 2023?",
        "tables": [
            {
                "table_id": "finqa_syn_001_t1",
                "caption": "Earnings Per Share Data",
                "headers": ["Year", "Basic EPS", "Diluted EPS"],
                "rows": [
                    ["2022", "2.50", "2.48"],
                    ["2023", "2.75", "2.72"]
                ]
            }
        ],
        "passages": [
            {
                "passage_id": "finqa_syn_001_p1",
                "text": (
                    "The company's earnings per share (EPS) benefited from share repurchases "
                    "executed throughout the fiscal year 2023."
                )
            }
        ],
        "answer_type": "number",
        "answer_value": "10",
        "answer_unit": "percent"
    },
    {
        "example_id": "finqa_syn_002",
        "split": "dev",
        "question_text": "Calculate the operating margin for 2023.",
        "tables": [
            {
                "table_id": "finqa_syn_002_t1",
                "caption": "Income Statement",
                "headers": ["Item", "2023 Amount (USD millions)"],
                "rows": [
                    ["Revenue", "5000"],
                    ["Operating Income", "1250"]
                ]
            }
        ],
        "passages": [],
        "answer_type": "number",
        "answer_value": "25",
        "answer_unit": "percent"
    }
]
finqa_raw_df = pd.DataFrame(finqa_data)

logger.info(f"Created TAT-QA DataFrame with {len(tatqa_raw_df)} rows.")
logger.info(f"Created Fin-QA DataFrame with {len(finqa_raw_df)} rows.")


# ------------------------------------------------------------------------------
# STEP 2: CONFIGURATION LOADING
# ------------------------------------------------------------------------------
# We load the study configuration from the 'config.yaml' file. This file contains
# all hyperparameters for compression, model selection, and evaluation.
# Ensure 'config.yaml' exists in your working directory

config_path = "config.yaml"

with open(config_path, "r") as f:
    study_config = yaml.safe_load(f)

logger.info("Configuration loaded successfully.")

# ------------------------------------------------------------------------------
# STEP 3: OFFLINE CORPUS PREPARATION (MOCK SETUP)
# ------------------------------------------------------------------------------
# The pipeline requires an offline corpus for static self-information.
# In a real deployment, this would be a large JSONL file containing Wikipedia,
# ShareGPT, and arXiv dumps.
#
# Format of "offline_corpus.jsonl":
# Each line is a JSON object with keys:
#   - "doc_id": Unique string ID (e.g., "wiki_123")
#   - "source": Source identifier ("wikipedia", "sharegpt", "arxiv")
#   - "title": Document title (optional)
#   - "text": The raw text content of the document.
#
# For this example, we will create a dummy corpus file if it doesn't exist
# to allow the pipeline to run.

corpus_dir = "compactprompt_outputs"
if not os.path.exists(corpus_dir):
    os.makedirs(corpus_dir)

corpus_path = os.path.join(corpus_dir, "offline_corpus.jsonl")
if not os.path.exists(corpus_path):
    logger.info("Creating dummy offline corpus for demonstration...")
    dummy_corpus = [
        {"doc_id": "wiki_1", "source": "wikipedia", "title": "Finance", "text": "Finance is the study of money and currency."},
        {"doc_id": "wiki_2", "source": "wikipedia", "title": "Revenue", "text": "Revenue is the income that a business has from its normal business activities."},
        {"doc_id": "sharegpt_1", "source": "sharegpt", "title": "Chat 1", "text": "User: What is EPS? Assistant: Earnings Per Share."},
        {"doc_id": "arxiv_1", "source": "arxiv", "title": "Transformers", "text": "Attention is all you need."}
    ]
    with open(corpus_path, "w") as f:
        for doc in dummy_corpus:
            f.write(json.dumps(doc) + "\n")

# ------------------------------------------------------------------------------
# STEP 4: PIPELINE EXECUTION
# ------------------------------------------------------------------------------
# We now invoke the orchestrator function.
#
# CRITICAL NOTE ON MODEL NAMES:
# The user MUST verify that the model names specified in `target_llm` and `scorer_llm`
# match the exact identifiers required by their API providers (OpenAI, Anthropic, Together).
# The config file provides defaults, but API updates may change these strings.
# Ensure your environment variables (OPENAI_API_KEY, etc.) are set before running.

logger.info("Initializing CompactPrompt Pipeline...")

try:
    # We use the 'compressed_plus_data' condition to demonstrate the full suite
    # of compression techniques (hard pruning + n-gram abbreviation + quantization).
    # We skip human evaluation here for automated execution, but in a real study,
    # set skip_human_eval=False to run the LLM-proxy annotation protocol.
    
    orchestrator_output, orchestrator_log = run_compactprompt_pipeline(
        tatqa_raw_df=tatqa_raw_df,
        finqa_raw_df=finqa_raw_df,
        study_config=study_config,
        condition="compressed_plus_data",
        target_llm="gpt-4o",          # Ensure this matches your API access
        scorer_llm="gpt-4.1-mini",    # Efficient scorer
        ngram_params={"top_n_T": 3, "ngram_size_G": 2}, # Best config from paper
        quantization_params={"bit_width": 8},           # Uniform 8-bit quantization
        output_dir="compactprompt_outputs",
        skip_corpus_build=False,      # Build stats from our dummy corpus
        skip_human_eval=True,         # Skip human eval for this demo run
        random_seed=42
    )

    # --------------------------------------------------------------------------
    # STEP 5: RESULTS INSPECTION
    # --------------------------------------------------------------------------
    logger.info("Pipeline execution successful.")
    
    # Accessing Metrics
    metrics = orchestrator_output.metrics
    logger.info("=== Pipeline Metrics ===")
    logger.info(f"Mean Compression Ratio: {metrics['mean_compression_ratio']:.2f}x")
    logger.info(f"Mean Cosine Similarity: {metrics['mean_cosine_similarity']:.4f}")
    logger.info(f"5th Percentile Similarity: {metrics['percentile_5_similarity']:.4f}")
    
    # Accessing Compressed Data
    # Example: Inspecting the first compressed prompt
    first_id = list(orchestrator_output.pruned_prompts.keys())[0]
    compressed_data = orchestrator_output.pruned_prompts[first_id]
    
    logger.info(f"=== Example Compression ({first_id}) ===")
    logger.info(f"Original Tokens: {compressed_data['original_tokens']}")
    logger.info(f"Compressed Tokens: {compressed_data['compressed_tokens']}")
    logger.info(f"Ratio: {compressed_data['compression_ratio']:.2f}x")
    # logger.info(f"Compressed Text Snippet: {compressed_data['compressed_text'][:200]}...")

except Exception as e:
    logger.error(f"Pipeline execution failed: {e}")
    # In a real scenario, inspect orchestrator_log for detailed error trace
    raise
```

In [None]:
# Task 1 – Validate `tatqa_raw_df` Schema and Content Quality

@dataclass
class ValidationError:
    """
    Dataclass to represent a single validation error found in the dataset.

    Attributes:
        row_index (int): The index of the row in the DataFrame where the error occurred.
        example_id (Optional[str]): The example_id of the row, if available.
        column (str): The name of the column where the error was found.
        error_type (str): A categorical description of the error (e.g., 'missing_column', 'dtype_mismatch').
        message (str): A detailed description of the error.
    """
    row_index: int
    column: str
    error_type: str
    message: str
    example_id: Optional[str] = None

@dataclass
class ValidationReport:
    """
    Dataclass to aggregate validation results.

    Attributes:
        is_valid (bool): True if no critical errors were found, False otherwise.
        errors (List[ValidationError]): A list of all validation errors encountered.
        missing_columns (List[str]): List of required columns missing from the DataFrame.
        extra_columns (List[str]): List of unexpected columns present in the DataFrame.
    """
    is_valid: bool = True
    errors: List[ValidationError] = field(default_factory=list)
    missing_columns: List[str] = field(default_factory=list)
    extra_columns: List[str] = field(default_factory=list)

    def add_error(self, error: ValidationError):
        """Adds an error to the report and marks the dataset as invalid."""
        self.errors.append(error)
        self.is_valid = False

# ==============================================================================
# Task 1: Validate tatqa_raw_df Schema and Content Quality
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 1, Step 1: Validate column presence and types
# -------------------------------------------------------------------------------------------------------------------------------
def validate_tatqa_columns(df: pd.DataFrame) -> ValidationReport:
    """
    Validates that the TAT-QA DataFrame contains exactly the required columns and that
    their element-level data types are correct based on a random sample.

    This function performs two levels of validation:
    1. Schema Existence: Checks if all required columns are present.
    2. Type Integrity: Samples non-null values from each column to ensure they match
       the expected Python native type (e.g., 'example_id' must be a string,
       'tables' must be a list). This detects mixed-type columns that Pandas
       might label generically as 'object'.

    Args:
        df (pd.DataFrame): The raw TAT-QA input DataFrame.

    Returns:
        ValidationReport: A report containing missing/extra columns and type errors.
    """
    # Initialize the validation report object
    report = ValidationReport()

    # Define the mapping of required columns to their expected native Python types
    # Note: 'answer_unit' is optional (can be None), but if present, must be a string.
    # We handle None values by filtering them out before type checking.
    column_type_map: Dict[str, type] = {
        "example_id": str,
        "split": str,
        "question_text": str,
        "tables": list,      # Complex structure (list of dicts)
        "passages": list,    # Complex structure (list of dicts)
        "answer_type": str,
        "answer_value": str, # Raw answers are strings before parsing
        "answer_unit": str
    }

    # Extract the set of required column names
    required_columns: Set[str] = set(column_type_map.keys())

    # Get the actual columns present in the DataFrame
    actual_columns: Set[str] = set(df.columns)


    # Check for Missing Columns
    # Calculate set difference: required - actual
    missing = required_columns - actual_columns

    # If any required columns are missing, log a critical error and return immediately
    # as we cannot validate types for missing columns.
    if missing:
        report.missing_columns = list(missing)
        report.is_valid = False
        logger.error(f"Missing required columns in TAT-QA dataset: {missing}")
        return report


    # Check for Extra Columns
    # Calculate set difference: actual - required
    extra = actual_columns - required_columns

    # Log extra columns as warnings (does not invalidate the dataset strictly,
    # but indicates schema deviation).
    if extra:
        report.extra_columns = list(extra)
        logger.warning(f"Extra columns detected in TAT-QA dataset: {extra}")


    # Element-Level Type Validation
    # Iterate through each required column to verify data types
    for col_name, expected_type in column_type_map.items():
        # Extract the series for the current column
        series = df[col_name]

        # Drop null values (NaN, None) to check only actual data.
        # If a column allows nulls (like answer_unit), we only validate the non-null entries.
        non_null_series = series.dropna()

        # If the column is empty or contains only nulls, we cannot validate types.
        # This is technically valid (no invalid types exist).
        if non_null_series.empty:
            continue

        # Select a random sample of up to 5 elements to verify.
        # Sampling reduces overhead compared to checking every row, while still
        # catching systematic type issues (e.g., integers in a string column).
        sample_size = min(5, len(non_null_series))
        sample = non_null_series.sample(n=sample_size, random_state=42)

        # Iterate through the sample elements
        for idx, value in sample.items():
            # Check if the value matches the expected native type
            if not isinstance(value, expected_type):
                # Construct a detailed error message
                actual_type_name = type(value).__name__
                expected_type_name = expected_type.__name__

                # Create the validation error record
                error = ValidationError(
                    row_index=int(idx), # Cast to int for serializability
                    column=col_name,
                    error_type="dtype_mismatch",
                    message=(
                        f"Column '{col_name}' contains invalid type at index {idx}. "
                        f"Expected '{expected_type_name}', got '{actual_type_name}'."
                    ),
                    example_id=df.at[idx, "example_id"] if "example_id" in df.columns else None
                )

                # Add error to report and log it
                report.add_error(error)
                logger.error(f"Type validation failed for column '{col_name}': {error.message}")

                # Break after finding one type error in a column to avoid flooding logs
                break

    return report

# -------------------------------------------------------------------------------------------------------------------------------
# Task 1, Step 2: Validate content integrity for identifying fields
# -------------------------------------------------------------------------------------------------------------------------------
def validate_tatqa_identifying_fields(df: pd.DataFrame, report: ValidationReport) -> ValidationReport:
    """
    Validates the integrity of 'example_id', 'split', and 'question_text'.
    Checks for nulls, uniqueness within splits, and empty strings.

    Args:
        df (pd.DataFrame): The raw TAT-QA DataFrame.
        report (ValidationReport): The existing report to append errors to.

    Returns:
        ValidationReport: The updated validation report.
    """
    # 1. Validate 'example_id'
    # Check for null values
    null_ids = df[df["example_id"].isnull()]
    for idx in null_ids.index:
        report.add_error(ValidationError(
            row_index=idx,
            column="example_id",
            error_type="null_value",
            message="example_id is None or NaN"
        ))

    # Check for uniqueness within each split
    # Group by split and count example_ids
    if "split" in df.columns and "example_id" in df.columns:
        # Filter out nulls before checking duplicates to avoid noise
        valid_df = df.dropna(subset=["example_id", "split"])

        # Find duplicates: boolean mask where True indicates a duplicate
        duplicates = valid_df.duplicated(subset=["split", "example_id"], keep=False)

        if duplicates.any():
            dup_rows = valid_df[duplicates]
            for idx, row in dup_rows.iterrows():
                report.add_error(ValidationError(
                    row_index=idx,
                    example_id=row["example_id"],
                    column="example_id",
                    error_type="duplicate_id",
                    message=f"Duplicate example_id '{row['example_id']}' found in split '{row['split']}'"
                ))

    # 2. Validate 'split'
    # Allowed values set
    allowed_splits = {"train", "dev", "test"}

    if "split" in df.columns:
        # Identify rows with invalid split values
        # We use apply to check membership for each element
        invalid_split_mask = ~df["split"].isin(allowed_splits)
        invalid_split_rows = df[invalid_split_mask]

        for idx, row in invalid_split_rows.iterrows():
            report.add_error(ValidationError(
                row_index=idx,
                example_id=row.get("example_id"), # Use .get() in case it's null
                column="split",
                error_type="invalid_value",
                message=f"Invalid split value: '{row['split']}'. Expected one of {allowed_splits}"
            ))

    # 3. Validate 'question_text'
    if "question_text" in df.columns:
        # Check for nulls
        null_q = df[df["question_text"].isnull()]
        for idx in null_q.index:
            report.add_error(ValidationError(
                row_index=idx,
                example_id=df.at[idx, "example_id"] if "example_id" in df.columns else None,
                column="question_text",
                error_type="null_value",
                message="question_text is None"
            ))

        # Check for empty or whitespace-only strings
        # Ensure we only check strings (exclude the nulls we just found)
        non_null_q = df.dropna(subset=["question_text"])
        # Strip whitespace and check length
        empty_q_mask = non_null_q["question_text"].astype(str).str.strip().str.len() == 0
        empty_q_rows = non_null_q[empty_q_mask]

        for idx, row in empty_q_rows.iterrows():
            report.add_error(ValidationError(
                row_index=idx,
                example_id=row.get("example_id"),
                column="question_text",
                error_type="empty_string",
                message="question_text is empty or whitespace only"
            ))

    return report

# -------------------------------------------------------------------------------------------------------------------------------
# Task 1, Step 3: Validate complex column structures
# -------------------------------------------------------------------------------------------------------------------------------
def validate_tatqa_complex_structures(df: pd.DataFrame, report: ValidationReport) -> ValidationReport:
    """
    Validates the internal structure of 'tables', 'passages', and answer fields.
    Enforces schema constraints on nested lists and dictionaries.

    Args:
        df (pd.DataFrame): The raw TAT-QA DataFrame.
        report (ValidationReport): The existing report to append errors to.

    Returns:
        ValidationReport: The updated validation report.
    """

    # Helper function to validate a single table dictionary
    def validate_single_table(tbl: Any, row_idx: int, ex_id: Optional[str]) -> List[ValidationError]:
        errors = []
        if not isinstance(tbl, dict):
            errors.append(ValidationError(row_idx, "tables", "invalid_type", f"Table entry is not a dict: {type(tbl)}", ex_id))
            return errors

        # Check required keys
        required_keys = {"table_id", "caption", "headers", "rows"}
        if not required_keys.issubset(tbl.keys()):
            missing = required_keys - tbl.keys()
            errors.append(ValidationError(row_idx, "tables", "missing_keys", f"Table missing keys: {missing}", ex_id))
            return errors

        # Validate headers
        headers = tbl.get("headers")
        if not isinstance(headers, list):
            errors.append(ValidationError(row_idx, "tables", "invalid_headers", "Headers is not a list", ex_id))
            return errors

        # Validate rows
        rows = tbl.get("rows")
        if not isinstance(rows, list):
            errors.append(ValidationError(row_idx, "tables", "invalid_rows", "Rows is not a list", ex_id))
            return errors

        # Validate row consistency
        header_len = len(headers)
        for i, r in enumerate(rows):
            if not isinstance(r, list):
                errors.append(ValidationError(row_idx, "tables", "invalid_row_type", f"Row {i} is not a list", ex_id))
                continue
            if len(r) != header_len:
                errors.append(ValidationError(row_idx, "tables", "row_length_mismatch",
                                              f"Row {i} length {len(r)} != header length {header_len}", ex_id))
        return errors

    # Helper function to validate a single passage dictionary
    def validate_single_passage(psg: Any, row_idx: int, ex_id: Optional[str]) -> List[ValidationError]:
        errors = []
        if not isinstance(psg, dict):
            errors.append(ValidationError(row_idx, "passages", "invalid_type", f"Passage entry is not a dict: {type(psg)}", ex_id))
            return errors

        if "passage_id" not in psg or "text" not in psg:
            errors.append(ValidationError(row_idx, "passages", "missing_keys", "Passage missing 'passage_id' or 'text'", ex_id))
            return errors

        text = psg.get("text")
        if not isinstance(text, str) or not text.strip():
            errors.append(ValidationError(row_idx, "passages", "empty_text", "Passage text is empty or not a string", ex_id))

        return errors

    # Iterate over rows to validate complex structures
    # We use iterrows() here because we need to inspect deep structures which is hard to vectorize
    for idx, row in df.iterrows():
        ex_id = row.get("example_id")

        # 1. Validate 'tables'
        tables = row.get("tables")
        if not isinstance(tables, list):
            report.add_error(ValidationError(idx, "tables", "invalid_type", f"Expected list of dicts, got {type(tables)}", ex_id))
        else:
            for tbl in tables:
                tbl_errors = validate_single_table(tbl, idx, ex_id)
                for err in tbl_errors:
                    report.add_error(err)

        # 2. Validate 'passages'
        passages = row.get("passages")
        if not isinstance(passages, list):
            report.add_error(ValidationError(idx, "passages", "invalid_type", f"Expected list of dicts, got {type(passages)}", ex_id))
        else:
            for psg in passages:
                psg_errors = validate_single_passage(psg, idx, ex_id)
                for err in psg_errors:
                    report.add_error(err)

        # 3. Validate Answer Fields
        # answer_type
        ans_type = row.get("answer_type")
        if not isinstance(ans_type, str):
             report.add_error(ValidationError(idx, "answer_type", "invalid_type", f"Expected string, got {type(ans_type)}", ex_id))

        # answer_value
        ans_val = row.get("answer_value")
        if not isinstance(ans_val, str) or not ans_val.strip():
             report.add_error(ValidationError(idx, "answer_value", "empty_or_invalid", "answer_value is empty or not a string", ex_id))

        # answer_unit (can be None or string)
        ans_unit = row.get("answer_unit")
        if ans_unit is not None and not isinstance(ans_unit, str):
             report.add_error(ValidationError(idx, "answer_unit", "invalid_type", f"Expected string or None, got {type(ans_unit)}", ex_id))

    return report

# -------------------------------------------------------------------------------------------------------------------------------
# Task 1, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def validate_tatqa_dataset(tatqa_raw_df: pd.DataFrame) -> ValidationReport:
    """
    Orchestrates the full validation pipeline for the TAT-QA raw dataset.

    This function executes three granular validation steps:
    1. Column presence and high-level type check.
    2. Content integrity check for identifying fields (IDs, splits, questions).
    3. Structural validation for complex nested fields (tables, passages, answers).

    Args:
        tatqa_raw_df (pd.DataFrame): The raw input DataFrame for TAT-QA.

    Returns:
        ValidationReport: A comprehensive report detailing validity status and any errors found.
    """
    logger.info("Starting TAT-QA dataset validation...")

    # Step 1: Column Schema Validation
    report = validate_tatqa_columns(tatqa_raw_df)

    # If critical columns are missing, we stop early to avoid KeyErrors in subsequent steps
    if report.missing_columns:
        logger.error("Critical columns missing. Aborting further validation.")
        return report

    # Step 2: Identifying Fields Validation
    report = validate_tatqa_identifying_fields(tatqa_raw_df, report)

    # Step 3: Complex Structure Validation
    report = validate_tatqa_complex_structures(tatqa_raw_df, report)

    if report.is_valid:
        logger.info("TAT-QA dataset validation passed successfully.")
    else:
        logger.warning(f"TAT-QA dataset validation failed with {len(report.errors)} errors.")

    return report


In [None]:
# Task 2 – Validate finqa_raw_df Schema and Content Quality

# ==============================================================================
# Task 2: Validate finqa_raw_df Schema and Content Quality
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 2, Step 1: Validate column presence and types
# -------------------------------------------------------------------------------------------------------------------------------
def validate_finqa_columns(df: pd.DataFrame) -> ValidationReport:
    """
    Validates that the Fin-QA DataFrame contains exactly the required columns and that
    their element-level data types are correct based on a random sample.

    This function performs two levels of validation:
    1. Schema Existence: Checks if all required columns are present.
    2. Type Integrity: Samples non-null values from each column to ensure they match
       the expected Python native type. This detects mixed-type columns that Pandas
       might label generically as 'object'.

    Args:
        df (pd.DataFrame): The raw Fin-QA input DataFrame.

    Returns:
        ValidationReport: A report containing missing/extra columns and type errors.
    """
    # Initialize the validation report object
    report = ValidationReport()

    # Define the mapping of required columns to their expected native Python types
    # Fin-QA schema mirrors TAT-QA but is applied to financial report data.
    column_type_map: Dict[str, type] = {
        "example_id": str,
        "split": str,
        "question_text": str,
        "tables": list,      # Complex structure (list of dicts)
        "passages": list,    # Complex structure (list of dicts)
        "answer_type": str,
        "answer_value": str, # Raw answers are strings before parsing
        "answer_unit": str
    }

    # Extract the set of required column names
    required_columns: Set[str] = set(column_type_map.keys())

    # Get the actual columns present in the DataFrame
    actual_columns: Set[str] = set(df.columns)


    # Step 1: Check for Missing Columns
    # Calculate set difference: required - actual
    missing = required_columns - actual_columns

    # If any required columns are missing, log a critical error and return immediately
    if missing:
        report.missing_columns = list(missing)
        report.is_valid = False
        logger.error(f"Missing required columns in Fin-QA dataset: {missing}")
        return report


    # Step 2: Check for Extra Columns
    # Calculate set difference: actual - required
    extra = actual_columns - required_columns

    # Log extra columns as warnings. Fin-QA often has 'program' or 'table_ori'.
    if extra:
        report.extra_columns = list(extra)
        logger.warning(f"Extra columns detected in Fin-QA dataset: {extra}")


    # Step 3: Element-Level Type Validation
    # Iterate through each required column to verify data types
    for col_name, expected_type in column_type_map.items():
        # Extract the series for the current column
        series = df[col_name]

        # Drop null values (NaN, None) to check only actual data.
        non_null_series = series.dropna()

        # If the column is empty or contains only nulls, skip validation.
        if non_null_series.empty:
            continue

        # Select a random sample of up to 5 elements to verify.
        sample_size = min(5, len(non_null_series))
        sample = non_null_series.sample(n=sample_size, random_state=42)

        # Iterate through the sample elements
        for idx, value in sample.items():
            # Check if the value matches the expected native type
            if not isinstance(value, expected_type):
                # Construct a detailed error message
                actual_type_name = type(value).__name__
                expected_type_name = expected_type.__name__

                # Create the validation error record
                error = ValidationError(
                    row_index=int(idx),
                    column=col_name,
                    error_type="dtype_mismatch",
                    message=(
                        f"Column '{col_name}' contains invalid type at index {idx}. "
                        f"Expected '{expected_type_name}', got '{actual_type_name}'."
                    ),
                    example_id=df.at[idx, "example_id"] if "example_id" in df.columns else None
                )

                # Add error to report and log it
                report.add_error(error)
                logger.error(f"Type validation failed for column '{col_name}': {error.message}")

                # Break after finding one type error in a column
                break

    return report

# -------------------------------------------------------------------------------------------------------------------------------
# Task 2, Step 2: Validate content integrity for identifying fields and answer types
# -------------------------------------------------------------------------------------------------------------------------------
def validate_finqa_content_integrity(df: pd.DataFrame, report: ValidationReport) -> ValidationReport:
    """
    Validates 'example_id', 'split', 'answer_type', and 'answer_value' integrity.
    Checks for uniqueness, valid splits, numeric answer types, and parseable answer values.

    Args:
        df (pd.DataFrame): The raw Fin-QA DataFrame.
        report (ValidationReport): The existing report to append errors to.

    Returns:
        ValidationReport: The updated validation report.
    """
    # 1. Validate 'example_id'
    null_ids = df[df["example_id"].isnull()]
    for idx in null_ids.index:
        report.add_error(ValidationError(idx, "example_id", "null_value", "example_id is None or NaN"))

    if "split" in df.columns and "example_id" in df.columns:
        valid_df = df.dropna(subset=["example_id", "split"])
        duplicates = valid_df.duplicated(subset=["split", "example_id"], keep=False)
        if duplicates.any():
            dup_rows = valid_df[duplicates]
            for idx, row in dup_rows.iterrows():
                report.add_error(ValidationError(
                    row_index=idx,
                    example_id=row["example_id"],
                    column="example_id",
                    error_type="duplicate_id",
                    message=f"Duplicate example_id '{row['example_id']}' found in split '{row['split']}'"
                ))

    # 2. Validate 'split'
    allowed_splits = {"train", "dev", "test"}
    if "split" in df.columns:
        invalid_split_mask = ~df["split"].isin(allowed_splits)
        for idx, row in df[invalid_split_mask].iterrows():
            report.add_error(ValidationError(
                row_index=idx,
                example_id=row.get("example_id"),
                column="split",
                error_type="invalid_value",
                message=f"Invalid split value: '{row['split']}'"
            ))

    # 3. Validate 'answer_type'
    # Fin-QA is predominantly numeric. We expect 'number' or similar.
    if "answer_type" in df.columns:
        # Check for unexpected types. We don't strictly fail, but we log non-number types if they are rare/unexpected.
        # Assuming 'number' is the standard.
        non_number_mask = df["answer_type"] != "number"
        # This might be too strict if the dataset has text answers, but per instructions "Often 'number' in Fin-QA".
        # We will just validate it is a string.
        invalid_type_mask = df["answer_type"].apply(lambda x: not isinstance(x, str))
        for idx, row in df[invalid_type_mask].iterrows():
             report.add_error(ValidationError(
                row_index=idx,
                example_id=row.get("example_id"),
                column="answer_type",
                error_type="invalid_type",
                message=f"answer_type is not a string: {type(row['answer_type'])}"
            ))

    # 4. Validate 'answer_value' parseability
    # We attempt to parse as float to flag potential data quality issues early.
    if "answer_value" in df.columns:
        def is_parseable(val: Any) -> bool:
            if not isinstance(val, str): return False
            # Remove commas, handle percent
            clean_val = val.replace(',', '').replace('%', '').strip()
            try:
                float(clean_val)
                return True
            except ValueError:
                return False

        # We only check rows where answer_type is 'number' to avoid false positives on text answers
        number_rows = df[df["answer_type"] == "number"]
        unparseable_mask = ~number_rows["answer_value"].apply(is_parseable)

        for idx, row in number_rows[unparseable_mask].iterrows():
            # We log this as a warning/error but don't necessarily fail the whole pipeline
            # as Task 5 will handle cleansing. However, validation should report it.
            report.add_error(ValidationError(
                row_index=idx,
                example_id=row.get("example_id"),
                column="answer_value",
                error_type="unparseable_number",
                message=f"answer_value '{row['answer_value']}' could not be parsed as float"
            ))

    return report

# -------------------------------------------------------------------------------------------------------------------------------
# Task 2, Step 3: Validate complex column structures
# -------------------------------------------------------------------------------------------------------------------------------
def validate_finqa_complex_structures(df: pd.DataFrame, report: ValidationReport) -> ValidationReport:
    """
    Validates the internal structure of 'tables' and 'passages' columns for the Fin-QA dataset.

    This function iterates through each row of the DataFrame to enforce strict schema constraints
    on nested list-of-dictionary structures. It ensures that 'tables' and 'passages' contain
    the required keys and correct data types, logging any deviations as validation errors.

    Args:
        df (pd.DataFrame): The raw Fin-QA input DataFrame containing 'tables' and 'passages' columns.
        report (ValidationReport): An existing ValidationReport object to which new errors will be appended.

    Returns:
        ValidationReport: The updated validation report containing any structural errors found.
    """

    def validate_single_table(tbl: Any, row_idx: int, ex_id: Optional[str]) -> List[ValidationError]:
        """
        Validates the structure of a single table dictionary.

        Args:
            tbl (Any): The table object to validate (expected to be a dict).
            row_idx (int): The index of the row in the DataFrame.
            ex_id (Optional[str]): The example_id associated with the row.

        Returns:
            List[ValidationError]: A list of validation errors found within the table structure.
        """
        # Initialize an empty list to collect errors for this specific table
        errors: List[ValidationError] = []

        # Check if the table entry is a dictionary
        if not isinstance(tbl, dict):
            # Log error if type is incorrect
            errors.append(ValidationError(
                row_index=row_idx,
                column="tables",
                error_type="invalid_type",
                message=f"Table entry is not a dict: {type(tbl)}",
                example_id=ex_id
            ))
            return errors

        # Define the set of required keys for a table object
        required_keys: Set[str] = {"table_id", "caption", "headers", "rows"}

        # Check if all required keys are present in the table dictionary
        if not required_keys.issubset(tbl.keys()):
            # Identify missing keys
            missing = required_keys - tbl.keys()
            # Log error for missing keys
            errors.append(ValidationError(
                row_index=row_idx,
                column="tables",
                error_type="missing_keys",
                message=f"Table missing keys: {missing}",
                example_id=ex_id
            ))
            return errors

        # Retrieve and validate the 'headers' field
        headers = tbl.get("headers")
        if not isinstance(headers, list):
            # Log error if headers is not a list
            errors.append(ValidationError(
                row_index=row_idx,
                column="tables",
                error_type="invalid_headers",
                message="Headers is not a list",
                example_id=ex_id
            ))
            return errors

        # Retrieve and validate the 'rows' field
        rows = tbl.get("rows")
        if not isinstance(rows, list):
            # Log error if rows is not a list
            errors.append(ValidationError(
                row_index=row_idx,
                column="tables",
                error_type="invalid_rows",
                message="Rows is not a list",
                example_id=ex_id
            ))
            return errors

        # Determine the expected length of each row based on the headers
        header_len = len(headers)

        # Iterate through each row to validate its structure and length
        for i, r in enumerate(rows):
            # Check if the row is a list
            if not isinstance(r, list):
                # Log error if a row is not a list
                errors.append(ValidationError(
                    row_index=row_idx,
                    column="tables",
                    error_type="invalid_row_type",
                    message=f"Row {i} is not a list",
                    example_id=ex_id
                ))
                continue

            # Check if the row length matches the header length
            if len(r) != header_len:
                # Log error for row length mismatch
                errors.append(ValidationError(
                    row_index=row_idx,
                    column="tables",
                    error_type="row_length_mismatch",
                    message=f"Row {i} length {len(r)} != header length {header_len}",
                    example_id=ex_id
                ))

        return errors

    def validate_single_passage(psg: Any, row_idx: int, ex_id: Optional[str]) -> List[ValidationError]:
        """
        Validates the structure of a single passage dictionary.

        Args:
            psg (Any): The passage object to validate (expected to be a dict).
            row_idx (int): The index of the row in the DataFrame.
            ex_id (Optional[str]): The example_id associated with the row.

        Returns:
            List[ValidationError]: A list of validation errors found within the passage structure.
        """
        # Initialize an empty list to collect errors for this specific passage
        errors: List[ValidationError] = []

        # Check if the passage entry is a dictionary
        if not isinstance(psg, dict):
            # Log error if type is incorrect
            errors.append(ValidationError(
                row_index=row_idx,
                column="passages",
                error_type="invalid_type",
                message=f"Passage entry is not a dict: {type(psg)}",
                example_id=ex_id
            ))
            return errors

        # Check for required keys 'passage_id' and 'text'
        if "passage_id" not in psg or "text" not in psg:
            # Log error for missing keys
            errors.append(ValidationError(
                row_index=row_idx,
                column="passages",
                error_type="missing_keys",
                message="Passage missing 'passage_id' or 'text'",
                example_id=ex_id
            ))
            return errors

        # Retrieve and validate the 'text' field
        text = psg.get("text")
        # Ensure text is a string and is not empty or whitespace-only
        if not isinstance(text, str) or not text.strip():
            # Log error for invalid text content
            errors.append(ValidationError(
                row_index=row_idx,
                column="passages",
                error_type="empty_text",
                message="Passage text is empty or not a string",
                example_id=ex_id
            ))

        return errors

    # Iterate over each row in the DataFrame to validate complex structures
    for idx, row in df.iterrows():
        # Retrieve the example_id for logging purposes
        ex_id = row.get("example_id")

        # 1. Validate 'tables' column
        tables = row.get("tables")
        # Check if 'tables' is a list
        if not isinstance(tables, list):
            # Log error if 'tables' is not a list
            report.add_error(ValidationError(
                row_index=idx,
                column="tables",
                error_type="invalid_type",
                message=f"Expected list of dicts, got {type(tables)}",
                example_id=ex_id
            ))
        else:
            # Iterate through each table in the list and validate it
            for tbl in tables:
                tbl_errors = validate_single_table(tbl, idx, ex_id)
                # Add any found errors to the main report
                for err in tbl_errors:
                    report.add_error(err)

        # 2. Validate 'passages' column
        passages = row.get("passages")
        # Check if 'passages' is a list
        if not isinstance(passages, list):
            # Log error if 'passages' is not a list
            report.add_error(ValidationError(
                row_index=idx,
                column="passages",
                error_type="invalid_type",
                message=f"Expected list of dicts, got {type(passages)}",
                example_id=ex_id
            ))
        else:
            # Iterate through each passage in the list and validate it
            for psg in passages:
                psg_errors = validate_single_passage(psg, idx, ex_id)
                # Add any found errors to the main report
                for err in psg_errors:
                    report.add_error(err)

    return report

# -------------------------------------------------------------------------------------------------------------------------------
# Task 2, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def validate_finqa_dataset(finqa_raw_df: pd.DataFrame) -> ValidationReport:
    """
    Orchestrates the full validation pipeline for the Fin-QA raw dataset.

    Executes granular validation steps:
    1. Column presence and high-level type check.
    2. Content integrity check for identifying fields and numeric answer validity.
    3. Structural validation for complex nested fields (tables, passages).

    Args:
        finqa_raw_df (pd.DataFrame): The raw input DataFrame for Fin-QA.

    Returns:
        ValidationReport: A comprehensive report detailing validity status and any errors found.
    """
    logger.info("Starting Fin-QA dataset validation...")

    # Step 1: Column Schema Validation
    report = validate_finqa_columns(finqa_raw_df)

    if report.missing_columns:
        logger.error("Critical columns missing in Fin-QA. Aborting further validation.")
        return report

    # Step 2: Identifying Fields and Content Validation
    report = validate_finqa_content_integrity(finqa_raw_df, report)

    # Step 3: Complex Structure Validation
    report = validate_finqa_complex_structures(finqa_raw_df, report)

    if report.is_valid:
        logger.info("Fin-QA dataset validation passed successfully.")
    else:
        logger.warning(f"Fin-QA dataset validation failed with {len(report.errors)} errors.")

    return report


In [None]:
# Task 3 – Validate study_config Completeness and Consistency

# ==============================================================================
# Task 3: Validate study_config Completeness and Consistency
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 3, Step 1: Inventory all <REQUIRED_BY_IMPLEMENTER> placeholders
# -------------------------------------------------------------------------------------------------------------------------------
def inventory_placeholders(config: Dict[str, Any], path: str = "") -> List[str]:
    """
    Recursively traverses the configuration dictionary to identify all keys
    that still hold the placeholder value "<REQUIRED_BY_IMPLEMENTER>".

    This function handles nested dictionaries and lists, constructing a
    dot-separated path for each placeholder found. List indices are
    represented as `[index]`.

    Args:
        config (Dict[str, Any]): The configuration dictionary to inspect.
        path (str): The current path to the config node (used for recursion).
                    Defaults to empty string.

    Returns:
        List[str]: A list of paths pointing to unresolved placeholders.
                   Example: ["offline_corpus_config.tokenization_scheme.name"]
    """
    placeholders = []

    for key, value in config.items():
        # Construct the current path
        current_path = f"{path}.{key}" if path else key

        if isinstance(value, dict):
            # Recurse into nested dictionaries
            placeholders.extend(inventory_placeholders(value, current_path))
        elif isinstance(value, list):
            # Iterate through lists (e.g., list of model configs)
            for i, item in enumerate(value):
                list_item_path = f"{current_path}[{i}]"
                if isinstance(item, dict):
                    placeholders.extend(inventory_placeholders(item, list_item_path))
                elif item == "<REQUIRED_BY_IMPLEMENTER>":
                    placeholders.append(list_item_path)
        elif value == "<REQUIRED_BY_IMPLEMENTER>":
            # Found a placeholder
            placeholders.append(current_path)

    return placeholders

# -------------------------------------------------------------------------------------------------------------------------------
# Task 3, Step 2: Assign concrete values to required parameters
# -------------------------------------------------------------------------------------------------------------------------------
def resolve_placeholders(config: Dict[str, Any]) -> Dict[str, Any]:
    """
    Assigns concrete, scientifically justifiable default values to all
    required parameters identified as placeholders.

    This function mutates the configuration dictionary in-place (or returns the mutated version)
    to ensure the pipeline is executable. It uses a robust path-based setter to handle
    nested keys and list indices.

    Args:
        config (Dict[str, Any]): The configuration dictionary containing placeholders.

    Returns:
        Dict[str, Any]: The fully resolved configuration dictionary.
    """

    # Define the mapping of config paths to concrete values based on the paper's context
    defaults = {
        "offline_corpus_config.tokenization_scheme.name": "cl100k_base",
        "hard_prompt_compression_config.phrase_grouping.parser_library": "spacy_en_core_web_trf",
        "hard_prompt_compression_config.phrase_score_aggregation.chosen_method": "mean",
        "hard_prompt_compression_config.prompt_token_budget": 1500,
        "ngram_abbreviation_config.ngram_length.default": 2,
        "ngram_abbreviation_config.dictionary_size_K.default": 100,
        "numeric_quantization_config.uniform_integer.bit_width_b": 8,
        "numeric_quantization_config.kmeans_based.num_clusters_k": 16,
        "preprocessing_config.table_serialization.method": "markdown",
        "llm_config.decoding_parameters.temperature": 0.0,
        "llm_config.decoding_parameters.top_p": 1.0,
        "llm_config.decoding_parameters.max_tokens": 256
    }

    def set_value_by_path(cfg: Dict[str, Any], path: str, value: Any) -> None:
        """
        Helper to set a value in a nested dict using a dot-separated path.
        Handles list indices like 'key[0]'.

        Args:
            cfg (Dict[str, Any]): The dictionary to modify.
            path (str): The dot-separated path to the key.
            value (Any): The value to set.

        Raises:
            KeyError: If a path segment cannot be resolved.
            IndexError: If a list index is out of bounds.
            ValueError: If a path segment is invalid.
        """
        keys = path.split('.')
        current = cfg

        # Traverse to the parent of the target key
        for key in keys[:-1]:
            if '[' in key and ']' in key:
                # Handle list index: "key[index]"
                k, idx_str = key[:-1].split('[')
                idx = int(idx_str)
                current = current[k][idx]
            else:
                current = current[key]

        # Set the value at the target key
        last_key = keys[-1]
        if '[' in last_key and ']' in last_key:
             k, idx_str = last_key[:-1].split('[')
             idx = int(idx_str)
             current[k][idx] = value
        else:
            current[last_key] = value

    # Apply defaults
    for path, value in defaults.items():
        try:
            set_value_by_path(config, path, value)
            logger.info(f"Resolved placeholder '{path}' to '{value}'")
        except (KeyError, IndexError, ValueError) as e:
            logger.warning(f"Could not resolve path '{path}': {e}")

    return config

# -------------------------------------------------------------------------------------------------------------------------------
# Task 3, Step 3: Validate LLM access and API credentials
# -------------------------------------------------------------------------------------------------------------------------------
def validate_llm_config(config: Dict[str, Any]) -> bool:
    """
    Validates that LLM configurations are structurally sound and that
    API credentials are not left as default placeholders.

    This function checks:
    1. API credentials are not the default placeholder string.
    2. Scorer LLMs are defined and have valid model names.
    3. Target LLMs are defined and have valid model names.

    Args:
        config (Dict[str, Any]): The full study configuration.

    Returns:
        bool: True if LLM config is valid, False otherwise.
    """
    llm_config = config.get("llm_config", {})
    credentials = llm_config.get("api_credentials", {})

    is_valid = True

    # Check credentials
    for key, value in credentials.items():
        if value == "<EXTERNAL_SECRET_NOT_IN_PAPER>":
            logger.warning(f"Credential '{key}' is still set to placeholder. "
                           "Ensure environment variables or secrets are loaded in production.")
            # In a strict check, we might set is_valid = False, but for this exercise
            # we assume secrets are injected at runtime. We flag it.

    # Check Scorer LLMs
    scorers = llm_config.get("scorer_llm_options", [])
    if not scorers:
        logger.error("No scorer LLMs defined.")
        is_valid = False
    for i, scorer in enumerate(scorers):
        if not scorer.get("model_name"):
            logger.error(f"Scorer definition at index {i} missing model_name: {scorer}")
            is_valid = False

    # Check Target LLMs
    targets = llm_config.get("target_llms_for_evaluation", [])
    if not targets:
        logger.error("No target LLMs defined.")
        is_valid = False
    for i, target in enumerate(targets):
        if not target.get("model_name"):
            logger.error(f"Target definition at index {i} missing model_name: {target}")
            is_valid = False

    return is_valid

# -------------------------------------------------------------------------------------------------------------------------------
# Task 3, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def validate_and_update_study_config(study_config: Dict[str, Any]) -> Dict[str, Any]:
    """
    Orchestrates the validation and resolution of the study configuration.

    This function executes the following steps:
    1. Inventories all placeholders to identify missing configurations.
    2. Resolves placeholders with concrete default values.
    3. Validates the structural integrity of LLM configurations.

    Args:
        study_config (Dict[str, Any]): The initial configuration dictionary.

    Returns:
        Dict[str, Any]: The validated and fully resolved configuration dictionary.
    """
    logger.info("Starting study_config validation...")

    # Step 1: Inventory
    placeholders = inventory_placeholders(study_config)
    if placeholders:
        logger.info(f"Found {len(placeholders)} placeholders to resolve.")
    else:
        logger.info("No placeholders found.")

    # Step 2: Resolve
    # We resolve regardless of whether inventory found them, to ensure defaults are enforced
    # if the keys exist.
    resolved_config = resolve_placeholders(study_config)

    # Verify no placeholders remain
    remaining_placeholders = inventory_placeholders(resolved_config)
    if remaining_placeholders:
        logger.error(f"Unresolved placeholders remain: {remaining_placeholders}")
        # In a strict pipeline, we might raise an error here.

    # Step 3: LLM Validation
    if not validate_llm_config(resolved_config):
        logger.error("LLM configuration validation failed.")
    else:
        logger.info("LLM configuration validation passed.")

    logger.info("study_config validation complete.")
    return resolved_config


In [None]:
# Task 4 – Cleanse and Handle Missing Entries in tatqa_raw_df

@dataclass
class CleansingLog:
    """
    Dataclass to track the results of the cleansing process.

    Attributes:
        initial_rows (int): Number of rows in the raw DataFrame.
        rows_dropped_missing_id (int): Rows dropped due to missing/null example_id.
        rows_dropped_duplicate_id (int): Rows dropped due to duplicate example_id within a split.
        rows_dropped_empty_question (int): Rows dropped due to empty/whitespace-only question_text.
        rows_dropped_invalid_tables (int): Rows dropped because 'tables' was not a list or was empty.
        rows_dropped_no_valid_content (int): Rows dropped because they had neither valid tables nor valid passages after repair.
        final_rows (int): Number of rows in the cleansed DataFrame.
        dropped_ids (List[str]): List of example_ids that were dropped.
    """
    initial_rows: int = 0
    rows_dropped_missing_id: int = 0
    rows_dropped_duplicate_id: int = 0
    rows_dropped_empty_question: int = 0
    rows_dropped_invalid_tables: int = 0
    rows_dropped_no_valid_content: int = 0
    final_rows: int = 0
    dropped_ids: List[str] = field(default_factory=list)

# ==============================================================================
# Task 4: Cleanse and Handle Missing Entries in tatqa_raw_df
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 4, Step 1: Drop rows with critical missing fields
# -------------------------------------------------------------------------------------------------------------------------------
def drop_critical_missing_rows(df: pd.DataFrame, log: CleansingLog) -> pd.DataFrame:
    """
    Drops rows from the DataFrame that are missing critical identifying information
    or essential content fields required for downstream processing.

    Criteria for dropping:
    1. `example_id` is Null or NaN.
    2. `example_id` is a duplicate within its `split`.
    3. `question_text` is Null, empty, or whitespace-only.
    4. `tables` is Null, not a list, or an empty list.

    Args:
        df (pd.DataFrame): The raw TAT-QA DataFrame.
        log (CleansingLog): The logging object to update with drop statistics.

    Returns:
        pd.DataFrame: A filtered DataFrame with critical issues resolved.
    """
    initial_count = len(df)

    # 1. Drop missing example_id
    # We capture indices to drop to ensure we can log them if needed,
    # but for bulk operations boolean indexing is faster.
    missing_id_mask = df["example_id"].isnull()
    log.rows_dropped_missing_id = missing_id_mask.sum()
    df = df[~missing_id_mask].copy()

    # 2. Drop duplicate example_id within split
    # We keep the first occurrence and drop subsequent ones.
    # Note: We assume 'split' exists and is valid (validated in Task 1).
    # If 'split' is missing, we treat global uniqueness or just ignore split grouping for those rows.
    # Here we strictly group by split.
    duplicate_mask = df.duplicated(subset=["split", "example_id"], keep="first")

    # Log dropped IDs for duplicates
    dropped_dups = df[duplicate_mask]["example_id"].tolist()
    log.dropped_ids.extend(dropped_dups)
    log.rows_dropped_duplicate_id = len(dropped_dups)

    df = df[~duplicate_mask].copy()

    # 3. Drop empty/whitespace question_text
    # Ensure string type before stripping
    # We handle NaN questions here as well
    def is_valid_question(q: Any) -> bool:
        if not isinstance(q, str):
            return False
        return len(q.strip()) > 0

    valid_question_mask = df["question_text"].apply(is_valid_question)
    dropped_questions = df[~valid_question_mask]["example_id"].tolist()
    log.dropped_ids.extend(dropped_questions)
    log.rows_dropped_empty_question = len(dropped_questions)

    df = df[valid_question_mask].copy()

    # 4. Drop invalid 'tables' field
    # Must be a non-empty list
    def is_valid_tables_field(t: Any) -> bool:
        if not isinstance(t, list):
            return False
        return len(t) > 0

    valid_tables_mask = df["tables"].apply(is_valid_tables_field)
    dropped_tables = df[~valid_tables_mask]["example_id"].tolist()
    log.dropped_ids.extend(dropped_tables)
    log.rows_dropped_invalid_tables = len(dropped_tables)

    df = df[valid_tables_mask].copy()

    return df

# -------------------------------------------------------------------------------------------------------------------------------
# Task 4, Step 2: Repair or drop malformed table entries
# -------------------------------------------------------------------------------------------------------------------------------
def repair_malformed_tables(df: pd.DataFrame) -> pd.DataFrame:
    """
    Iterates through the 'tables' column of each row to validate and repair individual table structures.

    Repair Logic:
    - A table is discarded if it lacks 'headers' (list of strings) or 'rows' (list of lists).
    - Rows within a table are repaired:
        - If a row is longer than headers, it is truncated.
        - If a row is shorter than headers, it is padded with empty strings (to preserve structure).
        - If a row is not a list, it is discarded.
    - If a table ends up with no valid rows, it is discarded.
    - If an example ends up with no valid tables after repairs, the 'tables' field becomes empty
      (which will be caught in the final cleanup step).

    Args:
        df (pd.DataFrame): The DataFrame filtered from Step 1.

    Returns:
        pd.DataFrame: The DataFrame with repaired 'tables' structures.
    """

    def repair_single_row_tables(tables_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        valid_tables = []

        for tbl in tables_list:
            # Basic type check (should pass due to Step 1, but defensive coding)
            if not isinstance(tbl, dict):
                continue

            headers = tbl.get("headers")
            rows = tbl.get("rows")

            # Critical structural check
            if not isinstance(headers, list) or not isinstance(rows, list):
                continue

            # Ensure headers are strings
            # If headers contain non-strings, convert to string
            clean_headers = [str(h) if h is not None else "" for h in headers]
            header_len = len(clean_headers)

            if header_len == 0:
                continue # Empty table headers, useless

            repaired_rows = []
            for r in rows:
                if not isinstance(r, list):
                    continue

                # Repair length mismatch
                current_len = len(r)
                if current_len > header_len:
                    # Truncate
                    new_row = r[:header_len]
                    # Ensure elements are strings
                    new_row = [str(c) if c is not None else "" for c in new_row]
                    repaired_rows.append(new_row)
                elif current_len < header_len:
                    # Pad
                    padding = [""] * (header_len - current_len)
                    new_row = r + padding
                    new_row = [str(c) if c is not None else "" for c in new_row]
                    repaired_rows.append(new_row)
                else:
                    # Exact match
                    new_row = [str(c) if c is not None else "" for c in r]
                    repaired_rows.append(new_row)

            # Only keep table if it has rows (or if empty tables are allowed?
            # Usually empty tables are useless for QA). Let's require at least one row
            # or just keep the structure if headers exist.
            # TAT-QA usually implies data extraction, so let's keep it if headers exist.
            # We update the table dict with repaired content
            tbl["headers"] = clean_headers
            tbl["rows"] = repaired_rows
            valid_tables.append(tbl)

        return valid_tables

    # Apply the repair function
    # We use a lambda wrapper to handle potential non-list inputs if any slipped through (defensive)
    df["tables"] = df["tables"].apply(lambda x: repair_single_row_tables(x) if isinstance(x, list) else [])

    return df

# -------------------------------------------------------------------------------------------------------------------------------
# Task 4, Step 3: Repair or drop malformed passage entries
# -------------------------------------------------------------------------------------------------------------------------------
def repair_malformed_passages(df: pd.DataFrame) -> pd.DataFrame:
    """
    Iterates through the 'passages' column to validate and repair passage structures.

    Repair Logic:
    - A passage is kept only if it is a dict containing 'text' which is a non-empty string.
    - 'passage_id' is preserved if present; otherwise generated or ignored based on downstream needs.
      (Here we assume it's required for citation, so we check for it).

    Args:
        df (pd.DataFrame): The DataFrame with repaired tables.

    Returns:
        pd.DataFrame: The DataFrame with repaired 'passages' structures.
    """

    def repair_single_row_passages(passages_list: Any) -> List[Dict[str, Any]]:
        if not isinstance(passages_list, list):
            return []

        valid_passages = []
        for psg in passages_list:
            if not isinstance(psg, dict):
                continue

            text = psg.get("text")
            # Check text validity
            if not isinstance(text, str) or not text.strip():
                continue

            # Ensure passage_id exists (if missing, maybe generate one?
            # For now, we skip if critical ID is missing to avoid alignment errors later)
            if "passage_id" not in psg:
                continue

            valid_passages.append(psg)

        return valid_passages

    df["passages"] = df["passages"].apply(repair_single_row_passages)
    return df

# -------------------------------------------------------------------------------------------------------------------------------
# Task 4, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def cleanse_tatqa_dataset(tatqa_raw_df: pd.DataFrame) -> Tuple[pd.DataFrame, CleansingLog]:
    """
    Orchestrates the cleansing pipeline for the TAT-QA dataset.

    Pipeline:
    1. Drop rows with critical missing fields (ID, Question, Tables).
    2. Repair internal table structures (headers/rows alignment).
    3. Repair passage structures (valid text).
    4. Final check: Drop rows that have NO valid tables AND NO valid passages (empty context).

    Args:
        tatqa_raw_df (pd.DataFrame): The raw input DataFrame.

    Returns:
        Tuple[pd.DataFrame, CleansingLog]: The cleansed DataFrame and a log of operations.
    """
    log = CleansingLog(initial_rows=len(tatqa_raw_df))
    logger.info(f"Starting TAT-QA cleansing. Initial rows: {len(tatqa_raw_df)}")

    # Step 1: Drop critical missing
    df = drop_critical_missing_rows(tatqa_raw_df, log)
    logger.info(f"Rows after dropping critical missing: {len(df)}")

    # Step 2: Repair tables
    df = repair_malformed_tables(df)

    # Step 3: Repair passages
    df = repair_malformed_passages(df)

    # Step 4: Final Content Check
    # We require at least one valid table OR one valid passage to form a prompt context.
    # If both lists are empty, the example is unusable.
    def has_valid_content(row: pd.Series) -> bool:
        has_tables = len(row["tables"]) > 0
        has_passages = len(row["passages"]) > 0
        return has_tables or has_passages

    valid_content_mask = df.apply(has_valid_content, axis=1)
    dropped_no_content = df[~valid_content_mask]["example_id"].tolist()

    log.dropped_ids.extend(dropped_no_content)
    log.rows_dropped_no_valid_content = len(dropped_no_content)

    df = df[valid_content_mask].copy()

    # Finalize log
    log.final_rows = len(df)
    logger.info(f"TAT-QA cleansing complete. Final rows: {len(df)}")

    # Reset index for clean usage, but keep original index for traceability
    df = df.reset_index(names="original_index")

    return df, log


In [None]:
# Task 5 – Cleanse and Handle Missing Entries in finqa_raw_df

@dataclass
class FinQACleansingLog:
    """
    Dataclass to track the results of the Fin-QA cleansing process.

    Attributes:
        initial_rows (int): Number of rows in the raw DataFrame.
        rows_dropped_missing_id (int): Rows dropped due to missing/null example_id.
        rows_dropped_duplicate_id (int): Rows dropped due to duplicate example_id within a split.
        rows_dropped_empty_question (int): Rows dropped due to empty/whitespace-only question_text.
        rows_dropped_invalid_tables (int): Rows dropped because 'tables' was not a list or was empty.
        rows_dropped_no_valid_content (int): Rows dropped because they had no valid tables after repair.
        rows_flagged_invalid_numeric (int): Rows flagged (not dropped) because their numeric answer could not be parsed.
        final_rows (int): Number of rows in the cleansed DataFrame.
        dropped_ids (List[str]): List of example_ids that were dropped.
    """
    initial_rows: int = 0
    rows_dropped_missing_id: int = 0
    rows_dropped_duplicate_id: int = 0
    rows_dropped_empty_question: int = 0
    rows_dropped_invalid_tables: int = 0
    rows_dropped_no_valid_content: int = 0
    rows_flagged_invalid_numeric: int = 0
    final_rows: int = 0
    dropped_ids: List[str] = field(default_factory=list)

# ==============================================================================
# Task 5: Cleanse and Handle Missing Entries in finqa_raw_df
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 5, Step 1: Drop rows with critical missing fields
# -------------------------------------------------------------------------------------------------------------------------------
def drop_critical_missing_rows_finqa(df: pd.DataFrame, log: FinQACleansingLog) -> pd.DataFrame:
    """
    Drops rows from the Fin-QA DataFrame that are missing critical identifying information
    or essential content fields required for downstream processing.

    Criteria for dropping:
    1. `example_id` is Null or NaN.
    2. `example_id` is a duplicate within its `split`.
    3. `question_text` is Null, empty, or whitespace-only.
    4. `tables` is Null, not a list, or an empty list.

    Args:
        df (pd.DataFrame): The raw Fin-QA DataFrame.
        log (FinQACleansingLog): The logging object to update with drop statistics.

    Returns:
        pd.DataFrame: A filtered DataFrame with critical issues resolved.
    """
    # 1. Drop missing example_id
    missing_id_mask = df["example_id"].isnull()
    log.rows_dropped_missing_id = missing_id_mask.sum()
    df = df[~missing_id_mask].copy()

    # 2. Drop duplicate example_id within split
    # We assume 'split' exists and is valid (validated in Task 2).
    duplicate_mask = df.duplicated(subset=["split", "example_id"], keep="first")

    dropped_dups = df[duplicate_mask]["example_id"].tolist()
    log.dropped_ids.extend(dropped_dups)
    log.rows_dropped_duplicate_id = len(dropped_dups)

    df = df[~duplicate_mask].copy()

    # 3. Drop empty/whitespace question_text
    def is_valid_question(q: Any) -> bool:
        if not isinstance(q, str):
            return False
        return len(q.strip()) > 0

    valid_question_mask = df["question_text"].apply(is_valid_question)
    dropped_questions = df[~valid_question_mask]["example_id"].tolist()
    log.dropped_ids.extend(dropped_questions)
    log.rows_dropped_empty_question = len(dropped_questions)

    df = df[valid_question_mask].copy()

    # 4. Drop invalid 'tables' field
    # Must be a non-empty list
    def is_valid_tables_field(t: Any) -> bool:
        if not isinstance(t, list):
            return False
        return len(t) > 0

    valid_tables_mask = df["tables"].apply(is_valid_tables_field)
    dropped_tables = df[~valid_tables_mask]["example_id"].tolist()
    log.dropped_ids.extend(dropped_tables)
    log.rows_dropped_invalid_tables = len(dropped_tables)

    df = df[valid_tables_mask].copy()

    return df

# -------------------------------------------------------------------------------------------------------------------------------
# Task 5, Step 2: Validate and repair table structures
# -------------------------------------------------------------------------------------------------------------------------------
def repair_malformed_tables_finqa(df: pd.DataFrame) -> pd.DataFrame:
    """
    Iterates through the 'tables' column of each row to validate and repair individual table structures.

    Repair Logic:
    - A table is discarded if it lacks 'headers' (list of strings) or 'rows' (list of lists).
    - Rows within a table are repaired:
        - If a row is longer than headers, it is truncated.
        - If a row is shorter than headers, it is padded with empty strings.
        - If a row is not a list, it is discarded.
    - If a table ends up with no valid rows, it is discarded.

    Args:
        df (pd.DataFrame): The DataFrame filtered from Step 1.

    Returns:
        pd.DataFrame: The DataFrame with repaired 'tables' structures.
    """

    def repair_single_row_tables(tables_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        valid_tables = []

        for tbl in tables_list:
            if not isinstance(tbl, dict):
                continue

            headers = tbl.get("headers")
            rows = tbl.get("rows")

            # Critical structural check
            if not isinstance(headers, list) or not isinstance(rows, list):
                continue

            # Ensure headers are strings
            clean_headers = [str(h) if h is not None else "" for h in headers]
            header_len = len(clean_headers)

            if header_len == 0:
                continue

            repaired_rows = []
            for r in rows:
                if not isinstance(r, list):
                    continue

                # Repair length mismatch
                current_len = len(r)
                if current_len > header_len:
                    # Truncate
                    new_row = r[:header_len]
                    new_row = [str(c) if c is not None else "" for c in new_row]
                    repaired_rows.append(new_row)
                elif current_len < header_len:
                    # Pad
                    padding = [""] * (header_len - current_len)
                    new_row = r + padding
                    new_row = [str(c) if c is not None else "" for c in new_row]
                    repaired_rows.append(new_row)
                else:
                    # Exact match
                    new_row = [str(c) if c is not None else "" for c in r]
                    repaired_rows.append(new_row)

            # Keep table if headers exist and structure is valid.
            # Fin-QA tables are critical for numeric reasoning.
            tbl["headers"] = clean_headers
            tbl["rows"] = repaired_rows
            valid_tables.append(tbl)

        return valid_tables

    # Apply the repair function
    df["tables"] = df["tables"].apply(lambda x: repair_single_row_tables(x) if isinstance(x, list) else [])

    return df

# -------------------------------------------------------------------------------------------------------------------------------
# Task 5, Step 3: Validate answer values for numeric parsing
# -------------------------------------------------------------------------------------------------------------------------------
def validate_numeric_answers_finqa(df: pd.DataFrame, log: FinQACleansingLog) -> pd.DataFrame:
    """
    Validates that rows with 'answer_type' == 'number' have 'answer_value' fields
    that can be parsed into floats.

    This step does NOT drop rows. Instead, it adds a flag column `is_valid_numeric_answer`.
    This allows downstream evaluation to exclude invalid rows from numeric accuracy calculations
    while preserving the data for other purposes (e.g., text generation training).

    Parsing Logic:
    - Remove commas (',').
    - Remove trailing percent signs ('%').
    - Attempt `float()` conversion.

    Args:
        df (pd.DataFrame): The DataFrame with repaired tables.
        log (FinQACleansingLog): Logging object to record invalid numeric answers.

    Returns:
        pd.DataFrame: The DataFrame with an additional 'is_valid_numeric_answer' boolean column.
    """

    def check_numeric_validity(row: pd.Series) -> bool:
        # If not a number type, we assume it's valid (textual) or irrelevant for numeric checks
        # However, Fin-QA is primarily numeric.
        if row.get("answer_type") != "number":
            return True

        val = row.get("answer_value")
        if not isinstance(val, str):
            return False

        # Clean string
        clean_val = val.replace(',', '').strip()
        if clean_val.endswith('%'):
            clean_val = clean_val[:-1]

        try:
            float(clean_val)
            return True
        except ValueError:
            return False

    # Apply validation
    df["is_valid_numeric_answer"] = df.apply(check_numeric_validity, axis=1)

    # Log statistics
    # We only count rows where answer_type IS 'number' but validity is False
    invalid_numeric_count = len(df[(df["answer_type"] == "number") & (~df["is_valid_numeric_answer"])])
    log.rows_flagged_invalid_numeric = invalid_numeric_count

    if invalid_numeric_count > 0:
        logger.warning(f"Flagged {invalid_numeric_count} rows with invalid numeric answers in Fin-QA.")

    return df

# -------------------------------------------------------------------------------------------------------------------------------
# Task 5, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def cleanse_finqa_dataset(finqa_raw_df: pd.DataFrame) -> Tuple[pd.DataFrame, FinQACleansingLog]:
    """
    Orchestrates the cleansing pipeline for the Fin-QA dataset.

    Pipeline:
    1. Drop rows with critical missing fields (ID, Question, Tables).
    2. Repair internal table structures (headers/rows alignment).
    3. Final check: Drop rows that have NO valid tables (Fin-QA relies heavily on tables).
    4. Validate numeric answer parseability and flag invalid rows.

    Args:
        finqa_raw_df (pd.DataFrame): The raw input DataFrame.

    Returns:
        Tuple[pd.DataFrame, FinQACleansingLog]: The cleansed DataFrame and a log of operations.
    """
    log = FinQACleansingLog(initial_rows=len(finqa_raw_df))
    logger.info(f"Starting Fin-QA cleansing. Initial rows: {len(finqa_raw_df)}")

    # Step 1: Drop critical missing
    df = drop_critical_missing_rows_finqa(finqa_raw_df, log)
    logger.info(f"Rows after dropping critical missing: {len(df)}")

    # Step 2: Repair tables
    df = repair_malformed_tables_finqa(df)

    # Step 3: Final Content Check (Tables are mandatory for Fin-QA)
    def has_valid_tables(row: pd.Series) -> bool:
        return isinstance(row["tables"], list) and len(row["tables"]) > 0

    valid_content_mask = df.apply(has_valid_tables, axis=1)
    dropped_no_content = df[~valid_content_mask]["example_id"].tolist()

    log.dropped_ids.extend(dropped_no_content)
    log.rows_dropped_no_valid_content = len(dropped_no_content)

    df = df[valid_content_mask].copy()

    # Step 4: Validate Numeric Answers (Flagging only)
    df = validate_numeric_answers_finqa(df, log)

    # Finalize log
    log.final_rows = len(df)
    logger.info(f"Fin-QA cleansing complete. Final rows: {len(df)}")

    # Reset index for clean usage, but keep original index for traceability
    df = df.reset_index(names="original_index")

    return df, log


In [None]:
# Task 6 – Normalize Numeric Columns and Answer Representations

# ==============================================================================
# Task 6: Normalize Numeric Columns and Answer Representations
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 6, Helper Class: FinancialTextNormalizer
# -------------------------------------------------------------------------------------------------------------------------------
class FinancialTextNormalizer:
    """
    spaCy-based text normalization for financial documents.

    This class provides robust methods for cleaning and normalizing financial text,
    specifically targeting currency symbols, unit multipliers (e.g., 'M', 'B'),
    percentages, and numeric value extraction. It leverages regex for pattern matching
    and spaCy for potential linguistic context (though primarily regex-driven for precision).

    Attributes:
        nlp (spacy.language.Language): The loaded spaCy language model.
        currency_pattern (re.Pattern): Compiled regex for currency symbols.
        unit_pattern (re.Pattern): Compiled regex for unit multipliers.
        percent_pattern (re.Pattern): Compiled regex for percentage signs.
        parentheses_negative_pattern (re.Pattern): Compiled regex for accounting-style negatives.
    """

    # Comprehensive currency symbol registry covering major global currencies
    CURRENCY_SYMBOLS: Set[str] = {
        '$', '€', '£', '¥', '₹', '₽', '₩', '₪', '₦', '₱', '₡', '₵', '₲', '₴', '₸',
        'USD', 'EUR', 'GBP', 'JPY', 'CNY', 'INR', 'CAD', 'AUD', 'CHF', 'SEK', 'NZD',
        'MXN', 'SGD', 'HKD', 'NOK', 'KRW', 'TRY', 'RUB', 'BRL', 'ZAR', 'PLN', 'THB',
        'IDR', 'HUF', 'CZK', 'ILS', 'CLP', 'PHP', 'AED', 'SAR', 'MYR', 'RON'
    }

    # Unit multipliers mapping textual representations to their numeric scale factors
    UNIT_PATTERNS: Dict[str, float] = {
        'million': 1e6, 'millions': 1e6,
        'billion': 1e9, 'billions': 1e9,
        'trillion': 1e12, 'trillions': 1e12,
        'thousand': 1e3, 'thousands': 1e3,
        'M': 1e6, 'B': 1e9, 'T': 1e12, 'K': 1e3, 'k': 1e3
    }

    def __init__(self, spacy_model: str = "en_core_web_sm"):
        """
        Initialize the normalizer with a spaCy model.

        Args:
            spacy_model (str): The name of the spaCy model to load (default: "en_core_web_sm").
        """
        try:
            self.nlp = spacy.load(spacy_model)
        except OSError:
            logger.warning(f"spaCy model '{spacy_model}' not found. Downloading...")
            from spacy.cli import download
            download(spacy_model)
            self.nlp = spacy.load(spacy_model)

        self._compile_patterns()

    def _compile_patterns(self) -> None:
        """
        Compile regex patterns for currency, units, and numeric values.

        This method pre-compiles regular expressions to optimize performance during
        repeated normalization calls.
        """
        # Escape symbols for regex safety
        currency_symbols_escaped = [re.escape(sym) for sym in self.CURRENCY_SYMBOLS]
        # Pattern matches currency symbols at word boundaries or followed by whitespace
        self.currency_pattern = re.compile(
            r'\b(?:' + '|'.join(currency_symbols_escaped) + r')\s*',
            re.IGNORECASE
        )

        # Pattern matches unit names as whole words
        unit_names = '|'.join(re.escape(unit) for unit in self.UNIT_PATTERNS.keys())
        self.unit_pattern = re.compile(
            r'\b(' + unit_names + r')\b',
            re.IGNORECASE
        )

        self.percent_pattern = re.compile(r'%')
        # Pattern matches numbers enclosed in parentheses, e.g., (1,000.00), typical in accounting
        self.parentheses_negative_pattern = re.compile(r'\(([0-9,\.]+)\)')

    def remove_currency_symbols(self, text: str) -> str:
        """
        Remove currency symbols from text using regex.

        Args:
            text (str): The input text containing potential currency symbols.

        Returns:
            str: The text with currency symbols removed and whitespace normalized.
        """
        # Note: We rely on regex here as it's more robust for symbols than NER alone for removal
        cleaned_text = self.currency_pattern.sub('', text)
        return ' '.join(cleaned_text.split())

    def _expand_units(self, text: str) -> str:
        """
        Expand unit multipliers (e.g., "45.2M" -> "45200000").

        This method identifies numeric values followed by unit suffixes and replaces them
        with the full numeric representation.

        Args:
            text (str): The input text containing numbers with units.

        Returns:
            str: The text with units expanded.
        """
        def replace_unit(match: re.Match) -> str:
            numeric_str = match.group(1)
            unit_str = match.group(2)
            try:
                # Remove commas before parsing float
                value = float(numeric_str.replace(',', ''))
                # Retrieve multiplier, defaulting to 1 if not found (case-insensitive fallback)
                multiplier = self.UNIT_PATTERNS.get(unit_str, self.UNIT_PATTERNS.get(unit_str.lower(), 1))
                expanded_value = value * multiplier
                # Return integer string if it's a whole number to avoid .0 artifacts
                return str(int(expanded_value) if expanded_value.is_integer() else expanded_value)
            except ValueError:
                # Return original match if parsing fails
                return match.group(0)

        # Regex to capture number (group 1) and unit (group 2)
        pattern = re.compile(
            r'([+-]?(?:\d{1,3}(?:,\d{3})*|\d+)(?:\.\d+)?)\s*(' +
            '|'.join(re.escape(u) for u in self.UNIT_PATTERNS.keys()) + r')\b',
            re.IGNORECASE
        )
        return pattern.sub(replace_unit, text)

    def normalize_numeric_text(
        self,
        text: str,
        remove_currency: bool = True,
        parse_units: bool = True,
        parse_percentages: bool = True
    ) -> str:
        """
        Normalize financial text by removing currency symbols and standardizing units.

        Args:
            text (str): The raw input text.
            remove_currency (bool): If True, strips currency symbols.
            parse_units (bool): If True, expands 'M', 'B', 'million', etc.
            parse_percentages (bool): If True, removes '%' signs.

        Returns:
            str: The normalized text string ready for numeric parsing.
        """
        if remove_currency:
            text = self.remove_currency_symbols(text)

        # Convert accounting negative format (100) to -100
        text = self.parentheses_negative_pattern.sub(r'-\1', text)

        if parse_units:
            text = self._expand_units(text)

        if parse_percentages:
            text = self.percent_pattern.sub('', text)

        # Remove commas from numbers (e.g., 1,000 -> 1000)
        text = re.sub(r'(\d),(\d)', r'\1\2', text)

        return text.strip()

    def parse_to_float(
        self,
        text: str,
        remove_currency: bool = True,
        parse_units: bool = True
    ) -> Optional[float]:
        """
        Parse a text string to a float value.

        This method orchestrates the normalization steps and attempts to convert the
        cleaned string into a float. It handles percentages by stripping the sign
        but preserving the magnitude (e.g., "15%" -> 15.0).

        Args:
            text (str): Input text containing a numeric value.
            remove_currency (bool): Whether to remove currency symbols.
            parse_units (bool): Whether to expand unit multipliers.

        Returns:
            Optional[float]: Parsed numeric value, or None if parsing fails.
        """
        if not isinstance(text, str):
            return None

        # Normalize text, keeping '%' temporarily to detect percentage context if needed
        normalized = self.normalize_numeric_text(
            text,
            remove_currency=remove_currency,
            parse_units=parse_units,
            parse_percentages=False # Keep % to detect it, but we handle value below
        )

        # Check if percentage sign exists before stripping
        is_percentage = '%' in text

        # Strip everything except digits, signs, and decimal points
        clean_normalized = re.sub(r'[^\d+\-\.]', '', normalized)

        try:
            value = float(clean_normalized)
            # Note: For TAT-QA/Fin-QA, percentages are often kept as-is (15 not 0.15)
            # We will follow the dataset convention of keeping the magnitude.
            return value
        except ValueError:
            return None

# -------------------------------------------------------------------------------------------------------------------------------
# Task 6, Step 1: Identify numeric columns in tables
# -------------------------------------------------------------------------------------------------------------------------------
def identify_numeric_columns(
    df: pd.DataFrame,
    dataset_name: str,
    normalizer: FinancialTextNormalizer,
    threshold: float = 0.9
) -> Dict[Tuple[str, str, str, int], bool]:
    """
    Identifies numeric columns in all tables within the DataFrame.

    Iterates through every table in every row of the dataset. For each column, it attempts
    to parse all non-empty cells as floats. If the ratio of successfully parsed cells to
    total non-empty cells meets or exceeds the threshold, the column is marked as numeric.

    Args:
        df (pd.DataFrame): The DataFrame containing tables.
        dataset_name (str): Name of the dataset ("TAT-QA" or "Fin-QA").
        normalizer (FinancialTextNormalizer): Initialized normalizer instance.
        threshold (float): Fraction of parseable cells required to mark a column as numeric (default: 0.9).

    Returns:
        Dict[Tuple[str, str, str, int], bool]: Metadata mapping
        (dataset_name, example_id, table_id, column_index) -> is_numeric.
    """
    numeric_metadata: Dict[Tuple[str, str, str, int], bool] = {}

    for idx, row in df.iterrows():
        example_id = row["example_id"]
        tables = row["tables"]

        if not isinstance(tables, list):
            continue

        for table in tables:
            if not isinstance(table, dict):
                continue

            table_id = table.get("table_id", "unknown")
            headers = table.get("headers", [])
            rows = table.get("rows", [])

            if not headers or not rows:
                continue

            num_cols = len(headers)

            for col_idx in range(num_cols):
                # Extract cells for this column
                cells = []
                for r in rows:
                    if isinstance(r, list) and len(r) > col_idx:
                        cells.append(str(r[col_idx]))

                # Filter empty cells to compute valid ratio
                non_empty_cells = [c for c in cells if c.strip()]

                if not non_empty_cells:
                    # Empty columns are not numeric
                    numeric_metadata[(dataset_name, example_id, table_id, col_idx)] = False
                    continue

                # Check parseability of each cell
                parseable_count = 0
                for cell in non_empty_cells:
                    if normalizer.parse_to_float(cell) is not None:
                        parseable_count += 1

                # Determine if column is numeric based on threshold
                ratio = parseable_count / len(non_empty_cells)
                is_numeric = ratio >= threshold

                numeric_metadata[(dataset_name, example_id, table_id, col_idx)] = is_numeric

    return numeric_metadata

# -------------------------------------------------------------------------------------------------------------------------------
# Task 6, Step 2: Parse answer values to numeric types
# -------------------------------------------------------------------------------------------------------------------------------
def parse_answers(
    df: pd.DataFrame,
    normalizer: FinancialTextNormalizer,
    numeric_types: Set[str]
) -> pd.DataFrame:
    """
    Parses 'answer_value' into a new 'answer_numeric' column for rows with numeric answer types.

    This function iterates through the DataFrame and attempts to convert the 'answer_value'
    to a float if the 'answer_type' matches one of the specified numeric types.
    If parsing fails or the type is non-numeric, 'answer_numeric' is set to NaN.

    Args:
        df (pd.DataFrame): The DataFrame containing answer data.
        normalizer (FinancialTextNormalizer): Initialized normalizer instance.
        numeric_types (Set[str]): Set of answer_type strings that indicate a numeric answer
                                  (e.g., {"number", "span_number"}).

    Returns:
        pd.DataFrame: The DataFrame with an added 'answer_numeric' column containing floats or NaN.
    """

    def parse_row_answer(row: pd.Series) -> float:
        """Helper to parse a single row's answer."""
        ans_type = row.get("answer_type")
        ans_val = row.get("answer_value")

        # Skip non-numeric types
        if ans_type not in numeric_types:
            return np.nan

        # Ensure value is a string before parsing
        if not isinstance(ans_val, str):
            return np.nan

        # Attempt parsing; return NaN on failure
        parsed = normalizer.parse_to_float(ans_val)
        return parsed if parsed is not None else np.nan

    # Apply parsing logic row-wise
    df["answer_numeric"] = df.apply(parse_row_answer, axis=1)
    return df

# -------------------------------------------------------------------------------------------------------------------------------
# Task 6, Step 3: Store cleansed DataFrames and metadata
# -------------------------------------------------------------------------------------------------------------------------------
def save_artifacts(
    tatqa_df: pd.DataFrame,
    finqa_df: pd.DataFrame,
    metadata: Dict[str, Any],
    output_dir: str = "processed_data"
):
    """
    Saves the processed DataFrames and metadata to disk.

    Args:
        tatqa_df (pd.DataFrame): Processed TAT-QA DataFrame.
        finqa_df (pd.DataFrame): Processed Fin-QA DataFrame.
        metadata (Dict[str, Any]): Numeric column metadata.
        output_dir (str): Directory to save files.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Save DataFrames
    tatqa_df.to_pickle(os.path.join(output_dir, "tatqa_normalized.pkl"))
    finqa_df.to_pickle(os.path.join(output_dir, "finqa_normalized.pkl"))

    # Save Metadata
    # Convert tuple keys to strings for JSON
    json_safe_metadata = {}
    for k, v in metadata.items():
        # k is (dataset, example_id, table_id, col_idx)
        key_str = f"{k[0]}|{k[1]}|{k[2]}|{k[3]}"
        json_safe_metadata[key_str] = v

    with open(os.path.join(output_dir, "numeric_metadata.json"), "w") as f:
        json.dump(json_safe_metadata, f, indent=2)

    logger.info(f"Artifacts saved to {output_dir}")

# -------------------------------------------------------------------------------------------------------------------------------
# Task 6, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def normalize_data_task(
    tatqa_df: pd.DataFrame,
    finqa_df: pd.DataFrame,
    output_dir: str = "processed_data"
) -> Tuple[pd.DataFrame, pd.DataFrame, Dict]:
    """
    Orchestrates Task 6: Normalizing numeric columns and answers.

    Args:
        tatqa_df (pd.DataFrame): Cleansed TAT-QA DataFrame.
        finqa_df (pd.DataFrame): Cleansed Fin-QA DataFrame.
        output_dir (str): Output directory.

    Returns:
        Tuple[pd.DataFrame, pd.DataFrame, Dict]: Normalized DataFrames and metadata.
    """
    logger.info("Starting Task 6: Normalization...")

    normalizer = FinancialTextNormalizer()

    # Step 1: Identify numeric columns
    logger.info("Identifying numeric columns...")
    tatqa_meta = identify_numeric_columns(tatqa_df, "TAT-QA", normalizer)
    finqa_meta = identify_numeric_columns(finqa_df, "Fin-QA", normalizer)

    # Merge metadata
    full_metadata = {**tatqa_meta, **finqa_meta}

    # Step 2: Parse answers
    logger.info("Parsing answers...")
    # TAT-QA numeric types
    tatqa_numeric_types = {"span_number", "arithmetic", "count"}
    tatqa_df = parse_answers(tatqa_df, normalizer, tatqa_numeric_types)

    # Fin-QA numeric types
    finqa_numeric_types = {"number"}
    finqa_df = parse_answers(finqa_df, normalizer, finqa_numeric_types)

    # Step 3: Save
    save_artifacts(tatqa_df, finqa_df, full_metadata, output_dir)

    logger.info("Task 6 complete.")
    return tatqa_df, finqa_df, full_metadata


In [None]:
# Task 7 – Collect and Normalize Offline Corpus Documents

# ==============================================================================
# Task 7: Collect and Normalize Offline Corpus Documents
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 7, Helper: LaTeX Cleaning Function
# -------------------------------------------------------------------------------------------------------------------------------
def clean_latex(text: str) -> str:
    """
    Removes LaTeX markup from text while preserving the content.

    This function applies a series of regex substitutions to strip common LaTeX
    commands, environments, and delimiters, leaving behind the natural language text.

    Args:
        text (str): Raw text containing LaTeX markup.

    Returns:
        str: Cleaned text.
    """
    if not text:
        return ""

    # Remove comments
    text = re.sub(r'%.*', '', text)

    # Remove commands but keep content: \textbf{text} -> text
    # This is a simplification; nested braces are hard for regex, but sufficient for corpus stats
    text = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', text)

    # Remove standalone commands: \noindent, \newpage
    text = re.sub(r'\\[a-zA-Z]+', '', text)

    # Remove math delimiters: $...$, \[...\], \(...\)
    text = re.sub(r'\$[^$]*\$', '', text) # Inline math
    text = re.sub(r'\\\[.*?\\\]', '', text, flags=re.DOTALL) # Display math
    text = re.sub(r'\\\(.*?\\\)', '', text, flags=re.DOTALL) # Inline math alt

    # Collapse whitespace
    text = re.sub(r'\s+', ' ', text).strip()

    return text

# -------------------------------------------------------------------------------------------------------------------------------
# Task 7, Step 1: Collect Wikipedia documents
# -------------------------------------------------------------------------------------------------------------------------------
def collect_wikipedia_documents(
    file_path: str,
    max_docs: Optional[int] = None
) -> Generator[Dict[str, str], None, None]:
    """
    Ingests Wikipedia documents from a JSONL file (e.g., output of WikiExtractor).

    Each line in the input file is expected to be a JSON object with 'id', 'url', 'title', and 'text'.

    Args:
        file_path (str): Path to the Wikipedia dump file (can be .gz).
        max_docs (Optional[int]): Maximum number of documents to yield.

    Yields:
        Dict[str, str]: A dictionary conforming to the corpus_document_schema.
    """
    logger.info(f"Collecting Wikipedia documents from {file_path}...")
    count = 0

    # Handle gzip or plain text
    open_func = gzip.open if file_path.endswith('.gz') else open

    try:
        with open_func(file_path, 'rt', encoding='utf-8') as f:
            for line in f:
                if max_docs and count >= max_docs:
                    break

                try:
                    data = json.loads(line)
                    # WikiExtractor format usually has 'id', 'url', 'title', 'text'
                    doc_id = f"wiki_en_{data.get('id', count)}"
                    text = data.get('text', '')
                    title = data.get('title', '')

                    # Basic cleaning if text contains HTML/XML artifacts not handled by extractor
                    # (Assuming WikiExtractor did most of the work, we just trim)
                    text = text.strip()

                    if text:
                        yield {
                            "doc_id": doc_id,
                            "source": "wikipedia",
                            "title": title,
                            "text": text
                        }
                        count += 1
                except json.JSONDecodeError:
                    continue
    except FileNotFoundError:
        logger.warning(f"Wikipedia file {file_path} not found. Skipping.")

# -------------------------------------------------------------------------------------------------------------------------------
# Task 7, Step 2: Collect ShareGPT documents
# -------------------------------------------------------------------------------------------------------------------------------
def collect_sharegpt_documents(
    file_path: str,
    max_docs: Optional[int] = None
) -> Generator[Dict[str, str], None, None]:
    """
    Ingests ShareGPT conversation logs from a JSON file.

    The input file is expected to be a list of conversation objects.
    Each conversation has an 'id' and a 'conversations' list of turns.

    Args:
        file_path (str): Path to the ShareGPT JSON file.
        max_docs (Optional[int]): Maximum number of conversations to yield.

    Yields:
        Dict[str, str]: A dictionary conforming to the corpus_document_schema.
    """
    logger.info(f"Collecting ShareGPT documents from {file_path}...")
    count = 0

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            # Load the entire list (ShareGPT dumps are usually a single JSON list)
            # For extremely large files, ijson would be better, but standard json is used here for simplicity
            data = json.load(f)

            for conv in data:
                if max_docs and count >= max_docs:
                    break

                original_id = conv.get('id', str(count))
                doc_id = f"sharegpt_{original_id}"

                # Concatenate turns
                turns = conv.get('conversations', [])
                text_parts = []
                for turn in turns:
                    role = turn.get('from', 'unknown')
                    content = turn.get('value', '')
                    # Format: "human: ... \n gpt: ..."
                    text_parts.append(f"{role}: {content}")

                full_text = "\n".join(text_parts)

                if full_text.strip():
                    yield {
                        "doc_id": doc_id,
                        "source": "sharegpt",
                        "title": f"ShareGPT Conversation {original_id}",
                        "text": full_text
                    }
                    count += 1
    except FileNotFoundError:
        logger.warning(f"ShareGPT file {file_path} not found. Skipping.")
    except json.JSONDecodeError:
        logger.error(f"Failed to decode JSON from {file_path}.")

# -------------------------------------------------------------------------------------------------------------------------------
# Task 7, Step 3: Collect arXiv documents
# -------------------------------------------------------------------------------------------------------------------------------
def collect_arxiv_documents(
    file_path: str,
    max_docs: Optional[int] = None
) -> Generator[Dict[str, str], None, None]:
    """
    Ingests arXiv metadata/abstracts from a JSONL file.

    Each line is expected to be a JSON object with 'id', 'title', 'abstract'.

    Args:
        file_path (str): Path to the arXiv JSONL file.
        max_docs (Optional[int]): Maximum number of documents to yield.

    Yields:
        Dict[str, str]: A dictionary conforming to the corpus_document_schema.
    """
    logger.info(f"Collecting arXiv documents from {file_path}...")
    count = 0

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if max_docs and count >= max_docs:
                    break

                try:
                    data = json.loads(line)
                    original_id = data.get('id', str(count))
                    doc_id = f"arxiv_{original_id}"
                    title = data.get('title', '')
                    abstract = data.get('abstract', '')

                    # Clean LaTeX from abstract
                    clean_abstract = clean_latex(abstract)
                    clean_title = clean_latex(title)

                    # Combine title and abstract
                    full_text = f"{clean_title}\n{clean_abstract}"

                    if full_text.strip():
                        yield {
                            "doc_id": doc_id,
                            "source": "arxiv",
                            "title": clean_title,
                            "text": full_text
                        }
                        count += 1
                except json.JSONDecodeError:
                    continue
    except FileNotFoundError:
        logger.warning(f"arXiv file {file_path} not found. Skipping.")

# -------------------------------------------------------------------------------------------------------------------------------
# Task 7, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def build_offline_corpus(
    config: Dict[str, Any],
    output_path: str = "offline_corpus.jsonl"
) -> str:
    """
    Orchestrates the collection of documents from multiple sources (Wikipedia, ShareGPT, arXiv)
    and normalizes them into a unified JSONL corpus file.

    Args:
        config (Dict[str, Any]): The study configuration dictionary containing paths and limits.
                                 Expected keys: 'offline_corpus_config' -> 'sources'.
                                 We assume 'sources' list contains dicts with 'name' and 'path'.
        output_path (str): The file path to save the consolidated corpus.

    Returns:
        str: The path to the generated corpus file.
    """
    logger.info("Starting offline corpus construction...")

    # Extract source configurations (simulated paths if not in config, for robustness)
    # In a real scenario, these paths would be in the config.
    # Here we default to placeholders or expect them to be injected.
    sources_config = config.get("offline_corpus_config", {}).get("sources", [])

    # Map source names to collector functions
    collectors = {
        "wikipedia": collect_wikipedia_documents,
        "sharegpt": collect_sharegpt_documents,
        "arxiv": collect_arxiv_documents
    }

    total_docs = 0

    with open(output_path, 'w', encoding='utf-8') as out_f:
        for source_def in sources_config:
            name = source_def.get("name")
            # We assume the config has been updated with actual file paths in a real run.
            # For this implementation, we look for a 'path' key, or default to a local file.
            path = source_def.get("path", f"{name}_dump.jsonl")
            limit = source_def.get("max_docs", None) # Optional limit for testing

            if name in collectors:
                collector_func = collectors[name]
                logger.info(f"Processing source: {name}")

                for doc in collector_func(path, max_docs=limit):
                    # Write to JSONL
                    out_f.write(json.dumps(doc) + "\n")
                    total_docs += 1
            else:
                logger.warning(f"Unknown source type: {name}")

    logger.info(f"Corpus construction complete. Total documents: {total_docs}")
    logger.info(f"Corpus saved to: {output_path}")

    return output_path


In [None]:
# Task 8 – Tokenize Offline Corpus and Compute Token Counts

# ==============================================================================
# Task 8: Tokenize Offline Corpus and Compute Token Counts
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 8, Step 1 & 2: Tokenize and Count (Combined for Efficiency)
# -------------------------------------------------------------------------------------------------------------------------------
def tokenize_and_count_corpus(
    corpus_path: str,
    encoding_name: str = "cl100k_base"
) -> Tuple[CounterType[int], int]:
    """
    Streams documents from the corpus file, tokenizes them using either tiktoken or
    a HuggingFace tokenizer, and computes global token frequencies.

    This function supports dynamic tokenizer selection:
    1. Checks if 'encoding_name' is a valid tiktoken encoding.
    2. If not, attempts to load it as a HuggingFace tokenizer ID (e.g., "meta-llama/Llama-2-7b").
    3. Raises ValueError if neither method succeeds.

    This flexibility ensures alignment between the offline corpus statistics and the
    scorer LLM's tokenization, which is critical for accurate self-information computation.

    Args:
        corpus_path (str): Path to the JSONL corpus file.
        encoding_name (str): The tokenizer identifier. Can be a tiktoken encoding name
                             (e.g., "cl100k_base") or a HuggingFace model ID.

    Returns:
        Tuple[Counter[int], int]: A tuple containing:
            - A Counter mapping token IDs (int) to their frequency (int).
            - The total number of tokens (N) observed in the corpus.

    Raises:
        ValueError: If the tokenizer cannot be loaded.
        FileNotFoundError: If the corpus file does not exist.
    """
    tokenizer = None
    is_tiktoken = False

    # 1. Attempt to load tiktoken encoding
    if tiktoken is not None:
        try:
            tokenizer = tiktoken.get_encoding(encoding_name)
            is_tiktoken = True
            logger.info(f"Loaded tiktoken encoding: {encoding_name}")
        except Exception:
            pass

    # 2. If not tiktoken, attempt HuggingFace AutoTokenizer
    if tokenizer is None:
        if AutoTokenizer is not None:
            try:
                # use_fast=True is recommended for speed
                tokenizer = AutoTokenizer.from_pretrained(encoding_name, use_fast=True)
                is_tiktoken = False
                logger.info(f"Loaded HuggingFace tokenizer: {encoding_name}")
            except Exception as e:
                logger.debug(f"Failed to load HF tokenizer '{encoding_name}': {e}")
        else:
            logger.warning("transformers library not installed; skipping HF tokenizer check.")

    # 3. If both fail, raise error
    if tokenizer is None:
        raise ValueError(
            f"Could not load tokenizer for '{encoding_name}'. "
            "Ensure it is a valid tiktoken encoding or HuggingFace model ID, "
            "and that necessary libraries (tiktoken, transformers) are installed."
        )

    token_counts: CounterType[int] = Counter()
    total_tokens = 0
    doc_count = 0

    logger.info(f"Processing corpus from {corpus_path}...")

    if not os.path.exists(corpus_path):
        raise FileNotFoundError(f"Corpus file not found: {corpus_path}")

    try:
        with open(corpus_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    doc = json.loads(line)
                    text = doc.get("text", "")

                    if not text:
                        continue

                    # Normalize whitespace (simple collapse) to match preprocessing expectations
                    text = " ".join(text.split())

                    # Tokenize based on the loaded backend
                    if is_tiktoken:
                        # tiktoken encode
                        tokens = tokenizer.encode(text, disallowed_special=())
                    else:
                        # HuggingFace encode
                        # add_special_tokens=False to avoid BOS/EOS bias in frequency counts
                        tokens = tokenizer.encode(text, add_special_tokens=False)

                    # Update counts
                    token_counts.update(tokens)
                    total_tokens += len(tokens)
                    doc_count += 1

                    if doc_count % 10000 == 0:
                        logger.info(f"Processed {doc_count} documents. Total tokens so far: {total_tokens}")

                except json.JSONDecodeError:
                    logger.warning("Skipping malformed JSON line in corpus.")
                    continue

    except Exception as e:
        logger.error(f"Error processing corpus file: {e}")
        raise

    logger.info(f"Tokenization complete. Processed {doc_count} documents.")
    logger.info(f"Total unique tokens: {len(token_counts)}")
    logger.info(f"Total token occurrences (N): {total_tokens}")

    return token_counts, total_tokens


# -------------------------------------------------------------------------------------------------------------------------------
# Task 8, Step 3: Persist raw token statistics
# -------------------------------------------------------------------------------------------------------------------------------
def save_token_statistics(
    token_counts: CounterType[int],
    total_tokens: int,
    output_dir: str,
    encoding_name: str
) -> str:
    """
    Saves the token counts and metadata to a JSON file.

    Since JSON keys must be strings, token IDs are converted to strings in the output.

    Args:
        token_counts (Counter[int]): The token frequency mapping.
        total_tokens (int): The total count N.
        output_dir (str): Directory to save the statistics file.
        encoding_name (str): Name of the tokenizer encoding used.

    Returns:
        str: Path to the saved statistics file.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    output_path = os.path.join(output_dir, "token_stats.json")

    # Prepare data structure
    # Convert integer keys to strings for JSON serialization
    stats_data = {
        "metadata": {
            "encoding": encoding_name,
            "total_tokens_N": total_tokens,
            "unique_tokens": len(token_counts),
            "timestamp": pd.Timestamp.now().isoformat()
        },
        "counts": {str(k): v for k, v in token_counts.items()}
    }

    logger.info(f"Saving token statistics to {output_path}...")
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(stats_data, f, indent=2)

    return output_path

# -------------------------------------------------------------------------------------------------------------------------------
# Task 8, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def compute_corpus_statistics(
    config: Dict[str, Any],
    corpus_path: str = "offline_corpus.jsonl",
    output_dir: str = "corpus_stats"
) -> Tuple[CounterType[int], int]:
    """
    Orchestrates the computation of static corpus statistics.

    1. Loads tokenizer configuration.
    2. Streams and tokenizes the corpus to compute counts.
    3. Persists the results to disk.

    Args:
        config (Dict[str, Any]): Study configuration containing tokenizer settings.
        corpus_path (str): Path to the input corpus JSONL file.
        output_dir (str): Directory to save output stats.

    Returns:
        Tuple[Counter[int], int]: The token counts and total token count N.
    """
    logger.info("Starting corpus statistics computation...")

    # Extract tokenizer name from config (resolved in Task 3)
    encoding_name = config.get("offline_corpus_config", {}).get("tokenization_scheme", {}).get("name", "cl100k_base")

    # Step 1 & 2: Tokenize and Count
    token_counts, total_tokens = tokenize_and_count_corpus(corpus_path, encoding_name)

    # Step 3: Save
    save_token_statistics(token_counts, total_tokens, output_dir, encoding_name)

    logger.info("Corpus statistics computation complete.")
    return token_counts, total_tokens


In [None]:
# Task 9 – Compute Token Frequencies and Probabilities

# ==============================================================================
# Task 9: Compute Token Frequencies and Probabilities
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 9, Step 1 & 2: Compute Frequencies and Probabilities
# -------------------------------------------------------------------------------------------------------------------------------
def compute_frequencies_and_probabilities(
    token_counts: CounterType[int],
    total_tokens_N: int
) -> Tuple[Dict[int, float], Dict[int, float]]:
    """
    Computes empirical frequencies f(t) and corpus probabilities p(t) for each token.

    Equations:
        f(t) = Count(t) / N
        p(t) = f(t)  (No smoothing applied)

    Args:
        token_counts (Counter[int]): Mapping of token ID to count.
        total_tokens_N (int): Total number of tokens in the corpus.

    Returns:
        Tuple[Dict[int, float], Dict[int, float]]:
            - freq: Dictionary mapping token ID to frequency f(t).
            - p: Dictionary mapping token ID to probability p(t).
    """
    logger.info(f"Computing frequencies and probabilities for {len(token_counts)} unique tokens...")

    freq: Dict[int, float] = {}
    p: Dict[int, float] = {}

    if total_tokens_N == 0:
        logger.warning("Total token count is 0. Returning empty probabilities.")
        return freq, p

    for token_id, count in token_counts.items():
        # Compute frequency
        f_t = count / total_tokens_N
        freq[token_id] = f_t

        # Estimate probability (identity mapping per instructions)
        p[token_id] = f_t

    return freq, p

# -------------------------------------------------------------------------------------------------------------------------------
# Task 9, Step 3: Validate probability distribution
# -------------------------------------------------------------------------------------------------------------------------------
def validate_probability_distribution(p: Dict[int, float]) -> bool:
    """
    Validates that the computed probabilities sum to approximately 1.0.

    Args:
        p (Dict[int, float]): Dictionary of token probabilities.

    Returns:
        bool: True if valid, False otherwise.
    """
    total_prob = sum(p.values())
    is_valid = math.isclose(total_prob, 1.0, rel_tol=1e-6)

    logger.info(f"Total probability mass: {total_prob:.8f}")

    if not is_valid:
        logger.warning("Probability distribution does not sum to 1.0 within tolerance.")
    else:
        logger.info("Probability distribution is valid.")

    # Inspect top tokens (for debugging/sanity check)
    # Define Top K
    top_k = 5

    # Sort by probability descending
    sorted_p = sorted(p.items(), key=lambda item: item[1], reverse=True)[:top_k]

    logger.info(f"Top {top_k} tokens by probability (ID: prob): {sorted_p}")

    return is_valid

# -------------------------------------------------------------------------------------------------------------------------------
# Task 9, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def compute_token_probabilities(
    token_counts: CounterType[int],
    total_tokens_N: int
) -> Dict[int, float]:
    """
    Orchestrates the computation of token probabilities from raw counts.

    1. Computes f(t) and p(t).
    2. Validates the distribution.
    3. Returns the probability map p(t).

    Args:
        token_counts (Counter[int]): Raw token counts.
        total_tokens_N (int): Total tokens.

    Returns:
        Dict[int, float]: The probability map p(t).
    """
    logger.info("Starting probability computation...")

    # Compute frequencies and probabilities
    freq, p = compute_frequencies_and_probabilities(token_counts, total_tokens_N)

    # Validate probabilities
    validate_probability_distribution(p)

    logger.info("Probability computation complete.")
    return p


In [None]:
# Task 10 – Compute Static Self-Information Scores

# ==============================================================================
# Task 10: Compute Static Self-Information Scores
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 10, Step 1 & 2: Compute Static Self-Information
# -------------------------------------------------------------------------------------------------------------------------------
def calculate_self_information(p: Dict[int, float]) -> Dict[int, float]:
    """
    Computes the static self-information for each token based on its corpus probability.

    Equation:
        I(T) = -log2(p(T))

    Args:
        p (Dict[int, float]): Dictionary mapping token ID to probability p(T).

    Returns:
        Dict[int, float]: Dictionary mapping token ID to static self-information s_stat(T).
    """
    logger.info(f"Computing static self-information for {len(p)} tokens...")

    s_stat: Dict[int, float] = {}

    for token_id, prob in p.items():
        if prob <= 0:
            # This should theoretically not happen if p comes from counts > 0
            # We assign a large penalty or skip. Here we skip and log warning.
            logger.warning(f"Token {token_id} has non-positive probability {prob}. Skipping.")
            continue

        # Compute self-information in bits
        info = -math.log2(prob)
        s_stat[token_id] = info

    return s_stat

# -------------------------------------------------------------------------------------------------------------------------------
# Task 10, Step 3: Persist static self-information lookup
# -------------------------------------------------------------------------------------------------------------------------------
def save_static_scores(
    s_stat: Dict[int, float],
    output_dir: str,
    metadata: Dict[str, Any] = None
) -> str:
    """
    Saves the static self-information scores to a JSON file.

    Args:
        s_stat (Dict[int, float]): The self-information scores.
        output_dir (str): Directory to save the file.
        metadata (Dict[str, Any], optional): Additional metadata to save.

    Returns:
        str: Path to the saved file.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    output_path = os.path.join(output_dir, "static_self_information.json")

    # Compute summary stats
    scores = list(s_stat.values())
    if scores:
        stats = {
            "min": min(scores),
            "max": max(scores),
            "mean": sum(scores) / len(scores),
            "count": len(scores)
        }
    else:
        stats = {}

    # Prepare data
    data = {
        "metadata": {
            "timestamp": pd.Timestamp.now().isoformat(),
            "stats": stats,
            **(metadata or {})
        },
        # Convert int keys to strings for JSON
        "s_stat": {str(k): v for k, v in s_stat.items()}
    }

    logger.info(f"Saving static self-information to {output_path}...")
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2)

    return output_path

# -------------------------------------------------------------------------------------------------------------------------------
# Task 10, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def compute_static_self_information(
    p: Dict[int, float],
    output_dir: str = "corpus_stats",
    extra_metadata: Dict[str, Any] = None
) -> Dict[int, float]:
    """
    Orchestrates the computation and persistence of static self-information scores.

    1. Computes I(T) = -log2(p(T)).
    2. Persists the lookup table to disk.

    Args:
        p (Dict[int, float]): Token probabilities.
        output_dir (str): Output directory.
        extra_metadata (Dict[str, Any]): Metadata to include in the output file.

    Returns:
        Dict[int, float]: The static self-information map.
    """
    logger.info("Starting static self-information computation...")

    # Compute self-information score
    s_stat = calculate_self_information(p)

    # Persist stats
    save_static_scores(s_stat, output_dir, extra_metadata)

    logger.info("Static self-information computation complete.")
    return s_stat


In [None]:
# Task 11 – Define Table Serialization Format

# ==============================================================================
# Task 11: Define Table Serialization Format
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 11, Step 1: Choose and formalize serialization method
# -------------------------------------------------------------------------------------------------------------------------------
def serialize_table_markdown(table: Dict[str, Any]) -> str:
    """
    Serializes a table dictionary into a Markdown-formatted string.

    This function converts a structured table dictionary into a string representation
    suitable for LLM consumption. It handles captions, headers, and rows, ensuring
    proper Markdown syntax (pipes, separator lines). It also sanitizes cell content
    by escaping pipes and collapsing newlines to maintain table structure.

    Format:
    Caption: {caption} (if present)
    | Header 1 | Header 2 | ... |
    | --- | --- | ... |
    | Cell 1,1 | Cell 1,2 | ... |
    ...

    Args:
        table (Dict[str, Any]): Table dictionary containing 'headers', 'rows', and optional 'caption'.
                                Expected keys: 'headers' (List[str]), 'rows' (List[List[str]]), 'caption' (str).

    Returns:
        str: The serialized Markdown table string. Returns an empty string if the table structure is invalid.
    """
    if not isinstance(table, dict):
        return ""

    headers = table.get("headers", [])
    rows = table.get("rows", [])
    caption = table.get("caption", "")

    # Basic validation: headers are required to form a table structure
    if not headers or not isinstance(headers, list):
        return ""

    # Helper to escape special characters (pipes and newlines) within cell content
    def clean_cell(cell: Any) -> str:
        s = str(cell) if cell is not None else ""
        # Escape pipes to prevent breaking Markdown table structure
        s = s.replace("|", "\\|")
        # Collapse newlines to spaces to keep rows on single lines
        s = s.replace("\n", " ")
        return s.strip()

    # Build the table string parts
    parts = []

    # 1. Caption
    if caption and isinstance(caption, str) and caption.strip():
        parts.append(f"Caption: {clean_cell(caption)}")

    # 2. Headers
    clean_headers = [clean_cell(h) for h in headers]
    header_row = "| " + " | ".join(clean_headers) + " |"
    parts.append(header_row)

    # 3. Separator Row
    # Create a separator line with '---' for each column
    separator_row = "| " + " | ".join(["---"] * len(headers)) + " |"
    parts.append(separator_row)

    # 4. Data Rows
    if isinstance(rows, list):
        for row in rows:
            if not isinstance(row, list):
                continue

            # Ensure row length matches headers.
            # If row is shorter, pad with empty strings.
            # If row is longer, truncate (though cleansing should have handled this).
            current_row = row[:len(headers)] + [""] * (len(headers) - len(row))

            clean_row_cells = [clean_cell(c) for c in current_row]
            row_str = "| " + " | ".join(clean_row_cells) + " |"
            parts.append(row_str)

    # Join all parts with newlines
    return "\n".join(parts)

# -------------------------------------------------------------------------------------------------------------------------------
# Task 11, Step 2: Apply serialization consistently
# -------------------------------------------------------------------------------------------------------------------------------
def apply_table_serialization(df: pd.DataFrame) -> pd.DataFrame:
    """
    Applies Markdown serialization to all tables in the DataFrame.

    This function iterates over the 'tables' column of the DataFrame. For each row,
    it serializes the list of table dictionaries into a list of Markdown strings.
    The result is stored in a new column 'serialized_tables'.

    Args:
        df (pd.DataFrame): DataFrame containing a 'tables' column (list of dicts).

    Returns:
        pd.DataFrame: The input DataFrame with a new 'serialized_tables' column (list of strings).
    """
    def serialize_row_tables(tables_list: Any) -> List[str]:
        """Helper to serialize a list of tables for a single row."""
        if not isinstance(tables_list, list):
            return []

        serialized_list = []
        for t in tables_list:
            serialized = serialize_table_markdown(t)
            if serialized:
                serialized_list.append(serialized)
        return serialized_list

    # Apply the serialization function to the 'tables' column
    # We use a lambda or direct function reference. Here direct reference is cleaner.
    df["serialized_tables"] = df["tables"].apply(serialize_row_tables)

    return df

# -------------------------------------------------------------------------------------------------------------------------------
# Task 11, Step 3: Validate serialization output
# -------------------------------------------------------------------------------------------------------------------------------
def validate_serialization(df: pd.DataFrame, sample_size: int = 5) -> bool:
    """
    Validates that serialization produces non-empty strings for valid tables.

    This function inspects a sample of rows where tables exist to ensure that
    the 'serialized_tables' column contains valid Markdown strings (non-empty,
    containing pipes). It logs errors if malformed serialization is detected.

    Args:
        df (pd.DataFrame): DataFrame with 'serialized_tables' column.
        sample_size (int): Number of rows to inspect (default: 5).

    Returns:
        bool: True if validation passes, False otherwise.
    """
    # Filter for rows that actually have tables
    rows_with_tables = df[df["serialized_tables"].map(len) > 0]

    if len(rows_with_tables) == 0:
        logger.warning("No rows with tables found to validate serialization.")
        return True # Technically valid if empty, but worth noting

    # Sample rows
    sample = rows_with_tables.head(sample_size)

    all_valid = True
    for idx, row in sample.iterrows():
        serialized_list = row["serialized_tables"]
        for i, s_table in enumerate(serialized_list):
            # Check 1: Non-empty
            if not s_table.strip():
                logger.error(f"Empty serialization found for row {idx}, table {i}")
                all_valid = False
                continue

            # Check 2: Contains pipes (basic Markdown table indicator)
            if "|" not in s_table:
                logger.error(f"Malformed serialization (no pipes) for row {idx}, table {i}: {s_table[:50]}...")
                all_valid = False
                continue

            # Check 3: Contains separator line
            if "| --- |" not in s_table and "| ---" not in s_table:
                 # Note: Depending on column count, it might be "| --- |" or "| --- | --- |"
                 # We check for the basic separator pattern
                 if "---" not in s_table:
                    logger.error(f"Malformed serialization (no separator line) for row {idx}, table {i}")
                    all_valid = False

    if all_valid:
        logger.info(f"Serialization validation passed on {len(sample)} samples.")
    else:
        logger.error("Serialization validation failed on one or more samples.")

    return all_valid

# -------------------------------------------------------------------------------------------------------------------------------
# Task 11, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def serialize_tables_task(
    tatqa_df: pd.DataFrame,
    finqa_df: pd.DataFrame
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Orchestrates table serialization for both TAT-QA and Fin-QA datasets.

    This function applies the Markdown serialization logic to both DataFrames
    and performs validation to ensure the output is correct.

    Args:
        tatqa_df (pd.DataFrame): TAT-QA DataFrame.
        finqa_df (pd.DataFrame): Fin-QA DataFrame.

    Returns:
        Tuple[pd.DataFrame, pd.DataFrame]: The input DataFrames with a new 'serialized_tables' column.
    """
    logger.info("Starting table serialization task...")

    # Process TAT-QA
    logger.info("Serializing TAT-QA tables...")
    tatqa_df = apply_table_serialization(tatqa_df)
    if not validate_serialization(tatqa_df):
        logger.warning("TAT-QA serialization validation reported issues.")

    # Process Fin-QA
    logger.info("Serializing Fin-QA tables...")
    finqa_df = apply_table_serialization(finqa_df)
    if not validate_serialization(finqa_df):
        logger.warning("Fin-QA serialization validation reported issues.")

    logger.info("Table serialization task complete.")
    return tatqa_df, finqa_df


In [None]:
# Task 12 – Define QA Prompt Template Structure

# ==============================================================================
# Task 12: Define QA Prompt Template Structure
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 12, Step 1 & 2: Define Templates for TAT-QA and Fin-QA
# -------------------------------------------------------------------------------------------------------------------------------
class PromptTemplateManager:
    """
    Manages prompt templates for different datasets.

    Templates are designed to be consistent across datasets while allowing for
    specific instructions. They include placeholders for:
    - {instructions}: Task-specific instructions.
    - {exemplars}: Few-shot examples (optional).
    - {tables}: Serialized table content.
    - {passages}: Narrative text content.
    - {question}: The user query.
    """

    # TAT-QA Template
    # Emphasizes extracting information from hybrid context
    TATQA_TEMPLATE = (
        "{instructions}\n\n"
        "{exemplars}"
        "Context:\n"
        "{tables}\n\n"
        "{passages}\n\n"
        "Question: {question}\n"
        "Answer:"
    )

    TATQA_INSTRUCTIONS = (
        "You are a financial QA assistant. Answer the question based on the following "
        "tables and passages. Provide the answer directly."
    )

    # Fin-QA Template
    # Emphasizes numerical reasoning and calculations
    FINQA_TEMPLATE = (
        "{instructions}\n\n"
        "{exemplars}"
        "Context:\n"
        "{tables}\n\n"
        "{passages}\n\n"
        "Question: {question}\n"
        "Answer:"
    )

    FINQA_INSTRUCTIONS = (
        "You are a financial expert. Perform any necessary calculations to answer the "
        "question based on the provided financial data. Provide the numeric answer."
    )

    @classmethod
    def get_template(cls, dataset: str) -> str:
        """
        Retrieves the raw template string for a given dataset.

        Args:
            dataset (str): Name of the dataset ("TAT-QA" or "Fin-QA").

        Returns:
            str: The format string with placeholders.
        """
        if dataset == "TAT-QA":
            return cls.TATQA_TEMPLATE
        elif dataset == "Fin-QA":
            return cls.FINQA_TEMPLATE
        else:
            raise ValueError(f"Unknown dataset: {dataset}")

    @classmethod
    def get_instructions(cls, dataset: str) -> str:
        """
        Retrieves the default instructions for a given dataset.

        Args:
            dataset (str): Name of the dataset.

        Returns:
            str: The instruction text.
        """
        if dataset == "TAT-QA":
            return cls.TATQA_INSTRUCTIONS
        elif dataset == "Fin-QA":
            return cls.FINQA_INSTRUCTIONS
        else:
            raise ValueError(f"Unknown dataset: {dataset}")

# -------------------------------------------------------------------------------------------------------------------------------
# Task 12, Step 3: Document template placeholders (Implemented via Docstrings and Class Structure)
# -------------------------------------------------------------------------------------------------------------------------------
# The PromptTemplateManager class above encapsulates the documentation and structure.
# The placeholders are explicitly named in the template strings.

# -------------------------------------------------------------------------------------------------------------------------------
# Task 12, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def construct_prompt(
    dataset: str,
    question: str,
    tables_str: str,
    passages_str: str,
    exemplars_str: str = "",
    custom_instructions: Optional[str] = None
) -> str:
    """
    Constructs the final prompt string by populating the dataset-specific template.

    Args:
        dataset (str): "TAT-QA" or "Fin-QA".
        question (str): The user question.
        tables_str (str): Serialized tables.
        passages_str (str): Concatenated passages.
        exemplars_str (str, optional): Formatted few-shot exemplars. Defaults to "".
        custom_instructions (str, optional): Override default instructions.

    Returns:
        str: The fully constructed prompt.
    """
    template = PromptTemplateManager.get_template(dataset)
    instructions = custom_instructions or PromptTemplateManager.get_instructions(dataset)

    # Ensure exemplars have a trailing newline if present to separate from Context
    if exemplars_str and not exemplars_str.endswith("\n"):
        exemplars_str += "\n"

    prompt = template.format(
        instructions=instructions,
        exemplars=exemplars_str,
        tables=tables_str,
        passages=passages_str,
        question=question
    )

    return prompt


In [None]:
# Task 13 – Define Few-Shot Exemplar Embedding in Prompts

# ==============================================================================
# Task 13: Define Few-Shot Exemplar Embedding in Prompts
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 13, Step 1 & 2: Define Exemplar Formatting and Placement
# -------------------------------------------------------------------------------------------------------------------------------
def format_single_exemplar(
    exemplar: Dict[str, Any],
    index: int,
    dataset: str
) -> str:
    """
    Formats a single few-shot exemplar into a string block.

    Structure:
    Example {index}:
    Context:
    {tables}
    {passages}
    Question: {question}
    Answer: {answer}

    Args:
        exemplar (Dict[str, Any]): The exemplar data containing 'question_text',
                                   'serialized_tables', 'passages', 'answer_value', 'answer_unit'.
        index (int): The 1-based index of the exemplar in the sequence.
        dataset (str): "TAT-QA" or "Fin-QA" (used for minor formatting nuances if any).

    Returns:
        str: The formatted exemplar string.
    """
    # Extract components
    question = exemplar.get("question_text", "").strip()

    # Tables are already serialized strings in the 'serialized_tables' list
    # We join them with newlines
    tables_list = exemplar.get("serialized_tables", [])
    tables_str = "\n\n".join(tables_list) if tables_list else ""

    # Passages are a list of dicts, we need to extract text
    passages_list = exemplar.get("passages", [])
    passages_str = "\n\n".join([p.get("text", "").strip() for p in passages_list])

    # Answer
    answer_val = str(exemplar.get("answer_value", "")).strip()
    answer_unit = exemplar.get("answer_unit")
    if answer_unit:
        answer_full = f"{answer_val} {answer_unit}"
    else:
        answer_full = answer_val

    # Construct block
    # We use a separator line at the end to distinguish from the next example
    block = (
        f"Example {index}:\n"
        f"Context:\n"
        f"{tables_str}\n\n"
        f"{passages_str}\n\n"
        f"Question: {question}\n"
        f"Answer: {answer_full}"
    )

    return block

# -------------------------------------------------------------------------------------------------------------------------------
# Task 13, Step 3: Ensure no leakage and Orchestrate
# -------------------------------------------------------------------------------------------------------------------------------
def format_exemplars(
    target_example_id: str,
    exemplars: List[Dict[str, Any]],
    dataset: str
) -> str:
    """
    Selects and formats a list of exemplars into a single string block for the prompt.

    Enforces the constraint that the target example cannot be used as its own exemplar (leakage prevention).

    Args:
        target_example_id (str): The ID of the example being prompted for.
        exemplars (List[Dict[str, Any]]): A list of candidate exemplar dictionaries.
        dataset (str): The dataset name.

    Returns:
        str: A string containing all formatted exemplars, separated by horizontal rules.
    """
    formatted_blocks = []
    valid_count = 0

    for ex in exemplars:
        # Leakage check
        if str(ex.get("example_id")) == str(target_example_id):
            logger.warning(f"Skipping exemplar {ex.get('example_id')} to prevent leakage for target {target_example_id}.")
            continue

        valid_count += 1
        block = format_single_exemplar(ex, valid_count, dataset)
        formatted_blocks.append(block)

    if not formatted_blocks:
        return ""

    # Join blocks with a distinct separator
    # We add a header "Examples:" at the very top
    full_str = "Examples:\n\n" + "\n\n---\n\n".join(formatted_blocks) + "\n\n---\n\n"

    return full_str


In [None]:
# Task 14 – Configure Scorer LLMs and Target LLMs

# ==============================================================================
# Task 14: Configure Scorer LLMs and Target LLMs
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 14, Step 1 & 2: Define LLM Interfaces (Abstract and Concrete)
# -------------------------------------------------------------------------------------------------------------------------------
class LLMInterface(ABC):
    """
    Abstract base class defining the standardized interface for LLM interactions
    within the CompactPrompt compression pipeline.

    This interface provides a unified abstraction layer across heterogeneous LLM
    providers (OpenAI, Anthropic, Together AI), ensuring consistent methods for:

    1. **Token Counting**: Estimating prompt length for context-window management.
    2. **Prompting**: Sending chat-formatted messages to the model endpoint.
    3. **Response Extraction**: Parsing generated text and log-probabilities.
    4. **Dynamic Self-Information Computation**: Converting log-probabilities to
       information-theoretic scores in bits, as required by Equation (2) in the
       CompactPrompt paper for computing s_dyn(t | c).

    The log-probability extraction capability is critical for implementing the
    dynamic self-information scoring mechanism described in Section 3.1.2 of the
    CompactPrompt paper, where:

        s_dyn(t | c) = -log_2(P_model(t | c))

    Concrete implementations must handle provider-specific API structures while
    exposing a uniform interface to the compression orchestration layer.

    Attributes:
        None at the abstract level; concrete subclasses define provider-specific
        attributes such as API clients and tokenizer instances.

    Notes:
        - Not all LLM providers expose token-level log-probabilities. The Claude
          API, for instance, does not currently support public logprobs access.
          Concrete implementations must handle this gracefully by returning None
          for the logprobs component when unavailable.
        - Tokenizer alignment between the offline corpus and the LLM's native
          tokenizer is addressed in Task 14, Step 3.

    See Also:
        - GPT4oInterface: Concrete implementation for OpenAI GPT-4o family.
        - ClaudeInterface: Concrete implementation for Anthropic Claude-3.5.
        - LlamaInterface: Concrete implementation for Llama-3.3 via Together AI.
    """

    @abstractmethod
    def count_tokens(self, text: str) -> int:
        """
        Count the number of tokens in the input text using the model's tokenizer.

        This method is essential for prompt length estimation and ensuring that
        compressed prompts remain within the model's context window constraints.
        Accurate token counting is critical for the compression ratio calculations
        reported in the CompactPrompt evaluation (Section 5).

        Args:
            text (str): The input text string to tokenize and count.
                Must be a valid UTF-8 encoded string. Empty strings are
                permitted and should return 0.

        Returns:
            int: The total number of tokens in the input text according to
                the model's native tokenizer. Returns 0 if tokenization fails
                or if the input is empty.

        Raises:
            TypeError: If text is not a string type.
            ValueError: If text contains invalid characters that cannot be
                tokenized by the model's tokenizer.

        Notes:
            - Token counts may vary significantly between different tokenizers
              (e.g., BPE vs. WordPiece vs. SentencePiece).
            - For models where the native tokenizer is not accessible, an
              approximation using a similar tokenizer (e.g., tiktoken) may
              be employed, with documented variance expectations.
        """
        pass

    @abstractmethod
    def prompt(
        self,
        messages: List[Dict[str, str]],
        max_tokens: int = 512,
        temperature: float = 0.0,
        logprobs: bool = False
    ) -> Any:
        """
        Send a chat-formatted prompt to the LLM and retrieve the raw response.

        This method implements the core interaction with the LLM API, supporting
        both generation tasks (for downstream QA evaluation) and scoring tasks
        (for obtaining P_model(t | c) via log-probabilities).

        Args:
            messages (List[Dict[str, str]]): A list of message dictionaries,
                each containing:
                - 'role' (str): One of 'system', 'user', or 'assistant'.
                - 'content' (str): The message text content.
                The list must contain at least one message with role 'user'.

            max_tokens (int, optional): Maximum number of tokens to generate
                in the response. Defaults to 512. Must be a positive integer
                not exceeding the model's maximum generation limit.

            temperature (float, optional): Sampling temperature controlling
                output randomness. Defaults to 0.0 for deterministic outputs,
                which is recommended for reproducible evaluation experiments.
                Must be in the range [0.0, 2.0] for most providers.

            logprobs (bool, optional): Whether to request token-level log
                probabilities in the response. Defaults to False. When True,
                the response object will include per-token log-probabilities
                required for computing dynamic self-information scores.

        Returns:
            Any: The raw response object from the provider's API. The exact
                type depends on the provider:
                - OpenAI: openai.types.chat.ChatCompletion
                - Anthropic: anthropic.types.Message
                - Together: together.types.ChatCompletion

        Raises:
            TypeError: If messages is not a list or contains non-dict elements.
            ValueError: If messages is empty or lacks a 'user' role message.
            ValueError: If max_tokens is not a positive integer.
            ValueError: If temperature is outside the valid range.
            RuntimeError: If the API call fails due to network or auth errors.

        Notes:
            - The logprobs parameter may be ignored by providers that do not
              support log-probability access (e.g., Anthropic Claude).
            - For scoring operations, temperature should be set to 0.0 to
              ensure deterministic probability estimates.
        """
        pass

    @abstractmethod
    def extract_response(
        self,
        response: Any
    ) -> Tuple[str, Optional[List[Dict[str, Any]]]]:
        """
        Extract generated text and log-probabilities from the raw API response.

        This method parses the provider-specific response structure to extract:
        1. The generated text content for downstream task evaluation.
        2. Token-level log-probabilities for dynamic self-information scoring.

        Args:
            response (Any): The raw response object returned by the prompt()
                method. The exact type is provider-specific.

        Returns:
            Tuple[str, Optional[List[Dict[str, Any]]]]: A tuple containing:

                [0] str: The generated text content. For multi-block responses
                    (e.g., Anthropic), all text blocks are concatenated.

                [1] Optional[List[Dict[str, Any]]]: A list of log-probability
                    dictionaries, one per generated token, or None if log-probs
                    were not requested or are not supported. Each dictionary
                    contains:
                    - 'token' (str): The token string.
                    - 'logprob' (float): Natural log probability ln(P(t|c)).
                    - 'top_logprobs' (List[Dict[str, Any]]): Alternative tokens
                      and their log-probabilities (may be empty).

        Raises:
            TypeError: If response is None or of unexpected type.
            AttributeError: If the response object lacks expected attributes.
            KeyError: If required fields are missing from the response structure.

        Notes:
            - Log-probabilities are returned in natural log (base e) as provided
              by the API. Use compute_dynamic_self_information() to convert to
              bits (base 2) as required by the CompactPrompt scoring equations.
            - For providers without logprob support (Claude), this method
              returns None for the second tuple element.
        """
        pass

    def compute_dynamic_self_information(
        self,
        logprobs_list: List[Dict[str, Any]]
    ) -> List[float]:
        """
        Convert natural log probabilities to self-information in bits.

        This method implements the conversion from API-provided log-probabilities
        (in natural log) to self-information scores in bits, as required by the
        dynamic self-information computation in Section 3.1.2 of the CompactPrompt
        paper.

        The mathematical transformation is:

            s_dyn(t | c) = -log_2(P_model(t | c))
                         = -ln(P_model(t | c)) / ln(2)

        Since API logprobs are provided as ln(P), we compute:

            I(t) = -logprob / ln(2)

        Args:
            logprobs_list (List[Dict[str, Any]]): A list of log-probability
                dictionaries as returned by extract_response(). Each dictionary
                must contain a 'logprob' key with a float value representing
                the natural log probability ln(P(t | c)).

        Returns:
            List[float]: A list of self-information values in bits, one per
                token in the input list. Higher values indicate tokens that
                are more surprising (informative) given their context.

        Raises:
            TypeError: If logprobs_list is not a list.
            KeyError: If any dictionary in the list lacks a 'logprob' key.
            ValueError: If any 'logprob' value is not a valid float.
            ValueError: If logprobs_list is empty.

        Example:
            >>> logprobs = [{'token': 'the', 'logprob': -0.5},
            ...             {'token': 'cat', 'logprob': -2.3}]
            >>> interface.compute_dynamic_self_information(logprobs)
            [0.7213..., 3.3188...]  # Values in bits

        Notes:
            - The constant ln(2) ≈ 0.693147 is used for the base conversion.
            - Self-information values are always non-negative since logprobs
              are always ≤ 0 (probabilities are ≤ 1).
        """
        # Validate that logprobs_list is a non-empty list
        if not isinstance(logprobs_list, list):
            raise TypeError(
                f"logprobs_list must be a list, got {type(logprobs_list).__name__}"
            )

        # Validate that the list is not empty
        if len(logprobs_list) == 0:
            raise ValueError("logprobs_list cannot be empty")

        # Compute the natural log of 2 for base conversion (ln(2) ≈ 0.693147)
        ln_2 = math.log(2)

        # Initialize list to store self-information values in bits
        self_info: List[float] = []

        # Iterate over each token's log-probability dictionary
        for idx, lp in enumerate(logprobs_list):
            # Validate that each element is a dictionary
            if not isinstance(lp, dict):
                raise TypeError(
                    f"Element at index {idx} must be a dict, got {type(lp).__name__}"
                )

            # Validate that 'logprob' key exists
            if "logprob" not in lp:
                raise KeyError(
                    f"Element at index {idx} missing required 'logprob' key"
                )

            # Extract the natural log probability value
            logprob_value = lp["logprob"]

            # Validate that logprob is a numeric type
            if not isinstance(logprob_value, (int, float)):
                raise ValueError(
                    f"logprob at index {idx} must be numeric, got {type(logprob_value).__name__}"
                )

            # Convert ln(P) to -log_2(P) using: -ln(P) / ln(2)
            # Equation from Section 3.1.2: s_dyn(t|c) = -log_2(P_model(t|c))
            self_info_bits = -logprob_value / ln_2

            # Append the computed self-information value
            self_info.append(self_info_bits)

        # Return the list of self-information values in bits
        return self_info


class GPT4oInterface(LLMInterface):
    """
    Concrete LLM interface implementation for OpenAI's GPT-4o model family.

    This class provides access to GPT-4o and GPT-4o-mini models via the OpenAI
    API, supporting full log-probability extraction for dynamic self-information
    scoring as required by the CompactPrompt pipeline.

    The GPT-4o family uses the o200k_base tokenizer (via tiktoken), which must
    be aligned with the offline corpus tokenizer as specified in Task 14, Step 3.

    Attributes:
        client (OpenAI): The initialized OpenAI API client instance.
        model (str): The model identifier string (e.g., "gpt-4o", "gpt-4o-mini").
        encoding (tiktoken.Encoding): The tiktoken encoding instance for
            the o200k_base vocabulary used by GPT-4o models.

    Notes:
        - GPT-4o supports token-level log-probabilities via the logprobs parameter.
        - The top_logprobs parameter is set to 5 when logprobs are requested,
          providing alternative token probabilities for analysis.
        - Requires OPENAI_API_KEY environment variable to be set.

    Example:
        >>> interface = GPT4oInterface(model="gpt-4o")
        >>> token_count = interface.count_tokens("Hello, world!")
        >>> response = interface.prompt(
        ...     messages=[{"role": "user", "content": "What is 2+2?"}],
        ...     logprobs=True
        ... )
        >>> text, logprobs = interface.extract_response(response)
    """

    def __init__(self, model: str = "gpt-4o") -> None:
        """
        Initialize the GPT-4o interface with API client and tokenizer.

        This constructor sets up the OpenAI API client and loads the appropriate
        tiktoken encoding for accurate token counting. The o200k_base encoding
        is used for all GPT-4o family models.

        Args:
            model (str, optional): The OpenAI model identifier. Defaults to
                "gpt-4o". Valid options include:
                - "gpt-4o": Full GPT-4o model
                - "gpt-4o-mini": Smaller, faster variant

        Raises:
            ValueError: If OPENAI_API_KEY environment variable is not set.
            RuntimeError: If tiktoken encoding fails to load.

        Notes:
            - API key is retrieved from the OPENAI_API_KEY environment variable.
            - The tiktoken o200k_base encoding provides exact token counts
              matching the model's internal tokenization.
        """
        # Validate that model parameter is a non-empty string
        if not isinstance(model, str) or len(model.strip()) == 0:
            raise TypeError("model must be a non-empty string")

        # Retrieve API key from environment variables
        api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")

        # Validate that API key is present
        if not api_key:
            raise ValueError(
                "OPENAI_API_KEY environment variable is not set. "
                "Please set it to your OpenAI API key."
            )

        # Initialize the OpenAI client with the retrieved API key
        self.client: OpenAI = OpenAI(api_key=api_key)

        # Store the model identifier for use in API calls
        self.model: str = model

        # Load the tiktoken encoding for GPT-4o (o200k_base vocabulary)
        try:
            # o200k_base is the encoding used by GPT-4o family models
            self.encoding: tiktoken.Encoding = tiktoken.get_encoding("o200k_base")
        except Exception as e:
            # Log and re-raise if tokenizer loading fails
            logger.error(f"Failed to load tiktoken encoding 'o200k_base': {e}")
            raise RuntimeError(
                f"Failed to initialize tiktoken encoding: {e}"
            ) from e

        # Log successful initialization
        logger.info(f"GPT4oInterface initialized with model: {self.model}")

    def count_tokens(self, text: str) -> int:
        """
        Count tokens in the input text using tiktoken's o200k_base encoding.

        This method provides exact token counts matching GPT-4o's internal
        tokenization, which is essential for accurate compression ratio
        calculations and context window management.

        Args:
            text (str): The input text to tokenize. Must be a valid string.
                Empty strings return 0 tokens.

        Returns:
            int: The number of tokens in the input text according to the
                o200k_base encoding.

        Raises:
            TypeError: If text is not a string.

        Example:
            >>> interface = GPT4oInterface()
            >>> interface.count_tokens("Hello, world!")
            4
        """
        # Validate input type
        if not isinstance(text, str):
            raise TypeError(f"text must be a string, got {type(text).__name__}")

        # Handle empty string case
        if len(text) == 0:
            return 0

        # Encode the text to token IDs using tiktoken
        token_ids: List[int] = self.encoding.encode(text)

        # Return the count of token IDs
        return len(token_ids)

    def prompt(
        self,
        messages: List[Dict[str, str]],
        max_tokens: int = 512,
        temperature: float = 0.0,
        logprobs: bool = False
    ) -> Any:
        """
        Send a chat completion request to the GPT-4o API.

        This method wraps the OpenAI chat completions endpoint, supporting
        both standard generation and log-probability extraction for dynamic
        self-information scoring.

        Args:
            messages (List[Dict[str, str]]): List of chat messages, each with
                'role' and 'content' keys. Must contain at least one message.

            max_tokens (int, optional): Maximum tokens to generate. Defaults
                to 512. Must be positive and ≤ model's maximum.

            temperature (float, optional): Sampling temperature in [0.0, 2.0].
                Defaults to 0.0 for deterministic output.

            logprobs (bool, optional): Whether to return token log-probabilities.
                Defaults to False. When True, top_logprobs is set to 5.

        Returns:
            Any: OpenAI ChatCompletion response object containing the generated
                text and optionally log-probabilities.

        Raises:
            TypeError: If messages is not a list of dictionaries.
            ValueError: If messages is empty or malformed.
            ValueError: If max_tokens is not positive.
            ValueError: If temperature is outside [0.0, 2.0].
            RuntimeError: If the API call fails.
        """
        # Validate messages parameter type
        if not isinstance(messages, list):
            raise TypeError(f"messages must be a list, got {type(messages).__name__}")

        # Validate messages is non-empty
        if len(messages) == 0:
            raise ValueError("messages list cannot be empty")

        # Validate each message has required keys
        for idx, msg in enumerate(messages):
            if not isinstance(msg, dict):
                raise TypeError(f"Message at index {idx} must be a dict")
            if "role" not in msg or "content" not in msg:
                raise ValueError(
                    f"Message at index {idx} must have 'role' and 'content' keys"
                )

        # Validate max_tokens is a positive integer
        if not isinstance(max_tokens, int) or max_tokens <= 0:
            raise ValueError(f"max_tokens must be a positive integer, got {max_tokens}")

        # Validate temperature is within acceptable range
        if not isinstance(temperature, (int, float)) or not (0.0 <= temperature <= 2.0):
            raise ValueError(f"temperature must be in [0.0, 2.0], got {temperature}")

        # Determine top_logprobs value based on logprobs flag
        # When logprobs=True, request top 5 alternative tokens per position
        top_logprobs_value: Optional[int] = 5 if logprobs else None

        try:
            # Call the OpenAI chat completions API
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                logprobs=logprobs,
                top_logprobs=top_logprobs_value
            )

            # Return the raw response object
            return response

        except Exception as e:
            # Log the error and re-raise with context
            logger.error(f"OpenAI API call failed: {e}")
            raise RuntimeError(f"OpenAI API call failed: {e}") from e

    def extract_response(
        self,
        response: Any
    ) -> Tuple[str, Optional[List[Dict[str, Any]]]]:
        """
        Extract generated text and log-probabilities from GPT-4o response.

        This method parses the OpenAI ChatCompletion response structure to
        extract the generated content and, if available, token-level log-
        probabilities for dynamic self-information computation.

        Args:
            response (Any): The ChatCompletion response object returned by
                the prompt() method.

        Returns:
            Tuple[str, Optional[List[Dict[str, Any]]]]: A tuple containing:

                [0] str: The generated text from the first choice.

                [1] Optional[List[Dict[str, Any]]]: List of logprob dicts if
                    logprobs were requested, otherwise None. Each dict contains:
                    - 'token' (str): The generated token.
                    - 'logprob' (float): Natural log probability.
                    - 'top_logprobs' (List[Dict]): Top 5 alternative tokens.

        Raises:
            TypeError: If response is None.
            AttributeError: If response lacks expected structure.
        """
        # Validate response is not None
        if response is None:
            raise TypeError("response cannot be None")

        # Validate response has choices attribute
        if not hasattr(response, "choices") or len(response.choices) == 0:
            raise AttributeError("response must have non-empty 'choices' attribute")

        # Extract the generated text from the first choice's message content
        generated_text: str = response.choices[0].message.content

        # Handle case where content might be None
        if generated_text is None:
            generated_text = ""

        # Initialize logprobs_data as None (will be populated if available)
        logprobs_data: Optional[List[Dict[str, Any]]] = None

        # Check if log-probabilities are present in the response
        if response.choices[0].logprobs is not None:
            # Initialize list to store parsed logprob dictionaries
            logprobs_data = []

            # Iterate over each token's logprob data in the response
            for token_data in response.choices[0].logprobs.content:
                # Build the top_logprobs list for alternative tokens
                top_logprobs_list: List[Dict[str, Any]] = []

                # Check if top_logprobs exists for this token
                if token_data.top_logprobs:
                    # Extract each alternative token and its log probability
                    for alt in token_data.top_logprobs:
                        top_logprobs_list.append({
                            "token": alt.token,
                            "logprob": alt.logprob
                        })

                # Append the complete logprob dictionary for this token
                logprobs_data.append({
                    "token": token_data.token,
                    "logprob": token_data.logprob,
                    "top_logprobs": top_logprobs_list
                })

        # Return the extracted text and logprobs tuple
        return generated_text, logprobs_data


class ClaudeInterface(LLMInterface):
    """
    Concrete LLM interface implementation for Anthropic's Claude-3.5-Sonnet model.

    This class provides access to Claude models via the Anthropic API. Note that
    Claude does not currently expose token-level log-probabilities through its
    public API, which limits its use as a scorer LLM for dynamic self-information
    computation. However, Claude remains a valid target LLM for downstream task
    evaluation.

    Attributes:
        client (Anthropic): The initialized Anthropic API client instance.
        model (str): The model identifier string (e.g., "claude-3-5-sonnet-20241022").
        encoding (tiktoken.Encoding): A tiktoken encoding instance (cl100k_base)
            used as an approximation for offline token counting when the API
            count_tokens endpoint is unavailable.

    Notes:
        - Claude does NOT support public log-probability access; extract_response()
          always returns None for the logprobs component.
        - Token counting uses Claude's native count_tokens API when available,
          falling back to tiktoken cl100k_base approximation.
        - Requires ANTHROPIC_API_KEY environment variable to be set.

    Example:
        >>> interface = ClaudeInterface(model="claude-3-5-sonnet-20241022")
        >>> token_count = interface.count_tokens("Hello, world!")
        >>> response = interface.prompt(
        ...     messages=[{"role": "user", "content": "What is 2+2?"}]
        ... )
        >>> text, logprobs = interface.extract_response(response)
        >>> assert logprobs is None  # Claude does not support logprobs
    """

    def __init__(self, model: str = "claude-3-5-sonnet-20241022") -> None:
        """
        Initialize the Claude interface with API client and fallback tokenizer.

        This constructor sets up the Anthropic API client and loads a tiktoken
        encoding for approximate token counting when the API endpoint is
        unavailable.

        Args:
            model (str, optional): The Anthropic model identifier. Defaults to
                "claude-3-5-sonnet-20241022".

        Raises:
            ValueError: If ANTHROPIC_API_KEY environment variable is not set.
            RuntimeError: If tiktoken encoding fails to load.

        Notes:
            - The cl100k_base encoding is used as an approximation for Claude's
              tokenization; exact counts are obtained via the API when possible.
        """
        # Validate that model parameter is a non-empty string
        if not isinstance(model, str) or len(model.strip()) == 0:
            raise TypeError("model must be a non-empty string")

        # Retrieve API key from environment variables
        api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY")

        # Validate that API key is present
        if not api_key:
            raise ValueError(
                "ANTHROPIC_API_KEY environment variable is not set. "
                "Please set it to your Anthropic API key."
            )

        # Initialize the Anthropic client with the retrieved API key
        self.client: Anthropic = Anthropic(api_key=api_key)

        # Store the model identifier for use in API calls
        self.model: str = model

        # Load tiktoken encoding as fallback for token counting
        try:
            # cl100k_base provides reasonable approximation for Claude tokenization
            self.encoding: tiktoken.Encoding = tiktoken.get_encoding("cl100k_base")
        except Exception as e:
            # Log and re-raise if tokenizer loading fails
            logger.error(f"Failed to load tiktoken encoding 'cl100k_base': {e}")
            raise RuntimeError(
                f"Failed to initialize tiktoken encoding: {e}"
            ) from e

        # Log successful initialization with warning about logprobs limitation
        logger.info(
            f"ClaudeInterface initialized with model: {self.model}. "
            "Note: Claude does not support log-probability extraction."
        )

    def count_tokens(self, text: str) -> int:
        """
        Count tokens using Claude's API with tiktoken fallback.

        This method attempts to use Anthropic's native count_tokens API endpoint
        for exact token counts. If the API call fails, it falls back to tiktoken
        cl100k_base approximation.

        Args:
            text (str): The input text to tokenize. Must be a valid string.

        Returns:
            int: The number of tokens in the input text. Exact count from API
                when available, approximate count from tiktoken otherwise.

        Raises:
            TypeError: If text is not a string.

        Notes:
            - API-based counting is preferred for accuracy.
            - Tiktoken cl100k_base may undercount or overcount by ~5-10%.
        """
        # Validate input type
        if not isinstance(text, str):
            raise TypeError(f"text must be a string, got {type(text).__name__}")

        # Handle empty string case
        if len(text) == 0:
            return 0

        try:
            # Attempt to use Anthropic's native token counting API
            response = self.client.messages.count_tokens(
                model=self.model,
                messages=[{"role": "user", "content": text}]
            )

            # Return the exact token count from the API response
            return response.input_tokens

        except Exception as e:
            # Log the fallback to tiktoken
            logger.warning(
                f"Claude count_tokens API failed: {e}. "
                "Falling back to tiktoken approximation."
            )

            # Use tiktoken cl100k_base as approximation
            token_ids: List[int] = self.encoding.encode(text)

            # Return the approximate token count
            return len(token_ids)

    def prompt(
        self,
        messages: List[Dict[str, str]],
        max_tokens: int = 1024,
        temperature: float = 0.0,
        logprobs: bool = False
    ) -> Any:
        """
        Send a message request to the Claude API.

        This method wraps the Anthropic messages endpoint. Note that the logprobs
        parameter is accepted for interface compatibility but is ignored, as
        Claude does not support log-probability extraction.

        Args:
            messages (List[Dict[str, str]]): List of chat messages with 'role'
                and 'content' keys. Must contain at least one message.

            max_tokens (int, optional): Maximum tokens to generate. Defaults
                to 1024 (higher default for Claude's longer responses).

            temperature (float, optional): Sampling temperature in [0.0, 1.0].
                Defaults to 0.0 for deterministic output.

            logprobs (bool, optional): Ignored parameter (Claude does not
                support logprobs). Included for interface compatibility.

        Returns:
            Any: Anthropic Message response object containing the generated text.

        Raises:
            TypeError: If messages is not a list of dictionaries.
            ValueError: If messages is empty or malformed.
            ValueError: If max_tokens is not positive.
            ValueError: If temperature is outside valid range.
            RuntimeError: If the API call fails.
        """
        # Validate messages parameter type
        if not isinstance(messages, list):
            raise TypeError(f"messages must be a list, got {type(messages).__name__}")

        # Validate messages is non-empty
        if len(messages) == 0:
            raise ValueError("messages list cannot be empty")

        # Validate each message has required keys
        for idx, msg in enumerate(messages):
            if not isinstance(msg, dict):
                raise TypeError(f"Message at index {idx} must be a dict")
            if "role" not in msg or "content" not in msg:
                raise ValueError(
                    f"Message at index {idx} must have 'role' and 'content' keys"
                )

        # Validate max_tokens is a positive integer
        if not isinstance(max_tokens, int) or max_tokens <= 0:
            raise ValueError(f"max_tokens must be a positive integer, got {max_tokens}")

        # Validate temperature is within acceptable range (Claude uses [0, 1])
        if not isinstance(temperature, (int, float)) or not (0.0 <= temperature <= 1.0):
            raise ValueError(f"temperature must be in [0.0, 1.0] for Claude, got {temperature}")

        # Log warning if logprobs was requested (not supported by Claude)
        if logprobs:
            logger.warning(
                "logprobs=True was requested but Claude does not support "
                "log-probability extraction. This parameter will be ignored."
            )

        try:
            # Call the Anthropic messages API (logprobs parameter not passed)
            response = self.client.messages.create(
                model=self.model,
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature
            )

            # Return the raw response object
            return response

        except Exception as e:
            # Log the error and re-raise with context
            logger.error(f"Anthropic API call failed: {e}")
            raise RuntimeError(f"Anthropic API call failed: {e}") from e

    def extract_response(
        self,
        response: Any
    ) -> Tuple[str, None]:
        """
        Extract generated text from Claude response (logprobs not available).

        This method parses the Anthropic Message response structure to extract
        the generated text content. Log-probabilities are not available from
        Claude's public API.

        Args:
            response (Any): The Message response object returned by prompt().

        Returns:
            Tuple[str, None]: A tuple containing:
                [0] str: The concatenated text from all content blocks.
                [1] None: Always None (Claude does not support logprobs).

        Raises:
            TypeError: If response is None.
            AttributeError: If response lacks expected structure.
        """
        # Validate response is not None
        if response is None:
            raise TypeError("response cannot be None")

        # Validate response has content attribute
        if not hasattr(response, "content"):
            raise AttributeError("response must have 'content' attribute")

        # Initialize generated text accumulator
        generated_text: str = ""

        # Iterate over content blocks in the response
        for block in response.content:
            # Check if this block contains text content
            if hasattr(block, "type") and block.type == "text":
                # Append the text content to the accumulator
                generated_text += block.text

        # Return text and None for logprobs (Claude does not support logprobs)
        return generated_text, None


class LlamaInterface(LLMInterface):
    """
    Concrete LLM interface implementation for Llama-3.3-70B-Instruct via Together AI.

    This class provides access to Llama models hosted on Together AI's inference
    platform. It enforces the use of the correct Hugging Face tokenizer to ensure
    perfect alignment between local token counting/indexing and the model's
    internal tokenization. This alignment is critical for the accuracy of
    dynamic self-information scoring.

    Attributes:
        client (Together): The initialized Together AI client instance.
        model (str): The model identifier on Together AI's platform
            (e.g., "meta-llama/Llama-3.3-70B-Instruct-Turbo").
        tokenizer (AutoTokenizer): The Hugging Face tokenizer for exact token counting.

    Notes:
        - Requires TOGETHER_API_KEY environment variable.
        - Requires HF_TOKEN environment variable for accessing the gated Llama tokenizer.
        - Raises RuntimeError if the specific Llama tokenizer cannot be loaded;
          does NOT fall back to tiktoken to prevent data corruption.

    Example:
        >>> interface = LlamaInterface()
        >>> token_count = interface.count_tokens("Hello, world!")
        >>> response = interface.prompt(
        ...     messages=[{"role": "user", "content": "What is 2+2?"}],
        ...     logprobs=True
        ... )
        >>> text, logprobs = interface.extract_response(response)
    """

    def __init__(
        self,
        model: str = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
    ) -> None:
        """
        Initialize the Llama interface with Together AI client and HF tokenizer.

        Args:
            model (str, optional): The model identifier on Together AI.
                Defaults to "meta-llama/Llama-3.3-70B-Instruct-Turbo".

        Raises:
            ValueError: If TOGETHER_API_KEY is missing.
            ImportError: If 'together' or 'transformers' libraries are missing.
            RuntimeError: If the Hugging Face tokenizer cannot be loaded (e.g., invalid HF_TOKEN).
        """
        # Validate dependencies
        if Together is None:
            raise ImportError("The 'together' library is required. Install it via `pip install together`.")
        if AutoTokenizer is None:
            raise ImportError("The 'transformers' library is required. Install it via `pip install transformers`.")

        # Validate model parameter
        if not isinstance(model, str) or not model.strip():
            raise TypeError("model must be a non-empty string")

        # Retrieve API key
        api_key = os.environ.get("TOGETHER_API_KEY")
        if not api_key:
            raise ValueError(
                "TOGETHER_API_KEY environment variable is not set. "
                "Please set it to your Together AI API key."
            )

        # Initialize Client
        self.client = Together(api_key=api_key)
        self.model = model

        # Initialize Tokenizer
        # We strictly require the correct tokenizer. No fallbacks.
        hf_model_id = "meta-llama/Llama-3.3-70B-Instruct"
        hf_token = os.environ.get("HF_TOKEN")

        try:
            logger.info(f"Loading tokenizer for {hf_model_id}...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                hf_model_id,
                token=hf_token,
                use_fast=True
            )
            logger.info("Successfully loaded Llama tokenizer.")
        except Exception as e:
            logger.error(f"Failed to load Llama tokenizer: {e}")
            raise RuntimeError(
                f"Could not load tokenizer for '{hf_model_id}'. "
                "Ensure HF_TOKEN is set and you have access to the gated model. "
                "Strict token alignment is required; cannot proceed without correct tokenizer."
            ) from e

    def count_tokens(self, text: str) -> int:
        """
        Count tokens using the loaded Hugging Face tokenizer.

        Args:
            text (str): The input text.

        Returns:
            int: The exact number of tokens.

        Raises:
            TypeError: If text is not a string.
        """
        if not isinstance(text, str):
            raise TypeError(f"text must be a string, got {type(text).__name__}")

        if not text:
            return 0

        # Encode with special tokens to match model behavior
        token_ids = self.tokenizer.encode(text, add_special_tokens=True)
        return len(token_ids)

    def prompt(
        self,
        messages: List[Dict[str, str]],
        max_tokens: int = 512,
        temperature: float = 0.0,
        logprobs: bool = False
    ) -> Any:
        """
        Send a chat completion request to Together AI.

        Args:
            messages (List[Dict]): Chat messages.
            max_tokens (int): Max tokens to generate.
            temperature (float): Sampling temperature.
            logprobs (bool): Whether to return log-probabilities.

        Returns:
            Any: The raw API response object.
        """
        # Input validation
        if not isinstance(messages, list) or not messages:
            raise ValueError("messages must be a non-empty list of dicts")
        if not isinstance(max_tokens, int) or max_tokens <= 0:
            raise ValueError("max_tokens must be a positive integer")
        if not isinstance(temperature, (int, float)) or temperature < 0:
            raise ValueError("temperature must be non-negative")

        # Together AI uses 'logprobs' parameter (int 1 for True) in some endpoints,
        # but for chat completions, support varies. We assume standard usage.
        # Note: As of late 2024, Together's chat endpoint might not return logprobs
        # for all models. If strictly needed for scoring, we might use the completion
        # endpoint in 'score_prompt_tokens', but here we implement the standard chat interface.
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                logprobs=1 if logprobs else 0
            )
            return response
        except Exception as e:
            logger.error(f"Together AI API call failed: {e}")
            raise RuntimeError(f"Together AI API call failed: {e}") from e

    def extract_response(
        self,
        response: Any
    ) -> Tuple[str, Optional[List[Dict[str, Any]]]]:
        """
        Extract generated text and log-probabilities from the response.

        Args:
            response (Any): The API response object.

        Returns:
            Tuple[str, Optional[List[Dict]]]: Generated text and logprobs list.
        """
        if response is None:
            raise TypeError("response cannot be None")

        if not hasattr(response, "choices") or not response.choices:
            raise AttributeError("Response missing 'choices'")

        choice = response.choices[0]

        # Extract text
        if hasattr(choice, "message") and choice.message:
            text = choice.message.content or ""
        else:
            text = ""

        # Extract logprobs
        logprobs_data = None

        # Check for logprobs in the response structure
        # Together AI structure for chat logprobs can vary; typically under 'logprobs'
        if hasattr(choice, "logprobs") and choice.logprobs:
            # Assuming OpenAI-like structure if present
            if hasattr(choice.logprobs, "token_logprobs"):
                # Completion-style
                tokens = choice.logprobs.tokens
                values = choice.logprobs.token_logprobs
                logprobs_data = []
                for t, lp in zip(tokens, values):
                    logprobs_data.append({
                        "token": t,
                        "logprob": lp,
                        "top_logprobs": [] # Often not returned in simple mode
                    })
            elif hasattr(choice.logprobs, "content"):
                # Chat-style (like OpenAI)
                logprobs_data = []
                for item in choice.logprobs.content:
                    logprobs_data.append({
                        "token": item.token,
                        "logprob": item.logprob,
                        "top_logprobs": [] # Populate if available
                    })

        return text, logprobs_data

def create_llm_interface(model_name: str) -> LLMInterface:
    """
    Factory function to instantiate the appropriate LLM interface for a given model.

    This function provides a centralized mechanism for creating LLM interface
    instances, abstracting away provider-specific initialization details and
    ensuring consistent model identifier mapping across the CompactPrompt pipeline.

    The factory pattern enables seamless switching between scorer and target LLMs
    as specified in Task 14, supporting the four models evaluated in the
    CompactPrompt paper:
    - GPT-4-Omni (OpenAI)
    - GPT-4.1-Mini (OpenAI)
    - Claude-3.5-Sonnet (Anthropic)
    - Llama-3.3-70B-Instruct (Together AI)

    Args:
        model_name (str): Human-readable model name. Must be one of:
            - "GPT-4o"
            - "GPT-4o-mini"
            - "Claude-3.5-Sonnet"
            - "Llama-3.3-70B-Instruct"

    Returns:
        LLMInterface: A concrete LLMInterface implementation instance
            configured for the specified model.

    Raises:
        TypeError: If model_name is not a string.
        ValueError: If model_name is not recognized.

    Example:
        >>> scorer = create_llm_interface("GPT-4o")
        >>> target = create_llm_interface("Claude-3.5-Sonnet")
        >>> token_count = scorer.count_tokens("Hello, world!")

    Notes:
        - Each interface class handles its own API key retrieval from
          environment variables.
        - For Claude, note that logprobs are not available; use GPT-4o
          or Llama for scoring tasks requiring dynamic self-information.
    """
    # Validate input type
    if not isinstance(model_name, str):
        raise TypeError(
            f"model_name must be a string, got {type(model_name).__name__}"
        )

    # Define mapping from human-readable names to (model_id, interface_class)
    model_map: Dict[str, Tuple[str, type]] = {
        "GPT-4o": ("gpt-4o", GPT4oInterface),
        "GPT-4o-mini": ("gpt-4o-mini", GPT4oInterface),
        "Claude-3.5-Sonnet": ("claude-3-5-sonnet-20241022", ClaudeInterface),
        "Llama-3.3-70B-Instruct": (
            "meta-llama/Llama-3.3-70B-Instruct-Turbo",
            LlamaInterface
        )
    }

    # Validate that model_name is in the supported set
    if model_name not in model_map:
        raise ValueError(
            f"Unknown model: '{model_name}'. "
            f"Supported models: {list(model_map.keys())}"
        )

    # Retrieve the model identifier and interface class
    model_id, interface_class = model_map[model_name]

    # Instantiate the appropriate interface with the model identifier
    interface: LLMInterface = interface_class(model=model_id)

    # Log the successful creation
    logger.info(f"Created LLM interface for model: {model_name}")

    # Return the configured interface instance
    return interface

# -------------------------------------------------------------------------------------------------------------------------------
# Task 14, Step 3: Factory and Orchestrator
# -------------------------------------------------------------------------------------------------------------------------------
def create_llm_interface(model_name: str) -> LLMInterface:
    """
    Factory function to instantiate the appropriate LLM interface based on the model name.

    This function serves as a centralized factory for creating concrete LLMInterface
    instances. It maps human-readable model names (e.g., "GPT-4o") to their
    corresponding implementation classes (e.g., GPT4oInterface) and specific model
    identifiers required by the provider APIs.

    Args:
        model_name (str): The standardized model name. Must be one of:
            - "GPT-4o"
            - "GPT-4o-mini"
            - "Claude-3.5-Sonnet"
            - "Llama-3.3-70B-Instruct"

    Returns:
        LLMInterface: An initialized concrete implementation of LLMInterface
            configured for the specified model.

    Raises:
        ValueError: If the model_name is not recognized in the internal mapping.
        TypeError: If model_name is not a string.
    """
    # Validate input type
    if not isinstance(model_name, str):
        raise TypeError(f"model_name must be a string, got {type(model_name).__name__}")

    # Normalize model name map
    # Maps standardized keys to (API model identifier, Interface Class)
    model_map: Dict[str, Tuple[str, type]] = {
        "GPT-4o": ("gpt-4o", GPT4oInterface),
        "GPT-4o-mini": ("gpt-4o-mini", GPT4oInterface),
        "Claude-3.5-Sonnet": ("claude-3-5-sonnet-20241022", ClaudeInterface),
        "Llama-3.3-70B-Instruct": ("meta-llama/Llama-3.3-70B-Instruct-Turbo", LlamaInterface)
    }

    # Check if the requested model is supported
    if model_name not in model_map:
        raise ValueError(
            f"Unknown model: '{model_name}'. "
            f"Supported models: {list(model_map.keys())}"
        )

    # Retrieve the API identifier and the class constructor
    model_id, interface_class = model_map[model_name]

    # Instantiate and return the interface
    return interface_class(model=model_id)


def configure_llm_resources(study_config: Dict[str, Any]) -> Dict[str, LLMInterface]:
    """
    Orchestrates the configuration of all LLM resources defined in the study configuration.

    This function iterates through the configured scorer and target LLMs specified
    in the study configuration dictionary. It instantiates the appropriate
    LLMInterface for each unique model found, handling normalization of model names
    to match the factory's expected keys.

    The resulting dictionary serves as a registry of ready-to-use LLM interfaces
    for the entire pipeline, ensuring that clients and tokenizers are initialized
    only once.

    Args:
        study_config (Dict[str, Any]): The full study configuration dictionary.
            Expected to contain 'llm_config' with 'scorer_llm_options' and
            'target_llms_for_evaluation' lists.

    Returns:
        Dict[str, LLMInterface]: A dictionary where keys are standardized model names
            (e.g., "GPT-4o") and values are the initialized LLMInterface objects.
    """
    logger.info("Configuring LLM resources...")

    resources: Dict[str, LLMInterface] = {}
    llm_config = study_config.get("llm_config", {})

    # Collect all unique model names from scorers and targets to avoid duplicate initialization
    model_names: Set[str] = set()

    # Process scorer LLMs
    for scorer in llm_config.get("scorer_llm_options", []):
        raw_name = scorer.get("model_name", "")
        if not raw_name:
            continue

        # Normalize raw config names to factory keys
        # This logic handles potential case variations or partial matches from the config
        name_lower = raw_name.lower()
        if "gpt-4o" in name_lower and "mini" not in name_lower:
            key = "GPT-4o"
        elif "mini" in name_lower:
            key = "GPT-4o-mini"
        elif "claude" in name_lower:
            key = "Claude-3.5-Sonnet"
        elif "llama" in name_lower:
            key = "Llama-3.3-70B-Instruct"
        else:
            # Fallback to the raw name if no heuristic matches (will likely raise ValueError in factory)
            key = raw_name

        model_names.add(key)

    # Process target LLMs
    for target in llm_config.get("target_llms_for_evaluation", []):
        raw_name = target.get("model_name", "")
        if not raw_name:
            continue

        # Apply same normalization logic
        name_lower = raw_name.lower()
        if "gpt-4o" in name_lower and "mini" not in name_lower:
            key = "GPT-4o"
        elif "mini" in name_lower:
            key = "GPT-4o-mini"
        elif "claude" in name_lower:
            key = "Claude-3.5-Sonnet"
        elif "llama" in name_lower:
            key = "Llama-3.3-70B-Instruct"
        else:
            key = raw_name

        model_names.add(key)

    # Instantiate interfaces for all identified unique models
    for name in model_names:
        try:
            interface = create_llm_interface(name)
            resources[name] = interface
            logger.info(f"Successfully configured interface for {name}")
        except ValueError as e:
            logger.error(f"Failed to configure interface for {name}: {e}")
            # We continue to try other models rather than crashing entirely

    return resources


In [None]:
# Task 15 – Tokenize Prompts for Dynamic Scoring

# ==============================================================================
# Task 15: Tokenize Prompts for Dynamic Scoring
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 15, Step 1: Serialize example to prompt string
# -------------------------------------------------------------------------------------------------------------------------------
def serialize_example_to_prompt(
    row: pd.Series,
    dataset_name: str,
    exemplars_str: str = ""
) -> str:
    """
    Constructs the full prompt string for a single example row.

    Args:
        row (pd.Series): A row from the cleansed DataFrame containing 'question_text',
                         'serialized_tables', 'passages'.
        dataset_name (str): "TAT-QA" or "Fin-QA".
        exemplars_str (str): Pre-formatted exemplars string (optional).

    Returns:
        str: The full prompt string ready for tokenization.
    """
    question = row.get("question_text", "")

    # Tables are already serialized in 'serialized_tables' (list of strings)
    tables_list = row.get("serialized_tables", [])
    tables_str = "\n\n".join(tables_list) if isinstance(tables_list, list) else ""

    # Passages are list of dicts, need to extract text
    passages_list = row.get("passages", [])
    passages_str = ""
    if isinstance(passages_list, list):
        passages_str = "\n\n".join([p.get("text", "").strip() for p in passages_list if isinstance(p, dict)])

    # Use the construct_prompt function from Task 12
    # We assume construct_prompt is available in the environment
    return construct_prompt(
        dataset=dataset_name,
        question=question,
        tables_str=tables_str,
        passages_str=passages_str,
        exemplars_str=exemplars_str
    )

# -------------------------------------------------------------------------------------------------------------------------------
# Task 15, Step 2 & 3: Tokenize and Compute Offsets
# -------------------------------------------------------------------------------------------------------------------------------
def tokenize_with_offsets(
    text: str,
    interface: LLMInterface
) -> Tuple[List[int], List[str], List[Tuple[int, int]]]:
    """
    Tokenizes the input text and computes precise character offsets for each token.

    This function handles the complexities of different tokenizer implementations:
    1. **Hugging Face Tokenizers (e.g., Llama)**: Natively support offset mapping via
       `return_offsets_mapping=True`.
    2. **Tiktoken (e.g., GPT-4o, Claude approximation)**: Does not natively support
       offsets. We implement a robust reconstruction algorithm that decodes tokens
       sequentially and maps them back to the original string to determine character spans.

    This alignment is critical for Task 19, where dependency parsing (operating on
    character spans) must be mapped to LLM tokens for phrase-level pruning.

    Args:
        text (str): The full prompt text to tokenize.
        interface (LLMInterface): The configured LLM interface containing the
            appropriate tokenizer (either `tokenizer` for HF or `encoding` for tiktoken).

    Returns:
        Tuple[List[int], List[str], List[Tuple[int, int]]]: A tuple containing:
            - List[int]: The sequence of token IDs.
            - List[str]: The sequence of decoded token strings.
            - List[Tuple[int, int]]: A list of (start, end) character offsets for
              each token, such that `text[start:end]` corresponds to the token.

    Raises:
        ValueError: If the interface does not contain a valid tokenizer.
    """
    token_ids: List[int] = []
    token_strings: List[str] = []
    offsets: List[Tuple[int, int]] = []

    # Case 1: Hugging Face Tokenizer (e.g., Llama)
    # These tokenizers provide a built-in method to retrieve offset mappings.
    if hasattr(interface, 'tokenizer') and interface.tokenizer is not None:
        # Tokenize with offset mapping enabled
        # add_special_tokens=False ensures we don't get unexpected BOS/EOS unless intended
        encoding = interface.tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)

        token_ids = encoding.input_ids
        offsets = encoding.offset_mapping

        # Convert IDs back to tokens for interpretability and debugging
        token_strings = interface.tokenizer.convert_ids_to_tokens(token_ids)

    # Case 2: Tiktoken (e.g., GPT-4o, Claude approximation)
    # Tiktoken encodes to IDs but does not provide character offsets directly.
    # We must reconstruct them by decoding tokens and matching against the input text.
    elif hasattr(interface, 'encoding') and interface.encoding is not None:
        enc = interface.encoding
        token_ids = enc.encode(text)

        current_offset = 0

        # Iterate through each token ID to reconstruct its string and offset
        for tid in token_ids:
            # Decode the single token bytes
            token_bytes = enc.decode_single_token_bytes(tid)

            # Decode bytes to string, handling potential replacement characters
            token_str = token_bytes.decode('utf-8', errors='replace')
            token_strings.append(token_str)

            # Determine the length of the token in the original text
            # Note: This assumes the decoded token string length matches the
            # character length in the original text. For standard BPE, this holds.
            t_len = len(token_str)

            # Calculate start and end offsets
            start = current_offset
            end = current_offset + t_len

            # Store the offset tuple
            offsets.append((start, end))

            # Advance the current offset
            current_offset = end

    else:
        logger.warning("No valid tokenizer found in interface. Returning empty lists.")

    return token_ids, token_strings, offsets

# -------------------------------------------------------------------------------------------------------------------------------
# Task 15, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def prepare_dynamic_scoring_inputs(
    df: pd.DataFrame,
    dataset_name: str,
    scorer_interface: LLMInterface,
    exemplars_str: str = ""
) -> Dict[str, Dict[str, Any]]:
    """
    Orchestrates the preparation of prompts for dynamic scoring.

    For each example in the DataFrame:
    1. Constructs the full prompt string.
    2. Tokenizes the prompt.
    3. Computes character offsets.

    Args:
        df (pd.DataFrame): The input DataFrame (TAT-QA or Fin-QA).
        dataset_name (str): Name of the dataset.
        scorer_interface (LLMInterface): The interface for the scoring model.
        exemplars_str (str): Optional few-shot exemplars string.

    Returns:
        Dict[str, Dict[str, Any]]: A mapping from example_id to a dictionary containing:
            - 'prompt_text': The full prompt string.
            - 'token_ids': List of token IDs.
            - 'token_strings': List of token strings.
            - 'offsets': List of (start, end) character offsets.
    """
    logger.info(f"Preparing dynamic scoring inputs for {dataset_name}...")

    results = {}

    for idx, row in df.iterrows():
        example_id = row.get("example_id")
        if not example_id:
            continue

        # Step 1: Serialize
        prompt_text = serialize_example_to_prompt(row, dataset_name, exemplars_str)

        # Step 2 & 3: Tokenize and Offsets
        token_ids, token_strings, offsets = tokenize_with_offsets(prompt_text, scorer_interface)

        results[example_id] = {
            "prompt_text": prompt_text,
            "token_ids": token_ids,
            "token_strings": token_strings,
            "offsets": offsets
        }

    logger.info(f"Prepared inputs for {len(results)} examples.")
    return results


In [None]:
# Task 16 – Query LLM for Conditional Probabilities

# ==============================================================================
# Task 16: Query LLM for Conditional Probabilities
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 16, Helper: Iterative Scoring for Chat Models
# -------------------------------------------------------------------------------------------------------------------------------
def score_sequence_iterative(
    interface: Any,  # Typed as Any to avoid circular import with LLMInterface class definition
    token_ids: List[int],
    token_strings: List[str]
) -> List[float]:
    """
    Computes log-probabilities for a token sequence using iterative "teacher forcing"
    with a chat completion API.

    This function approximates the conditional probability P(t_i | t_0...t_{i-1})
    by iteratively prompting the model with the prefix t_0...t_{i-1} and checking
    the probability assigned to the actual next token t_i in the generation output.

    CRITICAL NOTE:
    Standard Chat APIs (e.g., OpenAI GPT-4o) do not support scoring the prompt
    directly (i.e., no `echo=True` equivalent). This method computes the
    *continuation probability*: P(t_i | User: t_0...t_{i-1}). This is a proxy
    for the true prompt probability P(t_i | t_0...t_{i-1}) but is the best
    available approximation for black-box Chat models.

    Algorithm:
    1. Initialize logprobs list with 0.0 for the first token (unconditional).
    2. Iterate through the sequence from i = 1 to T-1.
    3. Construct the context string from tokens t_0...t_{i-1}.
    4. Request generation of 1 token with logprobs enabled.
    5. Extract the logprob of the target token t_i from the response's top_logprobs.
       - Uses robust matching: compares token IDs (if tokenizer available) or
         stripped strings to handle whitespace artifacts.
    6. If the target token is not in the top-K logprobs, assign a penalty value (-15.0).

    Args:
        interface (LLMInterface): The LLM interface to use. Must support `prompt`
                                  and `extract_response`.
        token_ids (List[int]): The sequence of token IDs to score.
        token_strings (List[str]): The sequence of token strings (decoded).

    Returns:
        List[float]: A list of natural log probabilities, one for each token.
                     The first token's probability is set to 0.0.
    """
    # Initialize logprobs list; the first token has no preceding context in this scope
    logprobs: List[float] = [0.0]

    # Iterate through the sequence starting from the second token
    for i in range(1, len(token_ids)):
        target_token_id = token_ids[i]
        target_token_str = token_strings[i]

        # Construct context from previous tokens
        # We join with empty string as BPE token strings usually include necessary spacing
        context_text = "".join(token_strings[:i])

        # Construct the message payload for the Chat API
        messages = [{"role": "user", "content": context_text}]

        try:
            # Request generation of exactly 1 token with logprobs enabled
            # We set temperature to 0.0 for deterministic behavior
            response = interface.prompt(
                messages=messages,
                max_tokens=1,
                temperature=0.0,
                logprobs=True
            )

            # Extract the generated text and the logprobs data structure
            _, response_logprobs = interface.extract_response(response)

            # Handle cases where logprobs are missing or empty
            if not response_logprobs:
                logger.warning(f"No logprobs returned for token index {i}. Assigning penalty.")
                logprobs.append(-15.0) # Penalty for missing data
                continue

            # The API returns a list of generated tokens. We only asked for 1.
            # Get the data for the first (and only) generated token position.
            gen_token_data = response_logprobs[0]

            # Initialize target logprob with a penalty value (representing very low probability)
            # -15.0 corresponds to exp(-15) approx 3e-7
            target_logprob = -15.0
            match_found = False

            # Helper to check if a candidate matches the target
            def is_match(candidate_str: str) -> bool:
                # Method A: Re-encode and compare IDs (Most Rigorous)
                if hasattr(interface, 'encoding') and interface.encoding:
                    try:
                        # Encode candidate string to ID
                        cand_ids = interface.encoding.encode(candidate_str)
                        # Check if it matches the target ID
                        # Note: candidate might encode to multiple tokens if it's long,
                        # but here we expect single token generation.
                        if len(cand_ids) == 1 and cand_ids[0] == target_token_id:
                            return True
                    except Exception:
                        pass # Fallback to string comparison

                if hasattr(interface, 'tokenizer') and interface.tokenizer:
                     try:
                        cand_ids = interface.tokenizer.encode(candidate_str, add_special_tokens=False)
                        if len(cand_ids) == 1 and cand_ids[0] == target_token_id:
                            return True
                     except Exception:
                        pass

                # Method B: String comparison (Robust Fallback)
                # Compare stripped strings to ignore leading/trailing whitespace differences
                # which are common artifacts in tokenization (e.g. " world" vs "world")
                return candidate_str.strip() == target_token_str.strip()

            # 1. Check if the top generated token is our target
            if is_match(gen_token_data['token']):
                target_logprob = gen_token_data['logprob']
                match_found = True

            # 2. If not, check the top_logprobs list for alternatives
            elif 'top_logprobs' in gen_token_data:
                for alt in gen_token_data['top_logprobs']:
                    if is_match(alt['token']):
                        target_logprob = alt['logprob']
                        match_found = True
                        break

            # Append the found logprob (or the penalty if no match found)
            logprobs.append(target_logprob)

        except Exception as e:
            logger.error(f"Error scoring token {i} ('{target_token_str}'): {e}")
            logprobs.append(-15.0) # Fallback penalty

    return logprobs

# -------------------------------------------------------------------------------------------------------------------------------
# Task 16, Helper: Score Prompt Tokens
# -------------------------------------------------------------------------------------------------------------------------------
def score_prompt_tokens(
    interface: LLMInterface,
    text: str,
    token_ids: List[int],
    token_strings: List[str]
) -> List[float]:
    """
    Obtains the log-probability log(P(t_i | t_0...t_{i-1})) for each token in the text.

    This function dispatches the scoring request to the appropriate strategy based on
    the specific LLMInterface implementation. It enforces strict alignment between
    the locally computed token sequence and the API's response.

    Strategies:
    1. **Llama (Together AI)**: Uses the `completions` endpoint with `echo=True` and `logprobs=1`.
       This is efficient as it scores the entire prompt in one pass.
    2. **GPT-4o (OpenAI)**: Uses iterative "teacher forcing" via the chat API (implemented
       in `score_sequence_iterative`). This is slower but necessary as OpenAI does not
       support `echo=True` for chat models.
    3. **Claude (Anthropic)**: Returns 0.0s as log-probabilities are not supported.

    Args:
        interface (LLMInterface): The configured LLM interface.
        text (str): The full prompt text.
        token_ids (List[int]): The list of token IDs computed locally.
        token_strings (List[str]): The list of token strings computed locally.

    Returns:
        List[float]: A list of natural log probabilities, one for each token.

    Raises:
        RuntimeError: If the number of log-probabilities returned by the API does not
            match the number of local tokens (for Llama). This indicates a critical
            tokenizer mismatch.
    """
    # Strategy 1: Llama (Together AI) - Efficient `echo=True` scoring
    if isinstance(interface, LlamaInterface):
        # Verify that the interface has a valid tokenizer loaded
        # The remedied LlamaInterface guarantees this in __init__, but we check for safety.
        if not hasattr(interface, 'tokenizer') or interface.tokenizer is None:
             raise RuntimeError(
                 "LlamaInterface must have a valid HuggingFace tokenizer loaded for scoring. "
                 "Ensure HF_TOKEN is set and transformers is installed."
             )

        try:
            # Access the raw client to use completions endpoint
            # Note: LlamaInterface exposes .client (Together instance) and .model
            response = interface.client.completions.create(
                model=interface.model,
                prompt=text,
                max_tokens=0,       # We only want to score the prompt, not generate
                logprobs=1,         # Request token logprobs
                echo=True,          # Echo the prompt to get its scores
                temperature=0.0     # Deterministic
            )

            # Extract logprobs from response
            if hasattr(response.choices[0], 'logprobs') and response.choices[0].logprobs:
                token_logprobs = response.choices[0].logprobs.token_logprobs

                # Strict Length Validation
                # We must ensure the API's tokenization matches our local token_ids exactly.
                if len(token_logprobs) != len(token_ids):
                    raise RuntimeError(
                        f"Token length mismatch for Llama scoring! "
                        f"Local tokenizer found {len(token_ids)} tokens, "
                        f"API returned {len(token_logprobs)} logprobs. "
                        "This indicates a divergence between the local HuggingFace tokenizer "
                        "and the remote API tokenizer. Scoring cannot proceed safely."
                    )

                # Replace None values (often the first token has None logprob) with 0.0
                return [lp if lp is not None else 0.0 for lp in token_logprobs]
            else:
                logger.warning("Llama response did not contain logprobs. Returning 0.0s.")
                return [0.0] * len(token_ids)

        except Exception as e:
            logger.error(f"Llama completion scoring failed: {e}")
            raise RuntimeError(f"Llama scoring failed: {e}") from e

    # Strategy 2: GPT-4o (Chat API) - Iterative Scoring
    if isinstance(interface, GPT4oInterface):
        logger.info("Using iterative scoring for GPT-4o (this may take time)...")
        # score_sequence_iterative is assumed to be defined in the scope
        return score_sequence_iterative(interface, token_ids, token_strings)

    # Strategy 3: Claude - Unsupported
    if isinstance(interface, ClaudeInterface):
        logger.warning("Claude does not support logprobs. Returning 0.0s.")
        return [0.0] * len(token_ids)

    # Default Fallback for unknown interfaces
    logger.warning(f"Unknown interface type {type(interface)}. Returning 0.0s.")
    return [0.0] * len(token_ids)

# -------------------------------------------------------------------------------------------------------------------------------
# Task 16, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def get_prompt_logprobs_task(
    dynamic_inputs: Dict[str, Dict[str, Any]],
    scorer_interface: LLMInterface
) -> Dict[str, List[float]]:
    """
    Orchestrates the retrieval of log-probabilities for all prepared prompts.

    Iterates through each example's tokenized prompt and queries the scorer LLM
    to obtain conditional log-probabilities.

    Args:
        dynamic_inputs (Dict): Output from Task 15, mapping example_id to
                               {'prompt_text', 'token_ids', 'token_strings', ...}.
        scorer_interface (LLMInterface): The interface to use for scoring.

    Returns:
        Dict[str, List[float]]: Mapping example_id -> list of logprobs (natural log).
    """
    logger.info("Starting dynamic scoring query...")

    results = {}
    total_examples = len(dynamic_inputs)
    processed = 0

    # Iterate through dynamic inputs and score prompts
    for example_id, data in dynamic_inputs.items():
        prompt_text = data["prompt_text"]
        token_ids = data["token_ids"]
        token_strings = data["token_strings"]

        # Score the prompt
        logprobs = score_prompt_tokens(scorer_interface, prompt_text, token_ids, token_strings)

        results[example_id] = logprobs
        processed += 1

        if processed % 10 == 0:
            logger.info(f"Scored {processed}/{total_examples} examples.")

    logger.info(f"Scoring complete for {len(results)} examples.")
    return results


In [None]:
# Task 17 – Compute Dynamic Self-Information

# ==============================================================================
# Task 17: Compute Dynamic Self-Information
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 17, Step 1: Convert probabilities to self-information
# -------------------------------------------------------------------------------------------------------------------------------
def convert_logprobs_to_self_info(logprobs: List[float]) -> List[float]:
    """
    Converts natural log probabilities to self-information in bits.

    Equation:
        s_dyn(t) = -log2(P(t)) = -ln(P(t)) / ln(2)

    Args:
        logprobs (List[float]): List of natural log probabilities.

    Returns:
        List[float]: List of self-information scores in bits.
    """
    ln_2 = math.log(2)
    # Handle potential None or NaN in input by treating as 0 probability (high info)
    # But logprobs should be floats. We assume valid floats from Task 16.

    s_dyn = []
    for lp in logprobs:
        if lp is None or math.isnan(lp):
            # Assign a high penalty value for missing info
            s_dyn.append(20.0)
        else:
            # Ensure lp is <= 0 (probability <= 1)
            # If lp > 0 due to float error, clamp to 0
            clean_lp = min(0.0, lp)
            s_dyn.append(-clean_lp / ln_2)

    return s_dyn

# -------------------------------------------------------------------------------------------------------------------------------
# Task 17, Step 2: Validate dynamic scores
# -------------------------------------------------------------------------------------------------------------------------------
def validate_dynamic_scores(s_dyn: List[float], example_id: str) -> bool:
    """
    Validates the computed self-information scores.

    Checks for:
    - Non-negativity.
    - Finite values.

    Args:
        s_dyn (List[float]): Self-information scores.
        example_id (str): ID for logging.

    Returns:
        bool: True if valid.
    """
    if not s_dyn:
        logger.warning(f"Empty s_dyn for {example_id}")
        return False

    s_dyn_arr = np.array(s_dyn)

    if np.any(s_dyn_arr < 0):
        logger.error(f"Negative self-information detected for {example_id}. Min: {np.min(s_dyn_arr)}")
        return False

    if not np.all(np.isfinite(s_dyn_arr)):
        logger.error(f"Non-finite self-information detected for {example_id}.")
        return False

    return True

# -------------------------------------------------------------------------------------------------------------------------------
# Task 17, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def compute_dynamic_scores_task(
    logprobs_map: Dict[str, List[float]]
) -> Dict[str, List[float]]:
    """
    Orchestrates the computation of dynamic self-information scores.

    Args:
        logprobs_map (Dict[str, List[float]]): Mapping example_id -> logprobs.

    Returns:
        Dict[str, List[float]]: Mapping example_id -> s_dyn scores.
    """
    logger.info("Starting dynamic self-information computation...")

    s_dyn_map = {}

    for example_id, logprobs in logprobs_map.items():
        # Step 1: Convert
        s_dyn = convert_logprobs_to_self_info(logprobs)

        # Step 2: Validate
        if validate_dynamic_scores(s_dyn, example_id):
            s_dyn_map[example_id] = s_dyn
        else:
            # Fallback: return zeros or handle error?
            # We'll return zeros to allow pipeline to continue, but logged error indicates issue.
            s_dyn_map[example_id] = [0.0] * len(logprobs)

    logger.info(f"Computed scores for {len(s_dyn_map)} examples.")
    return s_dyn_map


In [None]:
# Task 18 – Compute Relative Difference and Combined Score

# ==============================================================================
# Task 18: Compute Relative Difference and Combined Score
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 18, Step 1 & 2: Compute Combined Scores
# -------------------------------------------------------------------------------------------------------------------------------
def calculate_combined_scores(
    token_ids: List[int],
    s_dyn: List[float],
    s_stat_lookup: Dict[str, float], # Keys are strings in JSON
    delta_threshold: float = 0.1
) -> List[float]:
    """
    Computes the combined importance score C(t) for each token.

    Equations:
        Delta = |s_dyn - s_stat| / s_stat
        C(t) = (s_stat + s_dyn) / 2  if Delta <= 0.1
             = s_dyn                 if Delta > 0.1

    Args:
        token_ids (List[int]): List of token IDs.
        s_dyn (List[float]): List of dynamic self-information scores.
        s_stat_lookup (Dict[str, float]): Static self-information lookup table.
        delta_threshold (float): Threshold for relative difference (default 0.1).

    Returns:
        List[float]: List of combined scores C(t).
    """
    combined_scores = []

    for i, tid in enumerate(token_ids):
        # Retrieve static score
        # JSON keys are strings, so convert tid to str
        s_stat_val = s_stat_lookup.get(str(tid))
        s_dyn_val = s_dyn[i]

        # Calculate Delta
        if s_stat_val is None or s_stat_val == 0:
            # If static score is missing or zero, we cannot compute relative difference reliably.
            # We default to dynamic score (Delta = infinity).
            delta = float('inf')
        else:
            delta = abs(s_dyn_val - s_stat_val) / s_stat_val

        # Fusion Rule
        if delta <= delta_threshold:
            c_val = (s_stat_val + s_dyn_val) / 2.0
        else:
            c_val = s_dyn_val

        combined_scores.append(c_val)

    return combined_scores

# -------------------------------------------------------------------------------------------------------------------------------
# Task 18, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def compute_combined_scores_task(
    dynamic_inputs: Dict[str, Dict[str, Any]],
    s_dyn_map: Dict[str, List[float]],
    s_stat_lookup: Dict[str, float],
    config: Dict[str, Any]
) -> Dict[str, List[float]]:
    """
    Orchestrates the computation of combined importance scores.

    Args:
        dynamic_inputs (Dict): Contains token_ids per example.
        s_dyn_map (Dict): Contains s_dyn scores per example.
        s_stat_lookup (Dict): Static scores lookup.
        config (Dict): Configuration containing delta threshold.

    Returns:
        Dict[str, List[float]]: Mapping example_id -> combined scores.
    """
    logger.info("Starting combined score computation...")

    threshold = config.get("hard_prompt_compression_config", {}).get("delta_relative_difference_threshold", 0.1)
    combined_scores_map = {}

    # Iterate through the examples
    for example_id, inputs in dynamic_inputs.items():
        token_ids = inputs["token_ids"]
        s_dyn = s_dyn_map.get(example_id)

        if not s_dyn:
            logger.warning(f"No dynamic scores found for {example_id}. Skipping.")
            continue

        if len(token_ids) != len(s_dyn):
            logger.error(f"Length mismatch for {example_id}: tokens={len(token_ids)}, s_dyn={len(s_dyn)}")

            # Truncate to minimum length to proceed safely
            min_len = min(len(token_ids), len(s_dyn))
            token_ids = token_ids[:min_len]
            s_dyn = s_dyn[:min_len]

        c_scores = calculate_combined_scores(token_ids, s_dyn, s_stat_lookup, threshold)
        combined_scores_map[example_id] = c_scores

    logger.info(f"Computed combined scores for {len(combined_scores_map)} examples.")
    return combined_scores_map


In [None]:
# Task 19 – Group Tokens into Phrases Using Dependency Parsing

# ==============================================================================
# Task 19: Group Tokens into Phrases Using Dependency Parsing
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 19, Step 1: Run dependency parser
# -------------------------------------------------------------------------------------------------------------------------------
def parse_prompt_text(
    text: str,
    spacy_model_name: str = "en_core_web_sm"
) -> List[Tuple[int, int, str]]:
    """
    Parses the prompt text using spaCy to identify syntactic phrase boundaries.

    This function implements the "dependency-based phrase grouping" required for
    hard prompt compression. Unlike simple noun chunking, this implementation
    extracts a broader set of syntactic units to enable more granular and
    comprehensive pruning.

    Extracted Phrase Types:
    1. **NP (Noun Phrases)**: Extracted via `doc.noun_chunks`. Represents entities
       and objects (e.g., "the total revenue").
    2. **VP (Verb Phrases)**: Identified by `VERB` or `AUX` tokens. We group the
       verb with its immediate auxiliaries and negations to form a coherent action
       unit (e.g., "did not increase").
    3. **PP (Prepositional Phrases)**: Identified by `ADP` (adposition) tokens.
       We group the preposition with its subtree (often an NP) to capture
       contextual modifiers (e.g., "in 2019").

    Args:
        text (str): The raw prompt text to parse.
        spacy_model_name (str): The name of the spaCy model to load.
            Defaults to "en_core_web_sm".

    Returns:
        List[Tuple[int, int, str]]: A list of tuples, each representing a phrase.
            Format: (char_start, char_end, label).
            The list is sorted by start position.

    Notes:
        - Overlapping phrases are possible (e.g., a PP containing an NP).
          Downstream logic in `map_tokens_to_phrases` handles assignment/partitioning.
        - Requires the specified spaCy model to be installed.
    """
    # Load spaCy model
    try:
        nlp = spacy.load(spacy_model_name)
    except OSError:
        logger.warning(f"spaCy model '{spacy_model_name}' not found. Downloading...")
        from spacy.cli import download
        download(spacy_model_name)
        nlp = spacy.load(spacy_model_name)

    # Process text
    # Disable unnecessary pipeline components for speed if possible, but we need parser/tagger
    doc = nlp(text)

    phrases: List[Tuple[int, int, str]] = []

    # 1. Extract Noun Phrases (NP)
    # spaCy's noun_chunks iterator provides robust NP extraction
    for chunk in doc.noun_chunks:
        phrases.append((chunk.start_char, chunk.end_char, "NP"))

    # 2. Extract Verb Phrases (VP) and Prepositional Phrases (PP)
    # We iterate over tokens to find heads of these phrases
    for token in doc:
        # Verb Phrases: Group verb with auxiliaries/particles
        if token.pos_ in ["VERB", "AUX"]:
            # We want a compact VP, not the whole clause.
            # Strategy: Collect the verb and its immediate children that are aux/neg/prt
            vp_tokens = [token]
            for child in token.children:
                if child.dep_ in ["aux", "auxpass", "neg", "prt"]:
                    vp_tokens.append(child)

            # Determine span of this token group
            # Note: These might be non-contiguous in rare cases, but for pruning
            # we usually want contiguous spans. We take the min/max extent.
            min_i = min(t.i for t in vp_tokens)
            max_i = max(t.i for t in vp_tokens)

            # Create span from doc
            span = doc[min_i : max_i + 1]
            phrases.append((span.start_char, span.end_char, "VP"))

        # Prepositional Phrases: Group preposition with its subtree
        elif token.pos_ == "ADP":
            # A PP is rooted at the ADP. Its subtree usually includes the object.
            # We take the full subtree of the ADP as the PP span.
            span = list(token.subtree)
            if span:
                min_i = min(t.i for t in span)
                max_i = max(t.i for t in span)

                # Create span object to get char offsets
                pp_span = doc[min_i : max_i + 1]
                phrases.append((pp_span.start_char, pp_span.end_char, "PP"))

    # 3. Deduplication and Sorting
    # We might have duplicates or identical spans with different labels.
    # We prioritize labels: NP > VP > PP (arbitrary, but consistent).
    # We use a set to remove exact (start, end) duplicates.
    unique_phrases = {}
    for start, end, label in phrases:
        if (start, end) not in unique_phrases:
            unique_phrases[(start, end)] = label
        else:
            # If duplicate span, keep existing (or apply priority logic here)
            pass

    # Convert back to list and sort by start position
    sorted_phrases = sorted(
        [(start, end, label) for (start, end), label in unique_phrases.items()],
        key=lambda x: (x[0], x[1])
    )

    return sorted_phrases

# -------------------------------------------------------------------------------------------------------------------------------
# Task 19, Step 2 & 3: Map tokens to phrases and resolve
# -------------------------------------------------------------------------------------------------------------------------------
def map_tokens_to_phrases(
    token_offsets: List[Tuple[int, int]],
    phrase_spans: List[Tuple[int, int, str]]
) -> List[List[int]]:
    """
    Maps LLM tokens to phrases based on character overlap.
    Ensures every token is assigned to exactly one phrase (partition).

    Args:
        token_offsets (List[Tuple[int, int]]): List of (start, end) char offsets for each token.
        phrase_spans (List[Tuple[int, int, str]]): List of (start, end, label) for candidate phrases.

    Returns:
        List[List[int]]: A list of phrases, where each phrase is a list of token indices.
    """
    num_tokens = len(token_offsets)
    token_to_phrase_idx = [-1] * num_tokens

    # 1. Assign tokens to phrases
    # We prioritize smaller phrases (more specific) if there's nesting,
    # but spacy noun_chunks usually don't overlap.
    # However, we must handle the case where a token partially overlaps.
    for p_idx, (p_start, p_end, _) in enumerate(phrase_spans):
        for t_idx, (t_start, t_end) in enumerate(token_offsets):
            # Check overlap
            # Token is inside phrase or overlaps significantly
            # Simple intersection logic: max(p_start, t_start) < min(p_end, t_end)
            if max(p_start, t_start) < min(p_end, t_end):
                # If token already assigned, check which phrase is "better"
                # Here we assume noun chunks are disjoint enough or we take the first one.
                # Since we sorted phrases, we can just assign if not assigned,
                # or overwrite if we want specific behavior.
                # Let's assign if unassigned.
                if token_to_phrase_idx[t_idx] == -1:
                    token_to_phrase_idx[t_idx] = p_idx
                else:
                    # Collision: Token overlaps multiple phrases?
                    # Spacy chunks shouldn't overlap.
                    # If they do, we keep the existing assignment (first one).
                    pass

    # 2. Group by phrase index
    # We need to handle unassigned tokens (-1).
    # We will create new singleton phrases for them.
    # Map: internal_phrase_id -> list of token_indices
    # We use a dictionary to build groups
    groups: Dict[int, List[int]] = {}

    # Existing phrases
    for p_idx in range(len(phrase_spans)):
        groups[p_idx] = []

    # Singleton counter (start after existing phrases)
    next_singleton_id = len(phrase_spans)

    for t_idx, p_idx in enumerate(token_to_phrase_idx):
        if p_idx != -1:
            groups[p_idx].append(t_idx)
        else:
            # Create singleton phrase
            groups[next_singleton_id] = [t_idx]
            next_singleton_id += 1

    # 3. Convert to list of lists
    # Filter out empty groups (e.g. phrases that matched no tokens due to alignment issues)
    final_phrases = []

    # We want to preserve the order of phrases as they appear in the text.
    # We can sort groups by the index of their first token.
    sorted_group_keys = sorted(groups.keys(), key=lambda k: groups[k][0] if groups[k] else float('inf'))

    for k in sorted_group_keys:
        tokens = groups[k]
        if tokens:
            final_phrases.append(tokens)

    return final_phrases

# -------------------------------------------------------------------------------------------------------------------------------
# Task 19, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def group_tokens_into_phrases_task(
    dynamic_inputs: Dict[str, Dict[str, Any]],
    config: Dict[str, Any]
) -> Dict[str, List[List[int]]]:
    """
    Orchestrates the grouping of tokens into phrases for all examples.

    Args:
        dynamic_inputs (Dict): Contains 'prompt_text' and 'offsets' per example.
        config (Dict): Configuration containing parser settings.

    Returns:
        Dict[str, List[List[int]]]: Mapping example_id -> List of phrases (each a list of token indices).
    """
    logger.info("Starting phrase grouping...")

    spacy_model = config.get("hard_prompt_compression_config", {}).get("phrase_grouping", {}).get("parser_library", "en_core_web_sm")
    # Handle "spacy_" prefix if present in config value
    if spacy_model.startswith("spacy_"):
        spacy_model = spacy_model.replace("spacy_", "")

    results = {}

    for example_id, data in dynamic_inputs.items():
        text = data["prompt_text"]
        offsets = data["offsets"]

        # Step 1: Parse
        phrase_spans = parse_prompt_text(text, spacy_model)

        # Step 2 & 3: Map
        phrases = map_tokens_to_phrases(offsets, phrase_spans)

        results[example_id] = phrases

    logger.info(f"Phrase grouping complete for {len(results)} examples.")
    return results


In [None]:
# Task 20 – Compute Phrase-Level Importance Scores

# ==============================================================================
# Task 20: Compute Phrase-Level Importance Scores
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 20, Step 1 & 2: Aggregate scores and count tokens
# -------------------------------------------------------------------------------------------------------------------------------
def aggregate_phrase_scores(
    combined_scores: List[float],
    phrases: List[List[int]],
    aggregation_method: str = "mean"
) -> Tuple[List[float], List[int]]:
    """
    Computes importance scores and token counts for each phrase.

    Equation:
        C(pi_k) = mean({C(t) for t in pi_k})

    Args:
        combined_scores (List[float]): Token-level combined scores.
        phrases (List[List[int]]): List of phrases, where each phrase is a list of token indices.
        aggregation_method (str): "mean" or "sum". Defaults to "mean".

    Returns:
        Tuple[List[float], List[int]]:
            - List of phrase importance scores.
            - List of phrase token counts.
    """
    phrase_scores = []
    phrase_counts = []

    # Iterate through phrases
    for phrase_tokens in phrases:
        count = len(phrase_tokens)
        phrase_counts.append(count)

        if count == 0:
            # Should not happen with valid grouping, but handle safely
            phrase_scores.append(0.0)
            continue

        # Gather scores for tokens in this phrase
        # Ensure indices are within bounds
        scores = []
        for tid in phrase_tokens:
            if 0 <= tid < len(combined_scores):
                scores.append(combined_scores[tid])
            else:
                logger.warning(f"Token index {tid} out of bounds for scores list of length {len(combined_scores)}")
                scores.append(0.0)

        if not scores:
            phrase_scores.append(0.0)
            continue

        if aggregation_method == "sum":
            phrase_scores.append(sum(scores))
        else: # mean
            phrase_scores.append(sum(scores) / len(scores))

    return phrase_scores, phrase_counts

# -------------------------------------------------------------------------------------------------------------------------------
# Task 20, Step 3: Validate phrase scores
# -------------------------------------------------------------------------------------------------------------------------------
def validate_phrase_scores(phrase_scores: List[float], example_id: str) -> bool:
    """
    Validates computed phrase scores.

    Args:
        phrase_scores (List[float]): List of scores.
        example_id (str): ID for logging.

    Returns:
        bool: True if valid.
    """
    for i, score in enumerate(phrase_scores):
        if math.isnan(score) or math.isinf(score):
            logger.error(f"Invalid phrase score at index {i} for {example_id}: {score}")
            return False
        if score < 0:
            logger.warning(f"Negative phrase score at index {i} for {example_id}: {score}")
            # Negative scores might be possible if logprobs were positive (impossible)
            # or if s_stat was negative (impossible). So this is a valid check.
            return False

    return True

# -------------------------------------------------------------------------------------------------------------------------------
# Task 20, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def compute_phrase_scores_task(
    combined_scores_map: Dict[str, List[float]],
    phrases_map: Dict[str, List[List[int]]],
    config: Dict[str, Any]
) -> Dict[str, Dict[str, List[Any]]]:
    """
    Orchestrates the computation of phrase-level scores for all examples.

    Args:
        combined_scores_map (Dict): Mapping example_id -> token scores.
        phrases_map (Dict): Mapping example_id -> phrases.
        config (Dict): Configuration containing aggregation method.

    Returns:
        Dict[str, Dict]: Mapping example_id -> {'scores': List[float], 'counts': List[int]}.
    """
    logger.info("Starting phrase score computation...")

    method = config.get("hard_prompt_compression_config", {}).get("phrase_score_aggregation", {}).get("chosen_method", "mean")
    results = {}

    # Iterate through phrases
    for example_id, phrases in phrases_map.items():
        combined_scores = combined_scores_map.get(example_id)

        if not combined_scores:
            logger.warning(f"No combined scores for {example_id}. Skipping.")
            continue

        # Compute aggregate phrase scores
        scores, counts = aggregate_phrase_scores(combined_scores, phrases, method)

        if validate_phrase_scores(scores, example_id):
            results[example_id] = {
                "scores": scores,
                "counts": counts
            }

    logger.info(f"Computed phrase scores for {len(results)} examples.")
    return results


In [None]:
# Task 21 – Prune Phrases to Enforce Token Budget

# ==============================================================================
# Task 21: Prune Phrases to Enforce Token Budget
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 21, Step 1 & 2: Sort and Select Phrases
# -------------------------------------------------------------------------------------------------------------------------------
def select_phrases_knapsack(
    phrase_scores: List[float],
    phrase_counts: List[int],
    budget: int
) -> List[int]:
    """
    Selects phrases to retain based on importance scores and a token budget.

    Uses a greedy strategy: sort by score descending, then add if it fits.
    This approximates the Knapsack problem where value = score and weight = count.

    Args:
        phrase_scores (List[float]): Importance scores for each phrase.
        phrase_counts (List[int]): Token count for each phrase.
        budget (int): Maximum allowed tokens.

    Returns:
        List[int]: Indices of phrases to retain, sorted by original position.
    """
    # Create list of (index, score, count)
    # We use index as tie-breaker (keep earlier phrases if scores are equal)
    candidates = []
    for i, (score, count) in enumerate(zip(phrase_scores, phrase_counts)):
        candidates.append((i, score, count))

    # Sort by score descending
    # Python's sort is stable, so original index order is preserved for ties if we don't specify secondary key.
    # But let's be explicit: score desc, then index asc (prefer earlier context)
    candidates.sort(key=lambda x: (-x[1], x[0]))

    selected_indices = []
    current_tokens = 0

    for idx, score, count in candidates:
        if current_tokens + count <= budget:
            selected_indices.append(idx)
            current_tokens += count

    # Sort indices back to original order to preserve narrative flow
    selected_indices.sort()

    return selected_indices

# -------------------------------------------------------------------------------------------------------------------------------
# Task 21, Step 3: Reconstruct compressed prompt
# -------------------------------------------------------------------------------------------------------------------------------
def reconstruct_prompt_text(
    selected_phrase_indices: List[int],
    all_phrases: List[List[int]],
    all_token_strings: List[str]
) -> str:
    """
    Reconstructs the prompt text from the selected phrases.

    Args:
        selected_phrase_indices (List[int]): Indices of phrases to keep.
        all_phrases (List[List[int]]): List of token indices for each phrase.
        all_token_strings (List[str]): List of all token strings in the original prompt.

    Returns:
        str: The compressed prompt text.
    """
    compressed_tokens = []

    for p_idx in selected_phrase_indices:
        # Get token indices for this phrase
        token_indices = all_phrases[p_idx]

        # Get strings
        # Note: We assume token_strings are decoded pieces (e.g. " The", " cat").
        # Simply joining them should reconstruct the text segment.
        for t_idx in token_indices:
            if 0 <= t_idx < len(all_token_strings):
                compressed_tokens.append(all_token_strings[t_idx])

    # Join tokens
    # BPE tokens usually reconstruct by simple concatenation
    return "".join(compressed_tokens)

# -------------------------------------------------------------------------------------------------------------------------------
# Task 21, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def prune_prompt_task(
    dynamic_inputs: Dict[str, Dict[str, Any]],
    phrase_scores_data: Dict[str, Dict[str, List[Any]]],
    phrases_map: Dict[str, List[List[int]]],
    config: Dict[str, Any]
) -> Dict[str, Dict[str, Any]]:
    """
    Orchestrates the pruning of prompts for all examples.

    Args:
        dynamic_inputs (Dict): Contains 'token_strings' per example.
        phrase_scores_data (Dict): Contains 'scores', 'counts' per example.
        phrases_map (Dict): Contains list of phrases (token indices) per example.
        config (Dict): Configuration containing token budget.

    Returns:
        Dict[str, Dict]: Mapping example_id -> {
            'compressed_text': str,
            'original_tokens': int,
            'compressed_tokens': int,
            'compression_ratio': float
        }
    """
    logger.info("Starting prompt pruning...")
    budget = config.get("hard_prompt_compression_config", {}).get("prompt_token_budget", 1500)

    results = {}

    for example_id, p_data in phrase_scores_data.items():
        if example_id not in dynamic_inputs or example_id not in phrases_map:
            continue

        scores = p_data["scores"]
        counts = p_data["counts"]
        phrases = phrases_map[example_id]
        token_strings = dynamic_inputs[example_id]["token_strings"]

        # Step 1 & 2: Select
        selected_indices = select_phrases_knapsack(scores, counts, budget)

        # Step 3: Reconstruct
        compressed_text = reconstruct_prompt_text(selected_indices, phrases, token_strings)

        # Metrics
        original_len = sum(counts)
        compressed_len = sum(counts[i] for i in selected_indices)
        ratio = original_len / compressed_len if compressed_len > 0 else 1.0

        results[example_id] = {
            "compressed_text": compressed_text,
            "original_tokens": original_len,
            "compressed_tokens": compressed_len,
            "compression_ratio": ratio
        }

    logger.info(f"Pruning complete for {len(results)} examples.")
    return results


In [None]:
# Task 22 – Extract N-grams from Passages and Compute Frequencies

# ==============================================================================
# Task 22: Extract N-grams from Passages and Compute Frequencies
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 22, Step 1 & 2: Extract and Count N-grams
# -------------------------------------------------------------------------------------------------------------------------------
def count_ngrams_in_passages(
    passages_list: List[str],
    n: int = 2,
    encoding_name: str = "cl100k_base"
) -> CounterType[Tuple[int, ...]]:
    """
    Extracts n-grams from a list of passage texts and computes their frequencies.

    Args:
        passages_list (List[str]): List of passage texts.
        n (int): N-gram length (default 2).
        encoding_name (str): Tokenizer encoding name.

    Returns:
        Counter[Tuple[int, ...]]: Mapping of n-gram (tuple of token IDs) to count.
    """
    logger.info(f"Counting {n}-grams in {len(passages_list)} passages...")

    try:
        encoding = tiktoken.get_encoding(encoding_name)
    except Exception as e:
        logger.error(f"Failed to load encoding '{encoding_name}': {e}")
        return Counter()

    ngram_counts: CounterType[Tuple[int, ...]] = Counter()

    for text in passages_list:
        if not text:
            continue

        # Tokenize
        tokens = encoding.encode(text, disallowed_special=())

        if len(tokens) < n:
            continue

        # Extract n-grams
        # Sliding window
        for i in range(len(tokens) - n + 1):
            ngram = tuple(tokens[i : i + n])
            ngram_counts[ngram] += 1

    return ngram_counts

# -------------------------------------------------------------------------------------------------------------------------------
# Task 22, Step 3: Select top-K n-grams
# -------------------------------------------------------------------------------------------------------------------------------
def select_top_k_ngrams(
    ngram_counts: CounterType[Tuple[int, ...]],
    k: int = 100
) -> List[Tuple[Tuple[int, ...], int]]:
    """
    Selects the top-K most frequent n-grams.

    Args:
        ngram_counts (Counter): N-gram counts.
        k (int): Number of top n-grams to select.

    Returns:
        List[Tuple[Tuple[int, ...], int]]: List of (ngram, count) sorted by count descending.
    """
    return ngram_counts.most_common(k)

# -------------------------------------------------------------------------------------------------------------------------------
# Task 22, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def extract_top_ngrams_task(
    tatqa_df: pd.DataFrame,
    finqa_df: pd.DataFrame,
    config: Dict[str, Any]
) -> List[Tuple[Tuple[int, ...], int]]:
    """
    Orchestrates the extraction of top-K n-grams from all passages in both datasets.

    Args:
        tatqa_df (pd.DataFrame): TAT-QA DataFrame.
        finqa_df (pd.DataFrame): Fin-QA DataFrame.
        config (Dict): Configuration containing n-gram settings.

    Returns:
        List[Tuple[Tuple[int, ...], int]]: Top-K n-grams and their counts.
    """
    logger.info("Starting n-gram extraction...")

    ngram_config = config.get("ngram_abbreviation_config", {})
    n = ngram_config.get("ngram_size_G", {}).get("best_performing_value", 2)
    k = ngram_config.get("dictionary_size_K", {}).get("default", 100)

    # Collect all passages
    all_passages = []

    # TAT-QA passages
    for passages in tatqa_df["passages"]:
        if isinstance(passages, list):
            for p in passages:
                if isinstance(p, dict):
                    all_passages.append(p.get("text", ""))

    # Fin-QA passages
    for passages in finqa_df["passages"]:
        if isinstance(passages, list):
            for p in passages:
                if isinstance(p, dict):
                    all_passages.append(p.get("text", ""))

    # Count
    counts = count_ngrams_in_passages(all_passages, n=n)

    # Select Top-K
    top_k = select_top_k_ngrams(counts, k=k)

    logger.info(f"Extracted top {len(top_k)} {n}-grams.")
    return top_k


In [None]:
# Task 23 – Construct Abbreviation Dictionary

# ==============================================================================
# Task 23: Construct Abbreviation Dictionary
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 23, Step 1 & 2: Assign placeholders and build dictionary
# -------------------------------------------------------------------------------------------------------------------------------
def build_abbreviation_maps(
    top_k_ngrams: List[Tuple[Tuple[int, ...], int]]
) -> Tuple[Dict[Tuple[int, ...], str], Dict[str, Tuple[int, ...]]]:
    """
    Constructs bidirectional mappings between n-grams and placeholder tokens.

    Placeholder format: "[PH{i:03d}]" (e.g., [PH001], [PH002], ...)

    Args:
        top_k_ngrams (List[Tuple[Tuple[int, ...], int]]): List of (ngram, count) tuples.

    Returns:
        Tuple[Dict, Dict]:
            - ngram_to_placeholder: Mapping from n-gram tuple to placeholder string.
            - placeholder_to_ngram: Mapping from placeholder string to n-gram tuple.
    """
    ngram_to_ph = {}
    ph_to_ngram = {}

    for i, (ngram, count) in enumerate(top_k_ngrams):
        # 1-based index for readability
        placeholder = f"[PH{i+1:03d}]"

        ngram_to_ph[ngram] = placeholder
        ph_to_ngram[placeholder] = ngram

    return ngram_to_ph, ph_to_ngram

# -------------------------------------------------------------------------------------------------------------------------------
# Task 23, Step 3: Verify uniqueness
# -------------------------------------------------------------------------------------------------------------------------------
def verify_abbreviation_maps(
    ngram_to_ph: Dict[Tuple[int, ...], str],
    ph_to_ngram: Dict[str, Tuple[int, ...]]
) -> bool:
    """
    Verifies that the abbreviation mappings are bijective and unique.

    Args:
        ngram_to_ph (Dict): N-gram to placeholder map.
        ph_to_ngram (Dict): Placeholder to n-gram map.

    Returns:
        bool: True if valid.
    """
    if len(ngram_to_ph) != len(ph_to_ngram):
        logger.error("Mismatch in dictionary lengths.")
        return False

    # Check round-trip
    for ngram, ph in ngram_to_ph.items():
        if ph_to_ngram.get(ph) != ngram:
            logger.error(f"Mapping mismatch for {ph}")
            return False

    return True

# -------------------------------------------------------------------------------------------------------------------------------
# Task 23, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def construct_abbreviation_dict_task(
    top_k_ngrams: List[Tuple[Tuple[int, ...], int]]
) -> Dict[str, Any]:
    """
    Orchestrates the construction of the abbreviation dictionary.

    Args:
        top_k_ngrams (List): Top-K n-grams from Task 22.

    Returns:
        Dict[str, Any]: Dictionary containing:
            - 'ngram_to_ph': Forward map.
            - 'ph_to_ngram': Reverse map.
            - 'metadata': Count of entries.
    """
    logger.info("Constructing abbreviation dictionary...")

    # Build abbreviation maps
    ngram_to_ph, ph_to_ngram = build_abbreviation_maps(top_k_ngrams)

    if verify_abbreviation_maps(ngram_to_ph, ph_to_ngram):
        logger.info(f"Successfully constructed dictionary with {len(ngram_to_ph)} entries.")
    else:
        logger.error("Dictionary construction failed validation.")
        # Return empty or raise error? We return what we have but logged error.

    return {
        "ngram_to_ph": ngram_to_ph,
        "ph_to_ngram": ph_to_ngram,
        "metadata": {"count": len(ngram_to_ph)}
    }


In [None]:
# Task 24 – Apply Abbreviation to Passage

# ==============================================================================
# Task 24: Apply Abbreviation to Passages
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 24, Step 1 & 2: Apply Abbreviation
# -------------------------------------------------------------------------------------------------------------------------------
def abbreviate_text(
    text: str,
    active_ngrams: Set[Tuple[int, ...]],
    ngram_to_ph: Dict[Tuple[int, ...], str],
    n: int,
    encoding: Any
) -> str:
    """
    Applies n-gram abbreviation to a text string using a greedy left-to-right strategy.

    This function tokenizes the input text, scans for occurrences of active n-grams,
    and replaces them with their corresponding placeholder tokens. Non-matching tokens
    are decoded back to strings. The result is a single string containing a mix of
    original text and placeholders.

    Args:
        text (str): The original passage text.
        active_ngrams (Set[Tuple[int, ...]]): A set of n-gram tuples (token IDs) to be replaced.
        ngram_to_ph (Dict[Tuple[int, ...], str]): Mapping from n-gram tuple to placeholder string.
        n (int): The length of the n-grams (e.g., 2 for bigrams).
        encoding (tiktoken.Encoding): The tokenizer used for encoding/decoding.

    Returns:
        str: The abbreviated text string.
    """
    if not text:
        return ""

    # Tokenize the text
    # disallowed_special=() ensures we process all text content safely
    tokens = encoding.encode(text, disallowed_special=())

    output_parts: List[str] = []
    i = 0
    num_tokens = len(tokens)

    while i < num_tokens:
        match_found = False

        # Check if an n-gram can start at this position
        if i <= num_tokens - n:
            # Extract candidate n-gram
            candidate = tuple(tokens[i : i + n])

            # Check if candidate is in the active set
            if candidate in active_ngrams:
                # Found a match: append placeholder
                placeholder = ngram_to_ph[candidate]
                output_parts.append(placeholder)

                # Skip n tokens
                i += n
                match_found = True

        if not match_found:
            # No match: decode the current token and append
            # decode_single_token_bytes returns bytes, we decode to utf-8
            token_bytes = encoding.decode_single_token_bytes(tokens[i])
            token_str = token_bytes.decode('utf-8', errors='replace')
            output_parts.append(token_str)
            i += 1

    # Join all parts to form the abbreviated string
    return "".join(output_parts)

# -------------------------------------------------------------------------------------------------------------------------------
# Task 24, Step 3: Verify reversibility
# -------------------------------------------------------------------------------------------------------------------------------
def verify_reversibility(
    original_text: str,
    abbreviated_text: str,
    ph_to_ngram: Dict[str, Tuple[int, ...]],
    encoding: Any
) -> bool:
    """
    Verifies that the abbreviated text can be reconstructed to match the original text.

    This function attempts to reverse the abbreviation process by replacing placeholders
    with their original n-gram text. It compares the reconstructed text with the original
    text, normalizing whitespace to account for minor tokenization artifacts.

    Args:
        original_text (str): The original passage text.
        abbreviated_text (str): The text after abbreviation.
        ph_to_ngram (Dict[str, Tuple[int, ...]]): Mapping from placeholder string to n-gram tuple.
        encoding (tiktoken.Encoding): The tokenizer used for decoding n-grams.

    Returns:
        bool: True if the reconstructed text matches the original (normalized), False otherwise.
    """
    reconstructed_text = abbreviated_text

    # Iterate through placeholders and replace them
    # Note: We assume placeholders are unique strings like "[PH001]"
    for ph, ngram in ph_to_ngram.items():
        if ph in reconstructed_text:
            # Decode the original n-gram tokens to text
            ngram_text = encoding.decode(list(ngram))
            reconstructed_text = reconstructed_text.replace(ph, ngram_text)

    # Normalize whitespace for comparison
    # Tokenization/detokenization can sometimes alter spacing (e.g. "word ." vs "word.")
    norm_orig = " ".join(original_text.split())
    norm_recon = " ".join(reconstructed_text.split())

    return norm_orig == norm_recon

# -------------------------------------------------------------------------------------------------------------------------------
# Task 24, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def apply_abbreviation_task(
    tatqa_df: pd.DataFrame,
    finqa_df: pd.DataFrame,
    abbrev_dict: Dict[str, Any],
    config: Dict[str, Any]
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Orchestrates the application of n-gram abbreviation to passages in both datasets.

    This function:
    1. Determines the active set of n-grams (Top-T) based on configuration.
    2. Iterates through all passages in TAT-QA and Fin-QA.
    3. Applies abbreviation to each passage.
    4. Stores the abbreviated text in a new column 'abbreviated_passages'.

    Args:
        tatqa_df (pd.DataFrame): TAT-QA DataFrame containing 'passages'.
        finqa_df (pd.DataFrame): Fin-QA DataFrame containing 'passages'.
        abbrev_dict (Dict[str, Any]): Dictionary containing 'ngram_to_ph' and 'ph_to_ngram'.
        config (Dict[str, Any]): Configuration dictionary with 'ngram_abbreviation_config'.

    Returns:
        Tuple[pd.DataFrame, pd.DataFrame]: The input DataFrames with a new 'abbreviated_passages' column.
    """
    logger.info("Starting abbreviation application task...")

    # Extract mappings
    ngram_to_ph = abbrev_dict["ngram_to_ph"]
    ph_to_ngram = abbrev_dict["ph_to_ngram"]

    # Extract configuration parameters
    ngram_config = config.get("ngram_abbreviation_config", {})
    T = ngram_config.get("top_n_T", {}).get("best_performing_value", 3)
    n = ngram_config.get("ngram_size_G", {}).get("best_performing_value", 2)
    encoding_name = config.get("offline_corpus_config", {}).get("tokenization_scheme", {}).get("name", "cl100k_base")

    # Load tokenizer
    try:
        encoding = tiktoken.get_encoding(encoding_name)
    except Exception as e:
        logger.error(f"Failed to load encoding '{encoding_name}': {e}. Falling back to cl100k_base.")
        encoding = tiktoken.get_encoding("cl100k_base")

    # Determine active n-grams (Top-T)
    # We assume placeholders are named sequentially [PH001], [PH002], etc.
    active_ngrams: Set[Tuple[int, ...]] = set()
    for i in range(1, T + 1):
        ph = f"[PH{i:03d}]"
        if ph in ph_to_ngram:
            active_ngrams.add(ph_to_ngram[ph])

    logger.info(f"Active n-grams count (T={T}): {len(active_ngrams)}")

    # Helper to process a list of passages
    def process_passages_list(passages_list: Any) -> List[Dict[str, Any]]:
        if not isinstance(passages_list, list):
            return []

        new_passages = []
        for p in passages_list:
            if not isinstance(p, dict):
                continue

            original_text = p.get("text", "")
            # Apply abbreviation
            abbr_text = abbreviate_text(original_text, active_ngrams, ngram_to_ph, n, encoding)

            # Create a copy of the passage dict with updated text
            new_p = p.copy()
            new_p["text"] = abbr_text
            new_passages.append(new_p)

        return new_passages

    # Apply to TAT-QA
    logger.info("Abbreviating TAT-QA passages...")
    tatqa_df["abbreviated_passages"] = tatqa_df["passages"].apply(process_passages_list)

    # Apply to Fin-QA
    logger.info("Abbreviating Fin-QA passages...")
    finqa_df["abbreviated_passages"] = finqa_df["passages"].apply(process_passages_list)

    logger.info("Abbreviation task complete.")
    return tatqa_df, finqa_df


In [None]:
# Task 25 – Identify Numeric Columns and Extract Values for Quantization

# ==============================================================================
# Task 25: Identify Numeric Columns and Extract Values for Quantization
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 25, Step 1 & 2: Extract Values
# -------------------------------------------------------------------------------------------------------------------------------
def extract_column_values(
    df: pd.DataFrame,
    dataset_name: str,
    numeric_metadata: Dict[Tuple[str, str, str, int], bool],
    normalizer: Any # FinancialTextNormalizer
) -> Dict[Tuple[str, str, str, int], List[float]]:
    """
    Extracts parsed float values from columns identified as numeric.

    Args:
        df (pd.DataFrame): The dataset DataFrame.
        dataset_name (str): "TAT-QA" or "Fin-QA".
        numeric_metadata (Dict): Metadata indicating which columns are numeric.
        normalizer (FinancialTextNormalizer): Parser instance.

    Returns:
        Dict: Mapping from column key to list of parsed float values.
    """
    extracted_values = {}

    for idx, row in df.iterrows():
        example_id = row["example_id"]
        tables = row["tables"]

        if not isinstance(tables, list):
            continue

        for table in tables:
            if not isinstance(table, dict):
                continue

            table_id = table.get("table_id", "unknown")
            rows = table.get("rows", [])

            # Check which columns in this table are numeric
            # We iterate through known numeric columns for this table
            # Optimization: iterate through columns in the table and check metadata
            if not rows:
                continue

            num_cols = len(rows[0]) # Assume consistent row length from cleansing

            for col_idx in range(num_cols):
                key = (dataset_name, example_id, table_id, col_idx)

                # Only process if marked numeric
                if not numeric_metadata.get(key, False):
                    continue

                # Extract values
                col_values = []
                for r in rows:
                    if col_idx < len(r):
                        val_str = str(r[col_idx])
                        parsed = normalizer.parse_to_float(val_str)
                        if parsed is not None:
                            col_values.append(parsed)

                if col_values:
                    extracted_values[key] = col_values

    return extracted_values

# -------------------------------------------------------------------------------------------------------------------------------
# Task 25, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def extract_numeric_values_task(
    tatqa_df: pd.DataFrame,
    finqa_df: pd.DataFrame,
    numeric_metadata: Dict[Tuple[str, str, str, int], bool]
) -> Dict[Tuple[str, str, str, int], Dict[str, Any]]:
    """
    Orchestrates the extraction of numeric values for quantization.

    Args:
        tatqa_df (pd.DataFrame): TAT-QA DataFrame.
        finqa_df (pd.DataFrame): Fin-QA DataFrame.
        numeric_metadata (Dict): Metadata from Task 6.

    Returns:
        Dict: Mapping column key -> {'values': List[float], 'count': int}.
    """
    logger.info("Starting numeric value extraction...")

    # We need the normalizer from Task 6
    # Assuming FinancialTextNormalizer is available in scope
    normalizer = FinancialTextNormalizer()

    # Extract TAT-QA
    tatqa_values = extract_column_values(tatqa_df, "TAT-QA", numeric_metadata, normalizer)

    # Extract Fin-QA
    finqa_values = extract_column_values(finqa_df, "Fin-QA", numeric_metadata, normalizer)

    # Merge and format
    all_values = {**tatqa_values, **finqa_values}

    final_output = {}
    for key, vals in all_values.items():
        final_output[key] = {
            "values": vals,
            "count": len(vals)
        }

    logger.info(f"Extracted values for {len(final_output)} numeric columns.")
    return final_output


In [None]:
# Task 26 – Apply Uniform Integer Quantization

# ==============================================================================
# Task 26: Apply Uniform Integer Quantization
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 26, Step 1 & 2: Compute Range and Encode
# -------------------------------------------------------------------------------------------------------------------------------
def quantize_column_uniform(
    values: List[float],
    bit_width: int = 8
) -> Tuple[List[int], Dict[str, float]]:
    """
    Applies uniform integer quantization to a list of float values.

    Equations:
        L = 2^b
        q_i = round((x_i - min_x) / (max_x - min_x) * (L - 1))
        epsilon_max = (max_x - min_x) / (L - 1)

    Args:
        values (List[float]): List of numeric values.
        bit_width (int): Bit width b (default 8).

    Returns:
        Tuple[List[int], Dict]:
            - List of quantized integer codes.
            - Metadata dictionary (min_x, max_x, L, epsilon_max).
    """
    if not values:
        return [], {"min_x": 0.0, "max_x": 0.0, "L": 0, "epsilon_max": 0.0}

    min_x = min(values)
    max_x = max(values)
    L = 2 ** bit_width

    # Handle constant column
    if max_x == min_x:
        # All values map to 0
        codes = [0] * len(values)
        epsilon_max = 0.0
        return codes, {
            "min_x": min_x,
            "max_x": max_x,
            "L": L,
            "epsilon_max": epsilon_max
        }

    # Compute codes
    codes = []
    denominator = max_x - min_x
    scale = L - 1

    for x in values:
        normalized = (x - min_x) / denominator
        q = int(round(normalized * scale))
        # Clip to ensure bounds (floating point errors might push slightly outside)
        q = max(0, min(scale, q))
        codes.append(q)

    epsilon_max = denominator / scale

    return codes, {
        "min_x": min_x,
        "max_x": max_x,
        "L": L,
        "epsilon_max": epsilon_max
    }

# -------------------------------------------------------------------------------------------------------------------------------
# Task 26, Step 3: Compute reconstruction (Helper)
# -------------------------------------------------------------------------------------------------------------------------------
def reconstruct_uniform(
    codes: List[int],
    metadata: Dict[str, float]
) -> List[float]:
    """
    Reconstructs approximate values from quantized codes.

    Equation:
        x_hat = min_x + (q_i / (L - 1)) * (max_x - min_x)

    Args:
        codes (List[int]): Quantized codes.
        metadata (Dict): Metadata from quantization.

    Returns:
        List[float]: Reconstructed values.
    """
    # Extract values
    min_x = metadata["min_x"]
    max_x = metadata["max_x"]
    L = int(metadata["L"])

    if max_x == min_x:
        return [min_x] * len(codes)

    # Compute scale and range
    scale = L - 1
    range_x = max_x - min_x

    reconstructed = []

    # Compute q
    for q in codes:
        val = min_x + (q / scale) * range_x
        reconstructed.append(val)

    return reconstructed

# -------------------------------------------------------------------------------------------------------------------------------
# Task 26, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def apply_uniform_quantization_task(
    extracted_values: Dict[Tuple[str, str, str, int], Dict[str, Any]],
    config: Dict[str, Any]
) -> Dict[Tuple[str, str, str, int], Dict[str, Any]]:
    """
    Orchestrates uniform quantization for all extracted numeric columns.

    Args:
        extracted_values (Dict): Output from Task 25.
        config (Dict): Configuration containing bit width.

    Returns:
        Dict: Mapping column key -> {
            'codes': List[int],
            'metadata': Dict
        }
    """
    logger.info("Starting uniform quantization...")

    bit_width = config.get("numeric_quantization_config", {}).get("uniform_integer", {}).get("bit_width_b", 8)
    results = {}

    # Iterate through extracted values
    for key, data in extracted_values.items():
        values = data["values"]

        # Quantize column data
        codes, meta = quantize_column_uniform(values, bit_width)

        results[key] = {
            "codes": codes,
            "metadata": meta
        }

    logger.info(f"Quantized {len(results)} columns.")
    return results


In [None]:
# Task 27 – (Optional) Apply K-Means-Based Quantization

# ==============================================================================
# Task 27: (Optional) Apply K-Means-Based Quantization
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 27, Step 1 & 2: Run K-Means and Encode
# -------------------------------------------------------------------------------------------------------------------------------
def quantize_column_kmeans(
    values: List[float],
    k: int = 16
) -> Tuple[List[int], Dict[str, Any]]:
    """
    Applies K-Means quantization to a list of float values.

    Args:
        values (List[float]): List of numeric values.
        k (int): Number of clusters (default 16).

    Returns:
        Tuple[List[int], Dict]:
            - List of quantized integer codes (cluster indices).
            - Metadata dictionary (centroids, k, mse).
    """
    if not values:
        return [], {"centroids": [], "k": 0, "mse": 0.0}

    # Adjust k if we have fewer unique values than k
    unique_vals = len(set(values))
    if unique_vals < k:
        k = unique_vals

    if k == 0: # Should not happen if values is not empty
         return [], {"centroids": [], "k": 0, "mse": 0.0}

    # Reshape for sklearn
    X = np.array(values).reshape(-1, 1)

    # Fit KMeans
    # n_init='auto' is default in newer sklearn, but explicit 10 is safe
    kmeans = KMeans(n_clusters=k, n_init=10, random_state=42)
    kmeans.fit(X)

    codes = kmeans.labels_.tolist()
    centroids = kmeans.cluster_centers_.flatten().tolist()

    # Compute MSE
    # inertia_ is sum of squared distances to nearest centroid
    mse = kmeans.inertia_ / len(values)

    return codes, {
        "centroids": centroids,
        "k": k,
        "mse": mse
    }

# -------------------------------------------------------------------------------------------------------------------------------
# Task 27, Step 3: Reconstruction (Helper)
# -------------------------------------------------------------------------------------------------------------------------------
def reconstruct_kmeans(
    codes: List[int],
    metadata: Dict[str, Any]
) -> List[float]:
    """
    Reconstructs values from k-means codes.

    Args:
        codes (List[int]): Cluster indices.
        metadata (Dict): Metadata containing centroids.

    Returns:
        List[float]: Reconstructed values (centroids).
    """
    centroids = metadata["centroids"]
    if not centroids:
        return []

    return [centroids[c] for c in codes]

# -------------------------------------------------------------------------------------------------------------------------------
# Task 27, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def apply_kmeans_quantization_task(
    extracted_values: Dict[Tuple[str, str, str, int], Dict[str, Any]],
    config: Dict[str, Any]
) -> Dict[Tuple[str, str, str, int], Dict[str, Any]]:
    """
    Orchestrates K-Means quantization for all extracted numeric columns.

    Args:
        extracted_values (Dict): Output from Task 25.
        config (Dict): Configuration containing k.

    Returns:
        Dict: Mapping column key -> {
            'codes': List[int],
            'metadata': Dict
        }
    """
    logger.info("Starting K-Means quantization...")

    k = config.get("numeric_quantization_config", {}).get("kmeans_based", {}).get("num_clusters_k", 16)
    results = {}

    # Iterate through extracted values
    for key, data in extracted_values.items():
        values = data["values"]

        # Quantize column means
        codes, meta = quantize_column_kmeans(values, k)

        results[key] = {
            "codes": codes,
            "metadata": meta
        }

    logger.info(f"Quantized {len(results)} columns using K-Means.")
    return results


In [None]:
# Task 28 – Embed Candidate Examples for Few-Shot Selection

# ==============================================================================
# Task 28: Embed Candidate Examples for Few-Shot Selection
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 28, Step 1: Define textual representation
# -------------------------------------------------------------------------------------------------------------------------------
def construct_embedding_text(row: pd.Series) -> str:
    """
    Constructs the text representation of an example for embedding.

    Format:
    Question: {question}
    Context:
    {serialized_tables}
    {passages}

    Args:
        row (pd.Series): DataFrame row.

    Returns:
        str: Text to embed.
    """
    question = row.get("question_text", "").strip()

    tables = row.get("serialized_tables", [])
    tables_str = "\n".join(tables) if isinstance(tables, list) else ""

    passages = row.get("passages", [])
    passages_str = ""
    if isinstance(passages, list):
        passages_str = "\n".join([p.get("text", "").strip() for p in passages if isinstance(p, dict)])

    return f"Question: {question}\nContext:\n{tables_str}\n{passages_str}"

# -------------------------------------------------------------------------------------------------------------------------------
# Task 28, Step 2: Compute embeddings
# -------------------------------------------------------------------------------------------------------------------------------
def compute_embeddings(
    texts: List[str],
    model_name: str = "all-mpnet-base-v2",
    batch_size: int = 32
) -> np.ndarray:
    """
    Computes embeddings for a list of texts using SentenceTransformers.

    Args:
        texts (List[str]): List of texts.
        model_name (str): Model name.
        batch_size (int): Batch size.

    Returns:
        np.ndarray: Embedding matrix of shape (N, 768).
    """
    logger.info(f"Loading embedding model: {model_name}")
    model = SentenceTransformer(model_name)

    logger.info(f"Encoding {len(texts)} texts...")
    embeddings = model.encode(texts, batch_size=batch_size, show_progress_bar=True)

    return embeddings

# -------------------------------------------------------------------------------------------------------------------------------
# Task 28, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def embed_examples_task(
    df: pd.DataFrame,
    dataset_name: str,
    config: Dict[str, Any],
    output_dir: str = "embeddings"
) -> Tuple[np.ndarray, List[str]]:
    """
    Orchestrates the embedding of examples.

    Args:
        df (pd.DataFrame): Input DataFrame.
        dataset_name (str): Dataset name.
        config (Dict): Configuration.
        output_dir (str): Directory to save embeddings.

    Returns:
        Tuple[np.ndarray, List[str]]: Embeddings matrix and list of example IDs.
    """
    logger.info(f"Starting embedding task for {dataset_name}...")

    # Step 1: Construct texts
    texts = df.apply(construct_embedding_text, axis=1).tolist()
    example_ids = df["example_id"].tolist()

    # Step 2: Compute
    model_name = config.get("fewshot_selection_config", {}).get("embedding_model", {}).get("name", "all-mpnet-base-v2")
    embeddings = compute_embeddings(texts, model_name)

    # Step 3: Save
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    np.save(os.path.join(output_dir, f"{dataset_name}_embeddings.npy"), embeddings)
    with open(os.path.join(output_dir, f"{dataset_name}_ids.json"), "w") as f:
        json.dump(example_ids, f)

    logger.info(f"Embeddings saved to {output_dir}")
    return embeddings, example_ids


In [None]:
# Task 29 – Select Optimal Cluster Count via Silhouette Score

# ==============================================================================
# Task 29: Select Optimal Cluster Count via Silhouette Score
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 29, Step 1 & 2: Run K-Means and Compute Silhouette
# -------------------------------------------------------------------------------------------------------------------------------
def compute_silhouette_scores(
    embeddings: np.ndarray,
    k_range: range
) -> Dict[int, float]:
    """
    Computes average silhouette scores for a range of cluster counts k.

    Args:
        embeddings (np.ndarray): Embedding matrix of shape (N, D).
        k_range (range): Range of k values to test.

    Returns:
        Dict[int, float]: Mapping from k to average silhouette score.
    """
    scores = {}

    # Ensure we have enough samples for clustering
    n_samples = embeddings.shape[0]
    if n_samples < 2:
        logger.warning("Not enough samples for clustering.")
        return {}

    # Adjust range if n_samples is small
    valid_k_range = [k for k in k_range if k < n_samples]

    for k in valid_k_range:
        # Fit KMeans
        kmeans = KMeans(n_clusters=k, n_init=10, random_state=42)
        labels = kmeans.fit_predict(embeddings)

        # Compute Silhouette
        # This can be slow for very large N
        score = silhouette_score(embeddings, labels, metric='euclidean')
        scores[k] = score

        logger.info(f"k={k}: Silhouette Score = {score:.4f}")

    return scores

# -------------------------------------------------------------------------------------------------------------------------------
# Task 29, Step 3: Select optimal k
# -------------------------------------------------------------------------------------------------------------------------------
def select_best_k(scores: Dict[int, float]) -> int:
    """
    Selects the k with the maximum silhouette score.

    Args:
        scores (Dict[int, float]): Mapping k -> score.

    Returns:
        int: The optimal k.
    """
    if not scores:
        return 5 # Default fallback

    best_k = max(scores, key=scores.get)
    logger.info(f"Optimal k selected: {best_k} (Score: {scores[best_k]:.4f})")
    return best_k

# -------------------------------------------------------------------------------------------------------------------------------
# Task 29, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def select_optimal_k_task(
    embeddings: np.ndarray,
    config: Dict[str, Any]
) -> int:
    """
    Orchestrates the selection of the optimal cluster count.

    Args:
        embeddings (np.ndarray): Embedding matrix.
        config (Dict): Configuration containing k range.

    Returns:
        int: Optimal k.
    """
    logger.info("Starting optimal k selection...")

    k_cfg = config.get("fewshot_selection_config", {}).get("kmeans_candidate_k_range", {})
    start = k_cfg.get("start", 5)
    end = k_cfg.get("end", 50)

    # Range is inclusive in config description, python range is exclusive at end
    k_range = range(start, end + 1)

    scores = compute_silhouette_scores(embeddings, k_range)
    best_k = select_best_k(scores)

    return best_k


In [None]:
# Task 30 – Select Representative Exemplars from Clusters

# ==============================================================================
# Task 30: Select Representative Exemplars from Clusters
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 30, Step 1 & 2: Run K-Means and Find Prototypes
# -------------------------------------------------------------------------------------------------------------------------------
def find_cluster_prototypes(
    embeddings: np.ndarray,
    k: int
) -> Tuple[List[int], np.ndarray]:
    """
    Runs K-Means and identifies the index of the sample closest to each centroid.

    Args:
        embeddings (np.ndarray): Embedding matrix.
        k (int): Number of clusters.

    Returns:
        Tuple[List[int], np.ndarray]:
            - List of indices of prototype samples (one per cluster).
            - Array of cluster labels for all samples.
    """
    if k < 1:
        return [], np.array([])

    # Fit KMeans
    kmeans = KMeans(n_clusters=k, n_init=10, random_state=42)
    labels = kmeans.fit_predict(embeddings)
    centroids = kmeans.cluster_centers_

    # Find closest sample to each centroid
    # pairwise_distances_argmin_min returns (indices, distances)
    closest_indices, _ = pairwise_distances_argmin_min(centroids, embeddings)

    return closest_indices.tolist(), labels

# -------------------------------------------------------------------------------------------------------------------------------
# Task 30, Step 3: Select top prototypes
# -------------------------------------------------------------------------------------------------------------------------------
def select_top_prototypes(
    prototype_indices: List[int],
    labels: np.ndarray,
    num_exemplars: int = 3
) -> List[int]:
    """
    Selects prototypes from the largest clusters.

    Args:
        prototype_indices (List[int]): Indices of prototypes (aligned with cluster ID 0..k-1).
        labels (np.ndarray): Cluster labels for all samples.
        num_exemplars (int): Number of exemplars to select.

    Returns:
        List[int]: Indices of the selected representative exemplars.
    """
    # Count cluster sizes
    # labels are 0..k-1
    unique, counts = np.unique(labels, return_counts=True)
    cluster_sizes = dict(zip(unique, counts))

    # Sort cluster IDs by size descending
    sorted_clusters = sorted(cluster_sizes.keys(), key=lambda c: cluster_sizes[c], reverse=True)

    # Select top clusters
    selected_clusters = sorted_clusters[:num_exemplars]

    # Get corresponding prototypes
    # prototype_indices[j] is the prototype for cluster j
    selected_indices = [prototype_indices[c] for c in selected_clusters]

    return selected_indices

# -------------------------------------------------------------------------------------------------------------------------------
# Task 30, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def select_representative_exemplars_task(
    embeddings: np.ndarray,
    example_ids: List[str],
    k_star: int,
    config: Dict[str, Any]
) -> List[str]:
    """
    Orchestrates the selection of representative few-shot exemplars.

    Args:
        embeddings (np.ndarray): Embedding matrix.
        example_ids (List[str]): List of example IDs corresponding to rows.
        k_star (int): Optimal cluster count.
        config (Dict): Configuration.

    Returns:
        List[str]: List of selected example_ids.
    """
    logger.info(f"Selecting representative exemplars with k={k_star}...")

    num_exemplars = config.get("fewshot_selection_config", {}).get("num_exemplars_per_prompt", 3)

    # Step 1 & 2: Prototypes
    proto_indices, labels = find_cluster_prototypes(embeddings, k_star)

    # Step 3: Selection
    selected_indices = select_top_prototypes(proto_indices, labels, num_exemplars)

    # Map back to IDs
    selected_ids = [example_ids[i] for i in selected_indices]

    logger.info(f"Selected {len(selected_ids)} exemplars: {selected_ids}")
    return selected_ids


In [None]:
# Task 31 – Compute Embeddings for Semantic Similarity Evaluation

# ==============================================================================
# Task 31: Compute Embeddings for Semantic Similarity Evaluation
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 31, Step 1: Define pairs
# -------------------------------------------------------------------------------------------------------------------------------
def align_prompt_pairs(
    dynamic_inputs: Dict[str, Dict[str, Any]],
    pruned_results: Dict[str, Dict[str, Any]]
) -> Tuple[List[str], List[str], List[str]]:
    """
    Aligns original and compressed prompts by example_id.

    Args:
        dynamic_inputs (Dict): Original prompt data.
        pruned_results (Dict): Compressed prompt data.

    Returns:
        Tuple[List[str], List[str], List[str]]:
            - List of original prompt texts.
            - List of compressed prompt texts.
            - List of corresponding example IDs.
    """
    original_texts = []
    compressed_texts = []
    ids = []

    # Intersection of keys
    common_ids = sorted(list(set(dynamic_inputs.keys()) & set(pruned_results.keys())))

    for eid in common_ids:
        orig = dynamic_inputs[eid]["prompt_text"]
        comp = pruned_results[eid]["compressed_text"]

        original_texts.append(orig)
        compressed_texts.append(comp)
        ids.append(eid)

    return original_texts, compressed_texts, ids

# -------------------------------------------------------------------------------------------------------------------------------
# Task 31, Step 2 & 3: Compute and Store (Orchestrator)
# -------------------------------------------------------------------------------------------------------------------------------
def compute_similarity_embeddings_task(
    dynamic_inputs: Dict[str, Dict[str, Any]],
    pruned_results: Dict[str, Dict[str, Any]],
    config: Dict[str, Any],
    output_dir: str = "similarity_embeddings"
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """
    Orchestrates the computation of embeddings for semantic similarity.

    Args:
        dynamic_inputs (Dict): Original prompts.
        pruned_results (Dict): Compressed prompts.
        config (Dict): Configuration.
        output_dir (str): Output directory.

    Returns:
        Tuple[np.ndarray, np.ndarray, List[str]]:
            - Matrix U (original embeddings).
            - Matrix V (compressed embeddings).
            - List of example IDs.
    """
    logger.info("Starting similarity embedding computation...")

    # Step 1: Align
    orig_texts, comp_texts, ids = align_prompt_pairs(dynamic_inputs, pruned_results)

    if not ids:
        logger.warning("No common examples found for similarity computation.")
        return np.array([]), np.array([]), []

    # Step 2: Compute
    # Reuse compute_embeddings from Task 28
    # We assume compute_embeddings is available in scope
    model_name = config.get("semantic_evaluation_config", {}).get("embedding_model", {}).get("name", "all-mpnet-base-v2")

    logger.info("Embedding original prompts...")
    U = compute_embeddings(orig_texts, model_name)

    logger.info("Embedding compressed prompts...")
    V = compute_embeddings(comp_texts, model_name)

    # Step 3: Save
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    np.save(os.path.join(output_dir, "U_original.npy"), U)
    np.save(os.path.join(output_dir, "V_compressed.npy"), V)
    with open(os.path.join(output_dir, "ids.json"), "w") as f:
        json.dump(ids, f)

    logger.info(f"Similarity embeddings saved to {output_dir}")
    return U, V, ids


In [None]:
# Task 32 – Compute Cosine Similarities and Summary Statistics

# ==============================================================================
# Task 32: Compute Cosine Similarities and Summary Statistics
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 32, Step 1: Compute Cosine Similarity
# -------------------------------------------------------------------------------------------------------------------------------
def compute_cosine_similarities(
    U: np.ndarray,
    V: np.ndarray
) -> np.ndarray:
    """
    Computes cosine similarity between corresponding rows of two matrices.

    Equation:
        s_i = (u_i . v_i) / (||u_i|| * ||v_i||)

    Args:
        U (np.ndarray): Matrix of original embeddings (N, D).
        V (np.ndarray): Matrix of compressed embeddings (N, D).

    Returns:
        np.ndarray: Array of cosine similarities (N,).
    """
    if U.shape != V.shape:
        raise ValueError(f"Shape mismatch: U={U.shape}, V={V.shape}")

    # Compute dot products
    # element-wise multiply then sum along axis 1
    dot_products = np.sum(U * V, axis=1)

    # Compute norms
    norm_u = np.linalg.norm(U, axis=1)
    norm_v = np.linalg.norm(V, axis=1)

    # Avoid division by zero
    # Replace 0 norms with 1 (similarity will be 0)
    norm_u[norm_u == 0] = 1.0
    norm_v[norm_v == 0] = 1.0

    similarities = dot_products / (norm_u * norm_v)

    # Clip to [-1, 1] to handle floating point errors
    similarities = np.clip(similarities, -1.0, 1.0)

    return similarities

# -------------------------------------------------------------------------------------------------------------------------------
# Task 32, Step 2 & 3: Compute Statistics and Thresholds
# -------------------------------------------------------------------------------------------------------------------------------
def analyze_similarity_stats(
    similarities: np.ndarray,
    threshold: float = 0.92
) -> Dict[str, float]:
    """
    Computes summary statistics and threshold violations.

    Args:
        similarities (np.ndarray): Array of similarity scores.
        threshold (float): Safety threshold (default 0.92).

    Returns:
        Dict[str, float]: Dictionary of statistics.
    """
    if len(similarities) == 0:
        return {}

    stats = {
        "mean_similarity": float(np.mean(similarities)),
        "median_similarity": float(np.median(similarities)),
        "p05_similarity": float(np.percentile(similarities, 5)),
        "min_similarity": float(np.min(similarities)),
        "max_similarity": float(np.max(similarities)),
        "std_dev": float(np.std(similarities)),
        "fraction_below_threshold": float(np.mean(similarities < threshold))
    }

    return stats

# -------------------------------------------------------------------------------------------------------------------------------
# Task 32, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def evaluate_semantic_similarity_task(
    U: np.ndarray,
    V: np.ndarray,
    example_ids: List[str],
    config: Dict[str, Any]
) -> Dict[str, Any]:
    """
    Orchestrates the semantic similarity evaluation.

    Args:
        U (np.ndarray): Original embeddings.
        V (np.ndarray): Compressed embeddings.
        example_ids (List[str]): Example IDs.
        config (Dict): Configuration.

    Returns:
        Dict[str, Any]: Results containing per-example scores and aggregate stats.
    """
    logger.info("Starting semantic similarity evaluation...")

    threshold = config.get("semantic_evaluation_config", {}).get("interpretive_thresholds", {}).get("semantic_safety_similarity_min", 0.92)

    # Step 1: Compute
    similarities = compute_cosine_similarities(U, V)

    # Step 2 & 3: Analyze
    stats = analyze_similarity_stats(similarities, threshold)

    # Map scores to IDs
    per_example_scores = {eid: float(score) for eid, score in zip(example_ids, similarities)}

    # Identify flagged examples
    flagged_ids = [eid for eid, score in per_example_scores.items() if score < threshold]

    logger.info(f"Mean Similarity: {stats['mean_similarity']:.4f}")
    logger.info(f"5th Percentile: {stats['p05_similarity']:.4f}")
    logger.info(f"Flagged {len(flagged_ids)} examples below {threshold}")

    return {
        "stats": stats,
        "per_example_scores": per_example_scores,
        "flagged_ids": flagged_ids
    }


In [None]:
# Task 33 – Recruit and Calibrate Human Annotators

# ==============================================================================
# Task 33: Recruit and Calibrate Human Annotators (LLM Proxies)
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 33, Step 1: Create Annotation Prompt Template
# -------------------------------------------------------------------------------------------------------------------------------
def get_annotation_prompt_template() -> str:
    """
    Returns the prompt template for LLM-based semantic equivalence evaluation.

    This template is designed to be neutral and unbiased. It presents two text
    segments labeled "Text A" and "Text B" without identifying which is the
    original and which is the compressed version. This blinding prevents the
    evaluator model from biasing its score based on the label.

    The prompt instructs the model to act as a human annotator and rate the
    semantic equivalence on a 1-5 scale, focusing on financial facts and logic.

    Returns:
        str: The prompt template string with placeholders `{text_a}` and `{text_b}`.
    """
    return """You are an expert financial analyst acting as a data quality evaluator.
            Your task is to rate the semantic equivalence between two text segments, labeled "Text A" and "Text B".

            Rating Scale:
            1: Completely different meaning. The texts are unrelated or contradict each other.
            2: Mostly different. Key financial figures or logic present in one are missing or altered significantly in the other.
            3: Similar but noticeable differences. Some nuance is lost, or minor details are inconsistent between the texts.
            4: Mostly identical. The core meaning and all key figures are preserved; only minor stylistic changes or negligible omissions exist.
            5: Completely identical meaning. Both texts convey exactly the same information with no loss of semantic content.

            Input:
            Text A:
            {text_a}

            Text B:
            {text_b}

            Instructions:
            - Compare Text A and Text B for semantic equivalence.
            - Focus on financial facts (numbers, dates, entities, direction of change).
            - Ignore minor grammatical fluency issues if the meaning remains clear.
            - Do not assume one text is the "ground truth"; evaluate their mutual consistency.
            - Provide your response in strict JSON format with keys "score" (integer 1-5) and "reasoning" (string).

            Response Format:
            {
              "score": <int>,
              "reasoning": "<string>"
            }
            """

# -------------------------------------------------------------------------------------------------------------------------------
# Task 33, Step 2: Initialize LLM Agents
# -------------------------------------------------------------------------------------------------------------------------------
def initialize_proxy_annotators() -> Dict[str, LLMInterface]:
    """
    Initializes the LLM agents that will serve as human proxies.

    Models:
    - Agent 1: OpenAI "gpt-5.1-high"
    - Agent 2: Anthropic "claude-opus-4-5-20251101"
    - Agent 3: Anthropic "claude-sonnet-4-5-20250929"

    Returns:
        Dict[str, LLMInterface]: Mapping of agent ID to initialized interface.
    """
    agents = {}

    # Define configurations
    configs = [
        ("Agent_GPT5", "gpt-5.1-high", GPT4oInterface), # Using GPT4oInterface for OpenAI protocol
        ("Agent_ClaudeOpus", "claude-opus-4-5-20251101", ClaudeInterface),
        ("Agent_ClaudeSonnet", "claude-sonnet-4-5-20250929", ClaudeInterface)
    ]

    for agent_id, model_name, interface_cls in configs:
        try:
            # Instantiate interface
            # The interface classes defined in Task 14 take 'model' as an init argument.
            agent = interface_cls(model=model_name)
            agents[agent_id] = agent
            logger.info(f"Initialized {agent_id} with model {model_name}")
        except Exception as e:
            logger.error(f"Failed to initialize {agent_id}: {e}")

    return agents

# -------------------------------------------------------------------------------------------------------------------------------
# Task 33, Step 3: Perform Calibration
# -------------------------------------------------------------------------------------------------------------------------------
def calculate_cohen_kappa(rater1_scores: List[int], rater2_scores: List[int]) -> float:
    """
    Calculates Cohen's Kappa for inter-rater agreement between two raters.

    Cohen's Kappa measures the agreement between two raters who each classify N items
    into C mutually exclusive categories. It accounts for the agreement occurring by chance.

    Formula:
        kappa = (p_o - p_e) / (1 - p_e)
    Where:
        p_o = Relative observed agreement among raters.
        p_e = Hypothetical probability of chance agreement.

    Args:
        rater1_scores (List[int]): List of scores from rater 1.
        rater2_scores (List[int]): List of scores from rater 2.

    Returns:
        float: The Cohen's Kappa score. Returns 0.0 if lists are empty or invalid.
    """
    if len(rater1_scores) != len(rater2_scores) or not rater1_scores:
        return 0.0

    # Confusion Matrix
    # Categories are 1, 2, 3, 4, 5
    categories = [1, 2, 3, 4, 5]
    n = len(rater1_scores)
    matrix = {c1: {c2: 0 for c2 in categories} for c1 in categories}

    for s1, s2 in zip(rater1_scores, rater2_scores):
        if s1 in categories and s2 in categories:
            matrix[s1][s2] += 1

    # Observed Agreement (p_o)
    # Sum of diagonal elements / Total items
    agreement_count = sum(matrix[c][c] for c in categories)
    p_o = agreement_count / n

    # Expected Agreement (p_e)
    # Sum of (prob r1 chooses c * prob r2 chooses c) for all c
    p_e = 0.0
    for c in categories:
        count_r1 = sum(matrix[c][k] for k in categories) # Row sum
        count_r2 = sum(matrix[k][c] for k in categories) # Col sum
        prob_r1 = count_r1 / n
        prob_r2 = count_r2 / n
        p_e += prob_r1 * prob_r2

    if p_e == 1.0:
        return 1.0 # Perfect agreement by chance (all same values)

    kappa = (p_o - p_e) / (1 - p_e)
    return kappa

def run_calibration_scoring(
    agents: Dict[str, LLMInterface],
    calibration_pairs: List[Dict[str, str]]
) -> Dict[str, Any]:
    """
    Runs the calibration phase where all agents score the same set of pairs.

    This function iterates through the calibration dataset, queries each proxy agent,
    and computes the inter-annotator agreement using Cohen's Kappa. High agreement
    indicates that the agents (and the prompt instructions) are aligned on the
    definition of semantic equivalence.

    Args:
        agents (Dict[str, LLMInterface]): Initialized LLM agents.
        calibration_pairs (List[Dict[str, str]]): List of dicts with 'original' and 'compressed' text.

    Returns:
        Dict[str, Any]: Calibration results including raw scores and the average Kappa agreement metric.
    """
    template = get_annotation_prompt_template()
    results = []

    logger.info(f"Starting calibration scoring for {len(calibration_pairs)} pairs...")

    for i, pair in enumerate(calibration_pairs):
        orig = pair["original"]
        comp = pair["compressed"]
        pair_id = pair.get("id", f"calib_{i}")

        # Use neutral labels as per the remedied template
        prompt_text = template.format(text_a=orig, text_b=comp)
        messages = [{"role": "user", "content": prompt_text}]

        pair_results = {"pair_id": pair_id, "scores": {}}

        for agent_id, agent in agents.items():
            try:
                # Call LLM with temperature 0.0 for consistent evaluation
                response = agent.prompt(messages=messages, temperature=0.0)
                text_response, _ = agent.extract_response(response)

                # Robust JSON Parsing
                try:
                    start = text_response.find("{")
                    end = text_response.rfind("}") + 1
                    if start != -1 and end != -1:
                        json_str = text_response[start:end]
                        data = json.loads(json_str)
                        score = int(data["score"])
                        # Clamp score to valid range 1-5
                        score = max(1, min(5, score))
                        pair_results["scores"][agent_id] = score
                    else:
                        logger.warning(f"No JSON found in response from {agent_id} for {pair_id}")
                        pair_results["scores"][agent_id] = None
                except (json.JSONDecodeError, ValueError, KeyError) as e:
                    logger.warning(f"Failed to parse response from {agent_id} for {pair_id}: {e}")
                    pair_results["scores"][agent_id] = None

            except Exception as e:
                logger.error(f"Agent {agent_id} failed on {pair_id}: {e}")
                pair_results["scores"][agent_id] = None

        results.append(pair_results)

    # Compute Agreement (Average Pairwise Cohen's Kappa)
    agent_ids = list(agents.keys())
    vectors = {aid: [] for aid in agent_ids}

    # Collect valid scores aligned by index
    valid_indices = []
    for idx, res in enumerate(results):
        scores = res["scores"]
        # Only include pairs where ALL agents provided a score
        if all(scores.get(aid) is not None for aid in agent_ids):
            for aid in agent_ids:
                vectors[aid].append(scores[aid])
            valid_indices.append(idx)

    kappa_scores = []
    if len(valid_indices) > 1:
        for i in range(len(agent_ids)):
            for j in range(i + 1, len(agent_ids)):
                a1 = agent_ids[i]
                a2 = agent_ids[j]
                v1 = vectors[a1]
                v2 = vectors[a2]

                kappa = calculate_cohen_kappa(v1, v2)
                kappa_scores.append(kappa)

    avg_kappa = sum(kappa_scores) / len(kappa_scores) if kappa_scores else 0.0

    logger.info(f"Calibration complete. Average Cohen's Kappa: {avg_kappa:.4f}")

    return {
        "raw_results": results,
        "agreement_metric": float(avg_kappa),
        "num_valid_pairs": len(valid_indices)
    }

# -------------------------------------------------------------------------------------------------------------------------------
# Task 33, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def run_llm_calibration_task(
    tatqa_pairs: List[Dict[str, str]],
    finqa_pairs: List[Dict[str, str]]
) -> Dict[str, Any]:
    """
    Orchestrates the LLM-based annotator recruitment and calibration.

    1. Initializes proxy agents.
    2. Selects calibration sample (15 from each dataset).
    3. Runs scoring.
    4. Reports agreement.

    Args:
        tatqa_pairs (List[Dict]): Available TAT-QA pairs (original, compressed).
        finqa_pairs (List[Dict]): Available Fin-QA pairs.

    Returns:
        Dict: Calibration report.
    """
    logger.info("Starting LLM proxy calibration...")

    # Step 1: Initialize
    agents = initialize_proxy_annotators()

    # Step 2: Sample
    # We need 15 from each.
    sample_tatqa = random.sample(tatqa_pairs, min(15, len(tatqa_pairs)))
    sample_finqa = random.sample(finqa_pairs, min(15, len(finqa_pairs)))
    calibration_set = sample_tatqa + sample_finqa
    random.shuffle(calibration_set)

    logger.info(f"Calibration set size: {len(calibration_set)}")

    # Step 3: Run
    calibration_results = run_calibration_scoring(agents, calibration_set)

    logger.info(f"Calibration complete. Agreement: {calibration_results['agreement_metric']:.4f}")

    return {
        "agents": list(agents.keys()),
        "results": calibration_results
    }


In [None]:
# Task 34 – Conduct Main Human Evaluation

# ==============================================================================
# Task 34: Conduct Main Human Evaluation (LLM Proxies)
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 34, Step 1: Sample evaluation pairs
# -------------------------------------------------------------------------------------------------------------------------------
def sample_evaluation_pairs(
    tatqa_pairs: List[Dict[str, str]],
    finqa_pairs: List[Dict[str, str]],
    total_samples: int = 90
) -> List[Dict[str, Any]]:
    """
    Samples pairs for the main evaluation from TAT-QA and Fin-QA datasets.

    This function attempts to sample an equal number of pairs from each dataset.
    If a dataset has fewer pairs than required, it takes all available pairs
    and fills the remainder from the other dataset if possible.

    Args:
        tatqa_pairs (List[Dict]): List of TAT-QA pairs (dicts with 'original', 'compressed').
        finqa_pairs (List[Dict]): List of Fin-QA pairs (dicts with 'original', 'compressed').
        total_samples (int): Total number of pairs to sample (default 90).

    Returns:
        List[Dict]: List of sampled pairs with metadata (id, dataset).
    """
    n_tatqa = total_samples // 2
    n_finqa = total_samples - n_tatqa

    # Adjust if one dataset is too small
    if len(tatqa_pairs) < n_tatqa:
        n_tatqa = len(tatqa_pairs)
        n_finqa = min(len(finqa_pairs), total_samples - n_tatqa)
    elif len(finqa_pairs) < n_finqa:
        n_finqa = len(finqa_pairs)
        n_tatqa = min(len(tatqa_pairs), total_samples - n_finqa)

    sample_t = random.sample(tatqa_pairs, n_tatqa)
    sample_f = random.sample(finqa_pairs, n_finqa)

    # Add metadata
    for p in sample_t:
        p["dataset"] = "TAT-QA"
    for p in sample_f:
        p["dataset"] = "Fin-QA"

    combined = sample_t + sample_f
    random.shuffle(combined)

    # Add unique pair IDs
    for i, pair in enumerate(combined):
        pair["id"] = f"eval_{i:03d}"

    logger.info(f"Sampled {len(combined)} pairs for evaluation ({len(sample_t)} TAT-QA, {len(sample_f)} Fin-QA).")
    return combined

# -------------------------------------------------------------------------------------------------------------------------------
# Task 34, Step 2: Collect ratings
# -------------------------------------------------------------------------------------------------------------------------------
def collect_proxy_ratings(
    agents: Dict[str, LLMInterface],
    pairs: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
    """
    Collects semantic equivalence ratings from all proxy agents for the given evaluation pairs.

    This function implements a randomized evaluation protocol to mitigate position bias.
    For each pair, the order of presentation ("Text A" vs "Text B") is randomized.
    The function queries each agent, parses the JSON response, and records the score,
    reasoning, and the specific mapping used for that instance.

    Args:
        agents (Dict[str, LLMInterface]): Dictionary of initialized proxy agents.
        pairs (List[Dict[str, Any]]): List of evaluation pairs containing 'id', 'original', 'compressed'.

    Returns:
        List[Dict[str, Any]]: A list of result dictionaries. Each dictionary contains:
            - pair_id, dataset, example_id
            - mapping: Which text was presented as Text A ("original" or "compressed")
            - scores: Dict mapping agent_id to integer score (1-5)
            - reasoning: Dict mapping agent_id to reasoning string
    """
    template = get_annotation_prompt_template()
    results = []

    logger.info(f"Starting data collection for {len(pairs)} pairs...")

    for i, pair in enumerate(pairs):
        pair_id = pair["id"]
        orig_text = pair["original"]
        comp_text = pair["compressed"]

        # Randomize presentation order to prevent position bias
        # swap = True means Text A is Compressed, Text B is Original
        swap = random.choice([True, False])

        if swap:
            text_a = comp_text
            text_b = orig_text
            mapping = "compressed_first"
        else:
            text_a = orig_text
            text_b = comp_text
            mapping = "original_first"

        # Construct prompt with neutral labels
        prompt_text = template.format(text_a=text_a, text_b=text_b)
        messages = [{"role": "user", "content": prompt_text}]

        pair_result = {
            "pair_id": pair_id,
            "dataset": pair.get("dataset", "unknown"),
            "example_id": pair.get("example_id", "unknown"),
            "mapping": mapping,
            "scores": {},
            "reasoning": {}
        }

        for agent_id, agent in agents.items():
            try:
                # Call LLM with temperature 0.0 for deterministic evaluation
                response = agent.prompt(messages=messages, temperature=0.0)
                text_response, _ = agent.extract_response(response)

                # Robust JSON Parsing
                try:
                    # Find the JSON object within the response
                    start = text_response.find("{")
                    end = text_response.rfind("}") + 1
                    if start != -1 and end != -1:
                        json_str = text_response[start:end]
                        data = json.loads(json_str)

                        score = int(data["score"])
                        # Clamp score to valid range 1-5
                        score = max(1, min(5, score))

                        pair_result["scores"][agent_id] = score
                        pair_result["reasoning"][agent_id] = data.get("reasoning", "")
                    else:
                        logger.warning(f"No JSON found in response from {agent_id} for {pair_id}")
                        pair_result["scores"][agent_id] = None
                        pair_result["reasoning"][agent_id] = "Parse Error: No JSON found"

                except (json.JSONDecodeError, ValueError, KeyError) as e:
                    logger.warning(f"Agent {agent_id} failed JSON parse on {pair_id}: {e}")
                    pair_result["scores"][agent_id] = None
                    pair_result["reasoning"][agent_id] = f"Parse Error: {str(e)}"

            except Exception as e:
                logger.error(f"Agent {agent_id} API error on {pair_id}: {e}")
                pair_result["scores"][agent_id] = None
                pair_result["reasoning"][agent_id] = f"API Error: {str(e)}"

        results.append(pair_result)

        if (i + 1) % 10 == 0:
            logger.info(f"Collected ratings for {i + 1}/{len(pairs)} pairs.")

    logger.info("Data collection complete.")
    return results


# -------------------------------------------------------------------------------------------------------------------------------
# Task 34, Step 3: Store raw ratings (Orchestrator)
# -------------------------------------------------------------------------------------------------------------------------------
def run_main_evaluation_task(
    tatqa_pairs: List[Dict[str, str]],
    finqa_pairs: List[Dict[str, str]],
    agents: Dict[str, LLMInterface],
    output_dir: str = "human_eval_results"
) -> List[Dict[str, Any]]:
    """
    Orchestrates the main human evaluation task using LLM proxies.

    1. Samples evaluation pairs from both datasets.
    2. Collects ratings from all proxy agents.
    3. Persists the raw ratings to disk.

    Args:
        tatqa_pairs (List[Dict]): List of TAT-QA pairs.
        finqa_pairs (List[Dict]): List of Fin-QA pairs.
        agents (Dict): Dictionary of initialized proxy agents.
        output_dir (str): Directory to save results.

    Returns:
        List[Dict]: The list of raw rating results.
    """
    logger.info("Starting main evaluation...")

    # Step 1: Sample
    eval_pairs = sample_evaluation_pairs(tatqa_pairs, finqa_pairs)

    # Step 2: Collect
    ratings = collect_proxy_ratings(agents, eval_pairs)

    # Step 3: Save
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    output_path = os.path.join(output_dir, "raw_ratings.json")
    with open(output_path, "w") as f:
        json.dump(ratings, f, indent=2)

    logger.info(f"Evaluation complete. Ratings saved to {output_path}")
    return ratings


In [None]:
# Task 35 – Aggregate Human Ratings and Compare to Embeddings

# ==============================================================================
# Task 35: Aggregate Human Ratings and Compare to Embeddings
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 35, Step 1 & 2: Aggregate Ratings
# -------------------------------------------------------------------------------------------------------------------------------
def aggregate_ratings(
    raw_ratings: List[Dict[str, Any]]
) -> Tuple[Dict[str, float], float]:
    """
    Aggregates raw ratings per pair and computes the global mean.

    Args:
        raw_ratings (List[Dict]): List of rating records from Task 34.

    Returns:
        Tuple[Dict[str, float], float]:
            - Mapping pair_id -> mean_rating.
            - Global mean rating.
    """
    pair_means = {}
    all_scores = []

    for record in raw_ratings:
        pair_id = record["pair_id"]
        scores = [s for s in record["scores"].values() if s is not None]

        if scores:
            mean_score = float(np.mean(scores))
            pair_means[pair_id] = mean_score
            all_scores.extend(scores)
        else:
            logger.warning(f"No valid scores for pair {pair_id}")

    global_mean = float(np.mean(all_scores)) if all_scores else 0.0

    logger.info(f"Aggregated ratings for {len(pair_means)} pairs. Global Mean: {global_mean:.4f}")
    return pair_means, global_mean

# -------------------------------------------------------------------------------------------------------------------------------
# Task 35, Step 3: Compare to Cosine Similarity
# -------------------------------------------------------------------------------------------------------------------------------
def compare_ratings_to_similarity(
    raw_ratings: List[Dict[str, Any]],
    pair_means: Dict[str, float],
    similarity_scores: Dict[str, float]
) -> Dict[str, Any]:
    """
    Compares human ratings to embedding cosine similarities.

    Identifies mismatches:
    - Case A: High Similarity (>= 0.92) but Low Rating (< 3).
    - Case B: Low Similarity (< 0.92) but High Rating (>= 4).

    Args:
        raw_ratings (List[Dict]): To map pair_id to example_id.
        pair_means (Dict): Mean human rating per pair.
        similarity_scores (Dict): Cosine similarity per example_id.

    Returns:
        Dict: Analysis results including mismatch counts and lists.
    """
    mismatches_A = [] # High Sim, Low Rating
    mismatches_B = [] # Low Sim, High Rating

    # Map pair_id to example_id
    pair_to_example = {r["pair_id"]: r.get("example_id") for r in raw_ratings}

    for pair_id, rating in pair_means.items():
        example_id = pair_to_example.get(pair_id)

        if not example_id or example_id not in similarity_scores:
            continue

        sim = similarity_scores[example_id]

        # Case A
        if sim >= 0.92 and rating < 3.0:
            mismatches_A.append({
                "pair_id": pair_id,
                "example_id": example_id,
                "similarity": sim,
                "rating": rating
            })

        # Case B
        if sim < 0.92 and rating >= 4.0:
            mismatches_B.append({
                "pair_id": pair_id,
                "example_id": example_id,
                "similarity": sim,
                "rating": rating
            })

    total_pairs = len(pair_means)
    frac_A = len(mismatches_A) / total_pairs if total_pairs > 0 else 0.0
    frac_B = len(mismatches_B) / total_pairs if total_pairs > 0 else 0.0

    logger.info(f"Mismatch Case A (High Sim, Low Rating): {len(mismatches_A)} ({frac_A:.2%})")
    logger.info(f"Mismatch Case B (Low Sim, High Rating): {len(mismatches_B)} ({frac_B:.2%})")

    return {
        "mismatches_A": mismatches_A,
        "mismatches_B": mismatches_B,
        "fraction_A": frac_A,
        "fraction_B": frac_B
    }

# -------------------------------------------------------------------------------------------------------------------------------
# Task 35, Orchestrator Function
# -------------------------------------------------------------------------------------------------------------------------------
def analyze_human_vs_embedding_task(
    raw_ratings: List[Dict[str, Any]],
    similarity_results: Dict[str, Any]
) -> Dict[str, Any]:
    """
    Orchestrates the comparison between human ratings and embedding similarities.

    Args:
        raw_ratings (List[Dict]): Output from Task 34.
        similarity_results (Dict): Output from Task 32 (contains 'per_example_scores').

    Returns:
        Dict: Comprehensive analysis report.
    """
    logger.info("Starting human vs embedding analysis...")

    # Step 1 & 2
    pair_means, global_mean = aggregate_ratings(raw_ratings)

    # Step 3
    sim_scores = similarity_results.get("per_example_scores", {})
    comparison = compare_ratings_to_similarity(raw_ratings, pair_means, sim_scores)

    return {
        "global_mean_rating": global_mean,
        "pair_means": pair_means,
        "comparison_analysis": comparison
    }


In [None]:
# Task 36 – Design Orchestrator Function Responsibilities

# ==============================================================================
# Task 36: Design Orchestrator Function Responsibilities
# ==============================================================================

# -------------------------------------------------------------------------------------------------------------------------------
# Task 36, Step 1: Define Pipeline Stages
# -------------------------------------------------------------------------------------------------------------------------------
class PipelineStage(Enum):
    """
    Enumeration of distinct stages in the CompactPrompt pipeline.

    This enum defines the sequential processing steps required to transform raw input
    into a compressed prompt and obtain model predictions. It serves as a control
    structure for the orchestrator to manage execution flow and logging.

    Attributes:
        INPUT_PROCESSING: Retrieval and validation of raw example data.
        SERIALIZATION: Conversion of structured data (tables) to string format.
        DYNAMIC_SCORING: Computation of token-level importance via scorer LLM.
        HARD_PRUNING: Removal of low-importance phrases based on combined scores.
        DATA_COMPRESSION: Application of n-gram abbreviation and numeric quantization.
        EXEMPLAR_SELECTION: Retrieval of few-shot examples (random or representative).
        PROMPT_ASSEMBLY: Construction of the final prompt string.
        INFERENCE: Querying the target LLM for the final answer.
        METRICS_COMPUTATION: Calculation of compression ratios and accuracy.
    """
    INPUT_PROCESSING = "input_processing"
    SERIALIZATION = "serialization"
    DYNAMIC_SCORING = "dynamic_scoring"
    HARD_PRUNING = "hard_pruning"
    DATA_COMPRESSION = "data_compression"
    EXEMPLAR_SELECTION = "exemplar_selection"
    PROMPT_ASSEMBLY = "prompt_assembly"
    INFERENCE = "inference"
    METRICS_COMPUTATION = "metrics_computation"

# -------------------------------------------------------------------------------------------------------------------------------
# Task 36, Step 2 & 3: Define Configuration and Condition Logic
# -------------------------------------------------------------------------------------------------------------------------------
@dataclass
class PipelineConfig:
    """
    Configuration object for a single execution of the CompactPrompt pipeline.

    This class parses the experimental condition string (e.g., "compressed_plus_data")
    and sets boolean flags to control which pipeline stages are executed. It encapsulates
    all hyperparameters required for compression algorithms, ensuring consistent
    configuration across the pipeline.

    Attributes:
        condition (str): The experimental condition identifier. Must be one of the
            predefined conditions (e.g., "baseline", "compressed_prompt").
        target_llm (str): Identifier for the target LLM used for inference.
        scorer_llm (str): Identifier for the scorer LLM used for dynamic self-information.
        ngram_params (Dict[str, int]): Parameters for n-gram abbreviation.
            Expected keys: 'top_n_T' (int), 'ngram_size_G' (int).
            Defaults to {'top_n_T': 3, 'ngram_size_G': 2}.
        quantization_params (Dict[str, Any]): Parameters for numeric quantization.
            Expected keys: 'bit_width_b' (int), 'num_clusters_k' (int).
            Defaults to empty dict (uses global defaults).

        # Derived Flags (set in __post_init__)
        use_hard_pruning (bool): Whether to apply phrase-level pruning (Task 21).
        use_abbreviation (bool): Whether to apply n-gram abbreviation (Task 24).
        use_quantization (bool): Whether to apply numeric quantization (Task 26/27).
        exemplar_mode (str): Strategy for few-shot examples ("none", "random", "representative").
        add_dict_context (bool): Whether to append the abbreviation dictionary to the prompt.
    """
    condition: str
    target_llm: str
    scorer_llm: str
    ngram_params: Dict[str, int] = field(default_factory=lambda: {"top_n_T": 3, "ngram_size_G": 2})
    quantization_params: Dict[str, Any] = field(default_factory=dict)

    # Derived flags (initialized in post_init based on condition string)
    use_hard_pruning: bool = field(init=False)
    use_abbreviation: bool = field(init=False)
    use_quantization: bool = field(init=False)
    exemplar_mode: str = field(init=False)
    add_dict_context: bool = field(init=False)

    def __post_init__(self) -> None:
        """
        Parses the condition string to set control flags for the pipeline.

        This method interprets the semantic components of the condition string
        to enable or disable specific compression modules.
        """
        c = self.condition

        # Hard Pruning: Enabled if "compressed" is in the condition name
        # (e.g., "compressed_prompt", "compressed_plus_data")
        self.use_hard_pruning = "compressed" in c

        # Data Compression: Enabled if "plus_data" is present
        # This implies both n-gram abbreviation and numeric quantization
        self.use_abbreviation = "plus_data" in c
        self.use_quantization = "plus_data" in c

        # Added Context: Enabled if "added_context" is present
        # Appends the abbreviation dictionary to the prompt context
        self.add_dict_context = "added_context" in c

        # Exemplar Mode: Determine few-shot strategy
        if "plus_3_random" in c:
            self.exemplar_mode = "random"
        elif "plus_3_representative" in c:
            self.exemplar_mode = "representative"
        else:
            self.exemplar_mode = "none"

        logger.debug(f"Pipeline Config Initialized: {self}")


In [None]:
# Task 37 – Define Orchestrator Inputs and Outputs

# -------------------------------------------------------------------------------------------------------------------------------
# Task 37, Step 1: Specify Input Parameters
# -------------------------------------------------------------------------------------------------------------------------------
@dataclass
class OrchestratorInput:
    """
    Encapsulates all input parameters required for a single execution of the
    CompactPrompt orchestrator.

    This dataclass serves as the primary data transfer object (DTO) for passing
    configuration and context into the orchestrator function. It ensures type safety
    and provides default values for optional compression parameters.

    Attributes:
        example_id (str): Unique identifier for the example being processed.
        dataset (str): Name of the dataset ("TAT-QA" or "Fin-QA").
        condition (str): Experimental condition identifier (e.g., "compressed_plus_data").
        target_llm (str): Identifier for the target LLM used for inference.
        scorer_llm (str): Identifier for the scorer LLM used for dynamic self-information.
        ngram_params (Dict[str, int]): Parameters for n-gram abbreviation.
            Expected keys: 'top_n_T' (int), 'ngram_size_G' (int).
            Defaults to {'top_n_T': 3, 'ngram_size_G': 2}.
        quantization_params (Dict[str, Any]): Parameters for numeric quantization.
            Expected keys: 'bit_width_b' (int), 'num_clusters_k' (int).
            Defaults to empty dict (uses global defaults).
    """
    example_id: str
    dataset: str
    condition: str
    target_llm: str
    scorer_llm: str
    ngram_params: Dict[str, int] = field(default_factory=lambda: {"top_n_T": 3, "ngram_size_G": 2})
    quantization_params: Dict[str, Any] = field(default_factory=dict)

# -------------------------------------------------------------------------------------------------------------------------------
# Task 37, Step 2: Specify Output Artifacts
# -------------------------------------------------------------------------------------------------------------------------------
@dataclass
class OrchestratorOutput:
    """
    Encapsulates the results and artifacts produced by the orchestrator.

    This dataclass standardizes the output format of the pipeline, ensuring that
    all necessary data for downstream analysis (accuracy computation, token usage
    tracking, debugging) is captured and easily accessible.

    Attributes:
        final_prompt (str): The exact text string sent to the target LLM for inference.
        prediction (str): The raw text generated by the target LLM.
        compression_metadata (Dict[str, Any]): Detailed metadata about the compression process.
            Includes:
            - 'retained_phrases': List of phrases kept after pruning.
            - 'active_ngrams': Set of n-grams replaced by abbreviations.
            - 'quantization_ranges': Min/max values used for quantization.
        metrics (Dict[str, Any]): Quantitative metrics for performance analysis.
            Includes:
            - 'original_token_count': Token count before compression.
            - 'compressed_token_count': Token count after compression.
            - 'compression_ratio': Ratio of original to compressed tokens.
    """
    final_prompt: str
    prediction: str
    compression_metadata: Dict[str, Any]
    metrics: Dict[str, Any]

# -------------------------------------------------------------------------------------------------------------------------------
# Task 37, Step 3: Define Logging Requirements
# -------------------------------------------------------------------------------------------------------------------------------
@dataclass
class OrchestratorLog:
    """
    Structured log record for a single orchestrator execution.

    This dataclass defines the schema for logging pipeline execution details.
    It is designed to be serialized (e.g., to JSON) for auditing, debugging,
    and performance monitoring.

    Attributes:
        timestamp (str): ISO 8601 timestamp of when the execution occurred.
        input_params (Dict[str, Any]): A copy of the input parameters provided to the orchestrator.
        metrics (Dict[str, Any]): A copy of the output metrics generated by the pipeline.
        latency_ms (float): Total execution time in milliseconds.
        status (str): Execution status, either "success" or "error".
        error_message (Optional[str]): Detailed error message if status is "error", else None.
    """
    timestamp: str
    input_params: Dict[str, Any]
    metrics: Dict[str, Any]
    latency_ms: float
    status: str
    error_message: Optional[str] = None


In [None]:
# Top-Level Orchestrator

# ==============================================================================
# END-TO-END ORCHESTRATOR FUNCTION
# ==============================================================================

def run_compactprompt_pipeline(
    tatqa_raw_df: pd.DataFrame,
    finqa_raw_df: pd.DataFrame,
    study_config: Dict[str, Any],
    condition: str = "compressed_plus_data",
    target_llm: str = "claude-3.5-sonnet",
    scorer_llm: str = "gpt-4o",
    ngram_params: Optional[Dict[str, int]] = None,
    quantization_params: Optional[Dict[str, Any]] = None,
    output_dir: str = "compactprompt_outputs",
    skip_corpus_build: bool = False,
    skip_human_eval: bool = False,
    random_seed: int = 42
) -> Tuple[OrchestratorOutput, OrchestratorLog]:
    """
    Execute the complete CompactPrompt research pipeline end-to-end.

    This function orchestrates all 35 task-specific functions to implement
    the full compression and evaluation workflow from the CompactPrompt paper:
    "A Unified Pipeline for Prompt and Data Compression in LLM Workflows".

    The pipeline consists of 12 phases:
        Phase 1 (Tasks 1-3): Input validation and configuration resolution
        Phase 2 (Tasks 4-6): Data cleansing and normalization
        Phase 3 (Tasks 7-10): Offline corpus construction and static self-information
        Phase 4 (Tasks 11-13): Prompt serialization and template construction
        Phase 5 (Task 14): LLM resource configuration
        Phase 6 (Tasks 15-18): Dynamic scoring and combined score computation
        Phase 7 (Tasks 19-21): Phrase-level grouping and budget-constrained pruning
        Phase 8 (Tasks 22-24): N-gram abbreviation for textual data compression
        Phase 9 (Tasks 25-27): Numeric quantization for table data compression
        Phase 10 (Tasks 28-30): Embedding and representative exemplar selection
        Phase 11 (Tasks 31-32): Semantic similarity evaluation
        Phase 12 (Tasks 33-35): Human evaluation protocol

    Args:
        tatqa_raw_df: Raw TAT-QA DataFrame with columns: example_id, split,
            question_text, tables, passages, answer_type, answer_value, answer_unit.
        finqa_raw_df: Raw Fin-QA DataFrame with identical column schema.
        study_config: Nested configuration dictionary containing all parameters
            for corpus construction, compression, quantization, and evaluation.
        condition: Experimental condition from experiment_design_config.prompting_conditions.
            One of: "baseline", "baseline_plus_3_random", "baseline_plus_3_representative",
            "compressed_prompt", "compressed_plus_3_random", "compressed_plus_3_representative",
            "compressed_plus_data", "compressed_plus_data_plus_added_context".
        target_llm: Model identifier for downstream QA generation. One of:
            "gpt-4o", "gpt-4.1-mini", "claude-3.5-sonnet", "llama-3.3-70b-instruct".
        scorer_llm: Model identifier for dynamic self-information scoring.
            Must support log-probability access.
        ngram_params: Dictionary with keys "top_n_T" (int) and "ngram_size_G" (int).
            Defaults to {"top_n_T": 3, "ngram_size_G": 2} per paper's best configuration.
        quantization_params: Dictionary with key "bit_width" (int) for uniform
            quantization or "num_clusters" (int) for k-means. Defaults to {"bit_width": 8}.
        output_dir: Base directory for all output artifacts including corpus,
            embeddings, logs, and evaluation results.
        skip_corpus_build: If True, attempt to load existing corpus statistics
            from output_dir/corpus_stats instead of rebuilding. Useful for iteration.
        skip_human_eval: If True, skip Tasks 33-35 (human evaluation protocol).
            Useful for rapid experimentation without full evaluation.
        random_seed: Seed for reproducibility of random exemplar selection.

    Returns:
        Tuple containing:
            - OrchestratorOutput: All processed DataFrames, compression artifacts,
              evaluation results, and aggregate metrics.
            - OrchestratorLog: Complete audit trail with timing, status, and hashes.

    Raises:
        ValueError: If input DataFrames fail critical validation checks.
        RuntimeError: If LLM API calls fail after retries.
        FileNotFoundError: If skip_corpus_build=True but corpus files don't exist.

    Example:
        >>> output, log = run_compactprompt_pipeline(
        ...     tatqa_raw_df=tatqa_df,
        ...     finqa_raw_df=finqa_df,
        ...     study_config=config,
        ...     condition="compressed_plus_data",
        ...     target_llm="claude-3.5-sonnet",
        ...     scorer_llm="gpt-4o"
        ... )
        >>> print(f"Mean compression ratio: {output.metrics['mean_compression_ratio']:.2f}x")
        >>> print(f"Mean cosine similarity: {output.metrics['mean_cosine_similarity']:.3f}")

    References:
        Choi et al. (2025). "CompactPrompt: A Unified Pipeline for Prompt and
        Data Compression in LLM Workflows." ACM ICAIF Workshop on LLMs and
        Generative AI for Finance.
    """
    # ==========================================================================
    # INITIALIZATION
    # ==========================================================================

    # Set random seed for reproducibility of random exemplar selection
    random.seed(random_seed)

    # Apply default values for optional parameters if not provided
    if ngram_params is None:
        ngram_params = {"top_n_T": 3, "ngram_size_G": 2}

    # Apply default quantization parameters if not provided
    if quantization_params is None:
        quantization_params = {"bit_width": 8}

    # Configure logging for pipeline execution tracking
    logger = logging.getLogger("CompactPromptPipeline")
    logger.setLevel(logging.INFO)

    # Record pipeline start timestamp in UTC for audit trail
    start_timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()

    # Construct OrchestratorInput dataclass for hashing and reference
    orchestrator_input = OrchestratorInput(
        tatqa_raw_df=tatqa_raw_df,
        finqa_raw_df=finqa_raw_df,
        study_config=study_config,
        condition=condition,
        target_llm=target_llm,
        scorer_llm=scorer_llm,
        ngram_params=ngram_params,
        quantization_params=quantization_params,
        output_dir=output_dir,
        skip_corpus_build=skip_corpus_build,
        skip_human_eval=skip_human_eval
    )

    # Compute SHA-256 hash of input specification for deduplication
    input_repr = f"{condition}|{target_llm}|{scorer_llm}|{ngram_params}|{quantization_params}"
    input_hash = hashlib.sha256(input_repr.encode()).hexdigest()[:16]

    # Initialize orchestrator log with timestamp and input hash
    orchestrator_log = OrchestratorLog(
        timestamp=start_timestamp,
        input_hash=input_hash,
        config_snapshot={}
    )

    # Parse experimental condition to set pipeline control flags
    pipeline_config = PipelineConfig(
        condition=condition,
        ngram_params=ngram_params,
        quantization_scheme="uniform" if "bit_width" in quantization_params else "kmeans"
    )

    # Log pipeline configuration for debugging
    logger.info(f"Pipeline condition: {condition}")
    logger.info(f"Hard pruning: {pipeline_config.use_hard_pruning}")
    logger.info(f"N-gram abbreviation: {pipeline_config.use_ngram_abbreviation}")
    logger.info(f"Quantization: {pipeline_config.use_quantization}")
    logger.info(f"Few-shot mode: {pipeline_config.exemplar_mode}")

    # ==========================================================================
    # PHASE 1: INPUT VALIDATION AND CONFIGURATION RESOLUTION (TASKS 1-3)
    # ==========================================================================

    # Record phase start time for timing statistics
    phase_1_start = time.perf_counter()

    # -------------------------------------------------------------------------
    # Task 1 Orchestrator: Validate TAT-QA DataFrame schema and content quality
    # -------------------------------------------------------------------------

    # Invoke Task 1 orchestrator to validate TAT-QA schema and content
    tatqa_validation_report: ValidationReport = validate_tatqa_dataset(
        tatqa_raw_df=tatqa_raw_df
    )

    # Check if validation passed; if critical errors exist, pipeline must halt
    if tatqa_validation_report.has_critical_errors:
        # Record error in log before raising
        orchestrator_log.errors.append(
            f"TAT-QA validation failed: {tatqa_validation_report.errors}"
        )
        orchestrator_log.status = "FAILED"
        # Raise exception to halt pipeline execution
        raise ValueError(
            f"TAT-QA validation failed with {len(tatqa_validation_report.errors)} errors"
        )

    # Log successful TAT-QA validation
    logger.info(f"TAT-QA validation passed: {len(tatqa_raw_df)} rows")

    # -------------------------------------------------------------------------
    # Task 2 Orchestrator: Validate Fin-QA DataFrame schema and content quality
    # -------------------------------------------------------------------------

    # Invoke Task 2 orchestrator to validate Fin-QA schema and content
    finqa_validation_report: ValidationReport = validate_finqa_dataset(
        finqa_raw_df=finqa_raw_df
    )

    # Check if validation passed; halt pipeline on critical errors
    if finqa_validation_report.has_critical_errors:
        # Record error in log before raising
        orchestrator_log.errors.append(
            f"Fin-QA validation failed: {finqa_validation_report.errors}"
        )
        orchestrator_log.status = "FAILED"
        # Raise exception to halt pipeline execution
        raise ValueError(
            f"Fin-QA validation failed with {len(finqa_validation_report.errors)} errors"
        )

    # Log successful Fin-QA validation
    logger.info(f"Fin-QA validation passed: {len(finqa_raw_df)} rows")

    # -------------------------------------------------------------------------
    # Task 3 Orchestrator: Resolve and validate study configuration
    # -------------------------------------------------------------------------

    # Invoke Task 3 orchestrator to inventory placeholders and assign defaults
    resolved_config: Dict[str, Any] = validate_and_update_study_config(
        study_config=study_config
    )

    # Store resolved configuration snapshot in log for reproducibility
    orchestrator_log.config_snapshot = resolved_config

    # Log configuration resolution completion
    logger.info("Study configuration validated and placeholders resolved")

    # Record Phase 1 elapsed time
    phase_1_elapsed = time.perf_counter() - phase_1_start
    orchestrator_log.stage_timings[PipelineStage.VALIDATION] = phase_1_elapsed

    # ==========================================================================
    # PHASE 2: DATA CLEANSING AND NORMALIZATION (TASKS 4-6)
    # ==========================================================================

    # Record phase start time
    phase_2_start = time.perf_counter()

    # -------------------------------------------------------------------------
    # Task 4 Orchestrator: Cleanse and handle missing entries in TAT-QA
    # -------------------------------------------------------------------------

    # Invoke Task 4 orchestrator to drop/repair malformed rows in TAT-QA
    tatqa_clean_df, tatqa_cleansing_log = cleanse_tatqa_dataset(
        tatqa_raw_df=tatqa_raw_df
    )

    # Record number of rows dropped for audit trail
    rows_dropped_tatqa: int = tatqa_cleansing_log.total_rows_dropped

    # Log TAT-QA cleansing results
    logger.info(f"TAT-QA cleansed: {rows_dropped_tatqa} rows dropped, {len(tatqa_clean_df)} remaining")

    # -------------------------------------------------------------------------
    # Task 5 Orchestrator: Cleanse and handle missing entries in Fin-QA
    # -------------------------------------------------------------------------

    # Invoke Task 5 orchestrator to cleanse Fin-QA dataset
    finqa_clean_df, finqa_cleansing_log = cleanse_finqa_dataset(
        finqa_raw_df=finqa_raw_df
    )

    # Record number of rows dropped for Fin-QA
    rows_dropped_finqa: int = finqa_cleansing_log.total_rows_dropped

    # Log Fin-QA cleansing results
    logger.info(f"Fin-QA cleansed: {rows_dropped_finqa} rows dropped, {len(finqa_clean_df)} remaining")

    # -------------------------------------------------------------------------
    # Task 6 Orchestrator: Normalize numeric columns and answer representations
    # -------------------------------------------------------------------------

    # Invoke Task 6 orchestrator to identify numeric columns and parse answers
    tatqa_normalized_df, finqa_normalized_df, numeric_metadata = normalize_data_task(
        tatqa_df=tatqa_clean_df,
        finqa_df=finqa_clean_df,
        output_dir=f"{output_dir}/processed_data"
    )

    # Log normalization results
    numeric_col_count = sum(1 for v in numeric_metadata.values() if v)
    logger.info(f"Normalization complete: {numeric_col_count} numeric columns identified")

    # Record Phase 2 elapsed time
    phase_2_elapsed = time.perf_counter() - phase_2_start
    orchestrator_log.stage_timings[PipelineStage.CLEANSING] = phase_2_elapsed

    # Record rows processed for audit trail
    orchestrator_log.rows_processed = {
        "TAT-QA": len(tatqa_normalized_df),
        "Fin-QA": len(finqa_normalized_df)
    }

    # ==========================================================================
    # PHASE 3: OFFLINE CORPUS AND STATIC SELF-INFORMATION (TASKS 7-10)
    # ==========================================================================

    # Record phase start time
    phase_3_start = time.perf_counter()

    # Define paths for corpus artifacts
    corpus_path = f"{output_dir}/offline_corpus.jsonl"
    corpus_stats_dir = f"{output_dir}/corpus_stats"

    # -------------------------------------------------------------------------
    # Task 7 Orchestrator: Build offline corpus (conditional)
    # -------------------------------------------------------------------------

    if not skip_corpus_build:
        # Invoke Task 7 orchestrator to collect and normalize corpus documents
        corpus_path = build_offline_corpus(
            config=resolved_config,
            output_path=corpus_path
        )

        # Log corpus construction completion
        logger.info(f"Offline corpus built: {corpus_path}")

        # ---------------------------------------------------------------------
        # Task 8 Orchestrator: Tokenize corpus and compute token counts
        # ---------------------------------------------------------------------

        # Invoke Task 8 orchestrator to tokenize and count tokens
        token_counts, total_tokens_N = compute_corpus_statistics(
            config=resolved_config,
            corpus_path=corpus_path,
            output_dir=corpus_stats_dir
        )

        # Log tokenization results
        logger.info(f"Corpus tokenized: {total_tokens_N:,} total tokens, {len(token_counts):,} unique")

        # ---------------------------------------------------------------------
        # Task 9 Orchestrator: Compute token frequencies and probabilities
        # ---------------------------------------------------------------------

        # Invoke Task 9 orchestrator to convert counts to probabilities
        token_probabilities: Dict[int, float] = compute_token_probabilities(
            token_counts=token_counts,
            total_tokens_N=total_tokens_N
        )

        # Log probability computation
        logger.info("Token probabilities computed and validated")

        # ---------------------------------------------------------------------
        # Task 10 Orchestrator: Compute static self-information scores
        # ---------------------------------------------------------------------

        # Invoke Task 10 orchestrator to compute I(T) = -log2(p(T))
        s_stat_lookup: Dict[int, float] = compute_static_self_information(
            p=token_probabilities,
            output_dir=corpus_stats_dir,
            extra_metadata={"corpus_path": corpus_path, "total_tokens": total_tokens_N}
        )

        # Log static self-information computation
        logger.info(f"Static self-information computed: {len(s_stat_lookup):,} tokens")

    else:
        # Load existing corpus statistics from disk
        import json

        # Load pre-computed token counts
        with open(f"{corpus_stats_dir}/token_statistics.json", "r") as f:
            stats = json.load(f)
            token_counts = Counter({int(k): v for k, v in stats["counts"].items()})
            total_tokens_N = stats["total_tokens"]

        # Recompute probabilities from loaded counts
        token_probabilities = compute_token_probabilities(
            token_counts=token_counts,
            total_tokens_N=total_tokens_N
        )

        # Load pre-computed static self-information
        with open(f"{corpus_stats_dir}/static_self_information.json", "r") as f:
            s_stat_data = json.load(f)
            s_stat_lookup = {int(k): v for k, v in s_stat_data["scores"].items()}

        # Log that existing stats were loaded
        logger.info(f"Loaded existing corpus statistics from {corpus_stats_dir}")

    # Record Phase 3 elapsed time
    phase_3_elapsed = time.perf_counter() - phase_3_start
    orchestrator_log.stage_timings[PipelineStage.CORPUS_STATS] = phase_3_elapsed

    # ==========================================================================
    # PHASE 4: PROMPT SERIALIZATION AND TEMPLATE CONSTRUCTION (TASKS 11-13)
    # ==========================================================================

    # Record phase start time
    phase_4_start = time.perf_counter()

    # -------------------------------------------------------------------------
    # Task 11 Orchestrator: Serialize tables to text format
    # -------------------------------------------------------------------------

    # Invoke Task 11 orchestrator to convert table structures to Markdown
    tatqa_serialized_df, finqa_serialized_df = serialize_tables_task(
        tatqa_df=tatqa_normalized_df,
        finqa_df=finqa_normalized_df
    )

    # Log serialization completion
    logger.info("Tables serialized to Markdown format")

    # -------------------------------------------------------------------------
    # Tasks 12-13 are invoked per-example within dynamic scoring phase
    # -------------------------------------------------------------------------

    # Record Phase 4 elapsed time
    phase_4_elapsed = time.perf_counter() - phase_4_start
    orchestrator_log.stage_timings[PipelineStage.SERIALIZATION] = phase_4_elapsed

    # ==========================================================================
    # PHASE 5: LLM RESOURCE CONFIGURATION (TASK 14)
    # ==========================================================================

    # Record phase start time
    phase_5_start = time.perf_counter()

    # -------------------------------------------------------------------------
    # Task 14 Orchestrator: Configure scorer and target LLM interfaces
    # -------------------------------------------------------------------------

    # Invoke Task 14 orchestrator to instantiate all LLM interfaces
    llm_interfaces: Dict[str, LLMInterface] = configure_llm_resources(
        study_config=resolved_config
    )

    # Extract scorer interface for dynamic self-information computation
    scorer_interface: LLMInterface = llm_interfaces[scorer_llm]

    # Extract target interfaces for downstream QA generation
    target_interface: LLMInterface = llm_interfaces[target_llm]

    # Log LLM configuration
    logger.info(f"LLM interfaces configured: scorer={scorer_llm}, target={target_llm}")

    # Record Phase 5 elapsed time
    phase_5_elapsed = time.perf_counter() - phase_5_start
    orchestrator_log.stage_timings[PipelineStage.LLM_CONFIG] = phase_5_elapsed

    # ==========================================================================
    # PHASE 6: DYNAMIC SCORING AND COMBINED SCORES (TASKS 15-18)
    # ==========================================================================

    # Record phase start time
    phase_6_start = time.perf_counter()

    # -------------------------------------------------------------------------
    # Task 15 Orchestrator: Prepare dynamic scoring inputs
    # -------------------------------------------------------------------------

    # Invoke Task 15 orchestrator to serialize prompts and tokenize with offsets
    tatqa_dynamic_inputs: Dict[str, Dict[str, Any]] = prepare_dynamic_scoring_inputs(
        df=tatqa_serialized_df,
        dataset_name="TAT-QA",
        scorer_interface=scorer_interface,
        exemplars_str=""
    )

    # Prepare Fin-QA dynamic scoring inputs
    finqa_dynamic_inputs: Dict[str, Dict[str, Any]] = prepare_dynamic_scoring_inputs(
        df=finqa_serialized_df,
        dataset_name="Fin-QA",
        scorer_interface=scorer_interface,
        exemplars_str=""
    )

    # Merge inputs from both datasets for unified processing
    all_dynamic_inputs: Dict[str, Dict[str, Any]] = {
        **tatqa_dynamic_inputs,
        **finqa_dynamic_inputs
    }

    # Log dynamic input preparation
    logger.info(f"Dynamic scoring inputs prepared: {len(all_dynamic_inputs)} examples")

    # -------------------------------------------------------------------------
    # Task 16 Orchestrator: Query LLM for conditional log-probabilities
    # -------------------------------------------------------------------------

    # Invoke Task 16 orchestrator to obtain P_model(t_i | c_i) for each token
    logprobs_map: Dict[str, List[float]] = get_prompt_logprobs_task(
        dynamic_inputs=all_dynamic_inputs,
        scorer_interface=scorer_interface
    )

    # Log logprob retrieval completion
    logger.info(f"Log-probabilities retrieved for {len(logprobs_map)} prompts")

    # -------------------------------------------------------------------------
    # Task 17 Orchestrator: Compute dynamic self-information
    # -------------------------------------------------------------------------

    # Invoke Task 17 orchestrator to convert logprobs to self-information bits
    s_dyn_map: Dict[str, List[float]] = compute_dynamic_scores_task(
        logprobs_map=logprobs_map
    )

    # Log dynamic score computation
    logger.info("Dynamic self-information scores computed")

    # -------------------------------------------------------------------------
    # Task 18 Orchestrator: Compute combined importance scores
    # -------------------------------------------------------------------------

    # Invoke Task 18 orchestrator to fuse static and dynamic scores
    combined_scores_map: Dict[str, List[float]] = compute_combined_scores_task(
        dynamic_inputs=all_dynamic_inputs,
        s_dyn_map=s_dyn_map,
        s_stat_lookup=s_stat_lookup,
        config=resolved_config
    )

    # Log combined score computation
    logger.info("Combined importance scores computed using relative-difference rule")

    # Record Phase 6 elapsed time
    phase_6_elapsed = time.perf_counter() - phase_6_start
    orchestrator_log.stage_timings[PipelineStage.DYNAMIC_SCORING] = phase_6_elapsed

    # ==========================================================================
    # PHASE 7: PHRASE-LEVEL PRUNING (TASKS 19-21) - CONDITIONAL
    # ==========================================================================

    # Record phase start time
    phase_7_start = time.perf_counter()

    # Initialize pruned_results; will be populated based on condition
    pruned_results: Dict[str, Dict[str, Any]] = {}

    if pipeline_config.use_hard_pruning:
        # ---------------------------------------------------------------------
        # Task 19 Orchestrator: Group tokens into phrases using dependency parsing
        # ---------------------------------------------------------------------

        # Invoke Task 19 orchestrator to run spaCy and map tokens to phrases
        phrases_map: Dict[str, List[List[int]]] = group_tokens_into_phrases_task(
            dynamic_inputs=all_dynamic_inputs,
            config=resolved_config
        )

        # Log phrase grouping completion
        logger.info(f"Tokens grouped into phrases for {len(phrases_map)} examples")

        # ---------------------------------------------------------------------
        # Task 20 Orchestrator: Compute phrase-level importance scores
        # ---------------------------------------------------------------------

        # Invoke Task 20 orchestrator to aggregate token scores within phrases
        phrase_scores_data: Dict[str, Dict[str, List[Any]]] = compute_phrase_scores_task(
            combined_scores_map=combined_scores_map,
            phrases_map=phrases_map,
            config=resolved_config
        )

        # Log phrase score computation
        logger.info("Phrase-level importance scores computed")

        # ---------------------------------------------------------------------
        # Task 21 Orchestrator: Prune phrases to enforce token budget
        # ---------------------------------------------------------------------

        # Invoke Task 21 orchestrator to select high-importance phrases under budget
        pruned_results = prune_prompt_task(
            dynamic_inputs=all_dynamic_inputs,
            phrase_scores_data=phrase_scores_data,
            phrases_map=phrases_map,
            config=resolved_config
        )

        # Log pruning completion
        mean_cr = np.mean([r["compression_ratio"] for r in pruned_results.values()])
        logger.info(f"Prompts pruned: mean compression ratio = {mean_cr:.2f}x")

    else:
        # No hard pruning; use original prompts as compressed prompts
        for example_id, inputs in all_dynamic_inputs.items():
            pruned_results[example_id] = {
                "compressed_prompt": inputs["prompt_text"],
                "original_tokens": inputs.get("token_count", len(inputs["token_ids"])),
                "compressed_tokens": inputs.get("token_count", len(inputs["token_ids"])),
                "compression_ratio": 1.0
            }

        # Log that no pruning was applied
        logger.info("Hard pruning skipped (baseline condition)")

    # Record Phase 7 elapsed time
    phase_7_elapsed = time.perf_counter() - phase_7_start
    orchestrator_log.stage_timings[PipelineStage.HARD_PRUNING] = phase_7_elapsed

    # ==========================================================================
    # PHASE 8: N-GRAM ABBREVIATION (TASKS 22-24) - CONDITIONAL
    # ==========================================================================

    # Record phase start time
    phase_8_start = time.perf_counter()

    # Initialize abbreviation artifacts
    abbrev_dict: Optional[Dict[str, Any]] = None
    tatqa_abbreviated_df = tatqa_serialized_df.copy()
    finqa_abbreviated_df = finqa_serialized_df.copy()

    if pipeline_config.use_ngram_abbreviation:
        # ---------------------------------------------------------------------
        # Task 22 Orchestrator: Extract n-grams and compute frequencies
        # ---------------------------------------------------------------------

        # Invoke Task 22 orchestrator to extract top-K n-grams from passages
        top_k_ngrams: List[Tuple[Tuple[int, ...], int]] = extract_top_ngrams_task(
            tatqa_df=tatqa_serialized_df,
            finqa_df=finqa_serialized_df,
            config=resolved_config
        )

        # Log n-gram extraction
        logger.info(f"Extracted {len(top_k_ngrams)} top n-grams for abbreviation")

        # ---------------------------------------------------------------------
        # Task 23 Orchestrator: Construct abbreviation dictionary
        # ---------------------------------------------------------------------

        # Invoke Task 23 orchestrator to build bidirectional placeholder mapping
        abbrev_dict = construct_abbreviation_dict_task(
            top_k_ngrams=top_k_ngrams
        )

        # Log dictionary construction
        logger.info(f"Abbreviation dictionary constructed with {len(abbrev_dict['ngram_to_placeholder'])} entries")

        # ---------------------------------------------------------------------
        # Task 24 Orchestrator: Apply abbreviation to passages
        # ---------------------------------------------------------------------

        # Invoke Task 24 orchestrator to replace top-T n-grams with placeholders
        tatqa_abbreviated_df, finqa_abbreviated_df = apply_abbreviation_task(
            tatqa_df=tatqa_serialized_df,
            finqa_df=finqa_serialized_df,
            abbrev_dict=abbrev_dict,
            config=resolved_config
        )

        # Log abbreviation application
        logger.info("N-gram abbreviation applied to all passages")

    else:
        # Log that abbreviation was skipped
        logger.info("N-gram abbreviation skipped (condition does not require it)")

    # Record Phase 8 elapsed time
    phase_8_elapsed = time.perf_counter() - phase_8_start
    orchestrator_log.stage_timings[PipelineStage.NGRAM_ABBREVIATION] = phase_8_elapsed

    # ==========================================================================
    # PHASE 9: NUMERIC QUANTIZATION (TASKS 25-27) - CONDITIONAL
    # ==========================================================================

    # Record phase start time
    phase_9_start = time.perf_counter()

    # Initialize quantization results
    quantization_results: Optional[Dict[Tuple[str, str, str, int], Dict[str, Any]]] = None

    if pipeline_config.use_quantization:
        # ---------------------------------------------------------------------
        # Task 25 Orchestrator: Extract numeric values for quantization
        # ---------------------------------------------------------------------

        # Invoke Task 25 orchestrator to extract float arrays from numeric columns
        extracted_values: Dict[Tuple[str, str, str, int], Dict[str, Any]] = extract_numeric_values_task(
            tatqa_df=tatqa_abbreviated_df,
            finqa_df=finqa_abbreviated_df,
            numeric_metadata=numeric_metadata
        )

        # Log numeric extraction
        logger.info(f"Extracted numeric values from {len(extracted_values)} columns")

        # ---------------------------------------------------------------------
        # Task 26 or 27 Orchestrator: Apply quantization based on scheme
        # ---------------------------------------------------------------------

        if pipeline_config.quantization_scheme == "uniform":
            # Invoke Task 26 orchestrator for uniform integer quantization
            quantization_results = apply_uniform_quantization_task(
                extracted_values=extracted_values,
                config=resolved_config
            )

            # Log uniform quantization
            logger.info("Uniform integer quantization applied")

        else:
            # Invoke Task 27 orchestrator for k-means quantization
            quantization_results = apply_kmeans_quantization_task(
                extracted_values=extracted_values,
                config=resolved_config
            )

            # Log k-means quantization
            logger.info("K-means quantization applied")

    else:
        # Log that quantization was skipped
        logger.info("Numeric quantization skipped (condition does not require it)")

    # Record Phase 9 elapsed time
    phase_9_elapsed = time.perf_counter() - phase_9_start
    orchestrator_log.stage_timings[PipelineStage.QUANTIZATION] = phase_9_elapsed

    # ==========================================================================
    # PHASE 10: EMBEDDING AND REPRESENTATIVE EXEMPLAR SELECTION (TASKS 28-30)
    # ==========================================================================

    # Record phase start time
    phase_10_start = time.perf_counter()

    # -------------------------------------------------------------------------
    # Task 28 Orchestrator: Embed candidate examples for few-shot selection
    # -------------------------------------------------------------------------

    # Invoke Task 28 orchestrator for TAT-QA embeddings
    tatqa_embeddings, tatqa_example_ids = embed_examples_task(
        df=tatqa_abbreviated_df,
        dataset_name="TAT-QA",
        config=resolved_config,
        output_dir=f"{output_dir}/embeddings"
    )

    # Invoke Task 28 orchestrator for Fin-QA embeddings
    finqa_embeddings, finqa_example_ids = embed_examples_task(
        df=finqa_abbreviated_df,
        dataset_name="Fin-QA",
        config=resolved_config,
        output_dir=f"{output_dir}/embeddings"
    )

    # Log embedding computation
    logger.info(f"Embeddings computed: TAT-QA={tatqa_embeddings.shape}, Fin-QA={finqa_embeddings.shape}")

    # -------------------------------------------------------------------------
    # Task 29 Orchestrator: Select optimal cluster count via silhouette score
    # -------------------------------------------------------------------------

    # Invoke Task 29 orchestrator for TAT-QA optimal k selection
    tatqa_k_star: int = select_optimal_k_task(
        embeddings=tatqa_embeddings,
        config=resolved_config
    )

    # Invoke Task 29 orchestrator for Fin-QA optimal k selection
    finqa_k_star: int = select_optimal_k_task(
        embeddings=finqa_embeddings,
        config=resolved_config
    )

    # Log optimal k selection
    logger.info(f"Optimal cluster counts: TAT-QA k*={tatqa_k_star}, Fin-QA k*={finqa_k_star}")

    # -------------------------------------------------------------------------
    # Task 30 Orchestrator: Select representative exemplars from clusters
    # -------------------------------------------------------------------------

    # Invoke Task 30 orchestrator for TAT-QA representative exemplars
    tatqa_representative_ids: List[str] = select_representative_exemplars_task(
        embeddings=tatqa_embeddings,
        example_ids=tatqa_example_ids,
        k_star=tatqa_k_star,
        config=resolved_config
    )

    # Invoke Task 30 orchestrator for Fin-QA representative exemplars
    finqa_representative_ids: List[str] = select_representative_exemplars_task(
        embeddings=finqa_embeddings,
        example_ids=finqa_example_ids,
        k_star=finqa_k_star,
        config=resolved_config
    )

    # Log representative exemplar selection
    logger.info(f"Representative exemplars selected: TAT-QA={tatqa_representative_ids}, Fin-QA={finqa_representative_ids}")

    # Record Phase 10 elapsed time
    phase_10_elapsed = time.perf_counter() - phase_10_start
    orchestrator_log.stage_timings[PipelineStage.EMBEDDING] = phase_10_elapsed

    # ==========================================================================
    # PHASE 11: SEMANTIC SIMILARITY EVALUATION (TASKS 31-32)
    # ==========================================================================

    # Record phase start time
    phase_11_start = time.perf_counter()

    # -------------------------------------------------------------------------
    # Task 31 Orchestrator: Compute embeddings for semantic similarity
    # -------------------------------------------------------------------------

    # Invoke Task 31 orchestrator to embed original and compressed prompt pairs
    U_orig, V_comp, similarity_example_ids = compute_similarity_embeddings_task(
        dynamic_inputs=all_dynamic_inputs,
        pruned_results=pruned_results,
        config=resolved_config,
        output_dir=f"{output_dir}/similarity_embeddings"
    )

    # Log similarity embedding computation
    logger.info(f"Similarity embeddings computed for {len(similarity_example_ids)} pairs")

    # -------------------------------------------------------------------------
    # Task 32 Orchestrator: Evaluate semantic similarity and compute statistics
    # -------------------------------------------------------------------------

    # Invoke Task 32 orchestrator to compute cosine similarities and statistics
    similarity_results: Dict[str, Any] = evaluate_semantic_similarity_task(
        U=U_orig,
        V=V_comp,
        example_ids=similarity_example_ids,
        config=resolved_config
    )

    # Log similarity evaluation results
    logger.info(
        f"Semantic similarity: mean={similarity_results['mean']:.3f}, "
        f"5th percentile={similarity_results['percentile_5']:.3f}, "
        f"below threshold={similarity_results['below_threshold_count']}"
    )

    # Record Phase 11 elapsed time
    phase_11_elapsed = time.perf_counter() - phase_11_start
    orchestrator_log.stage_timings[PipelineStage.SIMILARITY_EVAL] = phase_11_elapsed

    # ==========================================================================
    # PHASE 12: HUMAN EVALUATION PROTOCOL (TASKS 33-35) - CONDITIONAL
    # ==========================================================================

    # Record phase start time
    phase_12_start = time.perf_counter()

    # Initialize human evaluation results
    human_evaluation: Optional[Dict[str, Any]] = None

    if not skip_human_eval:
        # ---------------------------------------------------------------------
        # Construct evaluation pairs from dynamic inputs and pruned results
        # ---------------------------------------------------------------------

        # Build TAT-QA pairs for evaluation
        tatqa_pairs: List[Dict[str, str]] = [
            {
                "example_id": eid,
                "original": all_dynamic_inputs[eid]["prompt_text"],
                "compressed": pruned_results[eid]["compressed_prompt"]
            }
            for eid in tatqa_example_ids if eid in pruned_results
        ]

        # Build Fin-QA pairs for evaluation
        finqa_pairs: List[Dict[str, str]] = [
            {
                "example_id": eid,
                "original": all_dynamic_inputs[eid]["prompt_text"],
                "compressed": pruned_results[eid]["compressed_prompt"]
            }
            for eid in finqa_example_ids if eid in pruned_results
        ]

        # ---------------------------------------------------------------------
        # Task 33 Orchestrator: Run calibration with LLM-based proxy annotators
        # ---------------------------------------------------------------------

        # Invoke Task 33 orchestrator to calibrate proxy annotators
        calibration_results: Dict[str, Any] = run_llm_calibration_task(
            tatqa_pairs=tatqa_pairs[:15],
            finqa_pairs=finqa_pairs[:15]
        )

        # Log calibration completion
        logger.info("Proxy annotator calibration complete")

        # ---------------------------------------------------------------------
        # Task 34 Orchestrator: Conduct main evaluation on 90 pairs
        # ---------------------------------------------------------------------

        # Extract calibrated agents from calibration results
        calibrated_agents: Dict[str, LLMInterface] = calibration_results["agents"]

        # Invoke Task 34 orchestrator to collect ratings
        raw_ratings: List[Dict[str, Any]] = run_main_evaluation_task(
            tatqa_pairs=tatqa_pairs,
            finqa_pairs=finqa_pairs,
            agents=calibrated_agents,
            output_dir=f"{output_dir}/human_eval_results"
        )

        # Log main evaluation completion
        logger.info(f"Main evaluation complete: {len(raw_ratings)} pairs rated")

        # ---------------------------------------------------------------------
        # Task 35 Orchestrator: Analyze human ratings vs embedding similarity
        # ---------------------------------------------------------------------

        # Invoke Task 35 orchestrator to compare ratings to cosine similarities
        human_embedding_analysis: Dict[str, Any] = analyze_human_vs_embedding_task(
            raw_ratings=raw_ratings,
            similarity_results=similarity_results
        )

        # Log analysis completion
        logger.info(
            f"Human-embedding analysis: mean rating={human_embedding_analysis['mean_human_rating']:.2f}/5, "
            f"mismatch fraction={human_embedding_analysis['mismatch_fraction']:.2%}"
        )

        # Package human evaluation results
        human_evaluation = {
            "calibration": calibration_results,
            "ratings": raw_ratings,
            "analysis": human_embedding_analysis
        }

    else:
        # Log that human evaluation was skipped
        logger.info("Human evaluation skipped (skip_human_eval=True)")

    # Record Phase 12 elapsed time
    phase_12_elapsed = time.perf_counter() - phase_12_start
    orchestrator_log.stage_timings[PipelineStage.HUMAN_EVAL] = phase_12_elapsed

    # ==========================================================================
    # FINAL ASSEMBLY: CONSTRUCT OUTPUT AND FINALIZE LOG
    # ==========================================================================

    # Compute aggregate metrics for output
    mean_compression_ratio = np.mean([
        r["compression_ratio"] for r in pruned_results.values()
    ])

    # Compute mean cosine similarity from similarity results
    mean_cosine_similarity = similarity_results["mean"]

    # Compute mean human rating if available
    mean_human_rating = (
        human_evaluation["analysis"]["mean_human_rating"]
        if human_evaluation is not None
        else None
    )

    # Construct aggregate metrics dictionary
    metrics: Dict[str, float] = {
        "mean_compression_ratio": float(mean_compression_ratio),
        "mean_cosine_similarity": float(mean_cosine_similarity),
        "percentile_5_similarity": float(similarity_results["percentile_5"]),
        "below_threshold_fraction": float(
            similarity_results["below_threshold_count"] / len(similarity_example_ids)
        )
    }

    # Add human evaluation metric if available
    if mean_human_rating is not None:
        metrics["mean_human_rating"] = float(mean_human_rating)

    # Construct final OrchestratorOutput dataclass
    orchestrator_output = OrchestratorOutput(
        tatqa_processed_df=tatqa_abbreviated_df,
        finqa_processed_df=finqa_abbreviated_df,
        pruned_prompts=pruned_results,
        abbreviation_dict=abbrev_dict,
        quantization_results=quantization_results,
        representative_exemplars={
            "TAT-QA": tatqa_representative_ids,
            "Fin-QA": finqa_representative_ids
        },
        similarity_results=similarity_results,
        human_evaluation=human_evaluation,
        metrics=metrics,
        s_stat_lookup=s_stat_lookup,
        validation_reports={
            "TAT-QA": tatqa_validation_report,
            "Fin-QA": finqa_validation_report
        },
        cleansing_logs={
            "TAT-QA": tatqa_cleansing_log,
            "Fin-QA": finqa_cleansing_log
        }
    )

    # Compute output hash for verification
    output_repr = f"{mean_compression_ratio}|{mean_cosine_similarity}|{len(pruned_results)}"
    output_hash = hashlib.sha256(output_repr.encode()).hexdigest()[:16]

    # Finalize orchestrator log
    orchestrator_log.output_hash = output_hash
    orchestrator_log.status = "SUCCESS"
    orchestrator_log.compression_summary = {
        "mean_compression_ratio": float(mean_compression_ratio),
        "total_examples_processed": len(pruned_results),
        "tatqa_examples": len(tatqa_example_ids),
        "finqa_examples": len(finqa_example_ids)
    }

    # Compute total pipeline elapsed time
    total_elapsed = sum(orchestrator_log.stage_timings.values())

    # Log pipeline completion
    logger.info(
        f"Pipeline complete: {orchestrator_log.status}, "
        f"total time={total_elapsed:.1f}s, "
        f"compression ratio={mean_compression_ratio:.2f}x"
    )

    # Return output and log as tuple
    return orchestrator_output, orchestrator_log
