In [1]:
import sys
import os
import json
from dotenv import load_dotenv
from pymongo import MongoClient
from pathlib import Path

load_dotenv(override=True)

OPENAI_API_KEY = os.environ["OPENAI_API_CHATBOT_TEST_KEY_INTERNAL"]
MONGO_URI = os.environ["MONGO_URI"]
EMBEDDING_MODEL_NAME = os.environ["EMBEDDING_MODEL_NAME"]
EMBEDDING_DIMENSIONS = os.environ["EMBEDDING_DIMENSIONS"]
CHAT_MODEL_NAME = os.environ["CHAT_MODEL_NAME"]
os.environ["OPENAI_API_KEY"] = os.environ["OPENAI_API_CHATBOT_TEST_KEY_INTERNAL"]

DB_NAME = "wcc"
COLLECTION_NAME = "documents"
ATLAS_VECTOR_SEARCH_INDEX_NAME = "vector_index"
VECTOR_SIMILARITY_FUNCTION = "dotProduct"
# VECTOR_SIMILARITY_FUNCTION = "cosine"



# RETRIEVER_SEARCH_TYPE="similarity_score_threshold"
RETRIEVER_SEARCH_TYPE="mmr"

MAX_CHUNKS_TO_RETRIEVE=5
SST_CHUNK_MIN_RELEVANCE_SCORE=0.2
MMR_FETCH_K = 500
MMR_LAMBDA_MULT = 0.5



MAX_TOKENS_FOR_RESPONSE = 2000
CHAT_MODEL_TEMPERATURE=0
CHAT_MODEL_FREQ_PENALTY=0.2
CHAT_MODEL_PRES_PENALTY=0.2
SHOW_VERBOSE=True
MAX_TOKENS_FOR_HISTORY = 300


PARENT_PATH = Path.cwd().parent
EVA_SETTINGS_PATH = PARENT_PATH / 'evasettings'
EVA_SETTINGS_ENVIRONMENT_DIRECTORY = 'local'

In [2]:
models_path = PARENT_PATH / 'scripts' / 'models'
vectordatabases_path = PARENT_PATH / 'scripts' / 'vectordatabases'
temp_data_path =  PARENT_PATH / 'data' / 'temp'

if str(models_path) not in sys.path:
    sys.path.append(str(models_path))
if str(vectordatabases_path) not in sys.path:
    sys.path.append(str(vectordatabases_path))

from models import model_rag
from vectordatabases import BaseDB

In [18]:
from langchain.vectorstores import MongoDBAtlasVectorSearch
from langchain.chains import RetrievalQAWithSourcesChain, LLMChain, ConversationChain
from langchain_core.runnables import RunnableSequence
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationSummaryBufferMemory, ChatMessageHistory
from langchain.schema import HumanMessage, SystemMessage
import re
import json
import time
import tiktoken
import pandas as pd
from itertools import takewhile
from datetime import datetime
from collections import defaultdict

