<h1> Differential Diagnosis with Mistral 7B RAG vs. BioMistral 7B by ContactDoctor

In [1]:
!pip install -q streamlit langchain_community chromadb huggingface-hub bitsandbytes pypdf tiktoken

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m106.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m104.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.9/18.9 MB[0m [31m116.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.9/94.9 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m13.3 MB/s[0m eta [36m0:

In [None]:
import os
os.makedirs('.streamlit', exist_ok=True)
with open('.streamlit/secrets.toml', 'w') as f:
    f.write("""
[huggingface]
token = "secret_token"

[models]
rag = "mistralai/Mistral-7B-Instruct-v0.2"
bio = "BioMistral/BioMistral-7B"
""".lstrip())

In [36]:
%%writefile app.py

# IMPORT LIBRARY
import streamlit as st
import pandas as pd
import os
import torch

from concurrent.futures import ThreadPoolExecutor, as_completed # FOR PARALLELIZATION

from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain import PromptTemplate
from sentence_transformers import SentenceTransformer

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import threading
import time
from tenacity import retry, stop_after_attempt, wait_fixed
import gc




# CODE BLOCK

PROMPT = """Answer the question based only on the following context,:{context}
Question:{question}
What are the top 10 most likely diagnoses? Be precise, listing one diagnosis per line, and try to cover many unique possibilities.
Ensure the order starts with the most likely. The top 10 diagnoses are."""
MAX_INPUT_TOKENS = 2048 # The sequence length limit of BioMistral-7V
MAX_CONTEXT_LENGTH = 4096 # Total context length including prompt
DB_DIR = "./db"

HF_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PIPELINE_DEVICE = 0 if torch.cuda.is_available() else -1
HF_TOKEN    = st.secrets["huggingface"]["token"]
model_id    = st.secrets["models"]["rag"]
bio_model_id= st.secrets["models"]["bio"]


# using 4bit to save memory for model loading
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)


if torch.cuda.is_available():
    import os
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True, max_split_size_mb:512, garbage_collection_threshold:0.8"


### HELPERS TO LOAD THE MODEL ###

#lock to serialize any “move‐model‐on/off GPU” calls
gpu_lock = threading.Lock()

@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def unload_model_from_gpu(model=None):
    """Clears CUDA cache and forces garbage collection to save memory"""
    with gpu_lock:
        torch.cuda.empty_cache()
        gc.collect()


### CACHING HEAVY RESOURCES ###

@st.cache_resource(show_spinner=False)
def get_embedding_fn():
  return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs = {"device": HF_DEVICE})

#load tokenizer separately for faster token counting
@st.cache_resource
def get_tokenizer():
    return AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)




#load Mistral 7B RAG
@st.cache_resource(show_spinner=False)
def get_rag_components():

  vs = Chroma(
      get_embedding_fn(),
      persist_directory=DB_DIR,
  )
  retriever  = vs.as_retriever()

  mod = AutoModelForCausalLM.from_pretrained(model_id,
                                              use_auth_token= HF_TOKEN,
                                              device_map='auto' if HF_DEVICE=="cuda" else "cpu",
                                              torch_dtype= torch.bfloat16 if HF_DEVICE=="cuda" else torch.float32,
                                              quantization_config=bnb_config)

  pipe  = pipeline(
      "text-generation",
      model=mod,
      tokenizer=AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN),
      #device="cpu",
      use_fast=True,
      max_new_tokens=256,
  )

  prompt = PromptTemplate(template=PROMPT, input_variables=["context", "question"])
  return pipe, retriever, prompt


#load Bio model
@st.cache_resource(show_spinner=False)
def get_bio_pipeline():
  with gpu_lock:
    #unload_model_from_gpu()

    bio_mod     = AutoModelForCausalLM.from_pretrained(bio_model_id,
                                               use_auth_token= HF_TOKEN,
                                               device_map='auto' if HF_DEVICE=="cuda" else "cpu",
                                               torch_dtype= torch.bfloat16 if HF_DEVICE=="cuda" else torch.float32,
                                               quantization_config=bnb_config)


    bio_pipe  =  pipeline(
        "text-generation",
        model=bio_mod,
        tokenizer=AutoTokenizer.from_pretrained(bio_model_id, use_auth_token=HF_TOKEN),
        use_fast=True,
        max_new_tokens=256,
    )

    return bio_pipe


