In [41]:
import os
import json
import boto3
import random
import uuid
from time import time
from datetime import datetime
from typing import List, Dict, Any
from dotenv import load_dotenv
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.chains import create_retrieval_chain
from langchain_core.prompts import MessagesPlaceholder
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_history_aware_retriever
from langchain_core.messages import HumanMessage, AIMessage
from langchain_aws import ChatBedrock

In [42]:
load_dotenv()

class Config:
    AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
    AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
    KB_ID = os.getenv("KB_ID")
    MODEL_ARN = os.getenv("MODEL_ARN")
    DYNAMODB_TABLE_NAME = os.getenv("DYNAMODB_TABLE_NAME")
    AWS_REGION = os.getenv("AWS_REGION", "us-east-1")

# Initialize AWS Bedrock Runtime Client
client = boto3.client("bedrock-agent-runtime", region_name=Config.AWS_REGION)

In [43]:
llm = ChatBedrock(
    model="anthropic.claude-3-5-sonnet-20240620-v1:0",
    region_name=Config.AWS_REGION,
    beta_use_converse_api=True,
    streaming=True
)

In [44]:
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

def generate_modified_queries(question: str):
    QUERY_PROMPT = PromptTemplate(
        input_variables=["question"],
        template="""You are an AI language model assistant. Your task is to generate five 
        different versions of the given user question to retrieve relevant documents from a 
        vector database. Provide these alternative questions separated by newlines. And only respond the questions along with the original question asked,
        do not include any additional text or explanations.
        Original question: {question}"""
    )

    chain = LLMChain(llm=llm, prompt=QUERY_PROMPT)
    result = chain.run({"question": question})
    return result.strip().split("\n")

In [45]:
user_question = "What are the benefits of our company's remote work policy?"
modified_queries = generate_modified_queries(user_question)

for q in modified_queries:
    print(q)

What are the benefits of our company's remote work policy?
What advantages does our organization offer through its work-from-home policy?
How does our company's flexible work arrangement benefit employees?
What are the positive aspects of our firm's telecommuting policy?
In what ways does our remote work program enhance employee experience?
What perks do staff members enjoy from our company's virtual work setup?


In [46]:
from boto3.dynamodb.conditions import Key,Attr
dynamodb = boto3.resource('dynamodb',region_name=Config.AWS_REGION)
TABLE_NAME = Config.DYNAMODB_TABLE_NAME
table = dynamodb.Table(TABLE_NAME)

