# 1. Preparation

In [1]:
import os
import setproctitle
import json
import chromadb
import pandas as pd
import duckdb
import bm25s
import Stemmer

from transformers import set_seed
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
setproctitle.setproctitle("python")
set_seed(42, deterministic=True)

In [2]:
def read_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data.append(json.loads(line.strip()))
    return data
model = SentenceTransformer('dunzhang/stella_en_1.5B_v5', trust_remote_code=True)
stemmer = Stemmer.Stemmer("english")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
dataset = "public"
if dataset == "chicago":
    col_descriptions = pd.read_csv("../pneuma_summarizer_narrations/chicago_cols.csv")
    jsonl_data = read_jsonl("benchmarks/content/pneuma_chicago_10K_questions_annotated.jsonl")
    # context_benchmark = pd.read_csv("benchmarks/BX1_chicago_corrected.csv")
    contexts = pd.read_csv("contexts/contexts_chicago.csv")
    path = "datasets/pneuma_chicago_10K"
elif dataset == "public":
    col_descriptions = pd.read_csv("narrations/public_cols.csv")
    jsonl_data = read_jsonl("benchmarks/content/pneuma_public_bi_questions_annotated.jsonl")
    # context_benchmark = pd.read_csv("benchmarks/BX1_public_corrected.csv")
    contexts = pd.read_csv("contexts/contexts_public.csv")
    path = "datasets/public_bi_benchmark"
elif dataset == "chembl":
    col_descriptions = pd.read_csv("narrations/chembl_cols.csv")
    jsonl_data = read_jsonl("benchmarks/content/pneuma_chembl_10K_questions_annotated.jsonl")
    # context_benchmark = pd.read_csv("")
    contexts = pd.read_csv("contexts/contexts_chembl.csv")
    path = "datasets/pneuma_chembl_10K"

con = duckdb.connect()
tables = [file[:-4] for file in sorted(os.listdir(path)) if file.endswith(".csv")]
tables.sort()

import shutil
try:
    shutil.rmtree(f"experiment-{dataset}")
except:
    pass
client = chromadb.PersistentClient(f"experiment-{dataset}")
collection = client.create_collection(name="benchmark", metadata={"hnsw:space": "cosine"})

In [None]:
def get_processed_df(path: str, table: str) -> pd.DataFrame:
    df = con.sql(f"from '{path}/{table}.csv'").to_df().drop_duplicates().reset_index(drop=True)
    for col in df.columns:
        if pd.api.types.is_datetime64_any_dtype(df[col]):
            df[col] = pd.to_datetime(df[col], errors='coerce')
            df[col] = df[col].apply(
                lambda x: x.strftime('%B ') + str(x.day).lstrip('0') + x.strftime(', %Y %H:%M:%S.%f')[:-3] if pd.notnull(x) else 'NaT'
            )
    return df

In [None]:
def find_largest_smaller_or_equal(tokens_list: list[int], max_tokens: int):
    for idx in range(len(tokens_list) - 1, -1, -1):
        if tokens_list[idx] <= max_tokens:
            return idx
    return -1

In [None]:
def get_relevancy_prompt(dataset: str, question: str):
    return f"""Given this dataset:
*/
{dataset}
*/
and this question:
/*
{question}
*/
Is the dataset relevant to answer the question? Begin your answer with yes/no."""

In [None]:
from utils.pipeline_initializer import initialize_pipeline
from utils.prompting_interface import prompt_pipeline
import torch