#streamlit caching: tricks for expensive file I/O:
@st.cache_data(show_spinner=False)
def build_vectorstore(uploaded_files):
    #start fresh every build index
    if os.path.isdir(DB_DIR):
        for fn in os.listdir(DB_DIR):
            os.remove(os.path.join(DB_DIR, fn))
    os.makedirs(DB_DIR, exist_ok=True)
    paths = []
    for f in uploaded_files:
        path = os.path.join(DB_DIR, f.name)
        with open(path, "wb") as fp:
            fp.write(f.getbuffer())
        paths.append(path)

    docs = []
    for pdf in paths:
        docs.extend(PyPDFLoader(pdf).load())

    splitter = CharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=1000, chunk_overlap=150 #define chunking strategy here: smaller for better targeting
    )
    splits = splitter.split_documents(docs)

    vs = Chroma.from_documents(
        documents=splits,
        embedding_function=get_embedding_fn(),
        persist_directory=DB_DIR,
    )
    vs.persist()
    return True

### HELPERS ###

#making sure the token inputs are within the limit
def check_length(text, tokenizer=None):
  if tokenizer is None:
    tokenizer = get_tokenizer()
  token_count = len(tokenizer.encode(text))
  if token_count > MAX_INPUT_TOKENS:
      st.warning(f"Your input is {token_count} tokens, over the {MAX_INPUT_TOKENS}-token limit. Please shorten it.")
      return False
  return True

#check context length after RAG retrieval
def check_context_length(prompt_text, tokenizer=None):
  if not tokenizer:
      tokenizer = get_tokenizer()
  token_count = len(tokenizer.encode(prompt_text))
  if token_count > MAX_CONTEXT_LENGTH-256:
      st.warning(f"Total context + prompt is {token_count} tokens, over the {MAX_CONTEXT_LENGTH}-token limit.")
      return False
  return True


def safe_invoke(model_or_chain, *args, **kwargs):
    try:
        if hasattr(model_or_chain, "invoke"):
            return model_or_chain.invoke(*args, **kwargs)
        return model_or_chain(*args, **kwargs)

    except Exception as e:
        st.error(f"MODEL ERROR: {e}")
        return None



#error handling for bio and naive model
@retry(stop=stop_after_attempt(2), wait=wait_fixed(2))
def run_rag_pipeline(pipe, prompt_text):
    try:
        raw_output = safe_invoke(pipe, prompt_text, max_new_tokens=256)
        with gpu_lock:
            pipe.model.to("cpu")
            torch.cuda.empty_cache()

        if raw_output:
            raw = raw_output[0]["generated_text"]
            #strip prompt echo to avoid prompt redudancy
            output = raw[len(prompt_text):].lstrip() if raw.startswith(prompt_text) else raw
            return output

        unload_model_from_gpu()
        return "Error: No output generated"
    except Exception as e:
        return f"Error: {str(e)}"


@retry(stop=stop_after_attempt(2), wait=wait_fixed(2))
def run_bio_pipeline(pipe, prompt_text):
    try:
        raw_output = safe_invoke(pipe, prompt_text, max_new_tokens=256)
        with gpu_lock:
            pipe.model.to("cpu")
            torch.cuda.empty_cache()

        if raw_output:
            raw = raw_output[0]["generated_text"]
            #strip prompt echo to avoid redudancy
            output = raw[len(prompt_text):].lstrip() if raw.startswith(prompt_text) else raw
            return output
        return "Error: No output generated"
    except Exception as e:
        return f"Error: {str(e)}"