class RAG:
    def __init__(self, chat_data):
        self.chat_data = chat_data
        self.llm_streaming_callback_handler = None
        self.eva_analytics = []

    ## Public Methods

    def _vector_search(self):
        self._get_vector_store()
        self._paraphrase_query()
        
        unranked_documents = []
        stats_dict = defaultdict(lambda: {
            "total_relevancy": 0,
            "chunk_count": 0,
            "chunks": {}
        })
        max_chunks_per_search = 10
    
        # Define weights for scoring
        weight_chunk_relevancy = 0.6  # Increase impact of chunk relevancy
        weight_frequency = 0.4        # Increase impact of frequency
    
        for phrase in self.paraphrases:
            results = self.vector_store.similarity_search_with_relevance_scores(query=phrase, k=max_chunks_per_search)
            
            for doc, score in results:
                doc_id = doc.metadata.get("_id")
                source = doc.metadata.get("source", "unknown")
    
                # Update source-level stats (without average source relevancy)
                stats_dict[source]["total_relevancy"] += score
                stats_dict[source]["chunk_count"] += 1
    
                # Update individual chunk stats within each source
                if doc_id not in stats_dict[source]["chunks"]:
                    stats_dict[source]["chunks"][doc_id] = {"total_relevancy": 0, "count": 0}
                stats_dict[source]["chunks"][doc_id]["total_relevancy"] += score
                stats_dict[source]["chunks"][doc_id]["count"] += 1
    
                # Add unique documents to unranked_documents list
                if doc_id not in {d['doc_id'] for d in unranked_documents}:
                    clean_metadata = {
                        key: str(value) if not isinstance(value, (str, dict)) else value
                        for key, value in doc.metadata.items() if key not in {"_id", "embedding"}
                    }
                    
                    unranked_documents.append({
                        "content": doc.page_content.replace("\\n", "\n"),
                        "metadata": clean_metadata,
                        "relevancy_score": score,
                        "doc_id": doc_id
                    })
    
        # Calculate average relevancies and prepare stats_dict
        weighted_chunks = []
        for source, source_data in stats_dict.items():
            for doc_id, chunk_data in source_data["chunks"].items():
                # Calculate average chunk relevancy
                avg_chunk_relevancy = chunk_data["total_relevancy"] / chunk_data["count"]
                chunk_frequency = chunk_data["count"]
                
                # Calculate the weighted score for this chunk
                weighted_score = (
                    weight_chunk_relevancy * avg_chunk_relevancy +
                    weight_frequency * (chunk_frequency ** 0.5)  # Apply square root to balance frequency's impact
                )
    
                # Append chunk info to weighted_chunks list
                weighted_chunks.append({
                    "source": source,
                    "doc_id": doc_id,
                    "average_chunk_relevancy": avg_chunk_relevancy,
                    "chunk_frequency": chunk_frequency,
                    "weighted_score": weighted_score
                })
    
        # Filter based on weighted score threshold
        filtered_chunks = sorted(weighted_chunks, key=lambda x: x["weighted_score"], reverse=True)
        print("Total chunks before filter:", len(filtered_chunks))
        filtered_chunks = filtered_chunks[:10]
        print("Total chunks after filter:", len(filtered_chunks))
    
        print("\n---- Filtered Chunks Based on Weighted Score ----\n")
        for chunk in filtered_chunks:
            print(f"Source: {chunk['source']}")
            print(f"  Doc ID: {chunk['doc_id']}")
            print(f"  Average Chunk Relevancy: {chunk['average_chunk_relevancy']:.2f}")
            print(f"  Chunk Frequency: {chunk['chunk_frequency']}")
            print(f"  Weighted Score: {chunk['weighted_score']:.2f}")
            print("-" * 50)  # Separator line for readability
    
        selected_doc_ids = {chunk["doc_id"] for chunk in filtered_chunks}
        ranked_and_filtered_chunks = [doc for doc in unranked_documents if doc["doc_id"] in selected_doc_ids]
    
        merged_ranked_filtered_chunks = self.extract_content_by_source(ranked_and_filtered_chunks)
        
        unique_sources = list({chunk["source"] for chunk in merged_ranked_filtered_chunks})
        chunk_summary = self._chunks_summarizer(merged_ranked_filtered_chunks)

        return unique_sources, chunk_summary

    
    def extract_content_by_source(self, unranked_documents):
        # Group chunks by source
        source_chunks = {}
        for doc in unranked_documents:
            source = doc["metadata"].get("source", "unknown")
            content = doc.get("content", "")
            chunk_sequence = doc["metadata"].get("chunk", 0)  # Use "chunk" as an integer sequence
    
            # Check for "text:" in content and extract content after it if present
            if "text:" in content:
                text_chunk = content.split("text:", 1)[1].strip()
            else:
                text_chunk = content.strip()
    
            if source not in source_chunks:
                source_chunks[source] = []
            source_chunks[source].append({"content": text_chunk, "chunk_sequence": chunk_sequence})
    
        content_by_source = []
        
        for source, chunks in source_chunks.items():
            sorted_chunks = sorted(chunks, key=lambda x: x["chunk_sequence"])
            merged_content = " ".join(chunk["content"] for chunk in sorted_chunks)
            content_by_source.append({
                "source": source,
                "content": merged_content
            })
    
        for source_data in content_by_source:
            print(f"Source: {source_data['source']}")
            print(f"Merged Content: {source_data['content']}")  # Print a preview of the merged content
            print()
    
        return content_by_source


    def _chunks_summarizer(self, merged_chunks_by_source):
        text_chunks_string = "\n".join(str(chunk) for chunk in merged_chunks_by_source)
        chunk_summary = text_chunks_string
        # self.eva_analytics.append({
        #     "event": "_chunks_summarizer",
        #     "response_metadata": chunks_summarizer_result.response_metadata
        # })
        
        try:
            chunks_summarizer_prompt = self._load_template(
                self.chat_data.prompt_template_directory_name, 
                self.chat_data.chunk_summarizer_prompt_template_file_name
            )
    
            prompt_template = PromptTemplate(
                template=chunks_summarizer_prompt,
                input_variables=["user_input", "text_chunks"]
            )
            
            chunks_summarizer_chain = RunnableSequence(prompt_template, self.llm_eva)
            chunks_summarizer_result = chunks_summarizer_chain.invoke({
                "user_input": self.chat_data.user_input,
                "text_chunks": text_chunks_string,
            })
            
            cleaned_content = re.sub(r"```json|```", "", chunks_summarizer_result.content).strip()
            chunks_summarizer_result_json = json.loads(cleaned_content)
            if "summary" in chunks_summarizer_result_json:
                chunk_summary = chunks_summarizer_result_json["summary"]
        except (json.JSONDecodeError, AttributeError, KeyError) as e:        
            chunk_summary = text_chunks_string

        print()                    
        print("Chunk Summary: ", chunk_summary)
        return chunk_summary
        
    

    def _paraphrase_query(self):
        self.paraphrases = [self.chat_data.user_input]
    
        paraphrase_prompt = self._load_template(
            self.chat_data.prompt_template_directory_name, 
            self.chat_data.paraphrasing_prompt_template_file_name
        )

        prompt_template = PromptTemplate(
            template=paraphrase_prompt,
            input_variables=["user_input"]
        )
        
        paraphrase_chain = RunnableSequence(prompt_template, self.llm_eva)
        paraphrase_result = paraphrase_chain.invoke({
            "user_input": self.chat_data.user_input
        })

        # self.eva_analytics.append({
        #     "event": "_paraphrase_query",
        #     "response_metadata": paraphrase_result.response_metadata
        # })
        
        try:
            cleaned_content = re.sub(r"```json|```", "", paraphrase_result.content).strip()
            paraphrase_result_json = json.loads(cleaned_content)
            if "paraphrases" in paraphrase_result_json:
                self.paraphrases.extend(paraphrase_result_json["paraphrases"])

            for phrase in self.paraphrases:
                print(phrase)
        
        except (json.JSONDecodeError, AttributeError, KeyError) as e:        
            pass     


    def _get_vector_store(self):
        self.llm_eva = ChatOpenAI(
            model_name=self.chat_data.rag_settings.chat_model_name,
            temperature=self.chat_data.rag_settings.temperature,
            max_tokens=self.chat_data.rag_settings.max_tokens_for_response,
            openai_api_key=self.chat_data.llm_settings.llm_key
        )

        llm_embeddings = OpenAIEmbeddings(
            model=self.chat_data.llm_settings.embedding_model_name,
            openai_api_key=self.chat_data.llm_settings.llm_key
        )
        
        db_instance = BaseDB().get_vector_db(
            self.chat_data.db_type,
            self.chat_data.db_settings,
            llm_embeddings
        )
        
        self.vector_store = db_instance.vector_index

    
    def get_response(self):
        llm_params = {
            "openai_api_key": self.chat_data.llm_settings.llm_key,
            "max_tokens": self.chat_data.rag_settings.max_tokens_for_response,
            "model_name": self.chat_data.rag_settings.chat_model_name,
            "temperature": self.chat_data.rag_settings.temperature,
            "frequency_penalty": self.chat_data.rag_settings.frequency_penalty,
            "presence_penalty": self.chat_data.rag_settings.presence_penalty
        }

        self.llm_eva = ChatOpenAI(**llm_params)

        self.summarized_history = ""
        self.memory = None
        if self.chat_data.chat_history:
            self.summarized_history, self.memory = self._summarize_history()

        print("Summarized History: ", self.summarized_history)
        print()
        print("User Query: ", self.chat_data.user_input)
        print()

        base_response = self._run_base_prompt()
        if base_response is not None:
            return model_rag.ChatResponse(response=base_response, sources=[])

        self._detect_intent_and_rephrase_query()

        print("Detected Intent:", self.detected_intent)
        print("Rephrased User Query:", self.chat_data.user_input)

        if self.chat_data.strict_follow_up:
            follow_up_question = self._follow_up_questions()
            if follow_up_question:
                return model_rag.ChatResponse(response=follow_up_question, sources=[]), self.eva_analytics

        if self.chat_data.rag_settings.streaming_response and self.llm_streaming_callback_handler:
            llm_params["streaming"] = True
            llm_params["callback_manager"] = [self.llm_streaming_callback_handler]
            self.llm_eva = ChatOpenAI(**llm_params)

        result = self._invoke_qa()
        return result, self.eva_analytics

        
    ## Private Methods

    def _invoke_qa(self):
        # Run vector search to get the chunk summary and unique sources
        unique_sources, chunk_summary = self._vector_search()
        
        qa_chain = self._get_qa_instance()
        qa_input = {
            "summaries": chunk_summary,
            "question": self.chat_data.user_input
        }
        
        start_time = time.time()
        result = qa_chain(qa_input)
        end_time = time.time()
        latency = end_time - start_time
    
        # Extract the answer
        output_text = result.get("text", "")
        
        # Calculate token usage for analytics
        try:
            tokenizer = tiktoken.encoding_for_model(self.chat_data.rag_settings.chat_model_name)
        except:
            tokenizer = tiktoken.encoding_for_model("gpt-4o")
    
        formatted_prompt = self.chat_prompt_content.format(
            history=self.summarized_history,
            summaries=chunk_summary,
            question=self.chat_data.user_input
        )
    
        prompt_tokens = len(tokenizer.encode(formatted_prompt))
        completion_tokens = len(tokenizer.encode(output_text))
        total_tokens = prompt_tokens + completion_tokens
    
        # Append analytics data
        self.eva_analytics.append({
            "event": "invoke_qa",
            "response_time_in_seconds": latency,
            "response_metadata": {
                "model_name": self.chat_data.rag_settings.chat_model_name,
                "token_usage": {
                    "prompt_tokens": prompt_tokens,
                    "completion_tokens": completion_tokens,
                    "total_tokens": total_tokens
                }
            }
        })
    
        # Return the result as needed
        return {
            "answer": output_text,
            "sources": unique_sources  # Include unique sources if needed
        }

    
    def _get_qa_instance(self):
        self.chat_prompt_content = self._build_chat_prompt()
        prompt_template = PromptTemplate(
            template=self.chat_prompt_content,
            input_variables=['summaries', 'question']
        )
    
        if self.memory:
            qa_chain = ConversationChain(
                    llm=self.llm_eva,
                    prompt=prompt_template,
                    memory=self.memory,
                    verbose=SHOW_VERBOSE
                )
        else:
            qa_chain = LLMChain(
                llm=self.llm_eva,
                prompt=prompt_template,
                verbose=SHOW_VERBOSE
            )
        
        return qa_chain
    
    
    def _follow_up_questions(self):
        # Load the follow-up prompt template
        follow_up_prompt = self._load_template(
            self.chat_data.prompt_template_directory_name, 
            self.chat_data.follow_up_prompt_template_file_name
        )
        
        # Define the prompt template
        prompt_template = PromptTemplate(
            template=follow_up_prompt,
            input_variables=["missing_fields", "intent_name"]
        )
    
        # Collect missing fields
        missing_fields = []
        
        if self.detected_intent in self.chat_data.intent_details:
            intent_detail_obj = self.chat_data.intent_details[self.detected_intent]
            
            required_fields = intent_detail_obj.required_fields_prior_responding or []
            
            for field in required_fields:
                if field not in self.extracted_fields:
                    missing_fields.append(field)
    
        # If there are no missing fields, no follow-up questions are required
        if not missing_fields:
            return None 
    
        # Generate the follow-up questions using the detected intent and missing fields
        follow_up_chain = RunnableSequence(prompt_template, self.llm_eva)
        follow_up_result = follow_up_chain.invoke({
            "missing_fields": ", ".join(missing_fields),
            "intent_name": self.detected_intent
        })
    
        # Return the follow-up questions content
        return follow_up_result.content.strip()



    def _format_response(self, response_text):
        extracted_sources = []
        try:
            source_tag_pattern = r"<sources>(.*?)<\/sources>"
            matches = re.findall(source_tag_pattern, response_text, re.DOTALL)
            for match in matches:
                pans = re.split(r"\s*,\s*|\n+", match)
                pans = [pan for pan in pans if pan.isdigit()]  # Only keep numeric values
                extracted_sources.extend(pans)

            response_text = re.sub(source_tag_pattern, '', response_text).strip()
        except Exception as e:
            print(f"Error while extracting sources: {e}")

        return response_text, extracted_sources

    def _extract_sources(self, sources_documents, extracted_source_ids):
        filtered_sources_list = []
        for doc in sources_documents:
            source_id = doc.metadata.get(next((key for key in doc.metadata if key.lower() == "source"), ""), "")
            if source_id in extracted_source_ids:
                language = doc.metadata.get(next((key for key in doc.metadata if key.lower() == "language"), ""), "")
                if language.lower() == "english":
                    filtered_sources_list.append(
                        model_rag.Source(
                            source=source_id,
                            type=doc.metadata.get(next((key for key in doc.metadata if key.lower() == "type"), ""), ""),
                            title=doc.metadata.get(next((key for key in doc.metadata if key.lower() == "title"), ""), ""),
                            country=doc.metadata.get(next((key for key in doc.metadata if key.lower() == "country"), ""), ""),
                            language=language
                        )
                    )
        return filtered_sources_list
        
    
    def _load_template(self, project_template_directory_name, template_file_name):
        project_template_directory_path = os.path.join(EVA_SETTINGS_PATH, project_template_directory_name, EVA_SETTINGS_ENVIRONMENT_DIRECTORY)
        template_file_path = project_template_directory_path+ '/' + self.chat_data.rag_type + '/' + template_file_name
        with open(template_file_path, "r") as file:
            return file.read()

    def _build_chat_prompt(self):
        if self.detected_intent == "other":
            chat_prompt_filename = self.chat_data.free_flowing_prompt_template_file_name
            chat_template = self._load_template(self.chat_data.prompt_template_directory_name, chat_prompt_filename)           
            return chat_template.format(
                history=self.summarized_history,
                summaries="{summaries}",
                question="{question}"
            )
        else:                        
            chat_prompt_filename = self.chat_data.intent_details.get(self.detected_intent).filename
            chat_template = self._load_template(self.chat_data.prompt_template_directory_name, chat_prompt_filename)
            key_fields = self.chat_data.intent_details[self.detected_intent].required_fields_prior_responding or ""
            
            return chat_template.format(
                history=self.summarized_history,
                key_fields=key_fields,
                summaries="{summaries}",
                question="{question}"
            )

    def _get_qa_retriever(self):
        llm_embeddings = OpenAIEmbeddings(
            model=self.chat_data.llm_settings.embedding_model_name,
            openai_api_key=self.chat_data.llm_settings.llm_key
        )
    
        db_instance = BaseDB().get_vector_db(
            self.chat_data.db_type,
            self.chat_data.db_settings,
            llm_embeddings
        )
        vector_store = db_instance.vector_index
        
        search_type = self.chat_data.rag_settings.retriever_search_settings.search_type
        search_kwargs = {
            "k": self.chat_data.rag_settings.retriever_search_settings.max_chunks_to_retrieve
        }
        if search_type == model_rag.RetrieverSearchType.Similarity_Score_Threshold:
            search_kwargs["score_threshold"] = self.chat_data.rag_settings.retriever_search_settings.retrieved_chunks_min_relevance_score
        elif search_type == model_rag.RetrieverSearchType.MMR:
            search_kwargs["fetch_k"] = self.chat_data.rag_settings.retriever_search_settings.fetch_k
            search_kwargs["lambda_mult"] = self.chat_data.rag_settings.retriever_search_settings.lambda_mult

        qa_retriever = vector_store.as_retriever(
            search_type=search_type,
            search_kwargs=search_kwargs
        )
        
        return qa_retriever


    def _get_key_fields_for_lookup(self):
        merged_key_fields = set()
        
        for intent, details in self.chat_data.intent_details.items():
            required_fields = details.required_fields_prior_responding or []
            merged_key_fields.update(required_fields)
    
        return list(merged_key_fields)
    
    def _detect_intent_and_rephrase_query(self):
        self.detected_intent = "other"
        self.extracted_fields = {}
        merged_key_fields = self._get_key_fields_for_lookup()
    
        intent_prompt = self._load_template(
            self.chat_data.prompt_template_directory_name, 
            self.chat_data.intent_detection_prompt_template_file_name
        )

        prompt_template = PromptTemplate(
            template=intent_prompt,
            input_variables=["user_input", "history", "intent_list", "merged_key_fields"]
        )
        
        intent_chain = RunnableSequence(prompt_template, self.llm_eva)
        intent_result = intent_chain.invoke({
            "user_input": self.chat_data.user_input,  
            "history": self.summarized_history,  
            "intent_list": "\n".join([f'- "{intent_name}"' for intent_name in self.chat_data.intent_details.keys()]),
            "merged_key_fields": ", ".join(merged_key_fields)
        })
        
        try:
            cleaned_content = re.sub(r"```json|```", "", intent_result.content).strip()
            intent_result_json = json.loads(cleaned_content)
            
            if all(key in intent_result_json for key in ["detected_intent", "rephrased_query", "extracted_fields"]):
                self.detected_intent = intent_result_json.get("detected_intent", self.detected_intent).strip().lower()
                self.chat_data.user_input = intent_result_json.get("rephrased_query", self.chat_data.user_input).strip()
                self.extracted_fields = intent_result_json.get("extracted_fields", self.extracted_fields)

        except (json.JSONDecodeError, AttributeError, KeyError) as e:        
            pass        
                
    
    
    def _run_base_prompt(self):
        base_prompt = self._load_template(
            self.chat_data.prompt_template_directory_name, 
            self.chat_data.base_prompt_template_file_name
        )
        
        prompt_template = PromptTemplate(
            template=base_prompt,
            input_variables=["user_input", 'history']
        )
        
        base_chain = RunnableSequence(prompt_template, self.llm_eva)
        base_result = base_chain.invoke({
            "user_input": self.chat_data.user_input,  
            "history": self.summarized_history
        })

        base_response = base_result.content.strip().strip(' "\'')
        if base_response.lower() == "none":
            return None
        return base_response

    
    def _summarize_history(self):        
        if not self.chat_data.chat_history:
            return "", None

        def trim_message(message, max_lines=2):
            lines = message.splitlines()
            if len(lines) > max_lines:
                return "\n".join(lines[:max_lines]) + "..."
            return message
        
        history = ChatMessageHistory()
        for conv in self.chat_data.chat_history:
            if conv.role.lower() == 'human':
                history.add_message(HumanMessage(content=conv.message))
            elif conv.role.lower() == 'ai':
                trimmed_message = trim_message(conv.message)
                history.add_message(SystemMessage(content=trimmed_message))

        memory_template = self._load_template(
            self.chat_data.prompt_template_directory_name, 
            self.chat_data.memory_prompt_template_file_name
        )
        
        custom_prompt = PromptTemplate(
            input_variables=['new_lines', 'summary'],
            template=memory_template
        )
        
        memory = ConversationSummaryBufferMemory(
            llm=self.llm_eva,
            max_token_limit=self.chat_data.rag_settings.max_tokens_for_history,
            prompt=custom_prompt,
            chat_memory=history,
            return_messages=True,
            memory_key="history",
            input_key="question"
        )

        memory.prune()
        summarized_history = memory.predict_new_summary(memory.chat_memory.messages, "")
       
        return summarized_history, memory



