# EN RAG now for inital
This is the version that is used for initial and final_to_english.

All Dialogues are in mongodb-collection: collection = db['NLP-EXPANDED-prammed-postprocessed_translation'].

RAG Results are saved per dialogue as json under directory of this notebook.



In [None]:
from pymongo import MongoClient
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from accelerate import Accelerator
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain.text_splitter import CharacterTextSplitter
import json
import re
import os

# Initialize the accelerator
accelerator = Accelerator()

# MongoDB connection
client = MongoClient('mongodb://localhost:27017/')
db = client['MIMIC-IV']
# collection = db['NLP-EXPANDED-prammed']
collection = db['NLP-EXPANDED-prammed-postprocessed_translation'] 


# Load LLM model and tokenizer
# model_name = "VAGOsolutions/Llama-3-SauerkrautLM-8b-Instruct" # sauerkraut was tested for german dialogues, but didnt perform well for diagnoses (ca. 0.33)
model_name = "MaziyarPanahi/Llama-3-8B-Instruct-v0.8"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)

text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200, temperature=0.2, return_full_text=False, do_sample=True)

In [None]:
# Define dictionaries for customizing the prompt
q_prefix = {
    "chiefcomplaint": "all previous or persisting",
    "medication_reported": "every single",
    "medication_pyxis": "every single",
    "diagnosis": "every single assessed, suspected or considered",
    "bp": "every single",
    "heartrate": "every single",
    "o2sat_triage": "every single",
    "resprate": "every single",
    "temperature": "every single",
    "pain_triage": "every single",
}

q_data = {
    "chiefcomplaint": "health problems and symptoms",
    "medication_reported": "pharmaceuticals, drugs and medications",
    "medication_pyxis": "pharmaceuticals, drugs, medications",
    "diagnosis": "medical conditions or diagnoses",
    "bp": "'blood pressure' readings values",
    "heartrate": "'heart rate' readings, pulse values",
    "o2sat_triage": "oxygen saturation readings",
    "resprate": "breathing rates values, like x per minute",
    "temperature": "temperature readings",
    "pain_triage": "pain scale",
}

q_postfix = {
    "chiefcomplaint": "that the patient reports having or having had",
    "medication_reported": "that the patient says he has taken or is taking",
    "medication_pyxis": "that the patient is being given to or administered",
    "diagnosis": "or injuries mentioned",
    "bp": "that are in the text snippets",
    "heartrate": "that are in the text snippets",
    "o2sat_triage": "that are in the text snippets",
    "resprate": "that are in the text snippets",
    "temperature": "that are in the text snippets",
    "pain_triage": "levels, values or intensity that someone has ",
}

q_format = {
    "chiefcomplaint": "as a list, like: nausea; ear pain; headache, alcohol intoxication",
    "medication_reported": "as a list, like: Aspirin 1 tablet; Paracetamol; Antihistamine",
    "medication_pyxis": "as a csv-list of only names/labels, like: Aspirin x mg; Paracetamol; antihistamine",
    "diagnosis": "as a list of diagnoses or symptoms, like: abd pain; hematoma; cardiac arrest",
    "bp": "as a list of numerical values, like: 169 over 90; 169/98",
    "heartrate": "as a csv-list of numerical values, like: 90; 100; 100",
    "o2sat_triage": "as a list of numerical values, like: 90%; 100; 100%",
    "resprate": "as a list of numerical values, like: 10; 20; 20",
    "temperature": "as a list of numerical values, like: 98; 100; 37",
    "pain_triage": "as a list of numerical values. "
}

q_format = {
    "chiefcomplaint": "as a list, like: nausea; ear pain; headache, alcohol intoxication",
    "medication_reported": "as a list, like: Aspirin 1 tablet; Paracetamol; Antihistamine",
    "medication_pyxis": "as a csv-list of only names/labels, like: Aspirin x mg; Paracetamol; antihistamine",
    "diagnosis": "as a list of diagnoses or symptoms, like: abd pain; hematoma; cardiac arrest",
    "bp": "as a list of numerical values, like: 169 over 90; 169/98",
    "heartrate": "as a csv-list of numerical values, like: 90; 100; 100",
    "o2sat_triage": "as a list of numerical values, like: 90%; 100; 100%",
    "resprate": "as a list of numerical values, like: 10; 20; 20",
    "temperature": "as a list of numerical values, like: 98; 100; 37",
    "pain_triage": "as a list of numerical values, like: 11; 22; 0"
}

prompt_template = """
Context information is below.
---------------------
{context}
---------------------
Given the context information and not prior knowledge answer always concisely.
Please extract {q_prefix} {q_data} {q_postfix} {q_format}!
ANSWER:
"""

In [None]:
def process_document(document, index):
    dialogue = document.get('dialogue', "")
    text_splitter = CharacterTextSplitter(chunk_size=160, chunk_overlap=50)
    text_splitter = CharacterTextSplitter(chunk_size=200, chunk_overlap=40, separator=" ")
    chunks = text_splitter.split_text(dialogue)

    embedding_model_name = "intfloat/multilingual-e5-large-instruct"
    embedder = HuggingFaceBgeEmbeddings(model_name=embedding_model_name)
    vector_store = FAISS.from_texts(chunks, embedder)

    answers = {}
    for key in q_data.keys():
        question = f"{q_prefix[key]} {q_data[key]} {q_postfix[key]} {q_format[key]}"
        retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 12})
        docs = retriever.get_relevant_documents(question)
        context_text = " ".join([doc.page_content for doc in docs])
        print('Context-Text from RETRIEVER--------------------')
        print(q_data[key])
        print()
        print(context_text)
        print()
        print('--------------------Context-Text from RETRIEVER')

        prompt = prompt_template.format(
            context=context_text,
            q_prefix=q_prefix[key],
            q_data=q_data[key],
            q_postfix=q_postfix[key],
            q_format=q_format[key]
        )
        
        result = text_generation_pipeline(prompt)
        answer = result[0]['generated_text'] if isinstance(result, list) else result
        
        for pattern in [
            "ANSWER", "\n\n\n", "! ! ", "Final Answer:", r"\(Note:",
            "nprint", "python", r"''''", "Answer:", "(no value mentioned)",
            'END OF ANSWER', 'The solution is:'
        ]:
            answer = re.split(pattern, answer)[0].strip()
        
        answer = re.sub(r"(\S)\s*\1\s*\1", r"\1", answer)
        print('Answer******************************************')
        print(answer)
        print(':******************************************Answer')
        print()
        answers[key] = answer

    json_data_used = document.get("json_data_used", {})
    if isinstance(json_data_used, dict):
        for subkey in json_data_used.keys():
            if "_id" in json_data_used[subkey]:
                json_data_used[subkey]["_id"] = str(json_data_used[subkey]["_id"])

    output = {
        "_id": str(document["_id"]),
        "answers": answers,
        "json_data_used": json_data_used
    }
    output_filename = f"r_{index}.json"
    with open(output_filename, 'w') as f:
        json.dump(output, f, indent=4)
    return output

# Set the starting index for document processing
start_index = 102  # Change this to your specific starting index

# Loop through each document starting from the start_index
try:
    for i, document in enumerate(collection.find().skip(start_index - 1), start_index):
        try:
            process_document(document, i)
        except Exception as e:
            print(f"Error processing document {i}: {e}")
finally:
    client.close()

print(f"Processed documents starting from index {start_index}")