In [8]:
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import torch
import json
from pathlib import Path
from typing import List, Dict, Tuple
import pickle
from tqdm import tqdm
from groq import Groq
import os
from dotenv import load_dotenv
from rouge_score import rouge_scorer
load_dotenv()
os.environ['GROQ_API_KEY'] = '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'

In [2]:
test_df_2hop = pd.read_csv('../dataset/MetaQA/2-hop/qa_test.txt', sep='\t', header=None, names=['question', 'answer'])
test_df_2hop

Unnamed: 0,question,answer
0,which person directed the movies starred by [J...,Nancy Meyers|Sam Mendes|George Clooney|Ken Kwa...
1,who are movie co-directors of [Delbert Mann],Franco Zeffirelli|Cary Fukunaga|Lewis Mileston...
2,what are the primary languages in the movies d...,German
3,the screenwriter [Mimsy Farmer] co-wrote movie...,Barbet Schroeder
4,the films acted by [Shaun White] were in which...,Sport|Documentary
...,...,...
14867,who are the actors in the movies directed by [...,Gary Sinise|Debra Messing|Ashton Kutcher|Marti...
14868,the films directed by [Larry Charles] were rel...,2003|2009|2012
14869,the director [Frank Oz] co-directed films with...,Neil LaBute|Jim Henson|Bryan Forbes
14870,what were the release dates of [Ritwik Ghatak]...,1958


In [3]:
class MovieKnowledgeBase:
    def __init__(
        self,
        model_name: str = 'sentence-transformers/all-MiniLM-L6-v2',
        gpu_id: int = 0,
        save_dir: str = './knowledge_base'
    ):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
        
        self.device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
        self.model = SentenceTransformer(model_name, device=self.device)
        self.embedding_size = self.model.get_sentence_embedding_dimension()
        
        if torch.cuda.is_available():
            self.gpu_resources = faiss.StandardGpuResources()
            self.cpu_index = faiss.IndexFlatIP(self.embedding_size)
            self.index = faiss.index_cpu_to_gpu(
                self.gpu_resources, gpu_id, self.cpu_index
            )
        else:
            self.index = faiss.IndexFlatIP(self.embedding_size)
        
        self.movie_facts = []
        self.fact_texts = []
        
    def _process_kb_line(self, line: str) -> Dict:
        """Process a single line from the knowledge base file."""
        movie, relation, value = line.strip().split('|')
        return {
            'movie': movie,
            'relation': relation,
            'value': value
        }
    
    def _create_fact_text(self, fact: Dict) -> str:
        """Create a natural language representation of a fact."""

        relation_templates = {
            'directed_by': '{movie} was directed by {value}',
            'written_by': '{movie} was written by {value}',
            'starred_actors': '{value} starred in {movie}',
            'release_year': '{movie} was released in {value}',
            'in_language': '{movie} is in {value}',
            'has_genre': '{movie} is a {value} movie',
            'has_tags': '{movie} has tag: {value}',
            'has_imdb_votes': '{movie} has {value} IMDb votes'
        }
        
        template = relation_templates.get(
            fact['relation'],
            '{movie} {relation} {value}'
        )
        
        return template.format(**fact)
    
    def load_knowledge_base(self, kb_file: str):
        """Load and process the knowledge base file."""
        print("Loading knowledge base...")
        with open(kb_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        for line in lines:
            fact = self._process_kb_line(line)
            self.movie_facts.append(fact)
            fact_text = self._create_fact_text(fact)
            self.fact_texts.append(fact_text)
        
        print("Creating embeddings...")
        embeddings = self.model.encode(
            self.fact_texts,
            convert_to_numpy=True,
            normalize_embeddings=True,
            batch_size=32
        )
        
        self.index.add(embeddings)
        print(f"Added {len(self.fact_texts)} facts to the index")
        
        self.save_knowledge_base()
    
    def save_knowledge_base(self):
        """Save the processed knowledge base and index."""
        print("Saving knowledge base...")
        
        with open(self.save_dir / 'fact_texts.pkl', 'wb') as f:
            pickle.dump(self.fact_texts, f)
        
        with open(self.save_dir / 'movie_facts.pkl', 'wb') as f:
            pickle.dump(self.movie_facts, f)
        
        if torch.cuda.is_available():
            cpu_index = faiss.index_gpu_to_cpu(self.index)
        else:
            cpu_index = self.index
            
        faiss.write_index(cpu_index, str(self.save_dir / 'faiss_index.bin'))
        print("Knowledge base saved successfully")
    
    def load_saved_knowledge_base(self):
        """Load a previously saved knowledge base."""
        print("Loading saved knowledge base...")
        
        with open(self.save_dir / 'fact_texts.pkl', 'rb') as f:
            self.fact_texts = pickle.load(f)
        
        with open(self.save_dir / 'movie_facts.pkl', 'rb') as f:
            self.movie_facts = pickle.load(f)
        
        cpu_index = faiss.read_index(str(self.save_dir / 'faiss_index.bin'))
        
        if torch.cuda.is_available():
            self.index = faiss.index_cpu_to_gpu(
                self.gpu_resources, 0, cpu_index
            )
        else:
            self.index = cpu_index
            
        print("Knowledge base loaded successfully")
    
    def retrieve_context(self, question: str, k: int = 5) -> List[str]:
        """Retrieve relevant context for a question."""

        question_embedding = self.model.encode(
            question,
            convert_to_numpy=True,
            normalize_embeddings=True
        )
        
        scores, indices = self.index.search(
            question_embedding.reshape(1, -1),
            k
        )
        
        relevant_facts = [self.fact_texts[idx] for idx in indices[0]]
        return relevant_facts

In [36]:
def multi_step_retrieve(kb: MovieKnowledgeBase, question: str, k_array: List[int]) -> List[str]:
    """
    Perform multi-step retrieval, accounting for self-retrieval in subsequent steps.
    """
    all_retrieved = set()  # Use set to avoid duplicates
    current_level_queries = [question]  # Start with original question
    
    for step, k in enumerate(k_array):
        next_level_queries = []
        
        # For each query at this step
        for query in current_level_queries:
            # Get embeddings and search
            query_embedding = kb.model.encode(
                query,
                convert_to_numpy=True,
                normalize_embeddings=True
            )
            
            # For subsequent steps (after first), retrieve k+1 and skip first result
            if step == 0:
                num_to_retrieve = k
                start_idx = 0
            else:
                num_to_retrieve = k + 1
                start_idx = 1  # Skip the first result (self-retrieval)
            
            scores, indices = kb.index.search(
                query_embedding.reshape(1, -1),
                num_to_retrieve
            )
            
            # Get retrieved facts for this query, skipping self-retrieval if necessary
            retrieved_facts = [kb.fact_texts[idx] for idx in indices[0][start_idx:]]
            
            # Add to overall set
            all_retrieved.update(retrieved_facts)
            
            # Add to next level queries
            next_level_queries.extend(retrieved_facts)
        
        # Update current level queries for next iteration
        current_level_queries = next_level_queries
        
        if not current_level_queries or step == len(k_array) - 1:
            break
    
    retrieved_list = list(all_retrieved)
    print(f"\nTotal unique contexts retrieved: {len(retrieved_list)}")
    print(f"Expected total (max possible unique): {sum([k_array[0] * (k_array[1] ** i) for i in range(len(k_array))])}")
    return retrieved_list

In [37]:
def process_qa_dataset(
    kb: MovieKnowledgeBase,
    qa_df: pd.DataFrame,
    llm_call_function,
    k_array: List[int] = [3, 2]  # Example: First get top 3, then top 2 for each
) -> List[str]:
    """
    Process QA dataset with multi-step retrieval knowledge base context.
    
    Args:
        kb: MovieKnowledgeBase instance
        qa_df: DataFrame with 'question' and 'answer' columns
        llm_call_function: Function that makes API call to LLM
        k_array: Array specifying number of retrievals at each step
    
    Returns:
        List of LLM responses
    """
    responses = []
    
    for idx, row in qa_df.iterrows():
        question = row['question']
        ground_truth = row['answer']
        
        # Debug print for this question
        print(f"\nProcessing Question: {question}")
        
        # Perform multi-step retrieval
        retrieved_contexts = multi_step_retrieve(kb, question, k_array)
        
        # Debug prints for retrieval results
        print("\nRetrieval breakdown:")
        print(f"Expected total contexts: {sum([k_array[0] * (k_array[1] ** i) for i in range(len(k_array))])}")
        print(f"Actual unique contexts retrieved: {len(retrieved_contexts)}")
        print("\nRetrieved contexts:")
        for i, context in enumerate(retrieved_contexts, 1):
            print(f"{i}. {context}")
        
        context_text = "\n".join(retrieved_contexts)
        
        prompt = f"""Based on the following context, please answer the question.
        
Context:
{context_text}

Question: {question}

Answer: """
        
        response = llm_call_function(prompt)
        print('\nResults:')
        print('Question: ', question)
        print('Ground Truth: ', ground_truth)
        print('Response: ', response)
        print(f'Total unique contexts used: {len(retrieved_contexts)}')
        print('-' * 50)
        
        # Write to a txt file
        with open('qa_results.txt', 'a') as f:
            f.write(f'Question: {question}\n')
            f.write(f'Ground Truth: {ground_truth}\n')
            f.write('Retrieved Contexts:\n')
            for i, context in enumerate(retrieved_contexts, 1):
                f.write(f"{i}. {context}\n")
            f.write(f'Response: {response}\n')
            f.write(f'Total unique contexts: {len(retrieved_contexts)}\n')
            f.write('-' * 50 + '\n\n')
        
        responses.append(response)
    
    return responses

In [None]:
kb = MovieKnowledgeBase(save_dir='./movie_kb')

In [30]:
kb.load_knowledge_base('../dataset/MetaQA/kb.txt')

Loading knowledge base...
Creating embeddings...
Added 134741 facts to the index
Saving knowledge base...
Knowledge base saved successfully


In [31]:
def get_rouge_score(hypothesis: str, reference: str) -> Dict:
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    scores = scorer.score(reference, hypothesis)
    return {
        'rouge1_f1': scores['rouge1'].fmeasure,
        'rouge1_precision': scores['rouge1'].precision,
        'rouge1_recall': scores['rouge1'].recall,
        'rougeL_f1': scores['rougeL'].fmeasure,
        'rougeL_precision': scores['rougeL'].precision,
        'rougeL_recall': scores['rougeL'].recall
    }

In [None]:
client = Groq(
    api_key=os.environ.get("GROQ_API_KEY"),
)

def lm_call(prompt: str) -> str:
    chat_completion = client.chat.completions.create(
        messages=[
        {
            "role": "system",
            "content": """
Rules:
- Answer ONLY using the information provided in the context
- Provide ONLY the answer, with no explanations or additional text
- Keep answers concise and to the point
- If there are multiple answers, output them in a pipe-separated list (e.g. "answer1|answer2")
"""
        },
        {
            "role": "user",
            "content": prompt,
        }
    ],
        temperature=0.7,
        model="llama-3.2-3b-preview",
    )

    return chat_completion.choices[0].message.content

responses = process_qa_dataset(kb, test_df_2hop, lm_call)

In [None]:
for q, a, r in zip(test_df_2hop['question'], test_df_2hop['answer'], responses):
    print(f"\nQuestion: {q}")
    print(f"Ground Truth: {a}")
    print(f"LLM Response: {r}")