# Google Authentication

In [1]:
import sys
if 'google.colab' in sys.modules:
    #!pip install langchain google-cloud-aiplatform rich
    google_auth.authenticate_user()

## compass_prompt.py

In [2]:
prefix = """
You should answer the following questions with up-to-date information. You have access to the following tools:"""

format_instruction = """You must use the following format for all responses or your response will be considered incorrect:

Question: the input question you must answer

Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action in a single sentence.
... (this Thought/Action/Action Input/Observation can repeat N times but Question should only appear once)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
If you do not know the answer, you can say so.
"""

suffix = """Begin!"

{chat_history}
Question: {input}
{agent_scratchpad}

"""
zero_shot_prompt = {
    "prefix": prefix,
    "format_instruction": format_instruction,
    "suffix": suffix,
}


## compass_toolkit.py

In [3]:
from dotenv import load_dotenv
import requests
import os

from typing import TYPE_CHECKING, List, Optional

from pydantic import Field

from langchain.tools.base import BaseTool
from langchain.agents.agent_toolkits.base import BaseToolkit

from langchain.callbacks.manager import CallbackManagerForToolRun

load_dotenv()

tool_microservice_url = str(os.getenv("TOOL_URL"))


class CompassToolkit(BaseToolkit):
    def get_tools(self) -> List[BaseTool]:
        """Get the tools in the toolkit."""
        return [
            ProfessorSearchResults(),
            CoursesSearchResults(),
            DegreesSearchResults(),
            GeneralSearchResults(),
        ]


def request_error_handler(url: str, params={}) -> str:
    try:
        response = requests.get(url, params)
        response.raise_for_status()  # Raise an HTTPError if the status code is not in the 200 range
        data = response.json()
        return data
    except requests.exceptions.RequestException as e:
        print(e)
        return e