#case processing
def process_case(txt, use_rag=True, use_bio=True):
    results = {"Case": txt}
    tokenizer = get_tokenizer()

    #validate input length first
    if not check_length(txt, tokenizer):
        results["Mistral7B+RAG"] = "Case too long, exceeds token limit"
        results["BioMistral7B"] = "Case too long, exceeds token limit"
        return results

    # 1) Retrieve context
    if use_rag:
      try:
        rag_pipe, rag_retriever, prompt = get_rag_components()
        docs    = rag_retriever.get_relevant_documents(txt)
        context = "\n\n".join(d.page_content for d in docs)

        # 2) Generate with RAG‐LLM
        prompt_text = prompt.format_prompt(context=context, question=txt).to_string()

        # Check combined length
        if not check_context_length(prompt_text, tokenizer):
            context_shortened = context[:len(context)//2] + "..."  #simple truncation
            prompt_text = prompt.format_prompt(context=context_shortened, question=txt).to_string()

        rag_out = run_rag_pipeline(rag_pipe, prompt_text)
        results["Mistral7B+RAG"] = rag_out

        #free up GPU memory after RAG
        unload_model_from_gpu()

      except Exception as e:
          results["Mistral7B+RAG"] = f"ERROR: {str(e)}"
    else:
        results["Mistral7B+RAG"] = "RAG processing skipped"

    # 3) Generate with Bio‐LLM
    if use_bio:
      try:
        bio_pipe    = get_bio_pipeline()
        bio_prompt  = PROMPT.format(context="", question=txt)
        bio_out = run_bio_pipeline(bio_pipe, bio_prompt)
        results["BioMistral7B"] = bio_out

        #free up GPU memory after bio_model
        unload_model_from_gpu()

      except Exception as e:
            results["BioMistral7B"] = f"ERROR: {str(e)}"
    else:
        results["BioMistral7B"] = "Bio processing skipped"

    return results




### STREAMLIT UI ###
st.title("Differential Diagnosis: Mistral 7B RAG vs. Bio Mistral 7B by BioMistral")
st.caption("Helps the doctor/nurse to develop their differential diagnosis using LLM models")

# Additional files
with st.sidebar:
    st.header("Upload additional resources for RAG (type:.pdf)")
    UploadedFiles = st.file_uploader("Upload here and click on 'Upload'", type="pdf", accept_multiple_files=True)
    MAX_LINES = 3 # limit maximum document uploaded
    if len(UploadedFiles) > MAX_LINES:
      st.warning(f"Maximum number of files reached. Only the first {MAX_LINES} will be processed.")
      UploadedFiles = UploadedFiles[:MAX_LINES]


    if st.button("Build Index"):
        if not UploadedFiles:
            st.error("Select at least one PDF first.")
        else:
            with st.spinner("Indexing…"):
                build_vectorstore(UploadedFiles)
            st.success("RAG index is ready!")

    st.markdown("---")
    st.header("Batch processing case upload (type:.csv)")
    csv_file = st.file_uploader(
        "Upload CSV",
        type="csv",
        accept_multiple_files=False)

### SINGLE CASE ###
st.subheader("SINGLE CASE")
question = st.text_area("Case Narrative:",
                        height=180,
                        placeholder="For example: 22-year-old patient with TB was admitted to hospital today. The patient has been to a country outside Sweden. The patient came back to Sweden from the other country. The patient has had a fever for two weeks and is admitted. The doctor has prescribed a medicine. ")
st.write(f"The number of characters are {len(question)} characters.")

if st.button('Start Processing'):
  if check_length(question):
    with st.spinner("Processing..."):
      result = process_case(question)
      tabs = st.tabs(["BIOMode", "RAGMode"])

      with tabs[0]:
        #Biomodel execution
        st.markdown("**BioMistral 7B**")
        st.text(result['BioMistral7B'])

      with tabs[1]:
        if not os.path.isdir(DB_DIR) or not os.listdir(DB_DIR):
          st.error("Please upload and build your PDF index first!")
          st.stop()
        else:
          st.markdown("**Mistral 7B + RAG**")
          st.text(result['Mistral7B+RAG'])

        #free up memory from bio
        unload_model_from_gpu()

  else:
    st.stop()

### BATCH PROCESSING ###
st.markdown("---")
st.subheader("BATCH MODE")

if csv_file:
    df = pd.read_csv(csv_file)
    if st.button("Start Batch Processing"):
        results = []
        futures = []
        prog = st.progress(0)

        #only run one case at a time to avoid GPU memory constraints
        BATCH_WORKERS = 1 # if HF_DEVICE != "cuda" else 2
        with ThreadPoolExecutor(max_workers=BATCH_WORKERS) as exe:
            for txt in df["Case"]:
                futures.append(exe.submit(process_case, txt))

            # as each case completes, update progress
            for i, fut in enumerate(as_completed(futures)):
              try:
                results.append(fut.result())
              except Exception as e:
                results.append({"Case": txt, "Mistral7B+RAG": f"ERROR: {str(e)}", "BioMistral7B": f"ERROR: {str(e)}"})
              prog.progress((i + 1) / len(futures))

        out_df = pd.DataFrame(results)
        st.download_button(
            "Download Results as CSV",
            data=out_df.to_csv(index=False),
            file_name="ddx_comparison.csv"
        )

        # Show sample of results
        st.write("Sample of processed results:")
        st.dataframe(out_df.head())

Overwriting app.py


<h2>Install local-tunnel </h2>

In [37]:
!npm install localtunnel

[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K⠴[1G[0K⠦[1G[0K
up to date, audited 23 packages in 969ms
[1G[0K⠦[1G[0K
[1G[0K⠦[1G[0K3 packages are looking for funding
[1G[0K⠦[1G[0K  run `npm fund` for details
[1G[0K⠦[1G[0K
2 [31m[1mhigh[22m[39m severity vulnerabilities

To address all issues (including breaking changes), run:
  npm audit fix --force

Run `npm audit` for details.
[1G[0K⠦[1G[0K

<h2> Run Streamlit in background </h2>

In [38]:
# AND Expose to the port 8501
!streamlit run /content/app.py &>/content/logs.txt & npx localtunnel --port 8501 & curl ipv4.icanhazip.com

34.125.180.239
[1G[0K⠙[1G[0Kyour url is: https://quiet-crabs-start.loca.lt
