In [None]:
# This notebook can be used to generate results for HousingStatutesQA in the project report
# Luka Rozgic

In [None]:
# ---------------------------------------------------------------------------------------------------------------------------------
# Set up paths
# ---------------------------------------------------------------------------------------------------------------------------------

import time
import pickle
import os
import numpy as np
import json
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import boto3
from collections import Counter, defaultdict
import re

from sentence_transformers import SentenceTransformer

import torch
from torch.nn import DataParallel

from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import SentenceTransformerEmbeddings

import gc

from concurrent.futures import ThreadPoolExecutor, as_completed
from vllm import LLM, SamplingParams

In [None]:
# ---------------------------------------------------------------------------------------------------------------------------------
# Set up paths
# ---------------------------------------------------------------------------------------------------------------------------------
parent_path = os.path.dirname(os.getcwd())
data_path = os.path.join(parent_path, "data")
code_path = os.path.join(parent_path, "code") # notebook is in code
vectorstore_path = os.path.join(data_path, 'statute_vectorstore')

In [None]:
# ---------------------------------------------------------------------------------------------------------------------------------
# Global Variables
# ---------------------------------------------------------------------------------------------------------------------------------

# Embedding models with corresponding max input sequence lenghts and batch size for batch encoding (optimizad for 24GB GPU VRAM)
embedding_models = [{"model name": "intfloat/e5-base-v2",       "max sentence length": 512,   "model id": "e5_base_v2",  "batch size": 16},
                    {"model name": "intfloat/e5-large-v2",      "max sentence length": 512,   "model id": "e5_large_v2", "batch size": 16},
                    {"model name": "Qwen/Qwen3-Embedding-0.6B", "max sentence length": 8192,  "model id": "qwen3_0p6b",  "batch size": 8}]

# Queries for QA (question-only, expansion-only, question + expansion)
retrieval_query_types = ['q', 'claude3sonnet', 'q_claude3sonnet', 'qwen2p57b', 'q_qwen2p57b']

In [None]:
# ---------------------------------------------------------------------------------------------------------------------------------
# Class for FEISS embedding and query to support L2 normalization and cosine similarity calculation
# ---------------------------------------------------------------------------------------------------------------------------------
class CosineQueryEmbeddings:
    def __init__(self, model):
        self.model = model

    def embed_query(self, text):
        embedding = self.model.encode([text])[0]
        return (embedding / np.linalg.norm(embedding)).tolist()

    def embed_documents(self, texts):
        return self.model.encode(texts, normalize_embeddings=True).tolist()

    def __call__(self, text):
        return self.embed_query(text)

In [None]:
# ---------------------------------------------------------------------------------------------------------------------------------
# Functions
# ---------------------------------------------------------------------------------------------------------------------------------

