# 1. Preparation

In [1]:
import setproctitle

setproctitle.setproctitle("python")
import os
import time
import sys
import chromadb
import bm25s
import Stemmer
import torch
import numpy as np

sys.path.append("../..")

from tqdm import tqdm
from collections import defaultdict
from transformers import set_seed
from chromadb.api.models.Collection import Collection
from benchmark_generator.context.utils.jsonl import read_jsonl
from benchmark_generator.context.utils.pipeline_initializer import initialize_pipeline
from benchmark_generator.context.utils.prompting_interface import prompt_pipeline


os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
set_seed(42, deterministic=True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# pipe = initialize_pipeline("../../models/llama", torch.bfloat16)
stemmer = Stemmer.Stemmer("english")

# # Specific setting for Llama-3-8B-Instruct for batching
# pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
# pipe.tokenizer.padding_side = "left"

In [3]:
dataset = "adventure"
if dataset == "chicago":
    narrations = read_jsonl("../../pneuma_summarizer/summaries/narrations/chicago_narrations.jsonl")
    rows = read_jsonl("../../pneuma_summarizer/summaries/rows/chicago.jsonl")
    content_benchmark = read_jsonl("../../data_src/benchmarks/content/pneuma_chicago_10K_questions_annotated.jsonl")
    path = "../../data_src/tables/pneuma_chicago_10K"
elif dataset == "public":
    narrations = read_jsonl("../../pneuma_summarizer/summaries/narrations/public_narrations.jsonl")
    rows = read_jsonl("../../pneuma_summarizer/summaries/rows/public.jsonl")
    content_benchmark = read_jsonl("../../data_src/benchmarks/content/pneuma_public_bi_questions_annotated.jsonl")
    path = "../../data_src/tables/pneuma_public_bi"
elif dataset == "chembl":
    narrations = read_jsonl("../../pneuma_summarizer/summaries/narrations/chembl_narrations.jsonl")
    rows = read_jsonl("../../pneuma_summarizer/summaries/rows/chembl.jsonl")
    contexts = read_jsonl("../../data_src/benchmarks/context/chembl/contexts_chembl.jsonl")
    content_benchmark = read_jsonl("../../data_src/benchmarks/content/pneuma_chembl_10K_questions_annotated.jsonl")
    context_benchmark = read_jsonl("../../data_src/benchmarks/context/chembl/bx_chembl.jsonl")
    path = "../../data_src/tables/pneuma_chembl_10K"
elif dataset == "adventure":
    narrations = read_jsonl("../../pneuma_summarizer/summaries/narrations/adventure_narrations.jsonl")
    rows = read_jsonl("../../pneuma_summarizer/summaries/rows/adventure.jsonl")
    contexts = read_jsonl("../../data_src/benchmarks/context/adventure/contexts_adventure.jsonl")
    content_benchmark = read_jsonl("../../data_src/benchmarks/content/pneuma_adventure_works_questions_annotated.jsonl")
    context_benchmark = read_jsonl("../../data_src/benchmarks/context/adventure/bx_adventure.jsonl")
    path = "../../data_src/tables/pneuma_adventure_works"
elif dataset == "fetaqa":
    narrations = read_jsonl("../../pneuma_summarizer/summaries/narrations/fetaqa_narrations.jsonl")
    rows = read_jsonl("../../pneuma_summarizer/summaries/rows/fetaqa.jsonl")
    content_benchmark = read_jsonl("../../data_src/benchmarks/content/pneuma_fetaqa_questions_annotated.jsonl")
    path = "../../data_src/tables/pneuma_fetaqa"

# 2. Indexing

In [4]:
def indexing_keyword(
    stemmer,
    narration_contents: list[dict[str, str]],
    contexts: list[dict[str, str]] = None,
):
    corpus_json = []
    tables = sorted({content["table"] for content in narration_contents})
    for table in tables:
        cols_descriptions = [
            content["summary"]
            for content in narration_contents
            if content["table"] == table
        ]
        for content_idx, content in enumerate(cols_descriptions):
            corpus_json.append(
                {
                    "text": content,
                    "metadata": {"table": f"{table}_SEP_contents_{content_idx}"},
                }
            )

        if contexts is not None:
            filtered_contexts = [
                context["context"] for context in contexts if context["table"] == table
            ]
            for context_idx, context in enumerate(filtered_contexts):
                corpus_json.append(
                    {
                        "text": context,
                        "metadata": {"table": f"{table}_SEP_{context_idx}"},
                    }
                )

    corpus_text = [doc["text"] for doc in corpus_json]
    corpus_tokens = bm25s.tokenize(
        corpus_text, stopwords="en", stemmer=stemmer, show_progress=False
    )

    retriever = bm25s.BM25(corpus=corpus_json)
    retriever.index(corpus_tokens, show_progress=False)
    return retriever

In [5]:
print(f"Processing {dataset} dataset")
start = time.time()
client = chromadb.PersistentClient(
    f"../indices/index-{dataset}-pneuma-summarizer"
)
collection = client.get_collection("benchmark")
retriever = indexing_keyword(stemmer, rows + narrations, contexts)
end = time.time()
print(f"Indexing time: {end-start} seconds")

Processing adventure dataset


DEBUG:bm25s:Building index from IDs objects


Indexing time: 0.563671350479126 seconds


# 3. Benchmarking

In [6]:
def process_nodes_bm25(items):
    # Normalize relevance scores and return the nodes in dict format
    results, scores = items
    scores: list[float] = scores[0]
    max_score = max(scores)
    min_score = min(scores)

    processed_nodes: dict[str, tuple[float, str]] = {}
    for i, node in enumerate(results[0]):
        if min_score == max_score:
            score = 1
        else:
            score = (scores[i] - min_score) / (max_score - min_score)
        processed_nodes[node["metadata"]["table"]] = (score, node["text"])
    return processed_nodes

In [7]:
def process_nodes_vec(items):
    # Normalize relevance scores and return the nodes in dict format
    scores: list[float] = [1 - dist for dist in items["distances"][0]]
    max_score = max(scores)
    min_score = min(scores)

    processed_nodes: dict[str, tuple[float, str]] = {}

    for idx in range(len(items["ids"][0])):
        if min_score == max_score:
            score = 1
        else:
            score = (scores[idx] - min_score) / (max_score - min_score)
        processed_nodes[items["ids"][0][idx]] = (score, items["documents"][0][idx])
    return processed_nodes

In [8]:
def hybrid_retriever(
    bm25_res,
    vec_res,
    k: int,
    question: str,
    use_reranker=False,
):
    processed_nodes_bm25 = process_nodes_bm25(bm25_res)
    processed_nodes_vec = process_nodes_vec(vec_res)

    node_ids = set(list(processed_nodes_bm25.keys()) + list(processed_nodes_vec.keys()))
    all_nodes: list[tuple[str, float, str]] = []
    for node_id in sorted(node_ids):
        bm25_score_doc = processed_nodes_bm25.get(node_id, (0.0, None))
        vec_score_doc = processed_nodes_vec.get(node_id, (0.0, None))

        combined_score = 0.5 * bm25_score_doc[0] + 0.5 * vec_score_doc[0]
        if bm25_score_doc[1] is None:
            doc = vec_score_doc[1]
        else:
            doc = bm25_score_doc[1]

        all_nodes.append((node_id, combined_score, doc))

    sorted_nodes = sorted(all_nodes, key=lambda node: (-node[1], node[0]))[:k]
    if use_reranker:
        sorted_nodes = rerank(sorted_nodes, question)
    return sorted_nodes

In [9]:
def get_relevance_prompt(desc: str, desc_type: str, question: str):
    if desc_type == "content":
        return f"""Given a table with the following columns:
*/
{desc}
*/
and this question:
/*
{question}
*/
Is the table relevant to answer the question? Begin your answer with yes/no."""
    elif desc_type == "context":
        return f"""Given this context describing a table:
*/
{desc}
*/
and this question:
/*
{question}
*/
Is the table relevant to answer the question? Begin your answer with yes/no."""

In [10]:
def rerank(nodes: list[tuple[str, float, str]], question: str):
    tables_relevance = defaultdict(bool)
    relevance_prompts = []
    node_tables = []

    for node in nodes:
        table_name = node[0]
        node_tables.append(table_name)
        if table_name.split("_SEP_")[1].startswith("contents"):
            relevance_prompts.append(
                [
                    {
                        "role": "user",
                        "content": get_relevance_prompt(node[2], "content", question),
                    }
                ]
            )
        else:
            relevance_prompts.append(
                [
                    {
                        "role": "user",
                        "content": get_relevance_prompt(node[2], "context", question),
                    }
                ]
            )

    arguments = prompt_pipeline(
        pipe,
        relevance_prompts,
        batch_size=1,
        context_length=8192,
        max_new_tokens=2,
        top_p=None,
        temperature=None,
    )
    for arg_idx, argument in enumerate(arguments):
        if argument[-1]["content"].lower().startswith("yes"):
            tables_relevance[node_tables[arg_idx]] = True

    new_nodes = [
        (table_name, score, doc)
        for table_name, score, doc in nodes
        if tables_relevance[table_name]
    ] + [
        (table_name, score, doc)
        for table_name, score, doc in nodes
        if not tables_relevance[table_name]
    ]
    return new_nodes

In [11]:
def get_question_key(benchmark_type: str, use_rephrased_questions: bool):
    if benchmark_type == "content":
        if not use_rephrased_questions:
            question_key = "question_from_sql_1"
        else:
            question_key = "question"
    else:
        if not use_rephrased_questions:
            question_key = "question_bx1"
        else:
            question_key = "question_bx2"
    return question_key

In [12]:
def evaluate_benchmark(
    benchmark: list[dict[str, str]],
    benchmark_type: str,
    k: int,
    collection: Collection,
    retriever,
    stemmer,
    n=3,
    use_reranker=False,
    use_rephrased_questions=False,
):
    start = time.time()
    hitrate_sum = 0
    wrong_questions = []

    if use_reranker:
        increased_k = k * n
    else:
        increased_k = k * n

    question_key = get_question_key(benchmark_type, use_rephrased_questions)

    questions = []
    for data in benchmark:
        questions.append(data[question_key])
    embed_questions = np.loadtxt(
        f"../embeddings/embed-{dataset}-questions-{benchmark_type}-{use_rephrased_questions}.txt"
    )
    embed_questions = [embed.tolist() for embed in embed_questions]

    for idx, datum in enumerate(tqdm(benchmark)):
        answer_tables = datum["answer_tables"]
        question_embedding = embed_questions[idx]

        query_tokens = bm25s.tokenize(
            questions[idx], stemmer=stemmer, show_progress=False
        )
        results, scores = retriever.retrieve(
            query_tokens, k=increased_k, show_progress=False
        )
        bm25_res = (results, scores)

        vec_res = collection.query(
            query_embeddings=[question_embedding], n_results=increased_k
        )

        all_nodes = hybrid_retriever(
            bm25_res, vec_res, increased_k, questions[idx], use_reranker
        )
        before = hitrate_sum
        for table, _, _ in all_nodes[:k]:
            table = table.split("_SEP_")[0]
            if table in answer_tables:
                hitrate_sum += 1
                break
        if before == hitrate_sum:
            wrong_questions.append(idx)
        # Checkpoint
        # if idx % 20 == 0:
        #     print(f"Current Hit Rate Sum at index {idx}: {hitrate_sum}")
        #     print(
        #         f"Current wrongly answered questions at index {idx}: {wrong_questions}"
        #     )

    end = time.time()
    print(f"Hit Rate: {round(hitrate_sum/len(benchmark) * 100, 2)}")
    print(f"Benchmarking Time: {end - start} seconds")
    print(f"Wrongly answered questions: {wrong_questions}")

In [13]:
ks = [1]
ns = [5]
use_reranker = False

In [14]:
for k in ks:
    for n in ns:
        print(f"BC1 with k={k} n={n}")
        evaluate_benchmark(
            content_benchmark,
            "content",
            k,
            collection,
            retriever,
            stemmer,
            n=n,
            use_reranker=use_reranker,
        )
        print("=" * 50)

BC1 with k=1 n=5


100%|██████████| 1000/1000 [00:03<00:00, 309.99it/s]

Hit Rate: 56.7
Benchmarking Time: 3.613685131072998 seconds
Wrongly answered questions: [1, 2, 5, 6, 8, 9, 11, 14, 16, 17, 26, 27, 28, 29, 30, 31, 34, 35, 37, 38, 39, 41, 42, 43, 44, 48, 54, 58, 64, 65, 68, 70, 71, 77, 92, 96, 97, 98, 99, 100, 101, 103, 105, 106, 110, 111, 112, 114, 117, 121, 122, 125, 136, 137, 138, 139, 143, 144, 146, 147, 151, 152, 153, 156, 159, 162, 163, 165, 166, 171, 172, 177, 179, 183, 184, 186, 187, 193, 196, 197, 198, 203, 208, 210, 212, 213, 214, 220, 228, 229, 231, 234, 237, 238, 239, 243, 244, 245, 246, 251, 252, 255, 256, 257, 258, 263, 266, 267, 268, 269, 272, 273, 274, 275, 276, 277, 280, 281, 282, 283, 284, 289, 292, 305, 306, 307, 309, 311, 312, 313, 319, 322, 323, 325, 326, 327, 328, 334, 336, 337, 339, 342, 344, 347, 352, 354, 356, 363, 365, 370, 371, 373, 374, 375, 376, 377, 382, 384, 387, 388, 390, 391, 394, 395, 397, 398, 399, 401, 407, 409, 411, 412, 413, 419, 420, 421, 424, 426, 428, 441, 446, 449, 456, 458, 460, 462, 463, 465, 469, 477, 482, 4




In [18]:
for k in ks:
    for n in ns:
        print(f"BC2 with k={k} n={n}")
        evaluate_benchmark(
            content_benchmark,
            "content",
            k,
            collection,
            retriever,
            stemmer,
            n=n,
            use_rephrased_questions=True,
            use_reranker=use_reranker,
        )
    print("=" * 50)

BC2 with k=1 n=5


100%|██████████| 1000/1000 [00:02<00:00, 340.89it/s]

Hit Rate: 54.2
Benchmarking Time: 3.3353850841522217 seconds
Wrongly answered questions: [0, 1, 2, 5, 6, 7, 8, 9, 11, 13, 14, 16, 17, 19, 21, 24, 26, 27, 28, 29, 32, 33, 36, 37, 39, 40, 41, 42, 43, 44, 48, 51, 52, 54, 57, 58, 65, 66, 68, 70, 71, 77, 78, 96, 97, 98, 99, 100, 101, 104, 105, 106, 107, 108, 110, 111, 113, 114, 118, 122, 125, 126, 128, 130, 131, 137, 138, 139, 144, 145, 146, 147, 148, 150, 151, 152, 153, 154, 156, 157, 158, 159, 160, 161, 163, 164, 165, 166, 175, 176, 177, 178, 179, 180, 182, 184, 186, 193, 197, 198, 208, 215, 220, 224, 225, 226, 229, 234, 235, 237, 238, 243, 244, 245, 246, 251, 252, 254, 255, 256, 257, 258, 260, 263, 264, 265, 267, 268, 269, 271, 272, 275, 276, 277, 280, 281, 282, 283, 292, 305, 306, 307, 309, 311, 314, 316, 319, 322, 323, 326, 330, 333, 334, 335, 337, 341, 343, 344, 349, 352, 354, 355, 356, 361, 362, 363, 370, 373, 375, 376, 377, 380, 384, 386, 387, 392, 395, 397, 399, 400, 401, 407, 409, 410, 411, 418, 419, 420, 421, 424, 425, 427, 428, 




In [16]:
# for k in ks:
#     for n in ns:
#         print(f"BX1 with k={k} n={n}")
#         evaluate_benchmark(
#             context_benchmark,
#             "context",
#             k,
#             collection,
#             retriever,
#             stemmer,
#             n=n,
#             use_reranker=use_reranker,
#         )
#         print("=" * 50)

In [17]:
# for k in ks:
#     for n in ns:
#         print(f"BX2 with k={k} n={n}")
#         evaluate_benchmark(
#             context_benchmark,
#             "context",
#             k,
#             collection,
#             retriever,
#             stemmer,
#             n=n,
#             use_rephrased_questions=True,
#             use_reranker=use_reranker,
#         )
#         print("=" * 50)