# Create a Custom Controller

The purposes of this notebook is to build the custom controller and filter used by the Financial Analyst Agent in notebook `4_financial_analyst_agent`.

In council's [LLMController](https://github.com/chain-ml/council/blob/main/council/controllers/llm_controller.py), the `_execute` method uses an LLM to determine which chains to execute. In the custom controller, the `_execute` method is responsible for both selecting which chains to execute but to also to reformulate the user query for each chain in such a way that improves the chains' results. This is accomplished by using an LLM to reformulate the user query based on the conversational history between the user and the query. For example, imagine asking the agent for the financial performance of Microsoft in one query, and then following up with "what is the stock price?". Routing such a query to a chain such as Google search would not search for user's intent, which is finding out the stock price of Microsoft.

In the custom filter, we use an LLM to write a comprehensive response to the user query by aggregating the results from all the successfully executed chains.

## Import the required modules

In [None]:
import logging
from typing import List, Tuple
from string import Template

from council.chains import Chain
from council.contexts import AgentContext, ScoredChatMessage, ChatMessage, ChatMessageKind, LLMContext
from council.controllers import LLMController, ExecutionUnit
from council.filters import FilterBase
from council.llm import LLMMessage, LLMBase

import constants

logger = logging.getLogger(__name__)

## Specifying constants used in the notebook

In [None]:
COMPANY_NAME = "Microsoft"

## Prompts for Controller

The `controller_system_prompt` and `controller_prompt` below are the prompts used for the LLM by controller. The `controller_prompt` performs the two-step process of understanding and classifying a user's query in a conversational context.  
The query reformulation and chain selection tasks are performed in a single LLM call that requests the model to perform the tasks sequentially; referred to as Subtask 1 and Subtask 2, respectively.

Subtask 1 focuses on understanding and improving the current user query by taking into consideration the context of the conversation, which includes the conversational history. The task is to refine the user query if needed by using context from the conversation history. For instance, if a user asked "Who is the CEO of OpenAI?" and then followed up with "How old is he?", the task would be to update the second query to "How old is Sam Altman?". However, the instructions also specify that if the query does not need updating, or if there's no relevant conversational history, the query should be left unchanged.

Subtask 2 uses the updated query from Subtask 1 to identify the intent of the user, similar to council's `LLMController`. The LLM is presented with a list of categories (each with a name and a description), and is instructed to score each category out of 10 based on how well its description aligns with the intent of the updated query. If no category is relevant, the LLM should respond with 'unknown'.

The prompt finished by describing the exact format to return the results of Subtask 1 and Subtask, to later be parsed by the controller.

Note the occurence of words with a `$` prefix, such as `$user_query`. When building the controller's `_execute` method, we will load this prompt as a `Template` from the python string module, and words with the prefix denote variables that will be substituted.

In [None]:
controller_system_prompt = "You are an assistant responsible to identify the intent of the user."

controller_prompt = """
Use the latest user query and the conversational history to identify the intent of the user. 
Break this task down into 2 subtasks. First perform subtask 1 and then subtask 2.

Context for subtask 1:
Conversational history:
$conversational_history

User query: $user_query

Instructions for subtask 1:
# Use the historical conversation to update the user query to better answer the user question
# If the query does not need to be updated, do not update the query
# If there is no conversational history, do not update the query
# If the conversational history is not relevant to the query, do not update the query
# See the below examples for how to update the user query
************
Example 1:
Conversational History:
User: Who is the CEO of OpenAI?
Assistant: Sam Altman

User Query: How old is he?

Updated Query: How old is Sam Altman?
************
Example 2:
Conversational History:
User: Who is the CEO of OpenAI?
Assistant: Sam Altman

User Query: What is the price of Bitcoin?

Updated Query: What is the price of Bitcoin?
************

Context for subtask 2: 
Categories are given as a name and a category (name: {name}, description: {description}):
$answer_choices

Instructions for subtask 2:
# Use the updated query to identify the intent of the user
# score categories out of 10 using there description
# For each category, you will answer with {name};{score};{short justification}"
# The updated query should be identical for each category
# Each response is provided on a new line
# When no category is relevant, you will answer exactly with 'unknown'
                                
Your response should always be formatted like this:
Subtask 1: {updated_query}
---
Subtask 2:
{subtask2_results}
"""

