Segment MIMIC Notes

In [1]:
# Import Libraries
from collections import defaultdict
from functools import lru_cache
import re

import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import unidecode

import os


  from .autonotebook import tqdm as notebook_tqdm


Load records!

In [2]:
types = {
        'CHARTDATE': pd.StringDtype(),
        'CHARTTIME': pd.StringDtype(),
        'STORETIME': pd.StringDtype(),
        'CATEGORY': pd.StringDtype(),
        'DESCRIPTION': pd.StringDtype(),
        'ISERROR': pd.StringDtype(),
        'TEXT': pd.StringDtype()
    }

In [3]:
good_categories = {
    'Nursing/other': 11, # 10
    'Radiology': 9,
    'Nursing': 6, # mostly Action, response, plan
    'ECG': 0,
    'Physician ': 10, # 10
    'Discharge summary': 10,
    'Echo': 10,
    'Respiratory ': 10,
    'Nutrition': 9,
    'General': 8,
    'Rehab Services': 9,
    'Social Work': 8, # no good titles
    'Case Management ': 5, # Action, response, plan
    'Pharmacy': 4, # assesment, recommanation
    'Consult': 10,
}

In [4]:
# data path
note_path = "/home/h6x/git_projects/cosc-526-data-engineering-project/data/NOTEEVENTS.csv.gz"

In [5]:
# read the data
cutoff = None
notes = pd.read_csv(note_path, dtype=types, nrows=cutoff)

Data summary

In [6]:
stats = pd.DataFrame({
    "count": notes["CATEGORY"].sort_index(inplace=False).value_counts(),
    "goodness": good_categories
}).sort_values(["goodness", "count"], ascending=[False, False])

In [7]:
stats

Unnamed: 0,count,goodness
Nursing/other,822497,11
Physician,141624,10
Discharge summary,59652,10
Echo,45794,10
Respiratory,31739,10
Consult,98,10
Radiology,522279,9
Nutrition,9418,9
Rehab Services,5431,9
General,8301,8


Convetr CHARTDATE into datetime

In [8]:
notes["CHARTDATE"] = pd.to_datetime(notes["CHARTDATE"])
notes = notes.sort_values(["SUBJECT_ID", "CHARTDATE"])

