In [None]:
# Import necessary packages
import pandas as pd
import torch
import warnings

warnings.filterwarnings("ignore")

from pipeline.pipeline_initializer import initialize_pipeline
from pipeline.prompting_interface import prompt_pipeline

In [None]:
pipe = initialize_pipeline("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16)

In [None]:
def get_prompt(metadata: str, question: str):
    return f"""Metadata:"{metadata}"
Question:"{question}"
Metadata describes a specific dataset that we have access to. Does this dataset answer the question? Begin your response with yes/no."""

In [None]:
bx1 = pd.read_csv("BX1_chicago.csv")  # Adjust
contexts = pd.read_csv("contexts_chicago.csv")  # Adjust

In [None]:
questions_to_get_contexts = []
for i in range(len(bx1)):
    table = bx1["table"][i]
    context = bx1["context"][i]
    filtered_df = contexts[(contexts['table'] == table) & (contexts["context"] == context)].reset_index(drop=True)
    question = filtered_df["context_question"][0]
    questions_to_get_contexts.append(question)
bx1["question_to_get_context"] = questions_to_get_contexts

In [None]:
dfd_questions = bx1["question_to_get_context"].unique()
for dfd_question in dfd_questions[:3]:
    print(f"Processing question: {dfd_question}")
    filtered_bx1 = bx1[bx1["question_to_get_context"] == dfd_question]
    for i in filtered_bx1.index:
        print(f"Processing row: {i}")
        question = filtered_bx1["question"][i]
        init_table = filtered_bx1["table"][i]

        relevant_tables = [init_table]
        for j in filtered_bx1.index:
            print(f"Checking metadata {j}")
            metadata = filtered_bx1["context"][j]
            table = filtered_bx1["table"][j]

            prompt = get_prompt(metadata, question)
            conversation = [{"role": "user", "content": prompt}]
            model_output = prompt_pipeline(pipe, conversation)[-1]["content"]
            if model_output.lower().startswith("yes"):
                relevant_tables.append(table)
        relevant_tables = list(set(relevant_tables))
        relevant_tables.sort()
        bx1.loc[i, "table"] = str(relevant_tables)
        bx1.to_csv("BX1_chicago_corrected.csv", index=False)  # Adjust name