# ---------------------------------------------------------------------------------------------------------------------------------
# Create FAISS vector store from documents
# ---------------------------------------------------------------------------------------------------------------------------------
def create_vectorstore(embedding_models, statutes_list, to_save = True, use_GPUs = 4):

    # Create statute texts for embedding and metadata to be added to the vector store
    statuteTexts = []
    statuteMetadata = []
    for statute in statutes_list:
        statuteTexts.append(statute['text'])
        statuteMetadata.append({"id":statute['idx'], "citation": statute['citation'].lower(), "state": statute['state'].lower()})

    for model_spec_dict in embedding_models:

        # Clean GPU memory
        torch.cuda.empty_cache()
        gc.collect()

        embedding_model_name = model_spec_dict["model name"]
        model = SentenceTransformer(embedding_model_name)
        model.max_seq_length = model_spec_dict["max sentence length"]
        # Reduce footprint of Qwen model to increase eficiency
        if model_spec_dict['model id'] == "qwen3_0p6b": 
            model.half()

        # Encode statutes (multi-GPU) 
        print('Embedding statutes using %s ... ' %(embedding_model_name))
        # Encode statutes (multi-GPU) 

        start_time = time.time()
        embeddings = model.encode(
            statuteTexts, 
            batch_size=model_spec_dict["batch size"],
            device=['cuda:'+ str(cuda_id) for cuda_id in range(0,use_GPUs)],
            show_progress_bar=True,
            normalize_embeddings=True
        )
        print('Statutes embedded using %s in %.2f seconds' % (embedding_model_name, time.time()-start_time))

        # Create vectorstore with calculated embeddings
        cosine_embeddings = CosineQueryEmbeddings(model)
        vectorstore = FAISS.from_embeddings(
            text_embeddings=list(zip(statuteTexts, embeddings)),
            embedding=cosine_embeddings,
            metadatas=statuteMetadata
        )
        if to_save:
            vectorstore.save_local(os.path.join(vectorstore_path, model_spec_dict["model id"]))
    # Clean GPU memory
    torch.cuda.empty_cache()
    gc.collect()
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Single query expansion using Claude 3 Sonnet
# ---------------------------------------------------------------------------------------------------------------------------------
def call_claude_with_retry(question, state, max_retries=3):
    bedrock = boto3.client('bedrock-runtime', region_name='us-east-1')

    prompt = f"""Consider the housing statute for {state} in the year 2021. The question given in "Question:" is a legal question about housing and eviction law in {state}. Provide applicable legal rule in "Rule:". If you do not know the state law, provide governing rules that address the question under typical eviction law.

Question: {question}

Rule:"""

    for attempt in range(max_retries):
        try:
            response = bedrock.invoke_model(
                modelId='anthropic.claude-3-sonnet-20240229-v1:0',
                body=json.dumps({
                    "anthropic_version": "bedrock-2023-05-31",
                    "max_tokens": 500,
                    "messages": [{"role": "user", "content": prompt}]
                })
            )
            result = json.loads(response['body'].read())
            return result['content'][0]['text']
        except Exception as e:
            if "throttling" in str(e).lower() or "rate" in str(e).lower():
                wait_time = (2 ** attempt) + random.uniform(0, 1)
                time.sleep(wait_time)
                continue
            else:
                return f"Error: {str(e)}"
    return "Failed after retries"
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Multi-thread solution to make multiple inference calls to Claude 3 Sonnet and generate query expansions, progress tracking across batches
# ---------------------------------------------------------------------------------------------------------------------------------
def query_expansions_claude_with_batch_progress(questions, max_workers=5, to_save=True):
    results = []
    completed_count = 0

    # Calculate total batches of 100
    total_batches = (len(questions) + 99) // 100

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_question = {
            executor.submit(call_claude_with_retry, q['question'], q['state']): q 
            for q in questions
        }

        # Progress bar for batches of 100
        with tqdm(total=total_batches, desc="Processing batches of 100") as pbar:
            for future in as_completed(future_to_question):
                question = future_to_question[future]
                try:
                    rule = future.result()
                    results.append({
                        'idx': question['idx'],
                        'question': question['question'],
                        'state': question['state'],
                        'rule': rule
                    })
                except Exception as e:
                    results.append({
                        'idx': question['idx'],
                        'question': question['question'],
                        'state': question['state'],
                        'rule': f"Error: {str(e)}"
                    })

                completed_count += 1

                # Update progress bar every 100 completions
                if completed_count % 100 == 0:
                    pbar.update(1)
                    pbar.set_postfix(completed=completed_count)

                time.sleep(0.5)

            # Update for remaining items
            if completed_count % 100 != 0:
                pbar.update(1)
                pbar.set_postfix(completed=completed_count)

    if to_save:
        with open(os.path.join(store_dir,'claude3sonnet_generated_expansions.json'), 'w') as f:
            json.dump(structured_reasoning_results, f, indent=2)

    return results
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Generate all query expansion prompts upfront as for Qwen 2.5 7B Instruct we want to generate expansions in batches for speed
# ---------------------------------------------------------------------------------------------------------------------------------
def generate_query_expansion_prompts(questions):
    """Generate prompts from questions """
    prompts_with_metadata = []

    for q in questions:

        state = q['state']
        question = q['question']
        idx = q['idx']

        prompt = f"""Consider the housing statute for {state} in the year 2021. The question given in "Question:" is a legal question about housing and eviction law in {state}. Provide applicable legal rule in "Rule:". If you do not know the state law, provide governing rules that address the question under typical eviction law.

Question: {question}

Rule:"""
        prompts_with_metadata.append({
            'idx': idx,
            'question': question,
            'state': state,
            'prompt': prompt
        })
    return prompts_with_metadata
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Generate all query expansions using Qwen 2.5 7B Instruct and VLLM
# ---------------------------------------------------------------------------------------------------------------------------------
def generate_query_expansions(prompts_with_metadata, llm, sampling_params, batch_size=32, to_save=True):

    results = []

    # Process in batches
    for i in tqdm(range(0, len(prompts_with_metadata), batch_size), desc="Processing batches"):

        batch_prompts   = [pwm['prompt'] for pwm in prompts_with_metadata[i:i+batch_size]]
        batch_idxs      = [pwm['idx'] for pwm in prompts_with_metadata[i:i+batch_size]]
        batch_states    = [pwm['state'] for pwm in prompts_with_metadata[i:i+batch_size]]
        batch_questions = [pwm['question'] for pwm in prompts_with_metadata[i:i+batch_size]]

        outputs = llm.generate(batch_prompts, sampling_params)

        cnt = 0
        for output in outputs:
            rule = output.outputs[0].text
            results.append({
                        'idx': batch_idxs[cnt],
                        'question': batch_questions[cnt],
                        'state': batch_states[cnt],
                        'rule': rule
            })
            cnt+=1
    if to_save:
        with open(os.path.join(data_path,'qwen2p57b_generated_expansions.json'), 'w') as f:
            json.dump(structured_reasoning_results, f, indent=2)
    return results
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Generate all queries to be used for QA (question only, expansion-only and question+expansion)
# ---------------------------------------------------------------------------------------------------------------------------------
def generate_retrieval_query_list(retrieval_query_type, questions_list):
    """Generate queries for different retrieval query types"""

    if retrieval_query_type == 'q':
        # Just use the question
        return [{**q, 'query': q['question']} for q in questions_list]

    elif retrieval_query_type == 'claude3sonnet':
        # Load Claude 3 Sonnet rules
        with open(os.path.join(data_path, 'claude3sonnet_generated_expansions.json'), 'r') as f:
            claude_rules = json.load(f)
        claude_dict = {item['idx']: item['rule'] for item in claude_rules}
        return [{**q, 'query': claude_dict.get(q['idx'], '')} for q in questions_list]

    elif retrieval_query_type == 'qwen2p57b':
        # Load Qwen 2.5 7B rules
        with open(os.path.join(data_path, 'qwen2p57b_generated_expansions.json'), 'r') as f:
            qwen_rules = json.load(f)
        qwen_dict = {item['idx']: item['rule'] for item in qwen_rules}
        return [{**q, 'query': qwen_dict.get(q['idx'], '')} for q in questions_list]

    elif retrieval_query_type == 'q_claude3sonnet':
        # Concatenate question + Claude rules
        with open(os.path.join(data_path, 'claude3sonnet_generated_expansions.json'), 'r') as f:
            claude_rules = json.load(f)
        claude_dict = {item['idx']: item['rule'] for item in claude_rules}
        return [{**q, 'query': q['question'] + ' ' + claude_dict.get(q['idx'], '')} for q in questions_list]

    elif retrieval_query_type == 'q_qwen2p57b':
        # Concatenate question + Qwen rules
        with open(os.path.join(data_path, 'qwen2p57b_generated_expansions.json'), 'r') as f:
            qwen_rules = json.load(f)
        qwen_dict = {item['idx']: item['rule'] for item in qwen_rules}
        return [{**q, 'query': q['question'] + ' ' + qwen_dict.get(q['idx'], '')} for q in questions_list]

    else:
        raise ValueError(f"Unknown retrieval_query_type: {retrieval_query_type}")
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Collect all retrieves (up to top_to_store) for all combinations of embedding models and queries, output contains ground truths as well
# ---------------------------------------------------------------------------------------------------------------------------------
def generate_recall_results(embedding_models, retrieval_query_types, vectorstore_path, to_save=True, top_to_store=100, top_to_search=10000):
    recall_results = dict()
    for embedding_model in embedding_models:

        recall_results[embedding_model["model id"]] = dict()

        # Clear GPU memory
        torch.cuda.empty_cache()
        gc.collect()

        # Load model to be used for querying (sentence transformers)
        model = SentenceTransformer(embedding_model["model name"])
        model.max_seq_length = model_spec_dict["max sentence length"] 
        if model_spec_dict['model id'] == "qwen3_0p6b":
            model.half()

        # Load corresponding housing statutes vector store
        cosine_embeddings = CosineQueryEmbeddings(model)
        vectorstore = FAISS.load_local(
            os.path.join(vectorstore_path, embedding_model["model id"]), 
            cosine_embeddings,
            allow_dangerous_deserialization=True
        )

        for retrieval_query_type in retrieval_query_types:
            recall_results[embedding_model["model id"]][retrieval_query_type] = []
            query_list = generate_retrieval_query_list(retrieval_query_type, questions_list)

            # Filter valid queries and batch encode once
            valid_queries = [q for q in query_list if q['statutes']]
            questions_text = [q['query'] for q in valid_queries]

            query_embeddings = model.encode(
                questions_text,
                batch_size=embedding_model["batch size"],
                device=['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'],
                normalize_embeddings=True,
                show_progress_bar=True
            )

            # Single batch search for all queries
            scores, indices = vectorstore.index.search(query_embeddings, top_to_search)

            # Populate retrieval store
            # retrievals[model id][retrieval_query_type] = [{'idx': question id, 'state': state, 'golden statute idx': {},  'top_k_statute_idxs': [], 'top_k_statute_scores': []} 

            for query_idx, query in enumerate(tqdm(valid_queries, desc="Processing queries")):

                # Find top 100 in the corresponding state
                retrieved_statute_idxs   = []
                retrieved_statute_scores = []
                for score, idx in zip(scores[query_idx], indices[query_idx]):
                    if idx != -1:
                        doc_id = vectorstore.index_to_docstore_id[idx]
                        doc = vectorstore.docstore.search(doc_id)

                        if doc.metadata['state'] == query['state'].lower():
                            retrieved_statute_idxs.append(doc.metadata['id'])
                            retrieved_statute_scores.append(score)
                            if len(retrieved_statute_idxs) >= top_to_store:
                                break

                recall_results[embedding_model["model id"]][retrieval_query_type].append({
                                                                                        'idx': query['idx'], 
                                                                                        'state': query['state'].lower(), 
                                                                                        'golden statute idx': {s['statute_idx'] for s in query['statutes']},
                                                                                        'retrieved_idxs': retrieved_statute_idxs, 
                                                                                        'retrieved_scores': retrieved_statute_scores})
            torch.cuda.empty_cache()
            gc.collect()
        del model
        torch.cuda.empty_cache()
        gc.collect()
        if to_save:
            with open(os.path.join(data_path, 'recall_results_all.pkl'), 'wb') as f:
                pickle.dump(recall_results, f)
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Calculate recalls from collected retrieval results and ground truths
# ---------------------------------------------------------------------------------------------------------------------------------
def calculate_recall_at_k_with_ci(recall_results, selected_query_idxs, k, n_bootstrap=1000, confidence=0.95, to_print_table=True):

    recall_scores = {}

    for model_id in recall_results:
        recall_scores[model_id] = {}
        for query_type in recall_results[model_id]:
            results = recall_results[model_id][query_type]
            #n_queries = len(results)
            n_queries = 0
            for result in results:
                if result['idx'] in selected_query_idxs:
                    n_queries+=1

            # Calculate individual query hits
            #hits = []
            #for result in results:
            #    golden_idxs = result['golden statute idx']
            #    retrieved_idxs = set(result['retrieved_idxs'][:k])
            #    hits.append(1 if golden_idxs.intersection(retrieved_idxs) else 0)

            hits = []
            for result in results:
                if result['idx'] in selected_query_idxs:
                    golden_idxs = result['golden statute idx']
                    retrieved_idxs = set(result['retrieved_idxs'][:k])
                    hits.append(1 if golden_idxs.intersection(retrieved_idxs) else 0)            

            # Bootstrap sampling
            bootstrap_scores = []
            for _ in range(n_bootstrap):
                sample_hits = np.random.choice(hits, size=n_queries, replace=True)
                bootstrap_scores.append(np.mean(sample_hits))

            # Calculate confidence interval
            alpha = 1 - confidence
            lower_percentile = (alpha/2) * 100
            upper_percentile = (1 - alpha/2) * 100

            recall_scores[model_id][query_type] = {
                'recall': np.mean(hits),
                'ci_lower': np.percentile(bootstrap_scores, lower_percentile),
                'ci_upper': np.percentile(bootstrap_scores, upper_percentile)
            }

    # Print table
    if to_print_table:
        table_data = []
        for model_id in recall_scores:
            for query_type in recall_scores[model_id]:
                scores = recall_scores[model_id][query_type]
                table_data.append({
                    'Model': model_id,
                    'Query Type': query_type,
                    f'Recall@{k}': f"{scores['recall']:.3f}",
                    'CI Range': f"[{scores['ci_lower']:.3f}, {scores['ci_upper']:.3f}]"
                })
        df = pd.DataFrame(table_data)
        print(df.to_string(index=False))

    return recall_scores
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Calculate recall@k per state for all states corresponding to question-statute pairs
# ---------------------------------------------------------------------------------------------------------------------------------
def calculate_recall_per_state_with_ci(recall_results, selected_query_idxs, k, model_id, query_type, n_bootstrap=1000, confidence=0.95):
    import numpy as np
    from collections import defaultdict

    # Group results by state
    state_results = defaultdict(list)
    for result in recall_results[model_id][query_type]:
        if result['idx'] in selected_query_idxs:
            state_results[result['state']].append(result)

    state_recalls = {}
    for state, results in state_results.items():
        # Calculate individual query hits
        hits = []
        for result in results:
            golden_idxs = result['golden statute idx']
            retrieved_idxs = set(result['retrieved_idxs'][:k])
            hits.append(1 if golden_idxs.intersection(retrieved_idxs) else 0)

        if len(hits) == 0:
            continue

        # Bootstrap sampling
        bootstrap_scores = []
        for _ in range(n_bootstrap):
            sample_hits = np.random.choice(hits, size=len(hits), replace=True)
            bootstrap_scores.append(np.mean(sample_hits))

        # Calculate confidence interval
        alpha = 1 - confidence
        lower_percentile = (alpha/2) * 100
        upper_percentile = (1 - alpha/2) * 100

        state_recalls[state] = {
            'recall': np.mean(hits),
            'ci_lower': np.percentile(bootstrap_scores, lower_percentile),
            'ci_upper': np.percentile(bootstrap_scores, upper_percentile),
            'n_queries': len(hits)
        }
    return state_recalls
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Plot recalls per state in increasing order with confidence intervals
# ---------------------------------------------------------------------------------------------------------------------------------
def plot_recall_by_state(state_recalls, k, model_id, query_type, statutes_list):

    # Count statutes per state
    statute_counts = Counter(statute['state'].lower() for statute in statutes_list)

    # Sort states by recall
    sorted_states = sorted(state_recalls.items(), key=lambda x: x[1]['recall'])

    states = [item[0] for item in sorted_states]
    recalls = [item[1]['recall'] for item in sorted_states]
    ci_lower = [item[1]['ci_lower'] for item in sorted_states]
    ci_upper = [item[1]['ci_upper'] for item in sorted_states]
    n_queries = [item[1]['n_queries'] for item in sorted_states]
    n_statutes = [statute_counts[state] for state in states]

    # Calculate error bars
    yerr_lower = [recalls[i] - ci_lower[i] for i in range(len(recalls))]
    yerr_upper = [ci_upper[i] - recalls[i] for i in range(len(recalls))]

    plt.figure(figsize=(16, 8))
    plt.errorbar(range(len(states)), recalls, yerr=[yerr_lower, yerr_upper], 
                 fmt='o', capsize=5, capthick=2)

    # Set x-axis labels with state names
    plt.xticks(range(len(states)), states, rotation=45, ha='right', fontsize=10)

    # Add question and statute counts as vertical text
    for i, (recall, n_q, n_s) in enumerate(zip(recalls, n_queries, n_statutes)):
        plt.text(i + 0.25, recall + 0.02, f'q={n_q}\ns={n_s}', ha='center', va='bottom', fontsize=8, rotation=90)

    plt.ylabel(f'Recall@{k}')
    plt.xlabel('State')
    plt.title(f'Recall@{k} by State - {model_id} ({query_type})')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Generate all prompts for question answering (query, query + golden passage (oracle aka. upper bound), query + top-10 retrieved statutes
# ---------------------------------------------------------------------------------------------------------------------------------
def generate_prompts(questions, retrieved_statutes=None, statute_lookup=None):
    """Generate prompts with three options: question only, golden context, retrieved context"""
    prompts = {
        'question_only': [],
        'golden_context': [],
        'retrieved_context': []
    }
    ground_truths = []

    for i, q in enumerate(questions):
        # Option 1: Question only
        prompt_q_only = f"Question: {q['question']}\n\nAnswer with only 'Yes' or 'No':"
        prompts['question_only'].append(prompt_q_only)

        # Option 2: Question + golden statute excerpts
        golden_excerpts = [s['excerpt'] for s in q['statutes']]
        golden_context = "\n".join(golden_excerpts)
        prompt_golden = f"Context: {golden_context}\n\nQuestion: {q['question']}\n\nAnswer with only 'Yes' or 'No':"
        prompts['golden_context'].append(prompt_golden)

        # Option 3: Question + top 10 retrieved statutes
        if retrieved_statutes and statute_lookup:
            top_10_idxs = retrieved_statutes[i]  # Top 10 retrieved statute indices
            retrieved_texts = [statute_lookup[idx] for idx in top_10_idxs if idx in statute_lookup]
            retrieved_context = "\n".join(retrieved_texts)
            prompt_retrieved = f"Context: {retrieved_context}\n\nQuestion: {q['question']}\n\nAnswer with only 'Yes' or 'No':"
        else:
            prompt_retrieved = prompt_q_only  # Fallback to question only
        prompts['retrieved_context'].append(prompt_retrieved)

        ground_truths.append(q['answer'])

    return prompts, ground_truths
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Generate QA answers with Qwen 2.5 7B Instruct using batch inference with VLLM
# ---------------------------------------------------------------------------------------------------------------------------------
def evaluate_qa_batch(prompts, ground_truths, llm, sampling_params, batch_size=32):
    def extract_yes_no(text):
        """Fast yes/no extraction"""
        text = text.lower()[:10]
        return 'Yes' if 'yes' in text else 'No' if 'no' in text else 'Unknown'

    predictions = []

    # Process in batches
    for i in tqdm(range(0, len(prompts), batch_size), desc="Processing batches"):
        batch_prompts = prompts[i:i+batch_size]
        outputs = llm.generate(batch_prompts, sampling_params)

        for output in outputs:
            prediction = extract_yes_no(output.outputs[0].text)
            predictions.append(prediction)

    return predictions
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Calculate QA accuracy with confidence intervals
# ---------------------------------------------------------------------------------------------------------------------------------
def calculate_metrics_with_ci(predictions, ground_truths):
    """Calculate accuracy with confidence intervals"""
    valid_pairs = [(p, g) for p, g in zip(predictions, ground_truths) if p != 'Unknown']
    if not valid_pairs:
        return {'accuracy': 0, 'total': len(predictions)}

    valid_preds, valid_gts = zip(*valid_pairs)
    correct = sum(p == g for p, g in zip(valid_preds, valid_gts))
    total = len(valid_pairs)
    accuracy = correct / total

    # Bootstrap confidence intervals
    n_bootstrap = 1000
    indices = np.random.choice(total, (n_bootstrap, total), replace=True)
    bootstrap_accs = np.mean([[valid_preds[i] == valid_gts[i] for i in boot_indices] for boot_indices in indices], axis=1)

    ci_lower = np.percentile(bootstrap_accs, 2.5)
    ci_upper = np.percentile(bootstrap_accs, 97.5)
    std = np.std(bootstrap_accs)

    return {
        'accuracy': accuracy,
        'std': std,
        'ci_95': (ci_lower, ci_upper),
        'correct': correct,
        'total': total,
        'unknown': len(predictions) - total
    }
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Is substring helper function to check if golden statute text matches statute database text with the same index
# ---------------------------------------------------------------------------------------------------------------------------------
def is_substring(str1, str2):
    # Normalize: lowercase, keep only letters
    norm_str1 = re.sub(r'[^a-z]', '', str1.lower())
    norm_str2 = re.sub(r'[^a-z]', '', str2.lower())

    # Trim norm_str1 from left to 70% of length
    if len(norm_str1) > 20:
        trim_length = int(len(norm_str1) * 0.5)
        norm_str1 = norm_str1[-trim_length:]

    return norm_str1 in norm_str2
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Identify questions that pass golden passage - statute matching criteria
# ---------------------------------------------------------------------------------------------------------------------------------
def verify_question_statute_match(questions_list, statutes_list, print_per_state_quality = True, to_save=True):
    unique_states = {q['state'].lower() for q in questions_list}

    statute_texts = {statute['idx']: statute['text'] for statute in statutes_list}

    state_match = dict()
    for state in unique_states:
        state_match[state] = []

    selected_questions = []
    for question in questions_list:
        state = question['state'].lower()
        question_statute_idxs = {s['statute_idx']: s['excerpt'] for s in question['statutes']}

        _tmp_match = []
        for idx, golden_text in question_statute_idxs.items():
            if is_substring(golden_text, statute_texts[idx]):
                _tmp_match.append(1)
            else:
                _tmp_match.append(0)
        if not 0 in _tmp_match:
            selected_questions.append(question)
        state_match[state].extend(_tmp_match)

    if print_per_state_quality:
        # Calculate average match percentage per state
        state_averages = []
        for state, matches in state_match.items():
            avg_matches = sum(matches) / len(matches) if matches else 0
            state_averages.append((state, avg_matches))

        # Sort by increasing percentages
        state_averages.sort(key=lambda x: x[1])

        for state, avg_matches in state_averages:
            print(f"{state}: {avg_matches:.3f}")
    if to_save:
        with open(os.path.join(data_path, 'selected_questions_with_matching_statutes.pkl'), 'wb') as f:
            pickle.dump(selected_questions, f)

    return selected_questions
# ---------------------------------------------------------------------------------------------------------------------------------

In [None]:
# ---------------------------------------------------------------------------------------------------------------------------------
# Load Housing Statutes Dataset
# ---------------------------------------------------------------------------------------------------------------------------------

'''
# To enable loading HousingStatute and BarExam datasets a specific older version of datasets is required
!pip3 install -U datasets==2.14.0 
!pip3 install --upgrade huggingface-hub==0.20.0
!pip3 install fsspec==2023.9.2

# Load Housing Statutes Dataset
questions = load_dataset("reglab/housing_qa", "questions", split="test")
statutes = load_dataset("reglab/housing_qa", "statutes", split="corpus")

# Save Housing Statutes Dataset
with open(os.path.join(data_path, "housing_statutes_dataset", "questions.json"), "w") as f:
    json.dump(questions.to_list(), f, indent=2)

with open(os.path.join(data_path, "housing_statutes_dataset", "statutes.json"), "w") as f:
    json.dump(statutes.to_list(), f, indent=2)
'''

with open(os.path.join(data_path, "housing_statutes_dataset", "questions.json"), "r") as f:
    questions_list = json.load(f)

with open(os.path.join(data_path, "housing_statutes_dataset", "statutes.json"), "r") as f:
    statutes_list = json.load(f)
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Create statute vectorstores with all embedding models
# ---------------------------------------------------------------------------------------------------------------------------------
create_vectorstore(embedding_models, statutes_list, to_save = True, use_GPUs = 4)
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Generate Clude 3 Sonnet query expansions
# ---------------------------------------------------------------------------------------------------------------------------------
claude_query_expansions = query_expansions_claude_with_batch_progress(questions, max_workers=5)
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Generate Qwen 2.5 7B Instruct query expansions
# ---------------------------------------------------------------------------------------------------------------------------------
# Initialize VLLM with Qwen 2.5 7B Instruct
llm = LLM(
    model="Qwen/Qwen2.5-7B-Instruct",
    tensor_parallel_size=4,
    gpu_memory_utilization=0.9
)

# Set Qwen 2.5 7B Instuct generation parameters
sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=512
)
prompts_with_metadata = generate_query_expansion_prompts(questions_list)
qwen_query_expansions = generate_query_expansions(prompts_with_metadata, llm, sampling_params, batch_size=32)
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Verify match between question golden passage text and statute database text with matching statute index and select questions
# ---------------------------------------------------------------------------------------------------------------------------------
selected_questions = verify_question_statute_match(questions_list, statutes_list, print_per_state_quality = True, to_save=True)
selected_query_idxs = {q['idx']: True for q in selected_questions}
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Perform retrieval evaluation (overall)
# ---------------------------------------------------------------------------------------------------------------------------------
recall_results = generate_recall_results(embedding_models, retrieval_query_types, vectorstore_path, to_save=True, top_to_store=100, top_to_search=10000)
recall_scores = calculate_recall_at_k_with_ci(recall_results, selected_query_idxs, k, n_bootstrap=1000, confidence=0.95, to_print_table=True)
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Perform retrieval evaluation per state
# ---------------------------------------------------------------------------------------------------------------------------------
model_id = "e5_large_v2" # embedding model used
query_type = 'q' # original queries only
state_recalls = calculate_recall_per_state_with_ci(recall_results, selected_query_idxs, k, model_id, query_type, n_bootstrap=1000, confidence=0.95)
plot_recall_by_state(state_recalls, k, model_id, query_type, statutes_list)

