# Dolly Server set-up

### Installing required dependencies

In [None]:
%pip install -U chromadb==0.3.22 langchain==0.0.164 transformers==4.29.0 accelerate==0.19.0 bitsandbytes pypdf uvicorn starlette ffmpeg-python ffmpy ffprobe python-magic

In [None]:
!pip install git+https://github.com/openai/whisper.git -q

In [None]:
!apt-get install ffmpeg -y

In [None]:
dbutils.library.restartPython() 

## Writing Dolly Server code

This section creates the server.py file which needs to be run on the Databricks cluster using AWS. We strongly recommend using AWS g5.4xlarge instance on databricks for the Dolly model to function effectively.



In [None]:
%%writefile server.py
import whisper
import re, time
from io import BytesIO
from typing import Any, Dict, List
from pypdf import PdfReader
import magic
import shutil

from langchain.docstore.document import Document
from langchain.document_loaders import PyPDFLoader

from langchain.memory import ConversationBufferMemory
from langchain.text_splitter import RecursiveCharacterTextSplitter


from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch, asyncio, os
from langchain import PromptTemplate
from langchain.llms import HuggingFacePipeline
from langchain.chains.question_answering import load_qa_chain

from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route

PDF_DIRECTORY = "/dbfs/data"

vector_db_path = "./vector_db"

model_whisper = whisper.load_model("medium")

hf_embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

def clear_files(directory):
    """
    Clears out all files and subdirectories in the specified directory.
    """
    if os.path.exists(directory):
        try:
            for root, dirs, files in os.walk(directory):
                for file in files:
                    file_path = os.path.join(root, file)
                    os.remove(file_path)

                for dir in dirs:
                    dir_path = os.path.join(root, dir)
                    shutil.rmtree(dir_path)
            print("Previous embeddings deleted")

        except OSError as e:
            print(f"Error while clearing files: {e}")
    else:
        print("Directory does not exist.")


def parse_pdf(file: BytesIO) -> List[str]:
    """Parse the content of a PDF file and extract the text from each page.
    Args:
        file (BytesIO): A file-like object containing the PDF data.
    Returns:
        List[str]: A list of extracted text from each page of the PDF.
    """
    pdf = PdfReader(file)
    output = []
    for page in pdf.pages:
        text = page.extract_text()
        # Merge hyphenated words
        text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
        # Fix newlines in the middle of sentences
        text = re.sub(r"(?<!\n\s)\n(?!\s\n)", " ", text.strip())
        # Remove multiple newlines
        text = re.sub(r"\n\s*\n", "\n\n", text)
        output.append(text)
    return output

def multimedia_to_text(path):
    """Transcribe audio or video file to text using the model_whisper library.
    Args:
        path (str): The path to the file. Can be MP3,WAV, MP4
    Returns:
        str: The transcribed text from the audio.
    """    
    text = model_whisper.transcribe(path)
    #printing the transcribe
    return text['text']

def text_to_docs(text: str) -> List[Document]:
    """Converts a string or list of strings to a list of Documents
    with metadata."""
    if isinstance(text, str):
        # Take a single string as one page
        text = [text]
    page_docs = [Document(page_content=page) for page in text]

    # Add page numbers as metadata
    for i, doc in enumerate(page_docs):
        doc.metadata["page"] = i + 1

    # Split pages into chunks
    doc_chunks = []

    for doc in page_docs:
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=2000,
            separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
            chunk_overlap=0,
        )
        chunks = text_splitter.split_text(doc.page_content)
        for i, chunk in enumerate(chunks):
            doc = Document(
                page_content=chunk, metadata={"page": doc.metadata["page"], "chunk": i}
            )
            # Add sources a metadata
            doc.metadata["source"] = f"{doc.metadata['page']}-{doc.metadata['chunk']}"
            doc_chunks.append(doc)
    return doc_chunks


def create_embed(pages) :
    """Create embeddings for a list of pages and persist them using Chroma.
    Args:
        pages (List[str]): A list of pages to create embeddings for.
    Returns:
        None
    """    
    global db
    db = Chroma.from_documents(documents=pages, embedding_function=hf_embed, persist_directory=vector_db_path)
    db.similarity_search("dummy") # tickle it to persist metadata (?)
    db.persist()

def get_similar_docs(question, similar_doc_count):
    # db = Chroma(embedding_function=hf_embed, persist_directory=vector_db_path)
    return db.similarity_search(question, k=similar_doc_count)


