In [1]:
%reset -f
!pip install --upgrade pip
!pip install nltk
!pip install xgboost
!pip install sentence_transformers
!pip install torch



In [2]:
import pandas as pd

# data_size = '10p'
data_size = 'full'
version = 'v1'

data = pd.read_csv('../data/neiss_10p_sample.csv') if data_size == '10p' else pd.read_csv('../data/consolidated_cleaned_neiss_2014_2023.csv')
data = data[['CPSC_Case_Number', 'Narrative' ]]
data.head(10)

  data = pd.read_csv('../data/neiss_10p_sample.csv') if data_size == '10p' else pd.read_csv('../data/consolidated_cleaned_neiss_2014_2023.csv')


Unnamed: 0,CPSC_Case_Number,Narrative
0,140103999,32 YOM CO PAIN IN THE CHEST AFTER FALLING WHIL...
1,140104001,18 YOM BURNED LT HAND WHILE POURING GASOLINE O...
2,140104003,31 YOF STATES SHE FELL THROUGH THE BATHROOM FL...
3,140104670,53YOF WASHING DISHES AT HOME A GLASS BROKE AND...
4,140104672,23YOM RIDING AN ATV ROLLED OVER HURT SHOULDERD...
5,140104673,14MOM WALKING WITH A PENCIL IN HIS MOUTH FELL ...
6,140104932,15YOM WITH CONTUSION TO HIP FROM ICE SKATING FALL
7,140104935,40YOM WITH CONCUSSION FROM ICE SKATING FALL
8,140104936,4YOM WITH FRACTURED WRIST IN FALL FROM BUNK BED
9,140104937,2YOM WITH A *** IN NOSE


In [3]:
def remove_after_dx(narrative):
  if isinstance(narrative, str):
    parts = narrative.split("DX", 1)
    if len(parts) > 1:
      return parts[0]
    else:
      return narrative  # No "DX:" found, return the original string
  else:
    return narrative  # Not a string, return as is

data.loc[:, 'Narrative'] = data['Narrative'].apply(remove_after_dx)
data.head(10)

Unnamed: 0,CPSC_Case_Number,Narrative
0,140103999,32 YOM CO PAIN IN THE CHEST AFTER FALLING WHIL...
1,140104001,18 YOM BURNED LT HAND WHILE POURING GASOLINE O...
2,140104003,31 YOF STATES SHE FELL THROUGH THE BATHROOM FL...
3,140104670,53YOF WASHING DISHES AT HOME A GLASS BROKE AND...
4,140104672,23YOM RIDING AN ATV ROLLED OVER HURT SHOULDER
5,140104673,14MOM WALKING WITH A PENCIL IN HIS MOUTH FELL ...
6,140104932,15YOM WITH CONTUSION TO HIP FROM ICE SKATING FALL
7,140104935,40YOM WITH CONCUSSION FROM ICE SKATING FALL
8,140104936,4YOM WITH FRACTURED WRIST IN FALL FROM BUNK BED
9,140104937,2YOM WITH A *** IN NOSE


In [4]:
medical_terms = {
    "&": "and",
    "***": "",
    ">>": "clinical diagnosis",
    "@": "at",
    "abd": "abdomen",
    "af": "accidental fall",
    "afib": "atrial fibrillation",
    "aki": "acute kidney injury",
    "am": "morning",
    "ams": "altered mental status",
    "bac": "blood alcohol content",
    "bal": "blood alcohol level,",
    "biba": "brought in by ambulance",
    "c/o": "complains of",
    "chi": "closed-head injury",
    "clsd": "closed",
    "cpk": "creatine phosphokinase",
    "cva": "cerebral vascular accident",
    "dx": "clinical diagnosis",
    "ecf": "extended-care facility",
    "er": "emergency room",
    "etoh": "ethyl alcohol",
    "eval": "evaluation",
    "fd": "fall detected",
    "fx": "fracture",
    "fxs": "fractures",
    "glf": "ground level fall",
    "h/o": "history of",
    "htn": "hypertension",
    "hx": "history of",
    "inj": "injury",
    "inr": "international normalized ratio",
    "intox": "intoxication",
    "l": "left",
    "loc": "loss of consciousness",
    "lt": "left",
    "mech": "mechanical",
    "mult": "multiple",
    "n.h.": "nursing home",
    "nh": "nursing home",
    "p/w": "presents with",
    "pm": "afternoon",
    "pt": "patient",
    "pta": "prior to arrival",
    "pts": "patient's",
    "px": "physical examination", # not "procedure",
    "r": "right",
    "r/o": "rules out",
    "rt": "right",
    "s'd&f": "slipped and fell",
    "s/p": "after",
    "sah": "subarachnoid hemorrhage",
    "sdh": "acute subdural hematoma",
    "sts": "sit-to-stand",
    "t'd&f": "tripped and fell",
    "tr": "trauma",
    "uti": "urinary tract infection",
    "w/": "with",
    "w/o": "without",
    "wks": "weeks"
}

