In [1]:
!pip install --upgrade pip
!pip install nltk
!pip install xgboost
!pip install sentence_transformers
!pip install torch



In [2]:
import pandas as pd

# new_columns=pd.read_csv('../data/neiss_10p_sample_new_columns.csv')
data_10P = pd.read_csv('../data/neiss_10p_sample.csv')

In [3]:
# data=data_10P.merge(new_columns,how='inner',on='CPSC_Case_Number').merge(sematic,how='inner',on='CPSC_Case_Number').reset_index(drop=True)
# data.rename(columns={"sematic_distance_bert": "sematic_distance"},inplace=True)
# pick only Narrative and Diagnosis columns
data = data_10P[['CPSC_Case_Number', 'Narrative' ]]
data.head(10)

Unnamed: 0,CPSC_Case_Number,Narrative
0,221032332,14YOM REPORTS HE FELL 1 WEEK AND COMPLAINS OF ...
1,181109464,A 28YOM BENT TO PICK UP CRATE AT HOME TO ED WI...
2,210103105,35YOMRIDING ON MOUNTAIN BIKE PRACTICING FELL D...
3,161157997,14 MONTH OLD FEMALE ABRASION FOR NOSE AND FORE...
4,181107411,4YR M PLAYING WITH TOY KITCHEN APPLIANCE AND ...
5,200134239,7MOM SITTING ON THE COUNTER AND GRABBED A HOT ...
6,140951498,12YOM FELL DURING PE ACTIVITY DX CONTUSED HAND
7,221017396,44YOF CHILD WAS SWINGING A BACK PACK THAT HAD ...
8,200645623,28YOM FELL OFF SKATEBOARD LANDED ON L SIDE DX ...
9,141040420,16YOF ACTIVE PLAYING VOLLEYBALL 7 DAYS A WEEK ...


In [4]:
def remove_after_dx(narrative):
  if isinstance(narrative, str):
    parts = narrative.split("DX", 1)
    if len(parts) > 1:
      return parts[0]
    else:
      return narrative  # No "DX:" found, return the original string
  else:
    return narrative  # Not a string, return as is

data.loc[:, 'Narrative'] = data['Narrative'].apply(remove_after_dx)
data.head(10)

Unnamed: 0,CPSC_Case_Number,Narrative
0,221032332,14YOM REPORTS HE FELL 1 WEEK AND COMPLAINS OF ...
1,181109464,A 28YOM BENT TO PICK UP CRATE AT HOME TO ED WI...
2,210103105,35YOMRIDING ON MOUNTAIN BIKE PRACTICING FELL DOWN
3,161157997,14 MONTH OLD FEMALE ABRASION FOR NOSE AND FORE...
4,181107411,4YR M PLAYING WITH TOY KITCHEN APPLIANCE AND ...
5,200134239,7MOM SITTING ON THE COUNTER AND GRABBED A HOT ...
6,140951498,12YOM FELL DURING PE ACTIVITY
7,221017396,44YOF CHILD WAS SWINGING A BACK PACK THAT HAD ...
8,200645623,28YOM FELL OFF SKATEBOARD LANDED ON L SIDE
9,141040420,16YOF ACTIVE PLAYING VOLLEYBALL 7 DAYS A WEEK ...


In [5]:
medical_terms = {
    "&": "and",
    "***": "",
    ">>": "clinical diagnosis",
    "@": "at",
    "abd": "abdomen",
    "af": "accidental fall",
    "afib": "atrial fibrillation",
    "aki": "acute kidney injury",
    "am": "morning",
    "ams": "altered mental status",
    "bac": "blood alcohol content",
    "bal": "blood alcohol level,",
    "biba": "brought in by ambulance",
    "c/o": "complains of",
    "chi": "closed-head injury",
    "clsd": "closed",
    "cpk": "creatine phosphokinase",
    "cva": "cerebral vascular accident",
    "dx": "clinical diagnosis",
    "ecf": "extended-care facility",
    "er": "emergency room",
    "etoh": "ethyl alcohol",
    "eval": "evaluation",
    "fd": "fall detected",
    "fx": "fracture",
    "fxs": "fractures",
    "glf": "ground level fall",
    "h/o": "history of",
    "htn": "hypertension",
    "hx": "history of",
    "inj": "injury",
    "inr": "international normalized ratio",
    "intox": "intoxication",
    "l": "left",
    "loc": "loss of consciousness",
    "lt": "left",
    "mech": "mechanical",
    "mult": "multiple",
    "n.h.": "nursing home",
    "nh": "nursing home",
    "p/w": "presents with",
    "pm": "afternoon",
    "pt": "patient",
    "pta": "prior to arrival",
    "pts": "patient's",
    "px": "physical examination", # not "procedure",
    "r": "right",
    "r/o": "rules out",
    "rt": "right",
    "s'd&f": "slipped and fell",
    "s/p": "after",
    "sah": "subarachnoid hemorrhage",
    "sdh": "acute subdural hematoma",
    "sts": "sit-to-stand",
    "t'd&f": "tripped and fell",
    "tr": "trauma",
    "uti": "urinary tract infection",
    "w/": "with",
    "w/o": "without",
    "wks": "weeks"
}

