<h1> Differential Diagnosis with Mistral RAG vs. BioMistral

In [1]:
!pip install -q streamlit langchain_community chromadb huggingface-hub bitsandbytes

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m67.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m37.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m60.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m62.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.9/94.9 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00

In [None]:
import os
os.makedirs('.streamlit', exist_ok=True)
with open('.streamlit/secrets.toml', 'w') as f:
    f.write("""
[huggingface]
token = "secret_token" ## Beware to remove before publishing or sharing this file

[models]
rag = "meta-llama/Meta-Llama-3-8B-Instruct"
bio = "ContactDoctor/Bio-Medical-Llama-3-8B"
""".lstrip())

In [None]:
%%writefile app.py

# IMPORT LIBRARY
import streamlit as st
import pandas as pd
import os
import torch

# FOR PARALLELIZATION
from concurrent.futures import ThreadPoolExecutor, as_completed 

from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
#from langchain_community.embeddings import OllamaEmbeddings
#from langchain_community.llms import Ollama
#from langchain_core.runnables import RunnablePassthrough
#from langchain_core.output_parsers import StrOutputParser
#from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from sentence_transformers import SentenceTransformer
from langchain_community.embeddings import HuggingFaceEmbeddings
#from langchain_core.messages import AIMessage, HumanMessage

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
#from langchain import HuggingFacePipeline, LLMChain
from langchain import PromptTemplate



# CODE BLOCK

PROMPT = """Answer the question based only on the following context,:{context}
Question:{question}
What are the top 10 most likely diagnoses? Be precise, listing one diagnosis per line, and try to cover many unique possibilities.
Ensure the order starts with the most likely. The top 10 diagnoses are."""
MAX_INPUT_TOKENS = 2048 # The sequence length limit of BioMistral-7V
DB_DIR = "./db"

HF_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PIPELINE_DEVICE = 0 if torch.cuda.is_available() else -1
HF_TOKEN    = st.secrets["huggingface"]["token"]
model_id    = st.secrets["models"]["rag"]
bio_model_id= st.secrets["models"]["bio"]


### CACHING HEAVY RESOURCES ###

@st.cache_resource(show_spinner=False)
def get_embedding_fn():
  return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs = {"device": HF_DEVICE})




# Load Llama-3-8B RAG

@st.cache_resource(show_spinner=False)
def get_rag_components():
    embed_fn = get_embedding_fn()
    vect  = Chroma(persist_directory=DB_DIR,
                   embedding_function=embed_fn)
    retriever  = vect.as_retriever()

    # quantization_config = BitsAndBytesConfig(
    # load_in_8bit=True,
    # llm_int8_threshold=6.0,
    # ) if HF_DEVICE == "cuda" else None

    mod = AutoModelForCausalLM.from_pretrained(
        model_id,
        use_auth_token=HF_TOKEN,
        device_map={"": HF_DEVICE},
        load_in_8bit=True,
        #quantization_config=quantization_config,
    )

    pipe  = pipeline(
        "text-generation",
        model=mod,
        tokenizer=AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN),
        use_fast=True,
        max_new_tokens=128,
    )

    prompt = PromptTemplate(template=PROMPT, input_variables=["context", "question"])
    return pipe, retriever, prompt


# Load Bio Model


@st.cache_resource(show_spinner=False)
def get_bio_pipeline():
    bio_mod     = AutoModelForCausalLM.from_pretrained(bio_model_id,
                                                     use_auth_token=HF_TOKEN,
                                                     device_map={"": HF_DEVICE},
                                                     torch_dtype= torch.float16 if HF_DEVICE=="cuda" else torch.float32,
                                                     )

    bio_pipe  =  pipeline(
        "text-generation",
        model=bio_mod,
        tokenizer=AutoTokenizer.from_pretrained(bio_model_id, use_auth_token=HF_TOKEN),
        use_fast=True,
        max_new_tokens=128,
    )

    return bio_pipe

@st.cache_resource
def get_tokenizer():
    return AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)


# Tricks for Expensive File I/O: streamlit caching
@st.cache_data(show_spinner=False)
def build_vectorstore(uploaded_files):
    os.makedirs("UploadedTextbook", exist_ok=True)
    paths = []
    for f in uploaded_files:
        path = os.path.join("UploadedTextbook", f.name)
        with open(path, "wb") as fp:
            fp.write(f.getbuffer())
        paths.append(path)

    docs = []
    for pdf in paths:
        docs.extend(PyPDFLoader(pdf).load())
    splitter = CharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=1500, chunk_overlap=200
    )
    splits = splitter.split_documents(docs)

    vs = Chroma.from_documents(
        documents=splits,
        embedding_function=get_embedding_fn(),
        persist_directory=DB_DIR,
    )
    vs.persist()
    return True


### HELPERS ###
# Make sure the token inputs are within the limit
def check_length(text):
    tok = get_tokenizer()
    token_count = len(tok.encode(text))
    if token_count > MAX_INPUT_TOKENS:
        st.warning(f"Your input is {token_count} tokens, over the {MAX_INPUT_TOKENS}-token limit. Please shorten it.")
        return False
    return True