In [5]:
import multiprocessing as mp
import numpy as np
import re
import nltk
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Load tokenizer globally
sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

def clean_text(text):
    """
    Clean a single text entry.
    """
    if not isinstance(text, str):
        return ""
        
    # lowercase everything
    text = text.lower()
    
    # unglue DX
    regex_dx = r"([ˆ\W]*(dx)[ˆ\W]*)"
    text = re.sub(regex_dx, r". dx: ", text)
    
    # remove age and sex identifications
    regex_age_sex = r"(\d+)\s*?(yof|yf|yo\s*female|yo\s*f|yom|ym|yr|yo\s*male|yo\s*m)"
    age_sex_match = re.search(regex_age_sex, text)
    
    if age_sex_match:
        text = text.replace(age_sex_match.group(0), "patient")
    
    # translate medical terms
    for term, replacement in medical_terms.items():
        if term in ["@", ">>", "&", "***"]:
            pattern = fr"({re.escape(term)})"
            text = re.sub(pattern, f" {replacement} ", text)
        else:
            pattern = fr"(?<!-)\b({re.escape(term)})\b(?!-)"
            text = re.sub(pattern, replacement, text)
    
    # user-friendly format
    try:
        sentences = sent_tokenizer.tokenize(text)
        sentences = [sent.capitalize() for sent in sentences]
        return " ".join(sentences)
    except Exception as e:
        return text

def clean_text_wrapper(args):
    """
    Wrapper function for multiprocessing that unpacks arguments
    """
    return clean_text(args)

def process_texts(texts, use_parallel=True, n_jobs=None):
    """
    Process texts either in parallel or sequentially.
    
    Args:
        texts: list or pandas Series of texts
        use_parallel: bool, whether to use parallel processing
        n_jobs: int, number of CPU cores to use (None = all cores)
    """
    if not use_parallel:
        # Sequential processing with progress bar
        return [clean_text(text) for text in tqdm(texts, desc="Processing texts")]
    
    # Parallel processing
    if n_jobs is None:
        n_jobs = mp.cpu_count()
        
    print(f"Processing {len(texts)} texts using {n_jobs} cores...")
    
    # Create a pool with the specified number of processes
    with mp.Pool(n_jobs) as pool:
        # Process texts in parallel with progress bar
        results = list(tqdm(
            pool.imap(clean_text_wrapper, texts, chunksize=100),
            total=len(texts),
            desc="Processing texts"
        ))
    
    return results

# Example usage:
# Single process version:
# data['Narrative'] = process_texts(data['Narrative'].tolist(), use_parallel=False)

# Multi-process version:
data['Narrative'] = process_texts(data['Narrative'].tolist(), use_parallel=True)

# Or specify number of cores:
# data['Narrative'] = process_texts(data['Narrative'].tolist(), use_parallel=True, n_jobs=4)

Processing 3520522 texts using 24 cores...


Processing texts:   0%|          | 0/3520522 [00:00<?, ?it/s]

In [6]:
data['Narrative']=data['Narrative'].str.upper()
data[['Narrative']]

Unnamed: 0,Narrative
0,PATIENT CO PAIN IN THE CHEST AFTER FALLING WHI...
1,PATIENT BURNED LEFT HAND WHILE POURING GASOLIN...
2,PATIENT STATES SHE FELL THROUGH THE BATHROOM F...
3,PATIENT WASHING DISHES AT HOME A GLASS BROKE A...
4,PATIENT RIDING AN ATV ROLLED OVER HURT SHOULDER
...,...
3520517,PATIENT HERE FOR A FALL HE GOT UP FROM BED TO ...
3520518,PATIENT HERE FOR RIGHT ANKLE PAIN 2 DAYS PRIOR...
3520519,PATIENT HERE VIA EMS FOR A FALL MONTHS PRIOR T...
3520520,PATIENT HERE VIA EMS FOR SMOKE INHALATION CEIL...


In [7]:
replace_list = [
  'INGESTION','ASPIRATION','BURN','ELECTRICAL','SCALD','CHEMICAL','AMPUTATION','THERMAL','CONCUSSION','CONTUSIONS','CRUSHING',
  'DISLOCATION','FOREIGN','FRACTURE','HEMATOMA','LACERATION','DENTAL','NERVE','DAMAGE','INTERNAL','PUNCTURE','STRAIN','SPRAIN', ' SPR ',
  'HEMORRHAGE','ELECTRIC','POISONING','SUBMERSION','AVULSION','RADIATION','DERMA','CONJUNCT','SWELL','WRIST','ABRASION','ACHE',
  'BREAK','CHIN','CUT','ER','FX','HIT','INJURY','LOSE','LOC','PAIN','TWIST','CONTUSION','LAC','YOM','YOF','YR','OLD','MALE','FEMALE',' YO ',' AFT ',
  'YO ',' F ','YF',' M ','MOF','MM ','MOM',' MO ','MO ','ACCIDENTALLY','PATIENT',' PT ','PT ',' P ',' Y ','INJURY','REPORT',' S ',' FE ','HURT','INJ',
  'FELL','INJURE','JURED','URED',' ED',' RT ',' LT '
]