In [6]:
import multiprocessing as mp
import numpy as np
import re
import nltk
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Load tokenizer globally
sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

def clean_text(text):
    """
    Clean a single text entry.
    """
    if not isinstance(text, str):
        return ""
        
    # lowercase everything
    text = text.lower()
    
    # unglue DX
    regex_dx = r"([ˆ\W]*(dx)[ˆ\W]*)"
    text = re.sub(regex_dx, r". dx: ", text)
    
    # remove age and sex identifications
    regex_age_sex = r"(\d+)\s*?(yof|yf|yo\s*female|yo\s*f|yom|ym|yr|yo\s*male|yo\s*m)"
    age_sex_match = re.search(regex_age_sex, text)
    
    if age_sex_match:
        text = text.replace(age_sex_match.group(0), "patient")
    
    # translate medical terms
    for term, replacement in medical_terms.items():
        if term in ["@", ">>", "&", "***"]:
            pattern = fr"({re.escape(term)})"
            text = re.sub(pattern, f" {replacement} ", text)
        else:
            pattern = fr"(?<!-)\b({re.escape(term)})\b(?!-)"
            text = re.sub(pattern, replacement, text)
    
    # user-friendly format
    try:
        sentences = sent_tokenizer.tokenize(text)
        sentences = [sent.capitalize() for sent in sentences]
        return " ".join(sentences)
    except Exception as e:
        return text

def clean_text_wrapper(args):
    """
    Wrapper function for multiprocessing that unpacks arguments
    """
    return clean_text(args)

def process_texts(texts, use_parallel=True, n_jobs=None):
    """
    Process texts either in parallel or sequentially.
    
    Args:
        texts: list or pandas Series of texts
        use_parallel: bool, whether to use parallel processing
        n_jobs: int, number of CPU cores to use (None = all cores)
    """
    if not use_parallel:
        # Sequential processing with progress bar
        return [clean_text(text) for text in tqdm(texts, desc="Processing texts")]
    
    # Parallel processing
    if n_jobs is None:
        n_jobs = mp.cpu_count()
        
    print(f"Processing {len(texts)} texts using {n_jobs} cores...")
    
    # Create a pool with the specified number of processes
    with mp.Pool(n_jobs) as pool:
        # Process texts in parallel with progress bar
        results = list(tqdm(
            pool.imap(clean_text_wrapper, texts, chunksize=100),
            total=len(texts),
            desc="Processing texts"
        ))
    
    return results

# Example usage:
# Single process version:
# data['Narrative'] = process_texts(data['Narrative'].tolist(), use_parallel=False)

# Multi-process version:
data['Narrative'] = process_texts(data['Narrative'].tolist(), use_parallel=True)

# Or specify number of cores:
# data['Narrative'] = process_texts(data['Narrative'].tolist(), use_parallel=True, n_jobs=4)

Processing 352052 texts using 24 cores...


Processing texts:   0%|          | 0/352052 [00:00<?, ?it/s]

In [7]:
data['Narrative']=data['Narrative'].str.upper()
data[['Narrative']]

Unnamed: 0,Narrative
0,PATIENT REPORTS HE FELL 1 WEEK AND COMPLAINS O...
1,A PATIENT BENT TO PICK UP CRATE AT HOME TO ED ...
2,PATIENTRIDING ON MOUNTAIN BIKE PRACTICING FELL...
3,14 MONTH OLD FEMALE ABRASION FOR NOSE AND FORE...
4,PATIENT M PLAYING WITH TOY KITCHEN APPLIANCE A...
...,...
352047,PATIENT HELPING LOAD BICYCLE IN VAN CUT HEAD
352048,PATIENT PUSHING DOWN FOOD IN BLENDER MAKING A ...
352049,PATIENT FELL DOWN STEPS
352050,PATIENT WAS RIDING A BIKE AND FELL WRIST INJURY


