In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftConfig, PeftModel
from dotenv import load_dotenv
import torch
import os

load_dotenv()

# DEFINE CONSTANTS
LLAMA3 = "meta-llama/Meta-Llama-3-8B-Instruct"
MISTRAL = "mistralai/Mistral-7B-Instruct-v0.2"
RAG1 = "rag-sources/correct-test-predictions.pdf"      # File containing all correct predictions
RAG2 = "rag-sources/icd10-tabular-list.pdf"            # ICD-10 manual

# Define which setup to run here
model_name = MISTRAL     # LLAMA3, MISTRAL
is_fine_tuned = False    # True, False
rag_source = RAG2       # None, RAG1, RAG2
input_file = "doc/translations.txt" # File to read dataframe from
output_file = "doc/test_results.txt" # File to save updated dataframe to

def get_model(name: str):
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    global tokenizer
    tokenizer = AutoTokenizer.from_pretrained(name)

    return AutoModelForCausalLM.from_pretrained(name,
                                             quantization_config=bnb_config,
                                             device_map="auto",
                                             token=os.getenv("HF_TOKEN"),
                                             pad_token_id = tokenizer.eos_token_id)

# Get model based on defined setup
if(is_fine_tuned):
    adapter_path = "llama-ft" if model_name == LLAMA3 else "mistral-ft"
    config = PeftConfig.from_pretrained(adapter_path)
    model = get_model(config.base_model_name_or_path)
    model = PeftModel.from_pretrained(model, adapter_path)
else:
    model = get_model(model_name)

# Set pad_token to suppress warnings
tokenizer.pad_token = tokenizer.eos_token

# Ensure model is in eval mode for inference
model.eval()

In [None]:
if rag_source is not None:
    from langchain.document_loaders.pdf import PyPDFLoader
    
    loader = PyPDFLoader(rag_source)
    pages = loader.load()

In [None]:
if rag_source is not None:
    from langchain.text_splitter import RecursiveCharacterTextSplitter

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=0)
    split_text = text_splitter.split_documents(pages)
    print(f"Split text into {len(split_text)} chunks.")

    # Print example chunk
    split_text[3]

In [None]:
if rag_source is not None:
    from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
    from langchain_community.vectorstores import Chroma

    # Use embedding trained on medical data
    embedding_function = SentenceTransformerEmbeddings(model_name="NeuML/pubmedbert-base-embeddings")

    db = Chroma.from_documents(documents=split_text, embedding=embedding_function) #, persist_directory="chroma-db"

    query = "Bronchitis"
    search_res = db.similarity_search(query, k=5)

    # Print test result
    print(search_res)

In [None]:
import pandas as pd
import time

# Manually fetch results from Chroma db
# Concatenate and return as a single string
def get_context(cause_of_death: str) -> str:
    context = ""
    search_res = db.similarity_search(cause_of_death)

    for res in search_res:
        context += f"{res.page_content}\n"
    
    return context

def predict_mistral(prompt: str, cause_of_death: str) -> str:
    prompt += f"\n\nCause of death: {cause_of_death}"
    
    if rag_source is not None:
        context = get_context(cause_of_death)
        prompt += f"\n\nContext: {context}"

    inputs = tokenizer(f"<s>[INST]{prompt}[/INST]", return_tensors="pt")
    outputs = model.generate(input_ids=inputs["input_ids"].to(model.device), 
                            max_new_tokens=256, 
                            pad_token_id=tokenizer.eos_token_id,
                            do_sample=True,
                            temperature=0.7,
                            top_p=0.9,
                            )
    output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract output from [INST][/INST] tags
    response = output.split("[/INST]", 1)
    res = response[-1].replace('\n', ' ')
    return res

def predict_llama(prompt: str, cause_of_death: str) -> str:
    user_input = f"Cause of death: {cause_of_death}"
    
    if rag_source is not None:
        context = get_context(cause_of_death)
        user_input += f"\n\nContext: {context}"

    messages = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": user_input},
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = model.generate(
        input_ids,
        max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id,
    )
    raw_response = outputs[0][input_ids.shape[-1]:]
    res = tokenizer.decode(raw_response, skip_special_tokens=True)
    # Remove assistant tag from message
    return res.strip().strip('assistant').replace("\n", " ")

def get_df_column() -> str:
    col_name = ""
    col_name += "llama3" if model_name == LLAMA3 else "mistral"
    col_name += "_ft" if is_fine_tuned == True else "_base"
    
    if rag_source == RAG1:
        col_name += "_rag1"
    elif rag_source == RAG2:
        col_name += "_rag2"

    return col_name

# Start timer to record runtime
start = time.perf_counter()

print(f"RUNNING CONFIG: {get_df_column()}")
print("--------------------------------")

# Read current data
df = pd.read_csv(input_file, sep="\t")
# df = df[:5]

prompt = """You are to assign an ICD-10 code to a cause of death using the following instructions:
- Use standard ICD-10 codes, not ICD-10-CM billing codes.
- Each ICD-10 code should be 3 or 5 characters long, for example: 'X01.0' or 'C15'.
- If the cause of death is 'unknown' or 'blank', use code 'R99'.
- If you lack sufficient information to assign a code, do not try to guess. Instead, use code 'Æ99.9'.
- Your response should only contain a single ICD-10 code using this format: '<ICD-10 CODE>'.
- Do not explain your answer."""

# Iterate through dataset and set predictions
for index, row in df.iterrows():  
    try:
        print(f"Processing index {index} of {len(df)}.....")

        if(model_name == LLAMA3):
            res = predict_llama(prompt, row["eng_translation"])
        elif(model_name == MISTRAL):
            res = predict_mistral(prompt, row["eng_translation"])
        else:
            raise Exception("Invalid setup at top of file.")
        
        # Save benchmark for given model in dataframe
        df.at[index, get_df_column()] = res

        print(f"Prediction: {res}, Gold: {row['icd10']}")

    except Exception as err:
        print(err)

# Save updated data as new file
df.to_csv(output_file, index=False, sep="\t")

end = time.perf_counter()

time_elapsed = f"Time elapsed: {end-start:.2f}"
print(time_elapsed)

# Write compute time to file
with open("timer.txt", "a") as f:
    f.write(f"{get_df_column()} {time_elapsed}\n")

In [None]:
# Print updated dataframe
pd.set_option('display.max_columns', None)
df