# add body parts
# replace_list += [
#   'ANKLE', 'ARM', 'BODY_PART', 'CHEST', 'CONTUSION', 'CUT', 'EAR', 'ELBOW', 'EYE', 'FACE', 'FINGER', 'FOOT', 'FOREHEAD', 'FRACTURE', 'FX', 
#   'HAND', 'HEAD', 'HIP', 'KNEE', 'LAC', 'LACERATION', 'LEG', 'LOC', 'LOSE', 'NECK', 'PAIN', 'SHOULDER', 'SPRAIN', 'STRAIN', 'SWELL', 'THUMB', 
#   'TOE', 'WRIST','ABRASION', 'ACHE', 'BREAK', 'BURN', 'CHIN', 'CUT', 'ER', 'FRACTURE', 'FX', 'HIT', 'INJURY', 'LACERATION', 'LIP', 'LOSE', 'LOC',
#   'MOUTH', 'NOSE', 'PAIN', 'RIB', 'SCALP', 'SPRAIN', 'STRAIN', 'SWELL', 'TOE', 'TWIST', 'WRIST'
# ]

for i in replace_list:
  data['Narrative'] = data['Narrative'].str.replace(i, '')

data.head(10)

Unnamed: 0,CPSC_Case_Number,Narrative
0,140103999,CO IN THE CHESTFALLING WHILE PLAYING FOOTBALL
1,140104001,LEFT HAND WHILE POURING GASOLINE ONTO A FIRE
2,140104003,STATES SHE THROUGH THE BATHROOM FLOOR AT HHO...
3,140104670,WASHING DISHES AT HOME A GLASS BROKE AND SHE ...
4,140104672,RIDING AN ATV ROLLED OV SHOULD
5,140104673,14 WALKING WITH A PENCIL IN HIS MOUTH AND MOUTH
6,140104932,WITH TO HIP FROM ICE SKATING FALL
7,140104935,WITH FROM ICE SKATING FALL
8,140104936,WITH D IN FALL FROM BUNK BED
9,140104937,WITH A IN NOSE


In [8]:
import torch
import os
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import gc  # for garbage collection

# Configure for maximum performance
torch.backends.cudnn.benchmark = True
BATCH_SIZE = 512 
NUM_WORKERS = 8  

chunk_folder = f'../data/embedding_chunks_{data_size}_{version}'
output_file = f'../data/gist_embedding_{data_size}_{version}.csv'

class NarrativeDataset(Dataset):
    def __init__(self, narratives):
        self.narratives = narratives
    
    def __len__(self):
        return len(self.narratives)
    
    def __getitem__(self, idx):
        return self.narratives[idx]

def process_in_batches(data, model, batch_size=BATCH_SIZE, save_interval=10000):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()  # Set to evaluation mode
    
    # Create dataset and dataloader
    dataset = NarrativeDataset(data['Narrative'].values)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    
    embeddings = []
    case_numbers = []
    chunk_number = 0
    total_processed = 0
    
    print(f"Processing {len(data)} narratives on {device}")
    
    with torch.no_grad():  # Disable gradient computation
        for batch in tqdm(dataloader, desc="Processing batches"):
            batch_embeddings = model.encode(
                batch,
                convert_to_tensor=True,
                batch_size=len(batch)
            )
            
            batch_embeddings = batch_embeddings.cpu().numpy()
            embeddings.append(batch_embeddings)
            
            # Collect case numbers based on actual batch length
            batch_size_actual = len(batch)
            case_numbers.extend(data['CPSC_Case_Number'].iloc[total_processed:total_processed + batch_size_actual].values)
            total_processed += batch_size_actual
            
            # Periodic saving
            if total_processed >= save_interval * (chunk_number + 1):
                save_chunk(embeddings, case_numbers, chunk_number)
                embeddings = []
                case_numbers = []
                chunk_number += 1
                
                torch.cuda.empty_cache()
                gc.collect()
    
    # Save any remaining data
    if embeddings:
        save_chunk(embeddings, case_numbers, chunk_number)
    
    return chunk_number + 1

