# Demo of a Tesla customer support assistant chatbot

In [1]:
import os
from typing import Tuple, List, Mapping, Text, Any
import torch

# simple hack to support import module from parent directory
import sys
sys.path.append('../')

from rag_llama.core.retrievers import StandardRetriever, HybridRetriever, RerankRetriever
from rag_llama.core.generation import Llama, Dialog


## Define system message and input query format templates. 
Note this are for single-turn chat, similar to how search works.

In [2]:
SYSTEM_MESSAGE = """
You are an assistant to a Tesla customer support team. Your job is to answer customer's questions to the best of your ability.
"""

SYSTEM_MESSAGE_RAG = SYSTEM_MESSAGE + """
 Your role involves leveraging a set of reference documents to ensure accurate responses. 
 While some documents may not directly apply to every question, focus solely on those that seem pertinent. 
 Avoid referencing or citing documents not provided. 
 Craft concise answers, incorporating relevant sections from the provided documents to assist customers effectively.
"""


def get_formatted_input_dialog_for_hyde(query: str) -> Dialog:
    formatted_query = f'Please write a short passage to answer the question to the best of your ability. The passage should not exceed 200 words.\nQuestion: {query}\n\nPassage: '
    dialog = [{'role': 'system', 'content': ""},{'role': 'user', 'content': formatted_query}]
    return dialog


def get_formatted_input_dialog(query: str, doc_strs: str=None) -> Dialog:

    if doc_strs is not None and len(doc_strs) > 10:
        formatted_query = f"Question:\n{query}\n\n----\n\nDocuments:\n\n{doc_strs}"
        dialog = [{'role': 'system', 'content': SYSTEM_MESSAGE_RAG}, {'role': 'user', 'content': formatted_query}]
    else:
        dialog = [{'role': 'system', 'content': SYSTEM_MESSAGE},{'role': 'user', 'content': query}]
    
    return dialog



Create the reranking retrieval instance and the LLaMA 2 chat generator instance.

In [3]:
doc_embed_file = "../data/Tesla_manual_embeddings.pk"
llama_model_ckpt = os.path.expanduser("~/models/meta_llama2/llama-2-7b-chat/consolidated.pth")
llama_tokenizer_ckpt = os.path.expanduser("~/models/meta_llama2/tokenizer.model")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# LLM parameters
max_seq_len = 4096
max_gen_len = 512
temperature = 0.6 # model sampling temperature
top_p = 0.7 # model sampling top P

# RAG specific parameters
top_k = 4 # number of item to retrieve

In [4]:
standard_retriever = StandardRetriever(embed_file=doc_embed_file, device=device) 
hybrid_retriever = HybridRetriever(embed_file=doc_embed_file, device=device) 
rerank_retriever = RerankRetriever(embed_file=doc_embed_file, device=device)

generator = Llama.build(
    ckpt_path=llama_model_ckpt,
    tokenizer_path=llama_tokenizer_ckpt,
    max_seq_len=max_seq_len,
    max_batch_size=1,
    device=device,
)

Loading sentence-transformers/all-MiniLM-L6-v2 model and tokenizer from HuggingFace...
Loading sentence-transformers/all-MiniLM-L6-v2 model and tokenizer from HuggingFace...
Loading sentence-transformers/all-MiniLM-L6-v2 model and tokenizer from HuggingFace...
Loading cross-encoder/ms-marco-MiniLM-L-6-v2 model and tokenizer from HuggingFace...
Starting to load tokenizer checkpoint '/home/michael/models/meta_llama2/tokenizer.model' ...
Starting to load model checkpoints '/home/michael/models/meta_llama2/llama-2-7b-chat/consolidated.pth' ...
Model checkpoint loaded in 8.26 seconds


## Main logic for retrieval and LLM generation

In [5]:
def get_formatted_text_from_documents(documents: List[Mapping[Text, Any]]) -> str:
    """Join multiple documents into a single document string"""
    result = "\n\n".join([item['formatted_text'] for item in documents])
    return result

def generate_hyde_response(query: str) -> str:
    dialog = get_formatted_input_dialog_for_hyde(query)
    result = generator.chat_completion(
        [dialog],  # input needs to be a batch of dialogs
        max_gen_len=max_gen_len,
        temperature=temperature,
        top_p=top_p,
    )

    response = result[0]['generation']['content']
    
    return response