In [8]:
replace_list=['ANKLE', 'ARM', 'BODY_PART', 'CHEST', 'CONTUSION', 'CUT', 'EAR', 'ELBOW', 'EYE', 'FACE', 'FINGER', 'FOOT', 'FOREHEAD', 'FRACTURE', 'FX', 'HAND', 'HEAD', 'HIP', 'KNEE', 'LAC', 'LACERATION', 'LEG', 'LOC', 'LOSE', 'NECK', 'PAIN', 'SHOULDER', 
'SPRAIN', 'STRAIN', 'SWELL', 'THUMB', 'TOE', 'WRIST','ABRASION', 'ACHE', 'BREAK', 'BURN', 'CHIN', 'CUT', 'ER', 'FRACTURE', 'FX', 'HIT', 'INJURY', 'LACERATION', 'LIP', 'LOSE', 'LOC', 'MOUTH', 'NOSE', 'PAIN', 'RIB', 'SCALP', 'SPRAIN', 'STRAIN', 'SWELL', 'TOE', 'TWIST', 'WRIST']

for i in replace_list:
  data['Narrative'] = data['Narrative'].str.replace(i, '')


data['Narrative'] = data['Narrative'].str.replace('YOM', '')
data['Narrative'] = data['Narrative'].str.replace('YOF', '')
data['Narrative'] = data['Narrative'].str.replace('YR', '')
data['Narrative'] = data['Narrative'].str.replace('OLD', '')
data['Narrative'] = data['Narrative'].str.replace('MALE', '')
data['Narrative'] = data['Narrative'].str.replace('FEMALE', '')
data['Narrative'] = data['Narrative'].str.replace(' YO ', '')
data['Narrative'] = data['Narrative'].str.replace('YO ', '')
data['Narrative'] = data['Narrative'].str.replace(' F ', '')
data['Narrative'] = data['Narrative'].str.replace('YF', '')
data['Narrative'] = data['Narrative'].str.replace(' M ', '')
data['Narrative'] = data['Narrative'].str.replace('ACCIDENTALLY','')
data['Narrative'] = data['Narrative'].str.replace('AGO', '')
data['Narrative'] = data['Narrative'].str.replace('TODAY', '')
data['Narrative'] = data['Narrative'].str.replace('YESTERDAY', '')
data['Narrative'] = data['Narrative'].str.replace('PATIENT', '')
data['Narrative'] = data['Narrative'].str.replace(' PT ', '')
data['Narrative'] = data['Narrative'].str.replace('INJURY', '')
data['Narrative'] = data['Narrative'].str.replace('REPORT', '')
data['Narrative'] = data['Narrative'].str.replace('HURT', '')
data['Narrative'] = data['Narrative'].str.replace('INJ', '')
data['Narrative'] = data['Narrative'].str.replace('FELL', 'FALL')
data['Narrative'] = data['Narrative'].str.replace('INJURE', '')
data['Narrative'] = data['Narrative'].str.replace('JURED', '')
data['Narrative'] = data['Narrative'].str.replace('URED', '')
data['Narrative'] = data['Narrative'].str.replace(' ED', '')
data['Narrative'] = data['Narrative'].str.replace(' RT ', '')
data['Narrative'] = data['Narrative'].str.replace(' LT ', '')

data.head(10)

Unnamed: 0,CPSC_Case_Number,Narrative
0,221032332,S HE FALL 1 WEEK AND COMPLAINS OF HE HAS BE...
1,181109464,A BENT TO PICK UP CRATE AT HOME TO WITH LOW B...
2,210103105,RIDING ON MOUNTAIN BIKE PRACTICING FALL DOWN
3,161157997,14 MONTH FE FOR AND WAS COMINGDOWN STAIRS...
4,181107411,PLAYING WITH TOY KITCHEN APPLIANCE AND GOT C...
5,200134239,7MOM SITTING ON THE COUNT AND GRABBED A HOT CU...
6,140951498,FALL DURING PE ACTIVITY
7,221017396,CHILD WAS SWINGING A BACK PACK THAT HAD A LAP...
8,200645623,FALL OFF SKATEBOARD LANDED ON LEFT SIDE
9,141040420,ACTIVE PLAYING VOLLEYBALL 7 DAYS A WEEK RUNNI...


In [9]:
import torch
import os
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import gc  # for garbage collection

# Configure for maximum performance
torch.backends.cudnn.benchmark = True
BATCH_SIZE = 512 
NUM_WORKERS = 8  

class NarrativeDataset(Dataset):
    def __init__(self, narratives):
        self.narratives = narratives
    
    def __len__(self):
        return len(self.narratives)
    
    def __getitem__(self, idx):
        return self.narratives[idx]