def safe_invoke(model_or_chain, *args, **kwargs):
    try:
        if hasattr(model_or_chain, "invoke"):
            return model_or_chain.invoke(*args, **kwargs)
        return model_or_chain(*args, **kwargs)

    except Exception as e:
        st.error(f"MODEL ERROR: {e}")
        return None


def process_case(txt):
    # 1) Retrieve context
    rag_pipe, rag_retriever, prompt = get_rag_components()
    docs    = rag_retriever.get_relevant_documents(txt)
    context = "\n\n".join(d.page_content for d in docs)

    # 2) Generate with RAG‐LLM
    prompt_text = prompt.format_prompt(context=context, question=txt).to_string()
    rag_out_raw = safe_invoke(rag_pipe, prompt_text, max_new_tokens=128)
    rag_out     = rag_out_raw[0]["generated_text"] if rag_out_raw else "Error"

    # 3) Generate with Bio‐LLM
    bio_pipe    = get_bio_pipeline()
    bio_prompt  = PROMPT.format(context="", question=txt)
    bio_raw     = safe_invoke(bio_pipe, bio_prompt, max_new_tokens=128)
    bio_out     = bio_raw[0]["generated_text"] if bio_raw else "Error"

    return {"Case": txt, "Llama3+RAG": rag_out, "BioLlama3": bio_out}






### STREAMLIT UI ###
st.title("Differential Diagnosis: Llama 3-8B RAG vs. BioLlama 3 by ContactDoctor")
st.caption("Predefined Medical Guidelines for RAG: Kumar & Clark's Clinical Medicine 10th Ed. 2020")

# Additional files
with st.sidebar:
    st.header("Upload additional resources for RAG (type:.pdf)")
    UploadedFiles = st.file_uploader("Upload here and click on 'Upload'", type="pdf", accept_multiple_files=True)

    if st.button("Build Index"):
        try:
            os.mkdir("UploadedTextbook")
            if not UploadedFiles:
              st.error('Select at least one PDF file.')
            else:
              with st.spinner("Indexing…"):
                  build_vectorstore(UploadedFiles)
              st.success("RAG index is ready!")
        except:
            print("File already exists")

    st.markdown("---")
    st.header("Batch processing case upload (type:.csv)")
    csv_file = st.file_uploader(
        "Upload CSV",
        type="csv",
        accept_multiple_files=False)

### SINGLE CASE ###
st.subheader("SINGLE CASE")
question = st.text_area("Case Narrative:",
                        height=180,
                        placeholder="For example: 22-year-old patient with TB was admitted to hospital today. The patient has been to a country outside Sweden. The patient came back to Sweden from the other country. The patient has had a fever for two weeks and is admitted. The doctor has prescribed a medicine. ")
st.write(f"The number of characters are {len(question)} characters.")

if st.button('Start Processing'):
  if check_length(question):
    with st.spinner("Processing..."):
      rag_pipe, rag_retriever, prompt = get_rag_components()
      docs = rag_retriever.get_relevant_documents(question)
      context = "\n\n".join([d.page_content for d in docs])
      prompt_text = prompt.format_prompt(context=context, question=question).to_string()

      st.markdown("**Llama 3-8B + RAG**")
      output = rag_pipe(prompt_text)[0]["generated_text"]
      st.text(output)
      
      bio_pipe = get_bio_pipeline()
      bio_prompt = PROMPT.format(context="",question=question)
      bio_output = bio_pipe(bio_prompt)[0]["generated_text"]

      st.markdown("**BioLlama 3**")
      st.text(bio_output)

  else:
    st.stop()

### BATCH PROCESSING ###
st.markdown("---")
st.subheader("BATCH MODE")

if csv_file:
    df = pd.read_csv(csv_file)
    if st.button("Start Batch Processing"):
        results = []
        futures = []
        prog = st.progress(0)


        with ThreadPoolExecutor(max_workers=4) as exe:
            for txt in df["Case"]:
                futures.append(exe.submit(process_case, txt))

            # as each case completes, update progress
            for i, fut in enumerate(as_completed(futures)):
                results.append(fut.result())
                prog.progress((i + 1) / len(futures))

        out_df = pd.DataFrame(results)
        st.download_button(
            "Download Results as CSV",
            data=out_df.to_csv(index=False),
            file_name="ddx_comparison.csv"
        )

Overwriting app.py


<h2>Install local-tunnel </h2>

In [33]:
!npm install localtunnel

[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K
up to date, audited 23 packages in 723ms
[1G[0K⠼[1G[0K
[1G[0K⠼[1G[0K3 packages are looking for funding
[1G[0K⠼[1G[0K  run `npm fund` for details
[1G[0K⠼[1G[0K
2 [31m[1mhigh[22m[39m severity vulnerabilities

To address all issues (including breaking changes), run:
  npm audit fix --force

Run `npm audit` for details.
[1G[0K⠼[1G[0K

<h2> Run Streamlit in background </h2>

In [None]:
# AND Expose to the port 8501
!streamlit run /content/app.py &>/content/logs.txt & npx localtunnel --port 8501 & curl ipv4.icanhazip.com

34.86.97.249
[1G[0K⠙[1G[0Kyour url is: https://gentle-lemons-draw.loca.lt