In [21]:
def generate_dummy_conversation():
    conversation_history = []

    conversation_history.append({
        "role": "Human",
        "message": "Give me list of pests for rice"
    })

    conversation_history.append({
        "role": "AI",
        "message": "To provide a list of pests affecting rice crops, I need to know the specific country or region you are interested in. Please provide this information so I can assist you better."
    })
    
    return conversation_history


conversation_history = []
# conversation_history = generate_dummy_conversation()

SHOW_VERBOSE=True

def get_chatbot_response(payload: model_rag.ChatRequest):
    chat_processor = RAG(payload)
    # chat_processor._vector_search()

    print()
    print('============================')
    print('============================')
    print()
    response = chat_processor.get_response()
    return response
    


def call_chatbot_endpoint(user_input_text, MAX_CHUNKS_TO_RETRIEVE=10, MMR_FETCH_K=500, MMR_LAMBDA_MULT=0.5):
    global conversation_history

    # Directly create an instance of model_rag.ChatRequest with the required values
    chat_request = model_rag.ChatRequest(
        db_type="mongodb",
        db_settings={
            "uri": MONGO_URI,  
            "db_name": DB_NAME,  
            "collection_name": COLLECTION_NAME,  
            "vector_index_name": ATLAS_VECTOR_SEARCH_INDEX_NAME,  
            "vector_similarity_function": VECTOR_SIMILARITY_FUNCTION
        },
        llm_settings={
            "llm_key": OPENAI_API_KEY,  
            "vector_dimension_size": EMBEDDING_DIMENSIONS,  
            "embedding_model_name": EMBEDDING_MODEL_NAME
        },
        rag_settings={
            "chat_model_name": CHAT_MODEL_NAME,
            "max_tokens_for_response": MAX_TOKENS_FOR_RESPONSE,
            "retriever_search_settings": {
              "fetch_k": MMR_FETCH_K,
              "lambda_mult": MMR_LAMBDA_MULT,
              "max_chunks_to_retrieve": MAX_CHUNKS_TO_RETRIEVE,
              "retrieved_chunks_min_relevance_score": SST_CHUNK_MIN_RELEVANCE_SCORE,
              "search_type": RETRIEVER_SEARCH_TYPE
            },
            "temperature": CHAT_MODEL_TEMPERATURE,  
            "frequency_penalty": CHAT_MODEL_FREQ_PENALTY,
            "presence_penalty": CHAT_MODEL_PRES_PENALTY,
            "max_tokens_for_history": MAX_TOKENS_FOR_HISTORY
        },        
        prompt_template_directory_name="wcc",  
        base_prompt_template_file_name="base_template.txt", 
        memory_prompt_template_file_name="memory_summarizer.txt", 
        intent_detection_prompt_template_file_name="detect_intent.txt", 
        follow_up_prompt_template_file_name="follow_up.txt",
        free_flowing_prompt_template_file_name="free_flowing.txt",
        paraphrasing_prompt_template_file_name="paraphrase.txt",
        chunk_summarizer_prompt_template_file_name="chunk_summarizer.txt",
        intent_details = {
            "diagnosis": {
                "filename": "diagnosis.txt",
                "description": "This intent covers queries related to diagnosing pests or problems affecting crops, including identifying potential pests or diseases based on symptoms, crop type, and location.",
                "required_fields_prior_responding": ["crop", "country/region/location", "symptoms"]
            },
            "symptoms identification": {
                "filename": "symptoms_identification.txt",
                "description": "This intent provides detailed information about symptoms caused by a specific pest or problem, including visual indicators and progression of the symptoms.",
                "required_fields_prior_responding": ["pest"]
            },
            "pest list by location": {
                "filename": "pest_list.txt",
                "description": "This intent provides a list of pests that affect a specific crop in a specific country or region.",
                "required_fields_prior_responding": ["crop", "country/region/location"]
            },
            "integrated pest management advice": {
                "filename": "ipm_pest_management.txt",
                "description": "This intent provides integrated pest management (IPM) advice, including prevention strategies, biocontrol recommendations, and chemical pesticide usage for managing pests or diseases on crops.",
                "required_fields_prior_responding": ["crop", "country/region/location", "pest"]
            },
            "chemical handling": {
                "filename": "chemical_handling_safety.txt",
                "description": "This intent provides safety advice for handling and applying specific chemicals, including personal protective equipment (PPE), safe storage, and disposal recommendations.",
                "required_fields_prior_responding": ["chemical name"]
            },
            "invasive pest status": {
                "filename": "invasive_pest_status.txt",
                "description": "This intent provides information on the current status, distribution, and spread of invasive pests in a specific country or region.",
                "required_fields_prior_responding": ["pest", "country/region/location"]
            },
            "dosage recommendations": {
                "filename": "dosage_recommendations.txt",
                "description": "This intent provides dosage recommendations for chemical or biocontrol products, including application rates, frequency, and any location-specific restrictions or precautions.",
                "required_fields_prior_responding": ["chemical name", "crop", "pest", "country/region/location", "size/area of the crop"]
            }
        },    
        rag_type= "v4",
        strict_follow_up= 0,
        chat_history= conversation_history,
        user_input= user_input_text
    )

    chatbot_response = get_chatbot_response(payload=chat_request)

    # conversation_history.append({
    #     "role": "Human",
    #     "message": user_input_text,  
    # })
    # conversation_history.append({
    #     "role": "AI",  
    #     "message": chatbot_response.response  
    # })

    if chatbot_response:
        # print('Bot''s Response:', chatbot_response.response)
        # print()
        # print('Sources: ', chatbot_response.sources)
        # print()
        # print()
        print(chatbot_response.answer)


    


In [22]:
# call_chatbot_endpoint("What's the best way to control white mould in my narcissus?", MAX_CHUNKS_TO_RETRIEVE=3)
call_chatbot_endpoint("What are the typical symptoms of Crook root disease in watercress and what are the various control options for managing it?", MAX_CHUNKS_TO_RETRIEVE=3)



Summarized History:  

User Query:  What are the typical symptoms of Crook root disease in watercress and what are the various control options for managing it?

Detected Intent: symptoms identification
Rephrased User Query: What are the typical symptoms of Crook root disease in watercress, and what are the various control options for managing it?
What are the typical symptoms of Crook root disease in watercress, and what are the various control options for managing it?
What symptoms are commonly associated with Crook root disease in watercress, and how can it be controlled?
Can you describe the usual signs of Crook root disease in watercress and the methods available to manage it?
What are the common indicators of Crook root disease in watercress, and what strategies exist for its control?
How does Crook root disease typically manifest in watercress, and what are the options for controlling it?
What are the typical symptoms of Crook root disease in watercress?
What are the various co