def run_chat_completions(query: str, use_hyde: bool=False) -> Tuple[str]:
    """Run chat completion with the same query multiples times with different retrieval components"""

    # build a list of dialogs for different retrieval components
    dialogs = []

    if use_hyde:
        prompt_input = generate_hyde_response(query)
    else:
        prompt_input = query

    # case 1 - standard retriever
    standard_retrieved_items = standard_retriever.retrieve(prompt_input, top_k)
    standard_doc_strs = get_formatted_text_from_documents(standard_retrieved_items)
    dialogs.append(get_formatted_input_dialog(prompt_input, standard_doc_strs))

    # case 2 - hybrid retriever 
    hybrid_retrieved_items = hybrid_retriever.retrieve(prompt_input, top_k)
    hybrid_doc_strs = get_formatted_text_from_documents(hybrid_retrieved_items)
    dialogs.append(get_formatted_input_dialog(prompt_input, hybrid_doc_strs))

    # case 3 - rerank retriever 
    rerank_retrieved_items = rerank_retriever.retrieve(prompt_input, 50, top_k)
    rerank_doc_strs = get_formatted_text_from_documents(rerank_retrieved_items)
    dialogs.append(get_formatted_input_dialog(prompt_input, rerank_doc_strs))

    # passing it separately is much faster due to the fact dialog without RAG is much shorter
    results = []
    for dialog in dialogs:
        result = generator.chat_completion(
            [dialog],  # input needs to be a batch of dialogs
            max_gen_len=max_gen_len,
            temperature=temperature,
            top_p=top_p,
        )

        results.extend(result)

    response_standard_rag = results[0]['generation']['content']
    response_standard_rag += f"\n\nReference documents:\n{standard_doc_strs}"
    
    response_hybrid_rag = results[1]['generation']['content']
    response_hybrid_rag += f"\n\nReference documents:\n{hybrid_doc_strs}"

    response_rerank_rag = results[2]['generation']['content']
    response_rerank_rag += f"\n\nReference documents:\n{rerank_doc_strs}"

    return response_standard_rag, response_hybrid_rag, response_rerank_rag


A simple hack to display the chat completions with and without RAG side-by-side for better comparison

In [6]:
from IPython.display import HTML, display

def display_completions_in_columns(query: str, response_standard_rag: str, response_hybrid_rag: str, response_rerank_rag: str) -> None:
    """
    Display content in columns side-by-side.
    
    Parameters:
    query (str): user query
    response_standard_rag (str): response for left column.
    response_hybrid_rag (str): response for middle column.
    response_rerank_rag (str): response for right column.
    """

    # Convert newlines to HTML line breaks and bulleted lists to HTML list items
    def handle_newline(input_strings) -> str:
        if input_strings is None:
            return
        return input_strings.replace('\n', '<br>')
    
    response_standard_rag = handle_newline(response_standard_rag)
    response_hybrid_rag = handle_newline(response_hybrid_rag)
    response_rerank_rag = handle_newline(response_rerank_rag)

    html_content = f'''
    <div>
        <div style="padding: 20px 0; font-weight: bold;">User: {query}</div>
        <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 20px;">
            <div>
                <div style="font-weight: bold;">Assistant standard RAG:</div>
                {response_standard_rag}
            </div>
            <div>
                <div style="font-weight: bold;">Assistant hybrid RAG:</div>
                {response_hybrid_rag}
            </div>
            <div>
                <div style="font-weight: bold;">Assistant reranking RAG:</div>
                {response_rerank_rag}
            </div>
        </div>
    <div>
    '''
    display(HTML(html_content))

In [7]:
def ask_question(query: str, use_hyde: bool=False):
    response_standard_rag, response_hybrid_rag, response_rerank_rag = run_chat_completions(query, use_hyde)
    display_completions_in_columns(query, response_standard_rag, response_hybrid_rag, response_rerank_rag)

# Now we can start asking questions about Tesla cars

We will go over the same query multiple times, with different retrieval components. The responses will be displayed in different columns for better comparisons.

In [8]:
question = "How to enable Autopilot on Tesla Model S 2018 model?"
ask_question(question)

In [9]:
question = "Under what circumstances that I should not use full self-driving on my Tesla car?"
ask_question(question)

In [10]:
question = "How to open the door of a Tesla Model S car when the power is very low?"
ask_question(question)

In [11]:
question = "Can I use autopilot in raining or snowing conditions?"
ask_question(question)

In [12]:
question = "What should I do if the touchscreen of my Tesla car is not responding?"
ask_question(question)

## Questions with custom vocabularies

We then ask some hard questions about alert codes. This example demonstrate the importance of hybrid retrieval, as the alert code is not a standard English word, so the embedding model does not treat it as a single entity. In this case, both the standard retrieval and retrieval with reranking selects incorrect documents, with reranking performs worser. Only the hybrid retrieval answered the question correctly.

**APP_w222**
Cruise control unavailable
Reduced front camera visibility

In [13]:
question = "I see code app_w222 on my dashboard what should I do?"
ask_question(question)

**BMS_a069**
Battery charge level low
Charge now

In [17]:
question = "what does bms_a069 mean?"
ask_question(question)

Let's try retrieval with HyDE, which does not help with this case.

In [15]:
question = "I see code app_w222 on my dashboard what should I do?"
ask_question(question, True)

In [16]:
question = "what does bms_a069 mean?"
ask_question(question, True)