In [None]:
import os
import openai
from azure.identity import DefaultAzureCredential
from azure.search.documents import SearchClient
from azure.search.documents.models import QueryType
from langchain.agents import Tool, AgentExecutor
from langchain.agents.react.base import ReActDocstoreAgent
from langchain.llms.openai import AzureOpenAI
from langchain.prompts import PromptTemplate, BasePromptTemplate
from typing import List

# Replace these with your own values, either in environment variables or directly here
AZURE_STORAGE_ACCOUNT = os.environ.get("AZURE_STORAGE_ACCOUNT") or "mystorageaccount"
AZURE_STORAGE_CONTAINER = os.environ.get("AZURE_STORAGE_CONTAINER") or "content"
AZURE_SEARCH_SERVICE = os.environ.get("AZURE_SEARCH_SERVICE") or "gptkb"
AZURE_SEARCH_INDEX = os.environ.get("AZURE_SEARCH_INDEX") or "gptkbindex"
AZURE_OPENAI_SERVICE = os.environ.get("AZURE_OPENAI_SERVICE") or "myopenai"
AZURE_OPENAI_GPT_DEPLOYMENT = os.environ.get("AZURE_OPENAI_GPT_DEPLOYMENT") or "davinci"

KB_FIELDS_CONTENT = os.environ.get("KB_FIELDS_CONTENT") or "content"
KB_FIELDS_CATEGORY = os.environ.get("KB_FIELDS_CATEGORY") or "category"
KB_FIELDS_SOURCEPAGE = os.environ.get("KB_FIELDS_SOURCEPAGE") or "sourcepage"

# Use the current user identity to authenticate with Azure OpenAI, Cognitive Search and Blob Storage (no secrets needed, 
# just use 'az login' locally, and managed identity when deployed on Azure). If you need to use keys, use separate AzureKeyCredential instances with the 
# keys for each service
azure_credential = DefaultAzureCredential()

# Used by the OpenAI SDK
openai.api_type = "azure"
openai.api_base = f"https://{AZURE_OPENAI_SERVICE}.openai.azure.com"
openai.api_version = "2022-12-01"

# Comment these two lines out if using keys, set your API key in the OPENAI_API_KEY environment variable instead
openai.api_type = "azure_ad"
openai.api_key = azure_credential.get_token("https://cognitiveservices.azure.com/.default").token

# Set up clients for Cognitive Search and Storage
search_client = SearchClient(
    endpoint=f"https://{AZURE_SEARCH_SERVICE}.search.windows.net",
    index_name=AZURE_SEARCH_INDEX,
    credential=azure_credential)

In [None]:
# Modified version of langchain's ReAct prompt that includes instructions and examples for how to cite information sources

EXAMPLES = [
    """Question: Which type of data should be sent from video cameras in a native binary format?
Thought: I need to search Identify data formats, find the Unstructured data.
Action: Search[Identify data formats]
Observation: <Explore file storage - Training _ Microsoft Learn-2.pdf> According to the source, video data should be sent from video cameras in a native binary format.
""",
]
SUFFIX = """\nQuestion: {input}
{agent_scratchpad}"""
PREFIX = "Answer questions as shown in the following examples, by splitting the question into individual search or lookup actions to find facts until you can answer the question. " \
"Observations are prefixed by their source name in angled brackets, source names MUST be included with the actions in the answers." \
"All questions must be answered from the results from search or look up actions, only facts resulting from those can be used in an answer. "
"Answer questions as truthfully as possible, and ONLY answer the questions using the information from observations, do not speculate or your own knowledge."

prompt = PromptTemplate.from_examples(
    EXAMPLES, SUFFIX, ["input", "agent_scratchpad"], PREFIX
)

In [None]:
# Exclude category, to simulate scenarios where there's a set of docs you can't see
exclude_category = None

def search(terms: str):
    print ("\nsearching: " + terms)
    # Optionally enable captions for summaries by adding optional arugment query_caption="extractive|highlight-false"
    # and adjust the string formatting below to include the captions from the @search.captions field
    filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
    r = search_client.search(terms, 
                             filter=filter,
                             top = 3,
                             query_type=QueryType.SEMANTIC, 
                             query_language="en-us", 
                             query_speller="lexicon", 
                             semantic_configuration_name="default")
    return "\n".join([f"<{doc[KB_FIELDS_SOURCEPAGE]}> " + (doc[KB_FIELDS_CONTENT][:500]).replace("\n", "").replace("\r", "") for doc in r])

def lookup(terms: str):
    print ("\nlooking up: " + terms)
    filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
    r = search_client.search(terms, 
                             filter=filter,
                             top = 1,
                             include_total_count=True,
                             query_type=QueryType.SEMANTIC, 
                             query_language="en-us", 
                             query_speller="lexicon", 
                             semantic_configuration_name="default",
                             query_answer="extractive|count-1",
                             query_caption="extractive|highlight-false")
    answers = r.get_answers()
    if len(answers) > 0:
        return answers[0].text
    if r.get_count() > 0:
        return "\n".join(c.text for c in next(r)["@search.captions"])
    return None

llm = AzureOpenAI(deployment_name=AZURE_OPENAI_GPT_DEPLOYMENT, temperature=0.3, openai_api_key=openai.api_key)
tools = [
    Tool(name="Search", func=search, description="useful for when you need to ask with search"),
    Tool(name="Lookup", func=lookup, description="useful for when you need to ask with lookup"  )
]

class ReAct(ReActDocstoreAgent):
    @classmethod
    def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
        return prompt

agent = ReAct.from_llm_and_tools(llm, tools)
chain = AgentExecutor.from_agent_and_tools(agent, tools, verbose=True)

In [None]:
# Exclude category, to simulate scenarios where there's a set of docs you can't see
exclude_category = None

chain.run("Ask me a question on DP-900 Microsoft Azure Data Fundamentals")