## Controller _execute method

One of the variables that will be substituted into the above prompt is the `conversational_history`. To do so, we first need to create a string representation of the conversational history from the context. 

The `build_chat_history` method below is responsible for iterating through the messages in chathistory in the context and building a string from messages that are either from the user or from the agent. The `max_history_len` argument, set to 4 by default, denotes the maximum number of messages to include in the conversational history. This was selected as the default with the assumption that more recent conversational history will be more relevant to determining the intent of the user's last message.

In [None]:
def build_chat_history(context: AgentContext, max_history_len: int = 4) -> str:
    """Format the chat history into a string that can be added to the prompt for the query reformulation model."""
    chat_history = ""
    # Remove the user's most recent message from the chat history
    message_history = list(context.chat_history.messages[:-1])
    # Return no history if there are less than 2 messages
    if len(message_history) < 1:
        return "No conversational history"

    for msg in message_history[-max_history_len:]:
        if msg.is_of_kind(ChatMessageKind.User):
            chat_history += f"User: {msg.message}\n"
        if msg.is_of_kind(ChatMessageKind.Agent):
            chat_history += f"Assistant: {msg.message}\n"

    return chat_history

After the prompt is created, sent to the LLM model and a response is received, we use the `parse_response` method below to extract the results of the subtasks. It returns a tuple of two strings: the first for the reformulated query and the second for the chains selected for execution.

In [None]:
def parse_response(response: str) -> Tuple[str, str]:
    """Function to separate reformulated query and chain selection from LLM response."""
    query_reformulation_response = (
        response.split("---")[0].replace("Subtask 1:", "").replace("Subtask 1: ", "").strip()
    )
    chain_selection_response = response.split("---")[1].replace("Subtask 2:", "").replace("Subtask 2: ", "").strip()

    return query_reformulation_response, chain_selection_response

The code below processes the selected chains from the LLM response by extracting their name and score, filtering for chains where the score exceeds the `response_threshold` (the minimum score required for the agent to execute the chain) and ordering by descending score. The `parse_response` function is a method of council's LLMController class for parsing the selected chains (subtask 2 for the custom controller). Since our custom controller will inherit from LLMController, it will not need to be redefined.

An `ExecutionUnit` specifies a chain to be executed, and an initial state for the execution. We store the reformulated query in the `initial_state` variable so that each chain will have access to it during execution. 

We are also keeping a list of the chains that will be executed in the current iteration in the `chains_current_iteration` variable. This will be used later in the `select_responses` method.

In [None]:
# Separate reformulated query and chain selection from response
query_reformulation_result, chain_selection_result = self.parse_response(response)
# Create execution plan and provide reformulated query to each execution unit as its initial state
parsed = [self._parse_line(line, self._chains) for line in chain_selection_result.splitlines()]
filtered = [r.unwrap() for r in parsed if r.is_some() and r.unwrap()[1] > self._response_threshold]
if (filtered is None) or (len(filtered) == 0):
    return []

filtered.sort(key=lambda item: item[1], reverse=True)
# List of chain names to be executed by agent in current iteration
self.chains_current_iteration = [chain.name for chain, _ in filtered]
result = [
    ExecutionUnit(r[0], context.budget, initial_state=ChatMessage.chain(query_reformulation_result))
    for r in filtered
    if r is not None
    ]

We can now bring everything together for the controller's `_execute` method below.

