In [10]:
import warnings
warnings.filterwarnings("ignore")

import os
import re
import glob
import json
import torch
import numpy as np
from tqdm import tqdm
import pandas as pd
from datasets import Dataset

from PIL import Image
from transformers import AutoModel, AutoTokenizer,BitsAndBytesConfig

# Dutch medical records

In [None]:
# Define the directory where the .txt files are located
directory_path = './data'  # Update this path

In [11]:
with open('prompts_version1.json', 'r') as f:
    json_prompts = json.load(f)

prompts = json_prompts[0]
print("Prompts:", prompts)
system = json_prompts[1]
print("System prompt:", system)

Prompts: {'Wordt er een diagnose of worden er meerdere diagnoses vastgesteld en zo ja, welke waren dit?': "Beschrijf alle diagnoses die worden vastgesteld en alles diagnoses waar een verdenking van is. Wanneer dit genoemd wordt, geef aan of het om de linker of de rechterhand gaat. Indien er geen informatie beschikbaar is om deze vraag te beantwoorden, schrijf 'None'. ", 'Had deze patiënt complicaties van behandelingen en zo ja, welke waren dit?': "Beschrijf alle klachten of complicaties die deze patiënt heeft ervaren tijdens of na de behandelingen. Geef aan welke complicaties, klachten of symptomen in het dossier worden genoemd, zoals pijn, ongemak of andere bijwerkingen. Indien er geen informatie beschikbaar is om deze vraag te beantwoorden, schrijf 'None'. Indien aanvullende informatie nodig is, mag deze kort worden gespecificeerd.", 'Is er pijnmedicatie voorgeschreven? Zo ja, welke?': "Geef alle voorgeschreven pijnmedicatie voor deze patiënt. Voeg geen extra informatie toe. Als er g

In [12]:
llama_prompt = """{system}

{prompt}

Medisch dossier:

{medical_record}

"""

In [13]:
def generate(model, tokenizer, llama_prompt, system, prompt, medical_record):
    prompt = llama_prompt.format(
        system=system,
        prompt=prompt,
        medical_record=medical_record
    )

    # Create a blank dummy image
    dummy_image_array = np.zeros((224, 224, 3), dtype=np.uint8)
    dummy_image = Image.fromarray(dummy_image_array)

    msgs = [{'role': 'user', 'content': [dummy_image, prompt]}]

    res = model.chat(
        image=dummy_image,
        msgs=msgs,
        tokenizer=tokenizer,
        sampling=True,
        temperature=0.95,
        stream=False
    )
    return res

In [14]:
# Translation Prompts
system_translate = "You are a skilled language translator specializing in Dutch-to-English translation. Your translations should be accurate, preserving the original tone, context, and subtle nuances of the Dutch text. Provide clear and fluent English translations that convey the intended meaning effectively."
prompt_translate = "Translate the following Dutch text to English and do not add any additional information:"
text_to_translate = ""

In [15]:
llama_prompt_translate = """{system_translate}

{prompt_translate}

{text_to_translate}

"""

In [16]:
def translate(model, tokenizer, system_translate, prompt_translate, text_to_translate, llama_prompt_translate):
    llama_prompt_translate = llama_prompt_translate.format(
        system_translate=system_translate,
        prompt_translate=prompt_translate,
        text_to_translate=text_to_translate
    )

    # Create a blank dummy image
    dummy_image_array = np.zeros((224, 224, 3), dtype=np.uint8)  # Shape (224, 224, 3)
    dummy_image = Image.fromarray(dummy_image_array)

    # Input parameters
    msgs = [{'role': 'user', 'content': [dummy_image, llama_prompt_translate]}]

    # Call the model
    res = model.chat(
        image=dummy_image,  # Pass dummy image
        msgs=msgs,
        tokenizer=tokenizer,
        sampling=True,
        temperature=0.95,
        stream=False
    )
    return res


In [None]:
# Create a list of all .txt file paths in the specified directory
file_list = glob.glob(os.path.join(directory_path, '*.txt'))

# Dictionary to hold groups of files by their main identifier
file_groups = {}

# Regex pattern to extract main identifier and sequence number
pattern = re.compile(r"(\d+)_dossieruitdraai_anonymized(?: \((\d+)\))?")

# Group files by main identifier and sort by sequence number
for file in file_list:
    match = pattern.search(file)
    if match:
        main_id = match.group(1)  # Main identifier before the first underscore
        seq_num = int(match.group(2)) if match.group(2) else 1  # Sequence number or 1 if absent
        if main_id not in file_groups:
            file_groups[main_id] = []
        file_groups[main_id].append((seq_num, file))

# Initialize lists to store main identifiers, concatenated texts, and file counts
main_ids = []
concatenated_texts = []
file_counts = []

# Process each group
for main_id, files in file_groups.items():
    # Sort files by sequence number within the group
    sorted_files = sorted(files, key=lambda x: x[0])
    
    # Concatenate content in order
    text_list = []
    for _, file in sorted_files:
        text = ""
        with open(file, 'r', encoding='utf-8-sig') as f:
            text += f.read()
        #text_list.append(text)

            # Split the file content by "Naam:" while keeping it as part of each record
            records = re.split(r'(?=Naam:)', text)
            
            # Filter out any empty strings and add the individual records to the list
            for record in records:
                if record.strip():  # Ensure non-empty record
                    text_list.append(record.strip())
    
    # Append main identifier, concatenated text, and file count
    main_ids.append(main_id)
    concatenated_texts.append(text_list)
    file_counts.append(len(text_list))  # Number of files concatenated for this identifier
    
# Create a DataFrame with columns for the document ID, concatenated texts, and file counts
df = pd.DataFrame({'Document_ID': main_ids, 'Medical_record': concatenated_texts, 'File_Count': file_counts})
df


# Information extraction

In [18]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModel.from_pretrained(
	"ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
	quantization_config=bnb_config,
	device_map="auto",
	torch_dtype=torch.float16,
	trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained("ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1", trust_remote_code=True)

Loading checkpoint shards: 100%|██████████| 4/4 [03:10<00:00, 47.73s/it]


In [19]:
# Return a list of retrieved information
def information_extraction_translation(list_of_med_records, model, tokenizer, system, prompt, llama_prompt, system_translate, prompt_translate, llama_prompt_translate):
    retrieved_translated_information = {}
    list_of_retrieved_information = []
    list_of_translated_information = []
    for medical_record in list_of_med_records:
        pred = generate(model, tokenizer, system, prompt, medical_record, llama_prompt)
        list_of_retrieved_information.append(pred)
        translated = translate(model, tokenizer, system_translate, prompt_translate, pred, llama_prompt_translate)
        list_of_translated_information.append(translated)
    retrieved_translated_information['retrieved'] = list_of_retrieved_information
    retrieved_translated_information['translated'] = list_of_translated_information
    return retrieved_translated_information

In [20]:
for column_name, prompt in tqdm(prompts.items()):
    df[column_name] = df['Medical_record'].apply(
        lambda list_of_med_records: information_extraction_translation(list_of_med_records, model, tokenizer, system, prompt, llama_prompt, system_translate, prompt_translate, llama_prompt_translate)
    )

100%|██████████| 15/15 [32:05<00:00, 128.34s/it]


In [21]:
hf_dataset = Dataset.from_pandas(df)

hf_dataset.save_to_disk("Dutch_extracted_translated_information")

Saving the dataset (1/1 shards): 100%|██████████| 2/2 [00:00<00:00, 15.98 examples/s]


In [22]:
df_dutch = df.copy()

# For each prompt column, we extract only the 'retrieved' part
df_dutch = df_dutch.drop(columns=['File_Count'])
prompt_columns = [prompt for prompt in df_dutch.columns[2:]]

# Extract 'retrieved' lists from each prompt column
for col in prompt_columns:
    df_dutch[col] = df_dutch[col].apply(lambda x: x['retrieved'] if isinstance(x, dict) else [])

# Now, expand `medical_record` and each 'retrieved' column so that each entry gets its own row
df_dutch = df_dutch.explode(['Medical_record'] + prompt_columns).reset_index(drop=True)

# Display the transformed DataFrame to verify the structure
df_dutch.head()

df_dutch.to_excel('Dutch_extracted_information.xlsx', index=False)

In [23]:
df_translated = df.copy()

# For each prompt column, we extract only the 'retrieved' part
df_translated = df_translated.drop(columns=['File_Count'])
prompt_columns = [prompt for prompt in df_translated.columns[2:]]

# Extract 'retrieved' lists from each prompt column
for col in prompt_columns:
    df_translated[col] = df_translated[col].apply(lambda x: x['translated'] if isinstance(x, dict) else [])

# Now, expand `medical_record` and each 'retrieved' column so that each entry gets its own row
df_translated = df_translated.explode(['Medical_record'] + prompt_columns).reset_index(drop=True)

# Display the transformed DataFrame to verify the structure
df_translated.head()

df_translated.to_excel('Dutch_extracted_translated_information.xlsx', index=False)