model_id = "e5_large_v2" # embedding model used
query_type = 'q_claude3sonnet' # original query + Claude 3 Sonnet expansion
state_recalls = calculate_recall_per_state_with_ci(recall_results, selected_query_idxs, k, model_id, query_type, n_bootstrap=1000, confidence=0.95)
plot_recall_by_state(state_recalls, k, model_id, query_type, statutes_list)
# ---------------------------------------------------------------------------------------------------------------------------------


# ---------------------------------------------------------------------------------------------------------------------------------
# Perform QA inference with different contexts (question, golden passage + question, retrieved statutes + question) and calcualte performance
# ---------------------------------------------------------------------------------------------------------------------------------
prompts, ground_truths = generate_prompts(questions, retrieved_statutes=None, statute_lookup=None)

torch.cuda.empty_cache()
gc.collect()

predictions = evaluate_qa_batch(prompts['question_only'], ground_truths, llm, sampling_params, batch_size=32)
results = calculate_metrics_with_ci(predictions, ground_truths)
print(f"Accuracy: {results['accuracy']:.3f}")
print(f"Standard Deviation: {results['std']:.3f}")
print(f"95% CI: ({results['ci_95'][0]:.3f}, {results['ci_95'][1]:.3f})")
print(f"Correct: {results['correct']}/{results['total']}")
print(f"Unknown predictions: {results['unknown']}")

torch.cuda.empty_cache()
gc.collect()

predictions = evaluate_qa_batch(prompts['golden_context'], ground_truths, llm, sampling_params, batch_size=32)
results = calculate_metrics_with_ci(predictions, ground_truths)
print(f"Accuracy: {results['accuracy']:.3f}")
print(f"Standard Deviation: {results['std']:.3f}")
print(f"95% CI: ({results['ci_95'][0]:.3f}, {results['ci_95'][1]:.3f})")
print(f"Correct: {results['correct']}/{results['total']}")
print(f"Unknown predictions: {results['unknown']}")

torch.cuda.empty_cache()
gc.collect()

predictions = evaluate_qa_batch(prompts['retrieved_context'], ground_truths, llm, sampling_params, batch_size=4) # batch size reduced as context is longer
results = calculate_metrics_with_ci(predictions, ground_truths)
print(f"Accuracy: {results['accuracy']:.3f}")
print(f"Standard Deviation: {results['std']:.3f}")
print(f"95% CI: ({results['ci_95'][0]:.3f}, {results['ci_95'][1]:.3f})")
print(f"Correct: {results['correct']}/{results['total']}")
print(f"Unknown predictions: {results['unknown']}")
# ---------------------------------------------------------------------------------------------------------------------------------