In [1]:
from datasets import load_dataset, Dataset, DatasetDict
from transformers import LlamaTokenizer
from utils import create_llama2_chat_prompt, save_dataset_to_json, count_tokens, create_llama2_instruction_prompt

from pprint import pprint
import os
import pandas as pd
import json
import openai
import re
from langchain.chains import LLMChain
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.callbacks import get_openai_callback
import asyncio

MAIN_DIR = os.path.dirname(os.getcwd())
DATA_DIR = os.path.join(MAIN_DIR, "data")

  from .autonotebook import tqdm as notebook_tqdm


In [254]:
with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)
    
openai.api_key = api_keys["OPENAI_API_KEY"]
os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]

In [469]:
tokenizer = LlamaTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", cache_dir=os.path.join(MAIN_DIR, "model"),
    )

In [None]:
def replace_string_by_index(string:str, repl: str, start_idx, end_idx):
    prefix = string[:start_idx]
    suffix = string[end_idx:]
    return prefix + repl + suffix

# Prep Guanaco Chat Dataset

In [82]:
# Download from HF Hub
dataset_name = "timdettmers/openassistant-guanaco"
datasets = load_dataset(dataset_name)

Repo card metadata block was not found. Setting CardData to empty.


In [154]:
chat_dataset = DatasetDict()

for split in datasets:
    chat_queries = []

    for query in datasets[split]["text"]:
        message_list = re.split("### Human: |### Assistant: ", query)[1:]
        chat_query = create_llama2_chat_prompt(
            message_list, hf_tokenizer
        )
        chat_queries.append(chat_query)
        
    chat_dataset[split] = Dataset.from_dict({"text": chat_queries})

In [156]:
save_dataset_to_json(chat_dataset, os.path.join(DATA_DIR, "guanaco", "chat")) 

Creating json from Arrow format: 100%|██████████| 10/10 [00:00<00:00, 15.10ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 26.30ba/s]


# Generate Dataset

## Replace Tokens and Generate Groundtruths

In [443]:
sdr_df = pd.read_csv(os.path.join(DATA_DIR, "sdr", "Clindoc_masked_with_classes.csv"))
masked_texts = sdr_df["masked_text"]

In [250]:
system_prompt = """You are a dataset generator of a diabetes dataset in Singapore.
You always keep people identifiers confidential and only generate fictional entities.
=====
TASK:
You are given a patient summary text. However, some of the texts have been masked. The masked texts are annotated with inside <category>.
Your task is to replace the <category> with a FICTIONAL entity based on the context of the summary to generate a synthetic dataset.
The descriptions for each type of masked token is as follow:

<Person Name>: Name of a person/human. Can be full name, short form, with or without salutation (Dr, Mr, Ms) 
<NRIC/Passport>: National Registration Identity Card (NRIC), Foreign Identity Numbers (FIN) and Passport Number
<Medical Clinical Records>: Medical record numbers
<Phone Number>: Telephone numbers
<Email Address>: Electronic mail addresses
<Home Address>: Patient home address (unit, level, block, street, full 6-digit postal codes).

NOTE:
- Only generate imaginary/fictional names (NOT ACTUAL doctors, phone numbers, email addresses, etc)
- Filled information (names, IDs, phone, emails) must be in Singapore contexts 
- Try to generate a combination of both full names and short names.
- Maintain the spaces and new lines of the original text.
- If no token is present, return an empty list.
- If there are duplicates, include all of them in the list
=====
Return a list which contains the token and the entity generated.
=====
EXAMPLES:
MASKED PATIENT SUMMARY: <Person Name> <NRIC/Passport> 65/Malay /Male ADL Independent Community-Ambulant.
ANSWER: [{{"<Person Name>": "Jonathan Lee"}}, {{"<NRIC/Passport>": "S1234567D"}}]

MASKED PATIENT SUMMARY: <Person Name> 52 year old Malay Female NKDA ADL-independent, community ambulant with wheelchair on follow-up SGH Renal <Person Name>
ANSWER: [{{"<Person Name>": "Steven Lee"}}, {{"<Person Name>": "Steven Lee"}}]

MASKED PATIENT SUMMARY: 51/Chi/M Premorbidly ADLi, comm amb without aid PAST MEDICAL HISTORY RENAL HISTORY ESRF secondary to chronic GN s/p DDRT
ANSWER: []
"""

