# RAG QA with memory using local LLMs

Retrieval-Augmented Generation (RAG) using Large Language Models (LLMs).
The process involves preparing data, creating the database, preparing questions, and running queries.
After successfully running all the cells, open the **public URL** with the Gradio interface.

***Key Points***
- Multiple file formats accepted in any combination: PDF, CSV, and plain text (`.txt`).
- Embedding database ChromaDB
- Local LLM
- Gradio interface

For more information, please see the accompanying `README.md`.

## Setting up LangChain


In [None]:
# import os
# os.environ['HUGGINGFACE_HUB_CACHE'] = # path to local models

## Create Database

In [None]:
from create_db import create_db

data = "../example/llm_papers/"
db_dir = "../results/llm-db"
config_file = "config.yaml"

db = create_db(data, db_dir, config_file)

## Make a retriever

In [None]:
import utils
import query

# Load default parameters
config = utils.load_config(config_file)

# Load database
retriever = query.read_db(db_dir, config)

## Make a chain

In [None]:
model, tokenizer = query.load_llm(config)

## Question answering chain with memory

In [None]:
qa_chain_with_memory, streamer = query.qa_generator(
    model,
    tokenizer,
    config,
    retriever,
    True)

memory = query.create_memory(model, tokenizer, config)

In [None]:
def _get_chat_history(chat_history) -> str:
    return chat_history

## Inference

In [None]:
## Cite sources
import textwrap
def wrap_text_preserve_newlines(text, width=110):
    # Split the input text into lines based on newline characters
    lines = text.split('\n')

    # Wrap each line individually
    wrapped_lines = [textwrap.fill(line, width=width) for line in lines]

    # Join the wrapped lines back together using newline characters
    wrapped_text = '\n'.join(wrapped_lines)

    return wrapped_text

def process_llm_response(llm_response):
    print(wrap_text_preserve_newlines(llm_response['result']))
    print('\n\nSources:')
    for source in llm_response["source_documents"]:
        print(source.metadata['source'])

In [None]:
from threading import Thread
 
class CustomThread(Thread):
    def __init__(self, group=None, target=None, name=None,
                 args=(), kwargs={}, Verbose=None):
        Thread.__init__(self, group, target, name, args, kwargs)
        self._return = None
 
    def run(self):
        if self._target is not None:
            self._return = self._target(*self._args, **self._kwargs)
             
    def join(self, *args):
        Thread.join(self, *args)
        return self._return

In [None]:
from transformers import TextIteratorStreamer
from threading import Thread
from transformers import StoppingCriteria, StoppingCriteriaList
import gradio as gr 
import torch

torch_device = "cuda" if torch.cuda.is_available() else "cpu"

def format_chat_prompt(message, chat_history, instruction):
    prompt = f"System:{instruction}"
    for turn in chat_history:
        user_message, bot_message = turn
        prompt = f"{prompt}\nUser: {user_message}\nAssistant: {bot_message}"
    prompt = f"{prompt}\nUser: {message}\nAssistant:"
    return prompt

tokenizer.pad_token = tokenizer.eos_token

def respond(message, chat_history, instruction, temperature=0.7):
    chat_history = chat_history + [[message, ""]]

    # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
    # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
    generate_kwargs = dict(
        inputs = {"question": message, "chat_history": memory.buffer}
    )
    
    t = CustomThread(target=qa_chain_with_memory, kwargs=generate_kwargs)
    t.start()
    
    acc_text = ""
    #Streaming the tokens
    for idx, response in enumerate(streamer):
            text_token = response

            if idx == 0 and text_token.startswith(" "):
                text_token = text_token[1:]

            acc_text += text_token
            last_turn = list(chat_history.pop(-1))
            last_turn[-1] += acc_text
            chat_history = chat_history + [last_turn]
            yield "", chat_history
            acc_text = ""
    
    llm_response = t.join()
    last_turn = list(chat_history.pop(-1))
    last_turn[-1] += " \n\n Source: \n\n "
    chat_history = chat_history + [last_turn]
    yield "", chat_history
    for source in llm_response["source_documents"]:
        last_turn = list(chat_history.pop(-1))
        last_turn[-1] += f"Page number:{source.metadata['page']} Document:{source.metadata['source']} \n\n"
        chat_history = chat_history + [last_turn]
        yield "", chat_history
        
    memory.save_context({"input": message}, {"output": llm_response['answer']})
    
gr.close_all()

with gr.Blocks(theme=gr.themes.Soft(), ) as demo:
    chatbot = gr.Chatbot(height=540)
    msg = gr.Textbox(label="Prompt")
    with gr.Accordion(label="Advanced options",open=False):
        system = gr.Textbox(label="System message", lines=2, value="A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.")
        temperature = gr.Slider(label="temperature", minimum=0.1, maximum=1, value=0.7, step=0.1)
        max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1000, value=250, step=1, interactive=True)
    btn = gr.Button("Submit")
    clear = gr.ClearButton(components=[msg, chatbot], value="Clear console")

    btn.click(respond, inputs=[msg, chatbot, system], outputs=[msg, chatbot])
    msg.submit(respond, inputs=[msg, chatbot, system], outputs=[msg, chatbot]) #Press enter to submit
demo.queue().launch(share=True)