def process_in_batches(data, model, batch_size=BATCH_SIZE, save_interval=10000):
    """
    Process narratives in batches with periodic saving
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()  # Set to evaluation mode
    
    # Create dataset and dataloader
    dataset = NarrativeDataset(data['Narrative'].values)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    
    # Initialize storage
    embeddings = []
    case_numbers = []
    current_count = 0
    chunk_number = 0
    
    print(f"Processing {len(data)} narratives on {device}")
    
    with torch.no_grad():  # Disable gradient computation
        for batch in tqdm(dataloader, desc="Processing batches"):
            # Generate embeddings
            batch_embeddings = model.encode(
                batch,
                convert_to_tensor=True,
                batch_size=len(batch)
            )
            
            # Move to CPU and convert to numpy
            batch_embeddings = batch_embeddings.cpu().numpy()
            embeddings.append(batch_embeddings)
            
            # Store case numbers
            start_idx = current_count
            end_idx = current_count + len(batch)
            case_numbers.extend(data['CPSC_Case_Number'].iloc[start_idx:end_idx].values)
            
            current_count += len(batch)
            
            # Periodic saving
            if current_count >= save_interval:
                save_chunk(embeddings, case_numbers, chunk_number)
                embeddings = []
                case_numbers = []
                current_count = 0
                chunk_number += 1
                
                # Clear GPU memory
                torch.cuda.empty_cache()
                gc.collect()
    
    # Save any remaining data
    if embeddings:
        save_chunk(embeddings, case_numbers, chunk_number)
    
    return chunk_number + 1

def save_chunk(embeddings, case_numbers, chunk_number):
    """
    Save a chunk of processed data
    """
    # Concatenate all embeddings in this chunk
    embeddings_array = np.vstack(embeddings)
    
    # Create DataFrame
    embedding_df = pd.DataFrame(
        embeddings_array,
        index=case_numbers
    )
    
    # Save to file
    os.makedirs('../data/embedding_chunks', exist_ok=True)
    filename = f'../data/embedding_chunks/gist_embedding_chunk_{chunk_number}.csv'
    embedding_df.to_csv(filename)
    print(f"Saved chunk {chunk_number} with shape {embedding_df.shape}")

def combine_chunks(num_chunks, output_file='../data/gist_embedding_final.csv'):
    """
    Combine all chunks into a single file
    """
    dfs = []
    for i in range(num_chunks):
        chunk_file = f'../data/embedding_chunks/gist_embedding_chunk_{i}.csv'
        df = pd.read_csv(chunk_file)
        dfs.append(df)
        
    final_df = pd.concat(dfs, axis=0)
    final_df.to_csv(output_file, index=False)
    print(f"Final file saved with shape {final_df.shape}")

# Main execution
try:
    # Initialize model
    revision = None
    model = SentenceTransformer("avsolatorio/GIST-small-Embedding-v0", revision=revision)
    
    # Reset index for consistent processing
    data = data.reset_index(drop=True)
    
    # Process data in chunks
    num_chunks = process_in_batches(
        data,
        model,
        batch_size=BATCH_SIZE,
        save_interval=50000  # Save every 50K records
    )
    
    # Combine all chunks
    combine_chunks(num_chunks)
    
    print("Processing completed successfully!")

except Exception as e:
    print(f"An error occurred: {str(e)}")

finally:
    # Clean up GPU memory
    torch.cuda.empty_cache()
    gc.collect()

Processing 352052 narratives on cuda


Processing batches:   0%|          | 0/688 [00:00<?, ?it/s]

Saved chunk 0 with shape (10240, 384)
Saved chunk 1 with shape (10240, 384)
Saved chunk 2 with shape (10240, 384)
Saved chunk 3 with shape (10240, 384)
Saved chunk 4 with shape (10240, 384)
Saved chunk 5 with shape (10240, 384)
Saved chunk 6 with shape (10240, 384)
Saved chunk 7 with shape (10240, 384)
Saved chunk 8 with shape (10240, 384)
Saved chunk 9 with shape (10240, 384)
Saved chunk 10 with shape (10240, 384)
Saved chunk 11 with shape (10240, 384)
Saved chunk 12 with shape (10240, 384)
Saved chunk 13 with shape (10240, 384)
Saved chunk 14 with shape (10240, 384)
Saved chunk 15 with shape (10240, 384)
Saved chunk 16 with shape (10240, 384)
Saved chunk 17 with shape (10240, 384)
Saved chunk 18 with shape (10240, 384)
Saved chunk 19 with shape (10240, 384)
Saved chunk 20 with shape (10240, 384)
Saved chunk 21 with shape (10240, 384)
Saved chunk 22 with shape (10240, 384)
Saved chunk 23 with shape (10240, 384)
Saved chunk 24 with shape (10240, 384)
Saved chunk 25 with shape (10240, 3