pipe = initialize_pipeline("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16)

# 2. Indexing

In [None]:
import math
def indexing_vector(model, collection, tables: list[str], contexts: pd.DataFrame):
    embeddings = []
    metadatas = []
    ids = []

    for table in tqdm(tables):
        df = con.sql(f"select * from '{path}/{table}.csv'").to_df().drop_duplicates().reset_index(drop=True)
        columns = " | ".join(df.columns)
        col_embedding = model.encode(columns)

        embeddings.append(col_embedding.tolist())
        metadatas.append({"table": f"{table}_SEP_columns"})
        ids.append(f"{table}_columns")

        table_contexts = contexts[contexts["table"].str.endswith(table)].reset_index(drop=True)
        for ctx_idx, context in enumerate(table_contexts["context"]):
            embedding = model.encode(context)
            embeddings.append(embedding.tolist())
            metadatas.append({"table": f"{table}_SEP_{ctx_idx}"})
            ids.append(f"{table}_context_{ctx_idx}")

    iterations = math.ceil(len(embeddings) / collection_max_batch_size)

    for i in tqdm(range(iterations), "inserting into collections"):
        lower_bound = (i * collection_max_batch_size) % len(embeddings)
        upper_bound = min(len(embeddings), (lower_bound + collection_max_batch_size) % len(embeddings))

        if upper_bound > lower_bound:
            embeddings_batch = embeddings[lower_bound:upper_bound]
            metadatas_batch = metadatas[lower_bound:upper_bound]
            ids_batch = ids[lower_bound:upper_bound]
        else:
            embeddings_batch = embeddings[lower_bound:] + embeddings[:upper_bound]
            metadatas_batch = metadatas[lower_bound:] + metadatas[:upper_bound]
            ids_batch = ids[lower_bound:] + ids[:upper_bound]

        collection.add(
            embeddings=embeddings_batch,
            metadatas=metadatas_batch,
            ids=ids_batch
        )

def indexing_keyword(stemmer, col_descriptions: str, tables: list[str], contexts: pd.DataFrame):
    corpus_json = []

    for table in tqdm(tables):
        filtered_col_descriptions = col_descriptions[col_descriptions["table"] == table].reset_index(drop=True)
        col_description = filtered_col_descriptions["description"].to_list()
        col_summary = " | ".join(col_description)

        corpus_json.append({"text": col_summary, "metadata": {"table": f"{table}_SEP_columns"}})

        table_contexts = contexts[contexts["table"].str.endswith(table)].reset_index(drop=True)
        for context_idx, context in enumerate(table_contexts["context"]):
            corpus_json.append({"text": context, "metadata": {"table": f"{table}_SEP_{context_idx}"}})

    corpus_text = [doc["text"] for doc in corpus_json]
    corpus_tokens = bm25s.tokenize(corpus_text, stopwords="en", stemmer=stemmer, show_progress=False)

    retriever = bm25s.BM25(corpus=corpus_json)
    retriever.index(corpus_tokens, show_progress=False)

    return retriever

In [None]:
indexing_vector(model, collection, tables, contexts)
retriever = indexing_keyword(stemmer, col_descriptions, tables, contexts)

# 3. Benchmarking

In [None]:
def is_table_relevant(
    path: str,
    table: str,
    max_tokens: int,
    question: str
) -> bool:
    df = get_processed_df(path, table)
    columns = "col: " + " | ".join(df.columns)
    rows = [""] * len(df)

    required_tokens = [len(pipe.tokenizer.tokenize(columns))] * len(df)

    for row_idx, row in df.iterrows():
        rows[row_idx] = f"row {row_idx+1}: " + " | ".join(row.astype(str))
        required_tokens[row_idx] += len(pipe.tokenizer.tokenize(rows[row_idx]))
        if row_idx > 0:
            required_tokens[row_idx] += required_tokens[row_idx-1]

    last_processed_idx = 0
    while last_processed_idx < len(required_tokens):

        to_process_idx = find_largest_smaller_or_equal(required_tokens[last_processed_idx:], max_tokens)
        if to_process_idx == -1:
            return False

        to_process_idx += last_processed_idx
        prompt = get_relevancy_prompt(
            columns + "\n" + "\n".join(rows[last_processed_idx:to_process_idx+1]),
            question
        )

        answer: str = prompt_pipeline(
            pipe, [{"role": "user", "content": prompt}], context_length=8192, max_new_tokens=3, top_p=None, temperature=None
        )[-1]["content"]

        if answer.lower().startswith("yes"):
            return True

        last_processed_idx = to_process_idx + 1
        for i in range(len(required_tokens[last_processed_idx:])):
            required_tokens[last_processed_idx+i] -= (required_tokens[last_processed_idx-1] - len(pipe.tokenizer.tokenize(columns)))
    return False

In [None]:
from collections import defaultdict
def rerank(nodes: list[tuple], question: str):
    tables_relevancy = defaultdict(bool)

    for node in nodes:
        table_name = node[0]
        if is_table_relevant(path, table_name, 7000, question):
            tables_relevancy[table_name] = True
    new_nodes = [(table_name, score) for table_name, score in nodes if tables_relevancy[table_name]] + [(table_name, score) for table_name, score in nodes if not tables_relevancy[table_name]]
    return new_nodes

In [None]:
def process_nodes_bm25(results, scores):
    # Normalize relevance scores and return the nodes in dict format.
    scores: list[float] = scores[0]
    max_score = max(scores)
    min_score = min(scores)

    processed_nodes = {}
    for i, node in enumerate(results[0]):
        if min_score == max_score:
            score = 1
        else:
            score = (scores[i] - min_score) / (max_score - min_score)
        processed_nodes[node["metadata"]["table"]] = score
    return processed_nodes

In [None]:
def process_nodes_vec(items):
    # Normalize relevance scores and return the nodes in dict format.
    scores: list[float] = [1 - dist for dist in items["distances"][0]]
    max_score = max(scores)
    min_score = min(scores)

    processed_nodes = {}
    for idx, table in enumerate(items['metadatas'][0]):
        if min_score == max_score:
            score = 1
        else:
            score = (scores[idx] - min_score) / (max_score - min_score)
        processed_nodes[table["table"]] = score
    return processed_nodes

In [None]:
def process(bm25_res, bm25_sc, vec_res, k: int, question: str, ranking=False):
    processed_nodes_bm25 = process_nodes_bm25(bm25_res, bm25_sc)
    processed_nodes_vec: dict = process_nodes_vec(vec_res)

    node_ids = set(list(processed_nodes_bm25.keys()) + list(processed_nodes_vec.keys()))
    all_nodes = []
    for node_id in node_ids:
        try:
            bm25_score = processed_nodes_bm25.get(node_id, 0.0)
        except:
            bm25_score = 0.0
        try:
            cosine_score = processed_nodes_vec.get(node_id, 0.0)
        except:
            cosine_score = 0.0
        combined_score = 0.5 * bm25_score + 0.5 * cosine_score
        all_nodes.append((node_id, combined_score))
    
    sorted_nodes = sorted(all_nodes, key=lambda node: (-node[1], node[0]))[:k]
    if ranking:
        reranked_nodes = rerank(sorted_nodes, question)
        return reranked_nodes
    return sorted_nodes

In [None]:
def benchmark(jsonl_data, k, model, collection, retriever, stemmer, ranking=False, rephrase=False):
    hitrate_sum = 0
    wrong_list = []

    if ranking:
        increased_k = k * 3
    else:
        increased_k = k
    
    if not rephrase:
        question_key = "question_from_sql_1"
    else:
        question_key = "question"

    for idx, datum in enumerate(tqdm(jsonl_data)):
        answer_tables = datum["answer_tables"]
        question_embedding = model.encode(datum[question_key]).tolist()

        query_tokens = bm25s.tokenize(datum[question_key], stemmer=stemmer, show_progress=False)
        results, scores = retriever.retrieve(query_tokens, k=increased_k, show_progress=False)

        vec_res = collection.query(
            query_embeddings=[question_embedding],
            n_results=increased_k
        )

        all_nodes = process(results, scores, vec_res, increased_k, datum[question_key], ranking)
        before = hitrate_sum
        for table, _ in all_nodes[:k]:
            table = table.split("_SEP_")[0]
            if table in answer_tables:
                hitrate_sum += 1
                break
        if before == hitrate_sum:
            wrong_list.append(idx)
        # Checkpoint
        # if idx % 25 == 0:
        #     print(f"Current Hit Rate Sum: {hitrate_sum}")
    print(f"Final Hit Rate Sum: {hitrate_sum}")
    print(f"Hit Rate: {hitrate_sum/len(jsonl_data)}")
    print(f"Wrong List: {wrong_list}")
    return hitrate_sum

In [None]:
ks = [1, 10, 30, 50]

In [None]:
for k in ks:
    print(f"K (no rerank): {k}")
    result = benchmark(jsonl_data, k, model, collection, retriever, stemmer, rephrase=True)

In [None]:
for k in ks:
    print(f"K (no rerank): {k}")
    result = benchmark(jsonl_data, k, model, collection, retriever, stemmer)

In [None]:
def benchmark_context(benchmark: pd.DataFrame, k, model, collection, retriever, stemmer, ranking=False):
    hitrate_sum = 0
    wrong_list = []

    if ranking:
        increased_k = k * 3
    else:
        increased_k = k
    
    for row_idx, row in tqdm(benchmark.iterrows()):
        answer_tables = row["answer_tables"]
        question_embedding = model.encode(row["question"]).tolist()

        query_tokens = bm25s.tokenize(row["question"], stemmer=stemmer, show_progress=False)
        results, scores = retriever.retrieve(query_tokens, k=increased_k, show_progress=False)

        vec_res = collection.query(
            query_embeddings=[question_embedding],
            n_results=increased_k
        )

        all_nodes = process(results, scores, vec_res, increased_k, row["question"], ranking)
        before = hitrate_sum
        for table, _ in all_nodes[:k]:
            if table in answer_tables:
                hitrate_sum += 1
                break
        if before == hitrate_sum:
            wrong_list.append(row_idx)
        # Checkpoint
        # if idx % 25 == 0:
        #     print(f"Current Hit Rate Sum: {hitrate_sum}")
    print(f"Final Hit Rate Sum: {hitrate_sum}")
    print(f"Hit Rate: {hitrate_sum/len(benchmark)}")
    print(f"Wrong List: {wrong_list}")
    return hitrate_sum

In [None]:
# for k in ks:
#     print(f"K (no rerank): {k}")
#     result = benchmark_context(context_benchmark, k, model, collection, retriever, stemmer)