In [None]:
! pip install langchain_community tiktoken langchainhub chromadb langchain langchain_fireworks google-search-results requests gradio

In [2]:
import os
os.environ['FIREWORKS_API_KEY'] = 'API_KEY'
os.environ["SERPER_API_KEY"] = 'API_KEY'

In [None]:
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain_fireworks import ChatFireworks
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
import json
import gradio as gr

# Set up the prompt template
system_message = """You are a chatbot having a conversation with a human. You might have multiple sessions separated by [NEW CHAT SESSION]. You can use previous chats if necessary.

Your primary task is to engage in a natural conversation with the user while maintaining context:

1. Maintain a working context of important information about the user (e.g., favorite color, birthday, name) across all chat sessions.
2. Update this working context based ONLY on explicit information provided by the user.
3. Always use the working context to inform your responses and maintain consistency, even in new chat sessions.

For EVERY user message, you MUST:
1. Provide a natural, conversational response to the user's input.
2. After your response, IF there is new information to add to the working context, include a context update in this format:
   [CONTEXT_UPDATE]{"key1": "value1", "key2": "value2"}[/CONTEXT_UPDATE]

Important guidelines:
- Always respond to the user's query or comment first.
- Only include a [CONTEXT_UPDATE] section if there is new information to add.
- The [CONTEXT_UPDATE] section should contain the full updated context, not just new information.
- Never mention the context updates or working context explicitly in your responses to the user.
- When a new chat session starts, use the existing working context to inform your responses, but don't repeat all the information unless it's relevant to the current conversation.

Remember, your main goal is to have a natural conversation. The context management is a background task to help you maintain consistency across all chat sessions."""

prompt = ChatPromptTemplate.from_messages([
    SystemMessage(content=system_message),
    MessagesPlaceholder(variable_name="chat_history"),
    HumanMessagePromptTemplate.from_template("{input}")
])

# Initialize the language model
llm = ChatFireworks(model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", temperature=0.3)

# Initialize memory
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

# Initialize working context
working_context = {}

# Create the conversation chain
chat_chain = LLMChain(
    llm=llm,
    prompt=prompt,
    verbose=True,
    memory=memory,
)

def extract_context_update(response):
    start_tag = "[CONTEXT_UPDATE]"
    end_tag = "[/CONTEXT_UPDATE]"
    start = response.find(start_tag)
    end = response.find(end_tag)
    if start != -1 and end != -1:
        context_json = response[start + len(start_tag):end].strip()
        try:
            return json.loads(context_json)
        except json.JSONDecodeError:
            print(f"Error parsing context JSON: {context_json}")
    return None

def format_working_context():
    if not working_context:
        return "No working context available yet."
    context_str = "Working Context:\n"
    for key, value in working_context.items():
        context_str += f"- {key}: {value}\n"
    return context_str

# Function to estimate token count
def estimate_token_count(text):
    return len(text.split())

# Function to summarize chat history
def summarize_chat_history(history):
    summary_prompt = f"""Summarize the following conversation, keeping all important details and context:

{history}

Summary:"""
    summary = llm.predict(summary_prompt)
    return summary

def calculate_total_tokens():
    system_tokens = estimate_token_count(system_message)
    context_tokens = estimate_token_count(json.dumps(working_context))
    history_tokens = sum(estimate_token_count(msg.content) for msg in memory.chat_memory.messages)
    return system_tokens + context_tokens + history_tokens

def chat(message, history, max_tokens):
    full_input = f"Human: {message}\n\nCurrent working context: {json.dumps(working_context)}"

    # Check if history needs summarization
    total_tokens = calculate_total_tokens()
    if total_tokens > max_tokens:
        summary = summarize_chat_history(str(memory.chat_memory.messages))
        memory.clear()
        memory.chat_memory.add_message(AIMessage(content=f"Previous conversation summary: {summary}"))

    full_response = chat_chain.predict(input=full_input)

    # Extract and update the working context
    context_update = extract_context_update(full_response)
    if context_update:
        working_context.update(context_update)

    # Extract the response part
    response_parts = full_response.split("[CONTEXT_UPDATE]")
    response = response_parts[0].strip()

    if len(response_parts) > 1:
        response += ' ' + response_parts[-1].split("[/CONTEXT_UPDATE]")[-1].strip()

    history.append((message, response))

    # Calculate and return the updated token count and context
    updated_token_count = calculate_total_tokens()
    updated_context = format_working_context()
    return "", history, updated_token_count, updated_context

def new_chat():
    # chat_chain.predict(input="[NEW CHAT SESSION] Please acknowledge the start of a new chat session while maintaining the existing working context.")
    memory.clear()
    initial_token_count = calculate_total_tokens()
    initial_context = format_working_context()
    return None, [], 1000, initial_token_count, initial_context

def clear_history():
    global working_context
    working_context = {}
    memory.clear()
    initial_token_count = calculate_total_tokens()
    return None, [], 1000, initial_token_count, "Working context and chat history cleared."

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    max_tokens = gr.Slider(minimum=500, maximum=100000, value=1000, step=100, label="Max Tokens Before Summarization in chat history")
    initial_count = calculate_total_tokens()
    token_count = gr.Number(label="Current Token Count", value=initial_count, interactive=False)

    with gr.Row():  # This creates a horizontal layout for the buttons
        clear = gr.Button("New Chat")
        clear_history_btn = gr.Button("Clear History")

    context_view = gr.Textbox(label="Current Working Context", interactive=False, value=format_working_context())

    msg.submit(chat, [msg, chatbot, max_tokens], [msg, chatbot, token_count, context_view])
    clear.click(new_chat, outputs=[msg, chatbot, max_tokens, token_count, context_view])
    clear_history_btn.click(clear_history, outputs=[msg, chatbot, max_tokens, token_count, context_view])

if __name__ == "__main__":
    demo.launch(share=True, debug=True)

  warn_deprecated(


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://c852b9f7e369881ca9.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