class ChatHistoryManager:
    def __init__(self, user_id: str = "default", session_id: str = "default_session"):
        self.user_id = user_id
        self.session_id = session_id

    def load_history(self) -> List[Dict[str, Any]]:
        today = datetime.now().date()
        start_time = datetime.combine(today, datetime.min.time()).isoformat()
        end_time = datetime.combine(today, datetime.max.time()).isoformat()

        try:
            response = table.query(
                KeyConditionExpression=Key('user_id').eq(self.user_id) & Key('timestamp').between(start_time, end_time),
                FilterExpression=Attr('session_id').eq(self.session_id)
            )
            return response.get('Items', [])
        except Exception as e:
            print("Error loading from DynamoDB:", e)
            return []

    def save_history(self, history: List[Dict[str, Any]]):
        for entry in history:
            item = {
                'user_id': self.user_id,
                'timestamp': entry.get("timestamp", datetime.now().isoformat()),  # sort key
                'session_id': entry.get("session_id", self.session_id),
                'user_message': entry.get("user_message", ""),
                'assistant_response': entry.get("assistant_response", ""),
                'summarized': entry.get("summarized", False)
            }

            if "summary" in entry:
                item["summary"] = entry["summary"]
            if "summary_of" in entry:
                item["summary_of"] = entry["summary_of"]

            try:
                table.put_item(Item=item)
            except Exception as e:
                print("Error writing to DynamoDB:", e)

    def summarize_if_needed(self, history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        unsummarized_blocks = [
            entry for entry in history
            if not entry.get("summarized", False)
            and entry.get("user_message") and entry.get("assistant_response")
        ]

        if len(unsummarized_blocks) < 5:
            return history

        history_text = "\n".join(
            f"User: {entry['user_message']}\nAssistant: {entry['assistant_response']}"
            for entry in unsummarized_blocks[:10]
        )

        summary_prompt = f"""
        Summarize the following 10 interactions into one concise but informative summary:
        {history_text}
        """

        summary = llm.invoke(summary_prompt.strip())
        if isinstance(summary, AIMessage):
            summary = summary.content

        summary_entry = {
            "role": "system",
            "summary": summary,
            "timestamp": datetime.now().isoformat(),
            "session_id": self.session_id,
            "summarized": True,
            "summary_of": [entry["timestamp"] for entry in unsummarized_blocks[:10]]
        }

        history = [entry for entry in history if entry not in unsummarized_blocks[:10]]
        history.insert(0, summary_entry)
        return history

    def append_chat_pair(self, history: List[Dict[str, Any]], user_msg: str, assistant_msg: str) -> List[Dict[str, Any]]:
        new_entry = {
            "user_message": user_msg,
            "assistant_response": assistant_msg,
            "timestamp": datetime.now().isoformat(),
            "session_id": self.session_id,
            "summarized": False
        }

        history.append(new_entry)
        self.save_history([new_entry])  # Save only the new one
        return self.summarize_if_needed(history)

In [38]:
def get_dynamic_prompt(user_input: str, history: List) -> PromptTemplate:
    sensitive_keywords = ["complaint", "harassment", "grievance", "termination"]
    policy_keywords = ["policy", "rule", "guideline"]
    benefit_keywords = ["benefit", "pto", "leave", "insurance"]
    
    if any(kw in user_input.lower() for kw in sensitive_keywords):
        instructions = "This is a sensitive topic. Be professional and direct the user to official HR channels if appropriate."
    elif any(kw in user_input.lower() for kw in policy_keywords):
        instructions = "Provide exact policy details with reference to the policy document when possible."
    elif any(kw in user_input.lower() for kw in benefit_keywords):
        instructions = "Include eligibility requirements and any limitations for benefits mentioned."
    else:
        instructions = "Respond helpfully and professionally."
    
    template = f"""You are an HR assistant for a company. Use the following context to answer the question at the end.
If you don't know the answer, say you don't know. Be concise but helpful.

Context:
{{context}}

Conversation history:
{{chat_history}}

Question: {{input}}

Considerations:
1. {instructions}
2. Format lists and important details clearly
3. Provide sources when available

Answer:"""
    
    return PromptTemplate.from_template(template)


In [47]:
def retrieve_from_kb(input):
    try:
        response = client.retrieve_and_generate(
            input={
                'text': input
            },
            retrieveAndGenerateConfiguration={
                'type': 'KNOWLEDGE_BASE',
                'knowledgeBaseConfiguration': {
                    'knowledgeBaseId': Config.KB_ID,
                    'modelArn': Config.MODEL_ARN
                }
            }
        )
        return response
    except Exception as e:
        print(f"Error retrieving knowledge base: {e}")
        return None


In [48]:
def rewrite_question_with_history(user_input: str, history: List[Dict[str, Any]]) -> str:
    """
    Rewrites the user's question incorporating relevant context from the chat history.
    """
    if not history:
        return user_input
    
    # Convert history to conversational format
    conversation_history = []
    for entry in history[-5:]:  # Only use last 5 exchanges to avoid context overload
        if "user_message" in entry:
            conversation_history.append(HumanMessage(content=entry["user_message"]))
        if "assistant_response" in entry:
            conversation_history.append(AIMessage(content=entry["assistant_response"]))
    
    # Create a prompt to rewrite the question with context
    rewrite_prompt = ChatPromptTemplate.from_messages([
        MessagesPlaceholder(variable_name="chat_history"),
        ("user", "Given the conversation history, rephrase and expand this question to be more specific and clear: {input}. Only respond with the rewritten question."),
    ])
    
    rewrite_chain = rewrite_prompt | llm
    rewritten_question = rewrite_chain.invoke({
        "chat_history": conversation_history,
        "input": user_input
    })
    
    if isinstance(rewritten_question, AIMessage):
        return rewritten_question.content
    return rewritten_question


def chat(user_input: str, user_id: str = "default", session_id: str = "default_session") -> str:
    history_manager = ChatHistoryManager(user_id, session_id)
    chat_history = history_manager.load_history()
    
    # Step 1: Rewrite the question with history context
    rewritten_question = rewrite_question_with_history(user_input, chat_history)
    
    # Step 2: Generate multiple query variations
    modified_queries = generate_modified_queries(rewritten_question)
    
    # Step 3: Retrieve from knowledge base for each query variation
    all_responses = []
    for query in modified_queries:
        kb_response = retrieve_from_kb(query)
        if kb_response and 'output' in kb_response and kb_response['output']['text']:
            all_responses.append(kb_response['output']['text'])
    
    # Combine all retrieved context
    combined_context = "\n\n".join(all_responses) if all_responses else "No relevant context found."
    
    # Step 4: Generate final response with dynamic prompt
    prompt_template = get_dynamic_prompt(user_input, chat_history)
    prompt = prompt_template.format(
        context=combined_context,
        chat_history="\n".join([
            f"User: {entry['user_message']}\nAssistant: {entry['assistant_response']}" 
            for entry in chat_history[-5:]  # Only show last 5 exchanges
            if 'user_message' in entry and 'assistant_response' in entry
        ]),
        input=user_input
    )
    
    # Get final response from LLM
    final_response = llm.invoke(prompt)
    if isinstance(final_response, AIMessage):
        final_response = final_response.content
    
    # Step 5: Save the interaction to history
    history_manager.append_chat_pair(chat_history, user_input, final_response)
    
    return final_response

In [None]:
# First question
response1 = chat("What is our remote work policy?", "user123", "session456")
print(response1)


In [None]:
chat("what do you know about the CEO of AyataCommerce?","123","1234")

In [None]:
chat("what do you know about him?","123","1234")