def save_chunk(embeddings, case_numbers, chunk_number):
    """
    Save a chunk of processed data
    """
    # Concatenate all embeddings in this chunk
    embeddings_array = np.vstack(embeddings)
    
    # Create DataFrame
    embedding_df = pd.DataFrame(
        embeddings_array,
        index=case_numbers,
    )
    
    # Save to file
    os.makedirs(chunk_folder, exist_ok=True)
    filename = f'{chunk_folder}/gist_embedding_chunk_{chunk_number}.csv'
    embedding_df.to_csv(filename)
    print(f"Saved chunk {chunk_number} with shape {embedding_df.shape}")

def combine_chunks(num_chunks, output_file):
    """
    Combine all chunks into a single file
    """
    dfs = []
    for i in range(num_chunks):
        chunk_file = f'{chunk_folder}/gist_embedding_chunk_{i}.csv'
        df = pd.read_csv(chunk_file)
        dfs.append(df)
        
    final_df = pd.concat(dfs, axis=0)
    final_df.rename(columns={'Unnamed: 0': 'CPSC_Case_Number'}, inplace=True)
    final_df.to_csv(output_file, index=False)
    print(f"Final file saved with shape {final_df.shape}")
    return final_df

# Main execution
try:
    # Initialize model
    revision = None
    model = SentenceTransformer("avsolatorio/GIST-small-Embedding-v0", revision=revision)
    
    # Reset index for consistent processing
    data = data.reset_index(drop=True)
    
    # Process data in chunks
    num_chunks = process_in_batches(
        data,
        model,
        batch_size=BATCH_SIZE,
        save_interval=50000  # Save every 50K records
    )
    
    # Combine all chunks
    final_df = combine_chunks(num_chunks, output_file=output_file)
    print("Processing completed successfully!")

except Exception as e:
    print(f"An error occurred: {str(e)}")

finally:
    # Clean up GPU memory
    torch.cuda.empty_cache()
    gc.collect()

Processing 3520522 narratives on cuda


Processing batches:   0%|          | 0/6877 [00:00<?, ?it/s]

Saved chunk 0 with shape (50176, 384)
Saved chunk 1 with shape (50176, 384)
Saved chunk 2 with shape (49664, 384)
Saved chunk 3 with shape (50176, 384)
Saved chunk 4 with shape (50176, 384)
Saved chunk 5 with shape (49664, 384)
Saved chunk 6 with shape (50176, 384)
Saved chunk 7 with shape (50176, 384)
Saved chunk 8 with shape (49664, 384)
Saved chunk 9 with shape (50176, 384)
Saved chunk 10 with shape (50176, 384)
Saved chunk 11 with shape (49664, 384)
Saved chunk 12 with shape (50176, 384)
Saved chunk 13 with shape (50176, 384)
Saved chunk 14 with shape (49664, 384)
Saved chunk 15 with shape (50176, 384)
Saved chunk 16 with shape (50176, 384)
Saved chunk 17 with shape (49664, 384)
Saved chunk 18 with shape (50176, 384)
Saved chunk 19 with shape (50176, 384)
Saved chunk 20 with shape (49664, 384)
Saved chunk 21 with shape (50176, 384)
Saved chunk 22 with shape (50176, 384)
Saved chunk 23 with shape (49664, 384)
Saved chunk 24 with shape (50176, 384)
Saved chunk 25 with shape (50176, 3

In [9]:
final_df.head(5)

Unnamed: 0,CPSC_Case_Number,0,1,2,3,4,5,6,7,8,...,374,375,376,377,378,379,380,381,382,383
0,140103999,-0.022808,0.007418,0.027467,-0.012784,0.065578,2.1e-05,0.085643,0.013887,-0.004115,...,0.036178,0.025958,-0.039153,0.030998,-0.036593,-0.015551,-0.02823,-0.04937,0.037316,0.069519
1,140104001,-0.026905,0.023876,0.032426,-0.036448,0.052163,0.016243,0.069046,0.048201,-0.042924,...,0.001678,-0.015718,-0.034743,0.029062,0.047233,-0.027738,-0.04926,0.009754,-0.000638,0.019649
2,140104003,-0.022458,-0.034115,0.086562,0.027872,-0.017475,0.035831,0.075362,0.066578,0.02902,...,0.033147,-0.031686,-0.025193,-0.048929,-0.055719,-0.011231,-0.062054,-0.023576,-0.011888,0.03031
3,140104670,-0.021107,-0.035272,0.060965,-0.047719,0.004092,-0.030488,0.082513,0.047852,0.004024,...,-0.051164,-0.067823,0.006062,-0.00843,-0.001242,0.018775,0.020745,-0.003641,-0.034434,0.058451
4,140104672,-0.021438,-0.023495,0.031477,-0.029871,-0.016338,0.063439,0.031051,0.068447,-0.001414,...,0.021603,0.010978,0.012187,0.006838,-0.048688,-0.024933,-0.088906,-0.003716,0.028649,-0.001607


In [11]:
%reset -f