In [1]:
import os
import pickle

import numpy as np
import pandas as pd

from uuid import uuid4
from tqdm.notebook import tqdm
from dotenv import load_dotenv

from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain.prompts import ChatPromptTemplate
from whatsappchattodf import WhatsappChatToDF

In [2]:
load_dotenv()
DATA_PATH = os.environ.get("project_path")

LLM_TO_USE = "gpt4o"
llm_api_key = os.environ.get(f"{LLM_TO_USE}_api_key")
llm_api_version = os.environ.get(f"{LLM_TO_USE}_api_version")
llm_azure_endpoint = os.environ.get(f"{LLM_TO_USE}_api_endpoint")
llm_deployment_name = os.environ.get(f"{LLM_TO_USE}_dep_name")

EMBED_TO_USE = "midasembed"
embed_api_key = os.environ.get(f"{EMBED_TO_USE}_api_key")
embed_api_version = os.environ.get(f"{EMBED_TO_USE}_api_version")
embed_azure_endpoint = os.environ.get(f"{EMBED_TO_USE}_api_endpoint")
embed_deployment_name = os.environ.get(f"{EMBED_TO_USE}_dep_name")
embed_model = os.environ.get(f"{EMBED_TO_USE}_model")

In [3]:
embeddings = AzureOpenAIEmbeddings(
    model=embed_model,
    api_key=embed_api_key,
    openai_api_version=embed_api_version,
    azure_endpoint=embed_azure_endpoint,
    deployment=embed_deployment_name,
    disallowed_special=(),
)

llm = AzureChatOpenAI(
    api_key=llm_api_key,
    openai_api_version=llm_api_version,
    azure_endpoint=llm_azure_endpoint,
    azure_deployment=llm_deployment_name,
)

### Load data

In [4]:
user_file = "wc_user"

In [None]:
chat_to_df = WhatsappChatToDF(f"../data/{user_file}/{user_file}.txt")
data = chat_to_df.run()
data = data.fillna("")
data = data[data["Message"] != ""].reset_index(drop=True)
data = data[~data["Message"].str.contains("deleted")].reset_index(drop=True)
data = data[data["Message"] != "<Media omitted>"].reset_index(drop=True)

bot_name = data["User"].unique()[data["User"].unique() != "Kartheek Palepu"][0]
data.shape

In [None]:
docs = []
for idx, row in data.iterrows():
    doc = Document(
        id=str(uuid4()),
        metadata={"idx": idx + 1, "date": row["Date"], "timestamp": row["Timestamp"]},
        page_content=f"User: {row['User']}\n{row['Message']}",
    )
    docs.append(doc)

with open(f"../data/{user_file}/{user_file}_docs.pkl", "wb") as f:
    pickle.dump(docs, f)

len(docs)

### Create Embeddings

In [None]:
embeddings_list = []
tokens_interval = 500
for i in tqdm(range(0, len(docs), tokens_interval)):
    print(f"Iteration {i} to {i + tokens_interval}")
    sub_texts = docs[i : i + tokens_interval]
    _tmp_embeddings = embeddings.embed_documents([s.page_content for s in sub_texts])
    embeddings_list.extend(_tmp_embeddings)


embeddings_list = np.array(embeddings_list)
print(embeddings_list.shape)

with open(f"../data/{user_file}/{user_file}_embeddings.pkl", "wb") as f:
    pickle.dump(embeddings_list, f)

In [9]:
text_embedding_pairs = zip([d.page_content for d in docs], embeddings_list)

with open(f"../data/{user_file}/{user_file}_text_embed_pair.pkl", "wb") as f:
    pickle.dump(text_embedding_pairs, f)

### Setup FAISS

In [10]:
metadatas = [doc.metadata for doc in docs]
uuids = [doc.id for doc in docs]

In [11]:
VS_PATH = f"../indexes/{user_file}"
if not os.path.exists(VS_PATH):
    vectorstore = FAISS.from_embeddings(
        text_embeddings=text_embedding_pairs,
        embedding=embeddings,
        ids=uuids,
        metadatas=metadatas,
    )
    vectorstore.save_local(VS_PATH)
else:
    vectorstore = FAISS.load_local(
        VS_PATH, embeddings, allow_dangerous_deserialization=True
    )

In [12]:
retriever = vectorstore.as_retriever(
    search_type="mmr",
    search_kwargs={
        "k": 5,
    },
)

In [13]:
template = (
    f"You are {bot_name}"
    + """ in the below conversation and I am Kartheek Palepu.
You have to reply to my messages.

Below are the steps:
- Understand the question/message from Kartheek Palepu.
- Use the Chat history to answer the question/message.
- Think step by step and understand how you would answer the question/message.
- If you find an appropriate answer in the chat history, use the same.
- If not, use the Chat History and come up with an identical response
Note: The chat language can be telugu typed in english. Hence follow the same wherever it is needed.

Kartheek Palepu: {question} 
Chat History: {context}"""
    + f"\n{bot_name}"
)
prompt = ChatPromptTemplate.from_template(template=template)

In [None]:
all_users = list(
    {d.page_content.splitlines()[0].replace("User: ", "") for d in docs[:10]}
)
next(user for user in all_users if user != "Kartheek Palepu")

In [None]:
def get_additional_msgs(retrieved_docs, docs, k=1):
    doc_list = []
    for d in retrieved_docs:
        current_idx = d.metadata["idx"]
        additional_ids = range(
            current_idx - k, current_idx + k + 1
        )  # Dynamically create range for +/- k
        additional_docs = [doc for doc in docs if doc.metadata["idx"] in additional_ids]
        formatted_docs = format_docs(additional_docs, sep="\n")
        doc_list.append(formatted_docs)
    return doc_list


def format_docs(docs, sep="\n\n"):
    if all(isinstance(doc, str) for doc in docs):
        return sep.join(docs)
    elif all(hasattr(doc, "page_content") for doc in docs):
        return sep.join(doc.page_content for doc in docs)
    else:
        raise ValueError(
            "The input must be a list of strings or Document objects with a 'page_content' attribute."
        )


QUESTION = "Whatcha doing"

retrieved_docs = retriever.get_relevant_documents(QUESTION)
retrieved_docs_full = get_additional_msgs(retrieved_docs, docs, k=1)
print(format_docs(retrieved_docs_full, "\n\n"))
print("===" * 20)

chain = prompt | llm
print(chain.invoke({"context": retrieved_docs_full, "question": QUESTION}).content)