class ProfessorSearchResults(BaseTool):
    name = "get_professor_rating_and_classes_taught"
    description = (
        "a search engine on professor of UT Dallas on RateMyProfessor database"
        "useful for when you need to answer questions about professors ratings, difficulty, and class taught."
        "will not return contact information, use the general_search tool for that."
        "Input should be a First, Last or Full name of the professor without greeting prefix"
        "Return will be full name, courses taught, overall rating, and difficulty rating"
    )

    def _run(
        self, name: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        url = f"https://{tool_microservice_url}/get-professor-rmp/{name}"
        return request_error_handler(url)

    async def _arun(
        self, name: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        raise NotImplementedError("does not support async yet")


class CoursesSearchResults(BaseTool):

    name = "course_search"
    description = (
        "a search engine on course database of UT Dallas"
        "useful for when you need to search for answer about courses."
        "Input should be a search query"
        "Return will be multiple results with course title and snippet"
    )

    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        url = f"https://{tool_microservice_url}/get-possible-courses/{query}"
        return request_error_handler(url)

    async def _arun(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        raise NotImplementedError("does not support async yet")


class DegreesSearchResults(BaseTool):

    name = "college_degree_search"
    description = (
        "a search engine on college degree database of UT Dallas"
        "useful for when you need to search for answer about college degrees."
        "Input should be a search query"
        "Return will be multiple results with title, and snippet of the degree"
    )

    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        url = f"https://{tool_microservice_url}/get-degree-info/{query}"
        return request_error_handler(url)

    async def _arun(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        raise NotImplementedError("does not support async yet")


class GeneralSearchResults(BaseTool):
    name = "general_search"
    description = (
        "a search engine for general information about UT Dallas"
        "useful for when you need to search for answer related to professor(s), staff(s), school(s), department(s), and UT Dallas"
        "Searching for courses or college degrees are discouraged as there are better tools"
        "Input should be a search query"
        "Return will be multiple results with title, link and snippet"
    )

    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        url = f"https://{tool_microservice_url}/search/{query}"
        return request_error_handler(url)

    async def _arun(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        raise NotImplementedError("does not support async yet")


class DictionaryRun(BaseTool):
    name = "get_definition_of_word"
    description = "a dictionary for simple word" "Input should be word or phrases"

    def _run(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        url = f"https://{tool_microservice_url}/dictionary/{query}"
        return request_error_handler(url)

    async def _arun(
        self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        url = f"https://{tool_microservice_url}/dictionary/{query}"
        return await request_error_handler(url)


## compass_agent.py 
### Zero Shot

In [14]:
from langchain.agents import ZeroShotAgent, AgentExecutor
from langchain.experimental.plan_and_execute import PlanAndExecute, load_agent_executor, load_chat_planner

from langchain import LLMChain

#from compass_prompt import zero_shot_prompt #! uncomment this line when using compass_prompt.py


class CompassAgent:
    
    def __init__(self, llm, tools, memory) -> None:
        self.__init_plan_and_execute__(llm, tools, memory)
    
    def __init_zero_shot__(self, llm, tools, memory) -> None:
        self.prompt = ZeroShotAgent.create_prompt(
            tools=tools,
            prefix=zero_shot_prompt["prefix"],
            suffix=zero_shot_prompt["suffix"],
            format_instructions=zero_shot_prompt["format_instruction"],
            input_variables=["input", "chat_history", "agent_scratchpad"],
        )
        llm_chain = LLMChain(llm=llm, prompt=self.prompt)
        agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
        self.agent_chain = AgentExecutor.from_agent_and_tools(
            agent=agent,
            tools=tools,
            verbose=True,
            memory=memory,
            handle_parsing_errors=False,
        )
        
    def __init_plan_and_execute__(self, llm, tools, memory) -> None:
        planner = load_chat_planner(llm=llm)
        executor = load_agent_executor(llm=llm, tools=tools, verbose=False)
        self.agent_chain = PlanAndExecute(planner=planner, executor=executor, memory=memory)
        

    def _run(self, input: str) -> str:
        try:
            return self.agent_chain.run(input=input)
        except ValueError as e:
            return self.error_handler(e)

    async def _arun(self, input: str) -> str:
        try:
            return self.agent_chain.arun(input=input)
        except ValueError as e:
            return self.error_handler(e)

    def error_handler(self, e):
        response = str(e)
        prefix = "Could not parse LLM output: `"
        if not response.startswith(prefix):
            raise e
        response = response.removeprefix(prefix).removesuffix("`")
        return response


## compass_inference.py


In [15]:
import os

#from compass_agent import CompassAgent #! uncomment this line when using compass_agent.py
#from compass_toolkit import CompassToolkit #! uncomment this line when using compass_toolkit.py

from dotenv import load_dotenv

from google.cloud import aiplatform

from langchain.llms import VertexAI
from langchain.memory import ConversationBufferMemory, MongoDBChatMessageHistory
from langchain.schema.messages import AIMessage, HumanMessage

load_dotenv()


class CompassInference:
    def __init__(self) -> None:
        aiplatform.init(project="aerobic-gantry-387923", location="us-central1")
        self.vertex = VertexAI(
            temperature=0, 
            max_tokens=1024, 
            top_p=0.95, 
            top_k=40)
        self.tools = CompassToolkit().get_tools()

    def _run(self, user_message: str, mongodb_past_history) -> str:
        
        clone_memory = self.clone_message_history(mongodb_past_history)
        agent = CompassAgent(llm=self.vertex, tools=self.tools, memory=clone_memory)

        bot_message = agent._run(user_message)

        return bot_message

    def clone_message_history(
        self, message_history: MongoDBChatMessageHistory
    ) -> ConversationBufferMemory:
        memory_clone = ConversationBufferMemory(memory_key="chat_history")
        try:
            for message in message_history.messages:
                if isinstance(message, AIMessage):
                    memory_clone.chat_memory.add_ai_message(message.content)
                elif isinstance(message, HumanMessage):
                    memory_clone.chat_memory.add_user_message(message.content)
        except Exception as e:
            pass

        return memory_clone


### Testing area


In [16]:
connection_string = (
    f"mongodb+srv://{str(os.getenv('MONGODB_LOGIN'))}@compass-utd.gc5s9o8.mongodb.net"
)

In [17]:
#! Rerun this cell to reset memory
import random
import string



message_history = MongoDBChatMessageHistory(
        #make random session id
        connection_string=connection_string, session_id="test_".join(random.choices(string.ascii_uppercase +
                                                                    string.digits, k=5))
    )

agent = CompassInference()

In [18]:

response = agent._run("What is the different between Math 2417 and Math 2413", message_history)


OutputParserException: Could not parse LLM output: Question: What is the course catalog?
Action:
```
{
  "action": general_search,
  "action_input": "course catalog"
}
```