user_prompt = """
=====
MASKED PATIENT SUMMARY: {query}
"""

FILL_PROMPT = ChatPromptTemplate.from_messages(
    messages = [
        SystemMessagePromptTemplate.from_template(system_prompt),
        HumanMessagePromptTemplate.from_template(user_prompt)
    ]
)

In [372]:
system_prompt = """You are a dataset generator of a diabetes dataset in Singapore.
You always keep people identifiers confidential and only generate fictional entities.
=====
TASK:
You are given a a list of token type and the corresponding number of required generated tokens.
Your task is to generate FICTIONAL entities of the corresponding <category> tokens to fill/replace the token inside a diabetic patient summary.
The descriptions for each type of masked token is as follow:

<Person Name>: Name of a person/human. Can be full name, short form, initials, with or without salutation (Dr, Mr, Ms) 
<NRIC/Passport>: National Registration Identity Card (NRIC), Foreign Identity Numbers (FIN) and Passport Number
<Medical Clinical Records>: Medical record numbers
<Phone Number>: Telephone numbers
<Email Address>: Electronic mail addresses
<Home Address>: Patient home address (unit, level, block, street, full 6-digit postal codes).

NOTE:
- Only generate imaginary/fictional names (NOT ACTUAL doctors, phone numbers, email addresses, etc)
- Filled information (names, IDs, phone, emails) must be in Singapore contexts 
- Try to generate a combination of both full names and short names.
- Maintain the spaces and new lines of the original text.
- If no token is present, return an empty list.
- If there are duplicates, include all of them in the list
=====
Return a list which contains the token and the entity generated.
=====
EXAMPLES:
NUMBER OF TOKENS: 
- <Person Name>: 1
- <NRIC/Passport>: 1
- <Medical Clinical Records>: 0
- <Phone Number>: 0
- <Email Address>: 0
- <Home Address>: 0
ANSWER: [{{"<Person Name>": "Jonathan Lee"}}, {{"<NRIC/Passport>": "S1234567D"}}]

EXAMPLES:
NUMBER OF TOKENS: 
- <Person Name>: 1
- <NRIC/Passport>: 1
- <Medical Clinical Records>: 0
- <Phone Number>: 0
- <Email Address>: 0
- <Home Address>: 0
ANSWER: [{{"<Person Name>": "Lee YM"}}, {{"<Person Name>": "Dr Tan MJ"}}, {{"<Person Name>": "Dr Lee Si Kiat"}}]

NUMBER OF TOKENS: 
- <Person Name>: 0
- <NRIC/Passport>: 0
- <Medical Clinical Records>: 0
- <Phone Number>: 0
- <Email Address>: 0
- <Home Address>: 0
ANSWER: []
"""

user_prompt = """
=====
NUMBER OF TOKENS: 
- <Person Name>: {name_no}
- <NRIC/Passport>: {nric_no}
- <Medical Clinical Records>: {mcr_no}
- <Phone Number>: {phone_no}
- <Email Address>: {email_no}
- <Home Address>: {address_no}
ANSWER:
"""

FALLBACK_PROMPT = ChatPromptTemplate.from_messages(
    messages = [
        SystemMessagePromptTemplate.from_template(system_prompt),
        HumanMessagePromptTemplate.from_template(user_prompt)
    ]
)

In [255]:
async def generate_result_class(text):
    llm = ChatOpenAI(model = "gpt-4-1106-preview", temperature = 1.0, max_tokens = 512)
    class_fill_chain = LLMChain(llm=llm, prompt=FILL_PROMPT)
    with get_openai_callback() as cb:
        resp = await class_fill_chain.acall(text)
    return resp["text"], cb.total_cost

In [256]:
results = []

for text in masked_texts:
    result = generate_result_class(text)
    results.append(result)
    
results = await asyncio.gather(*results)
ground_truths = [result[0] for result in results]
total_cost = sum([result[1] for result in results])

print(total_cost)

3.180749999999999


In [452]:
import numpy as np

TOKEN2ID = {
    "<Person Name>": 0,
    "<NRIC/Passport>": 1,
    "<Medical Clinical Records>": 2,
    "<Phone Number>": 3,
    "<Email Address>": 4,
    "<Home Address>": 5
}