In [9]:
notes.head(3)

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,CHARTDATE,CHARTTIME,STORETIME,CATEGORY,DESCRIPTION,CGID,ISERROR,TEXT
1671019,1678765,2,163353.0,2138-07-17,2138-07-17 23:08:00,2138-07-17 23:18:00,Nursing/other,Report,17774.0,,Nursing Transfer note Pt admitted to NICU fo...
1671574,1678764,2,163353.0,2138-07-17,2138-07-17 22:51:00,2138-07-17 23:12:00,Nursing/other,Report,16929.0,,Neonatology Attending Triage Note Baby [**Nam...
291220,272794,3,,2101-10-06,,,ECG,Report,,,Sinus rhythm Inferior/lateral ST-T changes are...


Filter the notes by goodness : select only Nursing/other notes(11)

In [10]:
note_relevance = notes["CATEGORY"].isin(stats.query("goodness == 11").index)
notes = notes[note_relevance]

In [11]:
notes.head(3)

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,CHARTDATE,CHARTTIME,STORETIME,CATEGORY,DESCRIPTION,CGID,ISERROR,TEXT
1671019,1678765,2,163353.0,2138-07-17,2138-07-17 23:08:00,2138-07-17 23:18:00,Nursing/other,Report,17774.0,,Nursing Transfer note Pt admitted to NICU fo...
1671574,1678764,2,163353.0,2138-07-17,2138-07-17 22:51:00,2138-07-17 23:12:00,Nursing/other,Report,16929.0,,Neonatology Attending Triage Note Baby [**Nam...
1297688,1260684,3,145834.0,2101-10-21,2101-10-21 06:58:00,2101-10-21 07:15:00,Nursing/other,Report,21570.0,,Micu Progress Nursing Note: Patient arrived i...


In [12]:
if len(notes) == 0:
    raise Exception("Filtering removed all notes")

In [13]:
# Add a new column record_number to the notes DataFrame, which gives each note a sequential number per patient (SUBJECT_ID). So for each patient, their first note will be 0, next is 1, and so on.
notes = notes.groupby('SUBJECT_ID', group_keys=False).apply(lambda group: group.assign(record_number=range(len(group))))

  notes = notes.groupby('SUBJECT_ID', group_keys=False).apply(lambda group: group.assign(record_number=range(len(group))))


Adding index to the notes

In [14]:
notes = notes[["ROW_ID", "SUBJECT_ID", "record_number", "TEXT"]]
notes = notes.rename(columns={"ROW_ID": "rid", "SUBJECT_ID": "pid", "record_number": "rord", "TEXT": "text"})
notes = notes.set_index(["rid", "pid", "rord"])

In [15]:
notes.head(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,text
rid,pid,rord,Unnamed: 3_level_1
1678765,2,0,Nursing Transfer note Pt admitted to NICU fo...
1678764,2,1,Neonatology Attending Triage Note Baby [**Nam...
1260684,3,0,Micu Progress Nursing Note: Patient arrived i...


Now we are doing some data cleaning and segments

In [16]:
#  Assign text into series
records = notes
it = records.text

Cut records into segments

In [17]:
#subsplit(text) tries to identify subsection headers (like Medications:, Allergies:) and split the text accordingly.
def subsplit(text):
    l = re.split(r"\n(.{1,30}:)(?![0-9])", text) # Split on the first 30 characters followed by a colon(So it's looking for section headers like: Medications:, Allergies:, etc.)
    if len(l) == 1:
        yield text # If no subsection headers are found, return the whole text
        return
    if l[0]:
        yield l[0] # If there's text before the first matched header, yield that separately.
    for i in range(1, len(l), 2):
        yield l[i] + l[i+1] # Yield the matched header and the text that follows it.


# Chief Complaint:
# Chest pain and shortness of breath.
# Medications:
# Aspirin 81mg daily.

In [18]:
def cut_record(text):
    top_split_pattern = r"\n\n|\n ?__+\n" # split by double newlines or a newline followed by one or more underscores
    for part in re.split(top_split_pattern, text):
        part = re.sub(r"^\s*\[\*\*[0-9\-]*\*\*\]\s+([0-9]{4}|[0-9]{1,2}:[0-9]{1,2} (PM|AM))", "", part) #  Remove certain date/time tokens
        part = re.sub(r" +FINAL REPORT\n", "", part) # Remove "FINAL REPORT" header
        part = part.strip() # Remove leading/trailing whitespace
        if not part: # Skips any empty or whitespace-only parts after cleaning.
            continue
        yield from subsplit(part) #Passes each cleaned part to a function called subsplit

In [19]:
tqdm.pandas()
parts = it.progress_apply(lambda x: pd.Series(cut_record(x))).stack()

100%|██████████| 822497/822497 [01:05<00:00, 12530.95it/s]


In [20]:
parts.index.names = ['rid', 'pid', 'rord',  'srord']
parts.name = "text"
parts = parts.reset_index()

In [21]:
parts.head(3)

Unnamed: 0,rid,pid,rord,srord,text
0,1678765,2,0,0,Nursing Transfer note
1,1678765,2,0,1,Pt admitted to NICU for sepsis eval. Please se...
2,1678765,2,0,2,"Infant stable in RA. RR 30-40's, sats 96-100%...."


Extract and normalize the segments

In [22]:
def normalize_title(title):
    if title is None:
        return None
    title = re.sub(r"\s+", " ", title) # Replace multiple spaces with a single space
    title = title.strip() # Remove leading/trailing whitespace
    title = title.lower() # Convert to lowercase
    title = unidecode.unidecode(title) # Remove accents and special characters
    title = re.sub(r"[0-9]", "9", title) # Replace digits with '9' : This trick can help when you're clustering or grouping similar sections while ignoring numeric variations.
    return title # Return the cleaned title

In [23]:
def get_title(text):
    m = re.search(r"^(.*?)(?:\:|\.{3,4})(?![0-9])", text) # This regex looks for a title at the start of the text, followed by either a colon or 3-4 dots, and ensures it doesn't end with a number.
    # Negative lookahead to make sure the separator isn’t followed by a number
    if not m: # If no match is found, return None and the original text as the body.
        return None, text
    l, r = m.span() # Get the start and end indices of the match
    title = m.group(1).strip() # Extract the title from the match and remove leading/trailing whitespace
    body = text[r:].strip() # Extract the body of the text after the title
    # Remove any leading/trailing whitespace from the body
    return title, body

In [24]:
def extract_and_normalize(text):
    title, body = get_title(text) #check if the text has a title and body
    return body, title, normalize_title(title)

In [25]:
tqdm.pandas(desc="> Extract and normalize")
derived_columns = pd.DataFrame.from_records(
    parts["text"].progress_apply(extract_and_normalize), # Applies a function with a progress bar over parts["text"].
    columns=["stext", "title", "stitle"]                 # Gets 3 values per row
)
parts = pd.concat([parts, derived_columns], axis=1) # Concatenates the original parts DataFrame with the new derived columns.

> Extract and normalize: 100%|██████████| 4230382/4230382 [00:10<00:00, 420897.68it/s]


In [26]:
parts.head(3)

Unnamed: 0,rid,pid,rord,srord,text,stext,title,stitle
0,1678765,2,0,0,Nursing Transfer note,Nursing Transfer note,,
1,1678765,2,0,1,Pt admitted to NICU for sepsis eval. Please se...,Pt admitted to NICU for sepsis eval. Please se...,,
2,1678765,2,0,2,"Infant stable in RA. RR 30-40's, sats 96-100%....","Infant stable in RA. RR 30-40's, sats 96-100%....",,


Filtering parts

In [27]:
def filter_parts(parts, min_stext_length=10):
    """Remove segments with little context"""
    mask = parts["stext"].str.len() >= min_stext_length # Check if the length of the text is greater than or equal to min_stext_length # boolian result
    return parts[mask].reset_index(drop=True) # Reset the index of the filtered DataFrame # apply the mask to the DataFrame and reset the index

In [28]:
parts = filter_parts(parts)
parts.reset_index(inplace=True, drop=True)

In [29]:
parts.head(3)

Unnamed: 0,rid,pid,rord,srord,text,stext,title,stitle
0,1678765,2,0,0,Nursing Transfer note,Nursing Transfer note,,
1,1678765,2,0,1,Pt admitted to NICU for sepsis eval. Please se...,Pt admitted to NICU for sepsis eval. Please se...,,
2,1678765,2,0,2,"Infant stable in RA. RR 30-40's, sats 96-100%....","Infant stable in RA. RR 30-40's, sats 96-100%....",,


Filtering Titles

In [30]:
def select_good_titles(titles, repeats=20, words=6):
    mask = titles["count"] >= repeats
    mask &= titles["title"].str.len() > 0
    mask &= titles["title"].str.count(" ") < words
    mask &= ~titles["title"].str.contains(",")
    return mask

In [31]:
def get_good_titles(parts, col="stitle"):
    titles = parts[col].value_counts().reset_index()
    titles.columns = ["title", "count"]
    titles = titles[select_good_titles(titles)].reset_index(drop=True)

    tid2t = titles.title.to_dict()
    tid2t = {k+1: v for k,v in tid2t.items()}
    t2tid = {v: k for k,v in tid2t.items()}
    return tid2t, t2tid

In [32]:
tid2t, t2tid = get_good_titles(parts)

In [33]:
tid2t

{1: 'resp',
 2: 'neuro',
 3: 'cv',
 4: 'gi',
 5: 'plan',
 6: 'gu',
 7: 'p',
 8: 'social',
 9: 'a',
 10: 'id',
 11: 'skin',
 12: 'o',
 13: 'endo',
 14: 'gi/gu',
 15: 'fen',
 16: 'a/p',
 17: '#9 o',
 18: 'cardiac',
 19: 'dev',
 20: 'pe',
 21: 'heme',
 22: '#9',
 23: 'access',
 24: 'start date',
 25: 's',
 26: 'respiratory care',
 27: 'events',
 28: 'abd',
 29: 'pain',
 30: 'pulm',
 31: 'renal',
 32: 'd',
 33: '9. o',
 34: 'nutrition',
 35: 'heent',
 36: 'a/goals',
 37: 's/o',
 38: 'ext',
 39: '9. resp',
 40: 'cvs',
 41: 'ln',
 42: 'parenting',
 43: 'assessment/plan',
 44: 'wt',
 45: '#9 fen',
 46: '#9 resp',
 47: 'r',
 48: 'g&d',
 49: 'gu/gi',
 50: 'hc',
 51: 'resp care',
 52: 'integ',
 53: 'soc',
 54: '#9o',
 55: '9. fen',
 56: '9',
 57: 'g/d',
 58: 'chest',
 59: 'imp',
 60: 'general',
 61: 'labs',
 62: 'addendum',
 63: 'plans',
 64: 'c/v',
 65: 'ms',
 66: '#9. o',
 67: 'bili',
 68: 'cvr',
 69: '[** **]',
 70: 'dispo',
 71: '9. g/d',
 72: 'imp/plan',
 73: 'respiratory',
 74: 'assessment

In [34]:
print(f'Good titles: {len(tid2t)}')
print(f'Good titles: {len(t2tid)}')

Good titles: 2121
Good titles: 2121


Adding Labels

In [35]:
parts["label"] = parts["stitle"].map(defaultdict(int, t2tid))-1

In [36]:
titles = pd.DataFrame({
        "title": tid2t.values(),
        "freq": parts.label.value_counts().iloc[1:].sort_index()
    })

Saving the segments

In [38]:
def create_name(pre, name, post):
    if name:
        return f"{pre}{name}-{post}"
    return f"{pre}{post}"

In [None]:
name = "c_nurse"
parts.reset_index(drop=True).to_feather(create_name("dataset/", name, "parts.feather")) # Save the parts DataFrame to a Feather file
titles.to_feather(create_name("dataset/", name, "titles.feather"))
# Saving both datasets using to_feather (a fast binary format for DataFrames)

In [41]:
parts.shape

(3969550, 9)

In [43]:
parts.head(10)

Unnamed: 0,rid,pid,rord,srord,text,stext,title,stitle,label
0,1678765,2,0,0,Nursing Transfer note,Nursing Transfer note,,,-1
1,1678765,2,0,1,Pt admitted to NICU for sepsis eval. Please se...,Pt admitted to NICU for sepsis eval. Please se...,,,-1
2,1678765,2,0,2,"Infant stable in RA. RR 30-40's, sats 96-100%....","Infant stable in RA. RR 30-40's, sats 96-100%....",,,-1
3,1678764,2,1,0,Neonatology Attending Triage Note,Neonatology Attending Triage Note,,,-1
4,1678764,2,1,1,Baby [**Name (NI) 1**] [**Known lastname 2**] ...,Baby [**Name (NI) 1**] [**Known lastname 2**] ...,,,-1
5,1678764,2,1,2,Mother is 34 years old G1 P0-1.,Mother is 34 years old G1 P0-1.,,,-1
6,1678764,2,1,3,"PNS: A pos, Ab neg, HBSAg neg, RPR NR, RI, GB...","A pos, Ab neg, HBSAg neg, RPR NR, RI, GBS neg....",PNS,pns,183
7,1678764,2,1,4,"PE - Baby is [**Name2 (NI) 5**] and vigorous, ...","PE - Baby is [**Name2 (NI) 5**] and vigorous, ...",,,-1
8,1678764,2,1,6,Assessment/plan:\nTerm male infant with increa...,Term male infant with increased risk of sepsis...,Assessment/plan,assessment/plan,42
9,1260684,3,0,1,Patient arrived in unit at 19:15 from ED. Hx o...,Patient arrived in unit at 19:15 from ED. Hx o...,,,-1


In [44]:
# get the count of rows in the parts DataFrame where the label is not -1
print(f"Number of rows with label not -1: {len(parts[parts.label != -1])}")

Number of rows with label not -1: 2351046