In [None]:
def _execute(self, context: AgentContext) -> List[ExecutionUnit]:
    """Generates an execution plan for the agent based on the provided context, chains, and budget."""
    answer_choices = "\n ".join([f"name: {c.name}, description: {c.description}" for c in self._chains])
    # Load prompts for LLM and substitute parameters
    system_prompt = "You are an assistant responsible to identify the intent of the user."
    
    controller_prompt = Template("""
    Use the latest user query and the conversational history to identify the intent of the user. 
    Break this task down into 2 subtasks. First perform subtask 1 and then subtask 2.
    
    Context for subtask 1:
    Conversational history:
    $conversational_history
    
    User query: $user_query
    
    Instructions for subtask 1:
    # Use the historical conversation to update the user query to better answer the user question
    # If the query does not need to be updated, do not update the query
    # If there is no conversational history, do not update the query
    # If the conversational history is not relevant to the query, do not update the query
    # See the below examples for how to update the user query
    ************
    Example 1:
    Conversational History:
    User: Who is the CEO of OpenAI?
    Assistant: Sam Altman
    
    User Query: How old is he?
    
    Updated Query: How old is Sam Altman?
    ************
    Example 2:
    Conversational History:
    User: Who is the CEO of OpenAI?
    Assistant: Sam Altman
    
    User Query: What is the price of Bitcoin?
    
    Updated Query: What is the price of Bitcoin?
    ************
    
    Context for subtask 2: 
    Categories are given as a name and a category (name: {name}, description: {description}):
    $answer_choices
    
    Instructions for subtask 2:
    # Use the updated query to identify the intent of the user
    # score categories out of 10 using there description
    # For each category, you will answer with {name};{score};{short justification}"
    # The updated query should be identical for each category
    # Each response is provided on a new line
    # When no category is relevant, you will answer exactly with 'unknown'
                                    
    Your response should always be formatted like this:
    Subtask 1: {updated_query}
    ---
    Subtask 2:
    {subtask2_results}
    """)
    
    user_prompt = controller_prompt.substitute(
        conversational_history=self.build_chat_history(context),
        user_query=context.chat_history.last_user_message.message,
        answer_choices=answer_choices,
    )
    # Send messages and receive response from model
    messages = [
        LLMMessage.system_message(system_prompt),
        LLMMessage.user_message(user_prompt),
    ]
    response = self._llm.post_chat_request(messages)[0]
    logger.debug(f"llm response: {response}")
    # Separate reformulated query and chain selection from response
    query_reformulation_result, chain_selection_result = self.parse_response(response)
    # Create execution plan and provide reformulated query to each execution unit as its initial state
    parsed = [self.parse_line(line, self._chains) for line in chain_selection_result.splitlines()]
    filtered = [r.unwrap() for r in parsed if r.is_some() and r.unwrap()[1] > self._response_threshold]
    if (filtered is None) or (len(filtered) == 0):
        return []

    filtered.sort(key=lambda item: item[1], reverse=True)
    result = [
        ExecutionUnit(r[0], context.budget, initial_state=ChatMessage.chain(query_reformulation_result))
        for r in filtered
        if r is not None
    ]

    return result[: self._top_k]

## Prompts for Filter class

The prompts below are used for the LLM call in the custom filter.

The `filter_prompt` instructs the LLM to write a research report using the information provided by the successfully completed chains.

In [None]:
filter_system_prompt = "You are a financial analyst whose job is to write a research report answering the user query based on data about $company from different sources."

filter_prompt = """
# Instructions
- The provided context is a list of research data answering the user query from different sources.
- Combine the following data from muliple sources into a single research report to answer the query.
- Make sure to highlight any agreements or disagreements between different responses in the final response.
- Explicitly state from which source different parts of the final response are from.

# Context:
$context

# Query:
$query

Answer:
"""

## Selected responses

The code below creates the `context` for the LLM prompt. It iterates through the chain results (that will be provided by the Evaluator) and uses the chain name from each result to map to the chain description to explain the source of the response to the LLM. If a chain was not executed during the current iteration, then it is skipped during the aggregation.