replacement_tokens = []

for text_idx, text in enumerate(ground_truths):
    token_counter = np.zeros((len(TOKEN2ID)), dtype=int)
    for idx, token in enumerate(TOKEN2ID.keys()):
        token_counter[idx] = len(re.findall(token, masked_texts[text_idx]))

    gpt_counter = np.zeros((len(TOKEN2ID)), dtype=int)

    ground_truth = re.findall(r"\[[^\[\]]*\]", text)[0]
    token_list = eval(text)

    for gpt_token in token_list:
        for k in gpt_token:
            gpt_counter[TOKEN2ID[k]] += 1
                
    if not np.array_equal(gpt_counter, token_counter):
        print(text_idx)
        llm = ChatOpenAI(model = "gpt-4-1106-preview", temperature = 1.0, max_tokens = 512)
        class_fill_chain = LLMChain(llm=llm, prompt=FALLBACK_PROMPT)
        refined_response = class_fill_chain(
            {
                "name_no": token_counter[0],
                "nric_no": token_counter[1],
                "mcr_no": token_counter[2],
                "phone_no": token_counter[3],
                "email_no": token_counter[4],
                "address_no": token_counter[5]
                }
            )
        ground_truth = re.findall(r"\[[^\[\]]*\]", refined_response["text"])[0]
        token_list = eval(text)
        gpt_counter = np.zeros((len(TOKEN2ID)), dtype=int)
        for gpt_token in token_list:
            for k in gpt_token:
                gpt_counter[TOKEN2ID[k]] += 1
        
        assert np.array_equal(gpt_counter, token_counter), "Not equal"
    
    assert isinstance(eval(text), list)
    replacement_tokens.append(token_list)
    ground_truths[text_idx] = ground_truth

In [None]:
gen_texts = []
span_lists = []
for text_idx, (entities, masked_text) in enumerate(zip(entity_list, masked_texts)):
    gen_text = masked_text
    span_list = []
    for entity_dict in entities:
        for entity_type, entity in entity_dict.items():
            match_obj = re.search(entity_type, gen_text)
            token_span = match_obj.span()
            start_index = token_span[0] - 15 if token_span[0] - 15 >= 0 else 0
            end_index = token_span[1] + 15
            print(text_idx, gen_text[start_index:end_index])
            gen_text = replace_string_by_index(gen_text, entity, token_span[0], token_span[1])
            entity_span = (token_span[0], token_span[0] + len(entity))
            span_list.append(entity_span)
    span_lists.append(span_list)
    gen_texts.append(gen_text)

In [455]:
with open(os.path.join(DATA_DIR, "sdr", "gt.json"), "w") as f:
    json.dump(ground_truths, f)
    
with open(os.path.join(DATA_DIR, "sdr", "gen_annotations.json"), "w") as f:
    json.dump(span_lists, f)
    
gen_df = pd.DataFrame({"text": gen_texts})

gen_df.to_csv(os.path.join(DATA_DIR, "sdr", "Clindoc_gen.csv"), index=False)

## Generate finetune and inference dataset

In [495]:
system_prompt = """You are a administrator working with diabete patients data.
===
TASK:
Given a patient summary text, identify relevant Protected Health Information (PHI). The PHI can belong to the following categories:

<Person Name>: Name of a person/human. Can be full name, short form, initials, with or without salutation (Dr, Mr, Ms) 
<NRIC/Passport>: National Registration Identity Card (NRIC), Foreign Identity Numbers (FIN) and Passport Number
<Medical Clinical Records>: Medical record numbers
<Phone Number>: Telephone numbers
<Email Address>: Electronic mail addresses
<Home Address>: Patient home address (unit, level, block, street, full 6-digit postal codes).

If no entities is present, return an empty list.
If there are duplicates, include all of them in the list
===
"""

user_prompt = """
PATIENT SUMMARY: {query}
"""

In [496]:
chat_prompts = []
inst_prompts = []