def build_qa_chain():
    model_name = "databricks/dolly-v2-7b" # can use dolly-v2-3b or dolly-v2-7b for smaller model and faster inferences.

    instruct_pipeline = pipeline(model=model_name, torch_dtype=torch.bfloat16, trust_remote_code=True,
                                device_map="auto",return_full_text=True, max_new_tokens=256, 
                                top_p=0.95, top_k=50)

    # Note: if you use dolly 12B or smaller model but a GPU with less than 24GB RAM, use 8bit. This requires %pip install bitsandbytes
    #   instruct_pipeline = pipeline(model = model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", model_kwargs={'load_in_8bit': True})
    # For GPUs without bfloat16 support, like the T4 or V100, use torch_dtype=torch.float16 below
    # model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)

    template = """
            I will ask you questions based on the following context:
            — Start of Context —
            {context}
            — End of Context—
            Use the information in the above paragraphs only to answer the question at the end. If the answer is not given in the context, say that "I do not know".

            Question: {question}

            Response:
            """

    prompt = PromptTemplate(input_variables=['context', 'question'], template=template)

    hf_pipe = HuggingFacePipeline(pipeline=instruct_pipeline)
    # Set verbose=True to see the full prompt:
    return load_qa_chain(llm=hf_pipe, chain_type="stuff", prompt=prompt, verbose=True)



async def homepage(request):
    """Handle the homepage request and return a JSON response.
    Args:
        request: The incoming request object.
    Returns:
        JSONResponse: The JSON response containing the output.
    """
    payload = await request.body()
    string = payload.decode("utf-8")
    # similar_docs = get_similar_docs(string, similar_doc_count=2)
    response_q = asyncio.Queue()
    await request.app.model_queue.put((string, response_q))
    output = await response_q.get()
    return JSONResponse(output)


async def check_file_changes():
    
    directory = PDF_DIRECTORY  # Specify the directory to monitor for file changes
    file_set = set()

    while True:
        new_files = set(os.listdir(directory)) - file_set
        if new_files:
            # Perform any notification or action for the new files
            print("New file(s) uploaded:", new_files)
            pdf_path = PDF_DIRECTORY + "/" + list(new_files)[0]
            print(pdf_path)
            file_path = PDF_DIRECTORY + "/" + list(new_files)[0]
            file_type = magic.from_file(file_path, mime=True)
            doc = None  # Initialize doc variable before assignment

            if file_type == 'application/pdf':
                # Process the PDF file
                doc = parse_pdf(file_path)
            elif file_type in ['audio/mp3', 'audio/wav','audio/mpeg', 'video/mp4']:
                # Process the audio file
                doc = multimedia_to_text(file_path)
            else:
                print("Unsupported file type:", file_type)
                # Add any necessary actions for unsupported file types

            # Clear out the path ./vector_db/
            clear_path = vector_db_path
            clear_files(clear_path)  # Implement a function to clear out the specified path
            if doc is not None:
                pages = text_to_docs(doc)
                # print(pages)
                create_embed(pages)
                print("************************----------Created Embeddings----------*************************")   
            else:     
                print(" ^^^^^^^^^^^^^^^^^^^^^^^^^^ No doc created ^^^^^^^^^^^^^^^^^^^^^^^^^^")

        file_set = set(os.listdir(directory))
        await asyncio.sleep(1)  # Adjust the time interval between checks as needed

async def server_loop(q):
    """Process incoming requests from a queue in a loop and generate responses.
    Args:
        q: The input queue containing requests.
    Returns:
        None
    """
    qa_chain = build_qa_chain()
    while True:
        
        (string, response_q) = await q.get()
        similar_docs = get_similar_docs(string, similar_doc_count=2)
        # Convert Document objects to dictionaries
        similar_docs_serializable = []
        for doc in similar_docs:
            doc_serializable = {
                "page_content": doc.page_content,
                "metadata": doc.metadata
            }
            similar_docs_serializable.append(doc_serializable)
        out = qa_chain({"input_documents": similar_docs, "question": string})

        res = {
            "similar_docs": similar_docs_serializable,
            "output_text": out['output_text']
        }

        await response_q.put(res)

def startup():
    q = asyncio.Queue()
    app.model_queue = q
    asyncio.create_task(server_loop(q))
    # asyncio.create_task(check_pdf_upload())  # Start the PDF upload checking task
    loop = asyncio.get_event_loop()
    loop.create_task(check_file_changes())
    # loop.run_until_complete(app.run())


app = Starlette(
    routes=[
        Route("/", homepage, methods=["POST"]),
    ],
    on_startup=[startup]
)

# Run the Server

In [None]:
# COMMAND ----------
# Run the server
cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
workspace_url = spark.conf.get("spark.databricks.workspaceUrl")
org_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterOwnerOrgId")
endpoint_url = f"https://{workspace_url}/driver-proxy-api/o/{org_id}/{cluster_id}/7777/"
print(f"Access this API at {endpoint_url}")

# COMMAND ----------

!uvicorn --host 0.0.0.0 --port 7777 server:app