# Shared Code
This file contains shared classes used across the various retrieval_augmented_chat notebooks.

In [1]:
import chromadb
from openai import OpenAI
from IPython.display import Markdown, display
from timeit import default_timer as timer

## Interaction with a local model

In [2]:
class ChatModel:
    def __init__(self, model, tokenizer, inst_separator = " [/INST] ", temperature = 0.4):
        self.model = model
        self.tokenizer = tokenizer
        self.inst_separator = inst_separator
        self.temperature = temperature
    
    def __send_to_model(self, msg):
        messages = [
            {"role": "user", "content": msg},
        ]
        
        encoded = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
        
        generated_ids = self.model.generate(encoded, max_new_tokens=1000, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, temperature=0.8, repetition_penalty=1.20)
        decoded = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        return decoded

    def send_chat(self, msg):
        result = self.__send_to_model(msg)[0]
        return result.rsplit(self.inst_separator, 1)[1]

    def basic_chat(self, msg):
        print(self.send_chat(msg))
    

## Interaction with OpenAI API
Here we subclass ChatModel and override the functions to use their API.

In [3]:
class OpenAiChatModel(ChatModel):
    def __init__(self, organization, api_key, openai_model, temperature = 0.4):
        # The base class expects a model and tokenizer, and we don't have them.
        ChatModel.__init__(self, None, None) 
        self.openai_client = OpenAI(
            organization=openai_organization,
            api_key=openai_api_key
        )
        self.openai_model = openai_model
        self.temperature = temperature

    def __send_to_model(self, msg):
        messages = [
            {"role": "user", "content": msg},
        ]
    
        completion = self.openai_client.chat.completions.create(
            model = self.openai_model,
            messages = messages,
            temperature = self.temperature
        )
    
        return completion.choices[0].message.content

    def send_chat(self, msg):
        return self.__send_to_model(msg)
    

## Retrieval convenience methods
These methods use the vector database to find the `database_top_n_results` from the vector database, add them into the request context, then annotate the result with links to the documents used in the context.

In [4]:
class RetrievalAugmentedChat:
    def __init__(self, path, collection_name, top_n_results, chat_model):
        self.client = chromadb.PersistentClient(path=path)
        self.collection = self.client.get_collection(name = collection_name)
        self.top_n_results = top_n_results
        self.chat_model = chat_model
    
    def printmd(self, string):
        display(Markdown(string))
    
    def chat(self, msg):
        query_result = self.collection.query(
            query_texts=[msg], 
            n_results=database_top_n_results
        )
        question_with_context = ""
        if len(query_result['documents'][0]) > 0:
            question_with_context = "Based on the following documents:\n" + "\n\n".join(query_result['documents'][0]) + "\n Answer the following question with lots of details: "
        question_with_context += msg
        start = timer()
        model_response = self.chat_model.send_chat(question_with_context)
        end = timer()

        model_response_time = f"\n**Inference time in seconds {end - start:3.4f}**\n"
    
        doc_links = ""
        if len(query_result['metadatas'][0]) > 0:
            doc_links = "\n\n **Reference documents:** \n\n"
            for i in range(0, len(query_result['metadatas'][0])):
                source = query_result['metadatas'][0][i]['source']
                distance = query_result['distances'][0][i]
                doc_links += f"* [{source}]({source}) distance: {distance:3.2f}\n"
        return model_response + doc_links + model_response_time

    def markdown_chat(self, msg):
        self.printmd(self.chat(msg))    