for ground_truth, query in zip(ground_truths, gen_texts):
    inst_prompt = create_llama2_instruction_prompt(
        system_prompt.format(),
        user_prompt.format(query=query),
        ground_truth,
        prompt_template="### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    )
    chat_prompt = create_llama2_chat_prompt(
        system_prompt=system_prompt.format(),
        messages = [user_prompt.format(query=query), ground_truth],
        tokenizer=tokenizer
    )
    
    chat_prompts.append(chat_prompt)
    inst_prompts.append(inst_prompt)
    
inst_dataset = Dataset.from_dict({"text": inst_prompts})
chat_dataset = Dataset.from_dict({"text": chat_prompts})
finetune_dataset = DatasetDict({"chat": chat_dataset,"inst": inst_dataset})

save_dataset_to_json(finetune_dataset, os.path.join(DATA_DIR, "sdr", "finetune")) 

Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  7.39ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 165.86ba/s]


In [497]:
max_tokens = 0
for chat_prompt in chat_prompts:
    token_count = count_tokens(chat_prompt, tokenizer)
    if token_count > max_tokens:
        max_tokens = token_count

print(max_tokens)

2040


In [498]:
eval_chat_prompts = []
eval_inst_prompts = []

for query in gen_texts:
    inst_prompt = create_llama2_instruction_prompt(
        system_prompt.format(),
        user_prompt.format(query=query),
        prompt_template="### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    )
    chat_prompt = create_llama2_chat_prompt(
        system_prompt=system_prompt.format(),
        messages = [user_prompt.format(query=query)],
        tokenizer=tokenizer
    )
    
    eval_chat_prompts.append(chat_prompt)
    eval_inst_prompts.append(inst_prompt)

inst_dataset = Dataset.from_dict({"text": eval_inst_prompts})
chat_dataset = Dataset.from_dict({"text": eval_chat_prompts})
eval_dataset = DatasetDict({"chat": chat_dataset, "inst": inst_dataset})

save_dataset_to_json(eval_dataset, os.path.join(DATA_DIR, "sdr", "eval")) 

Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 137.12ba/s]
Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 138.65ba/s]


# NER GPT-4

In [172]:
from langchain.chains import create_extraction_chain
from langchain.chains.openai_functions.extraction import _get_extraction_function

schema = {
    "properties": {
        "person_name": {"type": "string"},
        "person_height": {"type": "integer"},
        "person_hair_color": {"type": "string"},
        "dog_name": {"type": "string"},
        "dog_breed": {"type": "string"},
    },
    "required": [],
}

# ner_chain = create_extraction_chain(
#     llm=ChatOpenAI(model="gpt-4-1106-preview", temperature = 0, max_tokens = 512),
#     prompt=None,
#     verbose=True
# )

print(_get_extraction_function(schema))

{'name': 'information_extraction', 'description': 'Extracts the relevant information from the passage.', 'parameters': {'type': 'object', 'properties': {'info': {'type': 'array', 'items': {'type': 'object', 'properties': {'person_name': {'title': 'person_name', 'type': 'string'}, 'person_height': {'title': 'person_height', 'type': 'integer'}, 'person_hair_color': {'title': 'person_hair_color', 'type': 'string'}, 'dog_name': {'title': 'dog_name', 'type': 'string'}, 'dog_breed': {'title': 'dog_breed', 'type': 'string'}}, 'required': []}}}, 'required': ['info']}}

