# Retrieval Benchmark
Evaluate Recall@8 on 200 dev samples.

In [2]:
try:
    import finrag, jsonlines, openai
except ImportError as e:
    raise RuntimeError(
        "Missing dependencies – activate .venv‑notebook and run pip install -r requirements-notebook.txt"
    ) from e

In [3]:
import json
import random
import sys, os
sys.path.insert(0, os.path.abspath('src'))
from finrag.chunk_utils import build_candidate_chunks
from finrag.retriever import retrieve_evidence
from finrag.embeddings import EmbeddingStore
import matplotlib.pyplot as plt
import seaborn as sns

In [8]:
# Load 200 random samples (only those with QA)
with open('data/dev_turn.json') as f:
    raw = json.load(f)
# drop any entry missing the 'qa' key
data = [entry for entry in raw if 'qa' in entry]

random.seed(42)
samples = random.sample(data, 200)

In [9]:
# Run retrieval and compute recall per sample
hits = []
recalls = []
failed_ids = []
for sample in samples:
    # gold chunk_ids
    gold = ['table:row:' + str(i) for i in sample['qa'].get('ann_table_rows', [])] + ['text:pre:' + str(i) for i in sample['qa'].get('ann_text_rows', [])]
    # retrieve top 8
    preds = retrieve_evidence(sample, sample['qa']['question'], top_k=8)
    # recall@8 per sample
    if gold:
        inter = set(gold) & set(preds)
        recall_i = len(inter)/len(gold)
        recalls.append(recall_i)
        hits.append(1 if inter else 0)
        if recall_i == 0:
            failed_ids.append(sample['id'])

TooManyRequestsError: status_code: 429, body: {'message': "You are using a Trial key, which is limited to 10 API calls / minute. You can continue to use the Trial key for free or upgrade to a Production key with higher rate limits at 'https://dashboard.cohere.com/api-keys'. Contact us on 'https://discord.gg/XW44jPfYJu' or email us at support@cohere.com with any questions"}

In [None]:
# Total Recall@8
print(f'Recall@8: {sum(hits)}/{len(hits)} = {sum(hits)/len(hits):.3f}')

In [None]:
# Histogram of per-sample recall
sns.histplot(recalls, bins=10)
plt.xlabel('Recall per sample')
plt.ylabel('Count')
plt.title('Recall@8 Distribution')
plt.show()

In [None]:
# List failed sample IDs
print('Failed samples:', failed_ids)