## 2024.11.01 - Generative AI | Homework Assignment 2

In this exercise, you will implement the food-for-thought-generating prompts strategy presented by Prof. Gottlob in his lecture. In the food-for-thought-generating prompt strategy we want, we first want to ask an LLM to provide questions that help solve the original query.


*Relevant-Slide:*
<img src="./food-for-thought.png"/>

Passages where you should add your implementation are marked with:

\# YOUR CODE HERE</br>
raise NotImplementedError()

We have provided a simple Hugging Face wrapper so you can test your implementation against an actual llm. To do so, you will need to provide your [personal access token](https://huggingface.co/docs/hub/security-tokens).
The grading is based on a mocked client, similar to visible test cases. Therefore, when you are confident with your implementation, delete your personal access token before submitting the assignment.

In [1]:
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
import requests
import re
from abc import ABC, abstractmethod

In [2]:
@dataclass
class Message:
    role: str
    content: str

class LLMClientInterface(ABC):
    @abstractmethod
    def get_completion(self, messages: List[Message]) -> str:
        pass

class HuggingFaceClient(LLMClientInterface):
    def __init__(self, api_token: str, model_name: str = "HuggingFaceH4/zephyr-7b-alpha"):
        self.api_token = api_token
        self.model_name = model_name
        self.api_url = f"https://api-inference.huggingface.co/models/{model_name}/v1/chat/completions"
        self.headers = { "Authorization": f"Bearer {api_token}", "Content-Type": "application/json" }

    def get_completion(self, messages: List[Message]) -> str:
        messages_dict = [
            {"role": message.role, "content": message.content} for message in messages
        ]
        payload = { "model": self.model_name, "messages": messages_dict, "max_tokens": 500 }
        response = requests.post(self.api_url, headers=self.headers, json=payload)
        if response.status_code != 200:
            raise Exception(f"API request failed with status {response.status_code}: {response.text}")
        return response.json()["choices"][0]["message"]["content"]            

In [8]:
class FoodForThoughtPrompting:
    """
    Implements the food-for-thought prompting strategy for enhanced LLM interactions.
    This strategy breaks down a complex query into sub-questions, gets their answers,
    and synthesizes a final response using this additional context.
    """
    def __init__(self, llm_client: LLMClientInterface):
        self.llm_client = llm_client
        self.system_prompt = "You strive to answer the question as truthfully, precisely and concisely as possible."
    
    def generate_questions(self, query: str) -> List[str]:
        """
        Generates three relevant sub-questions that help break down the main query.
        
        Implementation Guidelines:
        1. Create a prompt that asks the LLM to generate 3 questions whose answers would help solve 
           the original query.

        """

        init_query = f"Which are the three questions Q1,Q2 and Q3 whose answers would be most helpful to solve the following problem: {query}"

        """
        2. Get the LLM's response using self.llm_client.get_completion(messages), where messages is a list of a user prompt and maybe a system prompt (optional).
            e.g. messages = [
                Message(role="system", content="Answer truthfully and concisely [...]"),
                Message(role="user", content="Which are the three questions Q1, [...]")
            ]

        """
        messages = [
        Message(role="system", content="You are a helpful assistant. Please provide three clear and concise questions labeled as Q1, Q2, and Q3."),
        Message(role="user", content=init_query)
        ]

        response = self.llm_client.get_completion(messages)
        """

        4. Extract the three questions from the response (e.g. using regex or simple string splitting).
           You might need to modify your user prompt if you don't get consistently formatted completions.
        
        Args:
            query (str): The original query to be broken down
            
        Returns:
            List[str]: List of exactly three questions
        """

        def _extract_questions(repsponse): 
            pattern = r'Q(\d+):\s*(.*?(?=Q\d+:|$))'
            matches = re.findall(pattern, response, re.DOTALL)
            questions = [match[1].strip() for match in matches]
            return questions[:3]  # Ensure we return exactly 3 questions

        
        questions = _extract_questions(response)
    
        # If we didn't get exactly 3 questions, we might want to try again or handle the error
        if len(questions) != 3:
            raise ValueError(f"Expected 3 questions, but got {len(questions)}")
        
        return questions

    def get_answers(self, questions: List[str]) -> List[str]:
        """
        Generate answers for each of the generated sub-questions.

        Args:
            questions (List[str]): List of questions to be answered
            
        Returns:
            List[str]: List of answers corresponding to each question generated by the LLM.
        """
        answers = []

        for question in questions:
            # Create a message for each question
            messages = [
                Message(role="system", content="You are a helpful assistant. Please provide a clear and concise answer to the following question."),
                Message(role="user", content=question)
            ]

        # Get the LLM's response
            response = self.llm_client.get_completion(messages)

        # Add the response to our list of answers
        answers.append(response.strip())

        return answers
    
    def get_final_answer(self, query: str, qa_pairs: List[Tuple[str, str]]) -> str:
        """
        Synthesizes a final answer using the original query and Q&A context.
        
        Implementation Guidelines:
        1. Create a messages list optionally starting with the system prompt Message.
           It might help to tell the LLM that the whole conversation should be incorporated into its final response
           in the system prompt.
        
        2. For each (question, answer) pair in qa_pairs:
           - Add two Message objects:
             * First with role="user" containing the question
             * Second with role="assistant" containing the answer
        
        3. Add a final Message with role="user" containing the original query
        
        4. Get and return the final response using self.llm_client.get_completion()
        
        Args:
            query (str): The original query
            qa_pairs (List[Tuple[str, str]]): List of (question, answer) tuples providing context
            
        Returns:
            str: Synthesized final answer incorporating the context
        """
        messages = [
            Message(role="system", content="You are a helpful assistant. Your task is to provide a comprehensive answer to the original query based on the context provided in the following Q&A pairs. Incorporate all relevant information from the Q&A pairs into your final response.")
        ]    

        for question, answer in qa_pairs:
            messages.extend([
            Message(role="user", content=question),
            Message(role="assistant", content=answer)
        ])

        messages.append(Message(role="user", content=f"Now, based on all the information provided above, please answer this original query: {query}"))

        # Get the final response from the LLM
        final_answer = self.llm_client.get_completion(messages)

        return final_answer.strip()
        
    def __call__(self, query: str) -> str:
        try:
            questions = self.generate_questions(query)            
            answers = self.get_answers(questions)            
            qa_pairs = list(zip(questions, answers))            
            final_answer = self.get_final_answer(query, qa_pairs)
            return final_answer
        except Exception as e:
            return f"Error processing query: {str(e)}"

In [9]:
# To test your implementation against a real LLM, provide your hugging face access token below.
# Before submitting it, please remove your token again. Grading will be based on a mocked implementation
# If the response seems cut off - no worries; its a known issue
# https://huggingface.co/HuggingFaceH4/zephyr-7b-beta/discussions/52
huggingface_client = HuggingFaceClient("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")
fft_prompting = FoodForThoughtPrompting(huggingface_client)
fft_prompting("How similar are a pen and a marker, on a score ranging from 1 (lowest) to 10 (highest)?")

"Based on their physical design and structure, pens and markers have many similarities but there are also notable differences. On a scale of 1 to 10, with 1 being the least similar and 10 being the most similar, pens and markers could be rated as a 7 in terms of similarity.\n\nBoth pens and markers are cylindrical or hexagonal in shape, have a dispensing mechanism, and come in a variety of barrel colors. However, markers typically have a wider barrel, a larger grip, and often feature a removable cap or lid, which isn't necessarily the case for pens.\n\nAdditionally, markers typically have chisel or bullet-shaped tips, while pens can have fine-point or broad tips. Pens do not have ink that can be refilled, whereas markers often do.\n\nSo while pens and markers share some similarities, they differ in terms of their intended use, design elements, and functional components, bringing their similarity score to around 7."

In [13]:
mock_responses = [
        "You should consider a balanced meal with protein and vegetables.",
        "Lunch is typically eaten between 12:00 and 2:00 PM.",
        "A reasonable budget is $10-15 for a healthy lunch."
    ]
    
client = MockLLMClient(mock_responses)
fft = FoodForThoughtPrompting(client)
    
questions = [
        "What makes a healthy lunch?",
        "When should I eat lunch?",
        "How much should I spend on lunch?"
]
    
answers = fft.get_answers(questions)
print(answers)

['You should consider a balanced meal with protein and vegetables.']


In [10]:
class MockLLMClient(LLMClientInterface):
    def __init__(self, responses):
        self.responses = responses
        self.call_count = 0
        self.messages_history = []
    
    def get_completion(self, messages: List[Message]) -> str:
        self.messages_history.append(messages)
        response = self.responses[self.call_count]
        self.call_count += 1
        return response

def test_generate_questions_success():
    mock_response = """Q1: What are your dietary restrictions or preferences?
Q2: What time of day will you be eating lunch?
Q3: What is your budget for lunch?"""
    
    client = MockLLMClient([mock_response])
    fft = FoodForThoughtPrompting(client)
    questions = fft.generate_questions("What is the best lunch to eat?")
    
    assert len(questions) == 3
    assert all(isinstance(q, str) for q in questions)
    assert "dietary restrictions" in questions[0]
    assert "time of day" in questions[1]
    assert "budget" in questions[2]

def test_get_answers():
    mock_responses = [
        "You should consider a balanced meal with protein and vegetables.",
        "Lunch is typically eaten between 12:00 and 2:00 PM.",
        "A reasonable budget is $10-15 for a healthy lunch."
    ]
    
    client = MockLLMClient(mock_responses)
    fft = FoodForThoughtPrompting(client)
    
    questions = [
        "What makes a healthy lunch?",
        "When should I eat lunch?",
        "How much should I spend on lunch?"
    ]
    
    answers = fft.get_answers(questions)
    assert len(answers) == 3
    assert all(isinstance(a, str) for a in answers)
    assert "balanced meal" in answers[0]
    assert "12:00" in answers[1]
    assert "$10-15" in answers[2]

def test_process_query_end_to_end():
    mock_responses = [
        # Questions generation response
        """Q1: What are your dietary restrictions?
Q2: What time of day will you eat?
Q3: What is your budget?""",
        # Three answers
        "No dietary restrictions.",
        "Lunchtime at 12:30 PM.",
        "Budget is $15.",
        # Final answer
        "Based on the information provided, I recommend a balanced meal..."
    ]
    
    client = MockLLMClient(mock_responses)
    fft = FoodForThoughtPrompting(client)
    
    result = fft("What should I eat for lunch?")
    assert isinstance(result, str)
    assert "balanced meal" in result
    assert client.call_count == 5  # 1 for questions + 3 for answers + 1 for final

def test_error_handling():
    client = MockLLMClient([])  # Empty responses will cause index error
    fft = FoodForThoughtPrompting(client)
    
    result = fft("What should I eat?")
    assert "Error processing query" in result


test_generate_questions_success()
test_get_answers()
test_process_query_end_to_end()
test_error_handling()

print("If you see this message, you are good to go ✅")

AssertionError: 