In [361]:
system_prompt = """You are a doctor who treats diabetes patients in Singapore.
One of your job scope is to write patient summary which contains information related to patient treatments.
=====
TASK:
You are given a patient summary text. You need to identify relevant Protected Health Information (PHI) for anonymisation pipeline.
Your task is to perform Name Entity Recognition on the given text and extract from the given text ALL relevant entities which belong to the following categories:

1. People name (Full name and Short Form - e.g. Mr Kim, Dr Lee)
2. Telephone and Fax numbers
3. Electronic mail addresses
4. National Registration Identity Card (NRIC), Foreign Identity Numbers (FIN) and Passport Number
5. Medical record numbers
6. Account numbers
7. Certificate/license numbers
8. Vehicle identifiers & license plate numbers
9. Device identifiers
10. Web Universal Resource Locators (URLs) and Internet Protocol (IP) address numbers
11. Patient home address (unit, level, block, street, full postal codes). Exclude if the initial four digits of a postal code or hospital addresses.
12. Date of birth (DOB) ONLY. Do not extract dates of clinical meetings, operations, etc.

=====
OUTPUT INSTRUCTIONS:
Your output should be a list of JSON objects which contains "entity" and "entity_type". Entity type should be an integer corresponding the class number listed above.
=====
EXAMPLES:
PATIENT SUMMARY: Madam Tan 73 year old Chinese female DA: Erythromycin Stays with husband, son and helper Baseline occasionally able to shake and nod in response to questions according to son Past Medical History 1. End stage renal failure secondary to diabetic nephropathy - on HD 1,3,5 via left BC AVF 2. Diabetes mellitus - last HbA1c 6.2% (December 2018) 3. Hypertension 4. Hyperlipidaemia 5. Steal syndrome - status post L BC AVF creation (20/2/12) by Prof Lim Boon Leng and team
OUTPUT JSON:
[{{"entity": "Madam Tan", "entity_type": 1}},{{"entity": "Prof Lim Boon Leng", "entity_type": 1}}]

PATIENT SUMMARY: 75 year old Chinese female ADLs assisted, WC bound lives with domestic helper and family NKDA === PMHX === 1. DM 2. HTN 3. Hyperlipidemia 4. IHD/TVD - s/p PCI (Mar 2014) DES prox-mid LAD - Last 2DE (Mar 2014): EF 57% mild-mod MR - MIBI (May 2014): mild ischaemia in basal infero-lateral wall of LV 5. ESRF secondary to DM on HD 2/4/6 via R BC AVF 6. R eye blindness 7. Previous ICH 8. Previous hx of biliary colic
OUTPUT JSON:
[] 
"""

user_prompt = """=====
PATIENT SUMMARY: {query}
"""

#Note: Keep the entities in the exact form as the original texts (no rephrasing, correction, or change in capital letters, etc)


EXTRACT_PROMPT = ChatPromptTemplate.from_messages(
    messages = [
        SystemMessagePromptTemplate.from_template(system_prompt),
        HumanMessagePromptTemplate.from_template(user_prompt)
    ]
)

In [341]:
extract_chain = LLMChain(
    llm=ChatOpenAI(model = "gpt-4-1106-preview", temperature = 0, max_tokens = 512),
    prompt=EXTRACT_PROMPT)

In [349]:
ner_results = []

async def generate_ner_result(text):
    with get_openai_callback() as cb:
        resp = await extract_chain.acall(text)
    return resp["text"], cb.total_cost

for text in unmasked_text:
    result = generate_ner_result(text)
    ner_results.append(result)
    
ner_results = await asyncio.gather(*ner_results)
gpt_preds = [result[0] for result in ner_results]
total_cost = sum([result[1] for result in ner_results])
print("Total Cost: {:.3f}".format(total_cost))


2.966900000000006


In [351]:
sdr_df["preds_3"] = gpt_preds

sdr_df.to_csv(
    os.path.join(DATA_DIR, "sdr", "Clindoc_recovered.csv")
)

In [366]:
sample_query = refined_texts[0]
sample_gt_response = gpt_preds[0]

is_chat = True

tokenizer = LlamaTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", cache_dir=os.path.join(MAIN_DIR, "model"),
    )

if is_chat:
    prompt_str = create_llama2_chat_prompt(
        messages=[
            user_prompt.format(query = sample_query),
            sample_gt_response
            ],
        tokenizer=tokenizer, system_prompt=system_prompt
    )

In [367]:
print(prompt_str)

[INST] <<SYS>>
You are a doctor who treats diabetes patients in Singapore.
One of your job scope is to write patient summary which contains information related to patient treatments.
=====
TASK:
You are given a patient summary text. You need to identify relevant Protected Health Information (PHI) for anonymisation pipeline.
Your task is to perform Name Entity Recognition on the given text and extract from the given text ALL relevant entities which belong to the following categories:

1. People name (Full name and Short Form - e.g. Mr Kim, Dr Lee)
2. Telephone and Fax numbers
3. Electronic mail addresses
4. National Registration Identity Card (NRIC), Foreign Identity Numbers (FIN) and Passport Number
5. Medical record numbers
6. Account numbers
7. Certificate/license numbers
8. Vehicle identifiers & license plate numbers
9. Device identifiers
10. Web Universal Resource Locators (URLs) and Internet Protocol (IP) address numbers
11. Patient home address (unit, level, block, street, full p