In [None]:
def _build_llm_messages(self, context: AgentContext) -> List[LLMMessage]:
    agent_messages = list(context.evaluation)
    query = context.chat_history.last_user_message.message
    context = ""
    for message in agent_messages:
        context += f"Response: {message.message.message}\n\n"

    filter_prompt = Template("""
    # Instructions
    - The provided context is a list of research data answering the user query from different sources.
    - Combine the following data from multiple sources into a single research report to answer the query.
    - Make sure to highlight any agreements or disagreements between different responses in the final response.
    - Explicitly state from which source different parts of the final response are from.
    
    # Context:
    $context
    
    # Query:
    $query
    
    Answer:
    """)
    prompt = filter_prompt.substitute(
        context=context, query=query
    )
    return [
        self._build_system_prompt(company=constants.COMPANY_NAME),
        LLMMessage.user_message(prompt),
    ]

def _build_system_prompt(self, company: str) -> LLMMessage:
    system_prompt = Template(
        "You are a financial analyst whose job is to write a research report answering the user query based on data about $company from different sources."
    ).substitute(company=company)
    return LLMMessage.system_message(system_prompt)


We can now bring everything together for the filter's `_execute` method below. It returns the LLM responses that will be sent back to the user.

In [None]:
def _execute(self, context: AgentContext) -> List[ScoredChatMessage]:
    """Selects responses from the agent's context."""
    messages = self._build_llm_messages(context)
    llm_response = self._llm.inner.post_chat_request(LLMContext.from_context(context, self._llm), messages=messages)

    return [ScoredChatMessage(ChatMessage.agent(llm_response.first_choice), 1.0)]

## Complete Controller Implementation

In [None]:
class Controller(LLMController):
    """
    A controller that uses an LLM to decide the execution plan and
    reformulates the user query based on the conversational history.

    Based on LLMController: https://github.com/chain-ml/council/blob/main/council/controllers/llm_controller.py
    """

    def __init__(self, chains: List[Chain], llm: LLMBase, response_threshold: float):
        """
        Initialize a new instance

        Parameters:
            llm (LLMBase): the instance of LLM to use
            response_threshold (float): a minimum threshold to select a response from its score
        """
        super().__init__(chains, llm, response_threshold)

    def _execute(self, context: AgentContext) -> List[ExecutionUnit]:
        """Generates an execution plan for the agent based on the provided context, chains, and budget."""
        response = self._call_llm(context)
        # Separate reformulated query and chain selection from response
        query_reformulation_result, chain_selection_result = self.parse_response(response)
        # Create execution plan and provide reformulated query to each execution unit as its initial state
        parsed = [self._parse_line(line, self._chains) for line in chain_selection_result.splitlines()]
        filtered = [r.unwrap() for r in parsed if r.is_some() and r.unwrap()[1] > self._response_threshold]
        if (filtered is None) or (len(filtered) == 0):
            return []

        filtered.sort(key=lambda item: item[1], reverse=True)
        result = [
            ExecutionUnit(r[0], context.budget, initial_state=ChatMessage.chain(query_reformulation_result))
            for r in filtered
            if r is not None
        ]

        return result[: self._top_k]

    def _build_llm_messages(self, context):
        answer_choices = "\n ".join([f"name: {c.name}, description: {c.description}" for c in self._chains])
        # Load prompts for LLM and substitute parameters
        system_prompt = "You are an assistant responsible to identify the intent of the user."
        controller_prompt = Template("""
        Use the latest user query and the conversational history to identify the intent of the user. 
        Break this task down into 2 subtasks. First perform subtask 1 and then subtask 2.
        
        Context for subtask 1:
        Conversational history:
        $conversational_history
        
        User query: $user_query
        
        Instructions for subtask 1:
        # Use the historical conversation to update the user query to better answer the user question
        # If the query does not need to be updated, do not update the query
        # If there is no conversational history, do not update the query
        # If the conversational history is not relevant to the query, do not update the query
        # See the below examples for how to update the user query
        ************
        Example 1:
        Conversational History:
        User: Who is the CEO of OpenAI?
        Assistant: Sam Altman
        
        User Query: How old is he?
        
        Updated Query: How old is Sam Altman?
        ************
        Example 2:
        Conversational History:
        User: Who is the CEO of OpenAI?
        Assistant: Sam Altman
        
        User Query: What is the price of Bitcoin?
        
        Updated Query: What is the price of Bitcoin?
        ************
        
        Context for subtask 2: 
        Categories are given as a name and a category (name: {name}, description: {description}):
        $answer_choices
        
        Instructions for subtask 2:
        # Use the updated query to identify the intent of the user
        # score categories out of 10 using there description
        # For each category, you will answer with {name};{score};short justification"
        # The updated query should be identical for each category
        # Each response is provided on a new line
        # When no category is relevant, you will answer exactly with 'unknown'
                                        
        Your response should always be formatted like this:
        Subtask 1: {updated_query}
        ---
        Subtask 2:
        {subtask2_results}
        """)
        user_prompt = controller_prompt.substitute(
            conversational_history=self.build_chat_history(context),
            user_query=context.chat_history.last_user_message.message,
            answer_choices=answer_choices,
        )
        # Send messages and receive response from model
        messages = [
            LLMMessage.system_message(system_prompt),
            LLMMessage.user_message(user_prompt),
        ]
        return messages

    @staticmethod
    def build_chat_history(context: AgentContext, max_history_len: int = 4) -> str:
        """Format the chat history into a string that can be added to the prompt for the query reformulation model."""
        chat_history = ""
        # Remove the user's most recent message from the chat history
        message_history = list(context.chat_history.messages[:-1])

        # Return no history if there are less than 2 messages
        if len(message_history) < 1:
            return "No conversational history"

        for msg in message_history[-max_history_len:]:
            if msg.is_of_kind(ChatMessageKind.User):
                chat_history += f"User: {msg.message}\n"
            if msg.is_of_kind(ChatMessageKind.Agent):
                chat_history += f"Assistant: {msg.message}\n"

        return chat_history

    @staticmethod
    def parse_response(response: str) -> Tuple[str, str]:
        """Function to separate reformulated query and chain selection from LLM response."""
        query_reformulation_response = (
            response.split("---")[0].replace("Subtask 1:", "").replace("Subtask 1: ", "").strip()
        )
        chain_selection_response = response.split("---")[1].replace("Subtask 2:", "").replace("Subtask 2: ", "").strip()

        return query_reformulation_response, chain_selection_response


## Complete Filter Implementation

In [None]:
class LLMFilter(FilterBase):

    def __init__(self, llm: LLMBase):
        super().__init__()
        self._llm = self.new_monitor("llm", llm)

    def _execute(self, context: AgentContext) -> List[ScoredChatMessage]:
        """Selects responses from the agent's context."""
        messages = self._build_llm_messages(context)
        llm_response = self._llm.inner.post_chat_request(LLMContext.from_context(context, self._llm), messages=messages)

        return [ScoredChatMessage(ChatMessage.agent(llm_response.first_choice), 1.0)]

    def _build_llm_messages(self, context) -> List[LLMMessage]:
        agent_messages = context.evaluation
        query = context.chat_history.last_user_message.message
        context = ""
        for message in agent_messages:
            context += f"Response: {message.message.message}\n\n"

        filter_prompt = Template("""
        # Instructions
        - The provided context is a list of research data answering the user query from different sources.
        - Combine the following data from multiple sources into a single research report to answer the query.
        - Make sure to highlight any agreements or disagreements between different responses in the final response.
        - Explicitly state from which source different parts of the final response are from.
        
        # Context:
        $context
        
        # Query:
        $query
        
        Answer:
        """)
        prompt = filter_prompt.substitute(
            context=context, query=query
        )
        return [
            self._build_system_prompt(company=constants.COMPANY_NAME),
            LLMMessage.user_message(prompt),
        ]

    def _build_system_prompt(self, company: str) -> LLMMessage:
        system_prompt = Template(
            "You are a financial analyst whose job is to write a research report answering the user query based on data about $company from different sources."
        ).substitute(company=company)
        return LLMMessage.system_message(system_prompt)


In the [next part](./4_financial_analyst_agent.ipynb), we will put everything together, add a few new components, and complete our financial analyst.  