In [None]:
import sys
sys.path.append('./src')
from models import EmbeddingModel

from tqdm import tqdm
import numpy as np
import pandas as pd 
import re
import torch

In [None]:
# ckd cohort contains ICD code for subject and admission

# N181: Chronic kidney disease, stage 1
# N182: Chronic kidney disease, stage 2 (mild)
# N183: Chronic kidney disease, stage 3 (moderate)
# N184: Chronic kidney disease, stage 4 (severe)
# N185: Chronic kidney disease, stage 5
# N186: End stage renal disease
# N189: Chronic kidney disease, unspecified

# egfr categories: 
# 0 Unknown
# 1 Normal or high 
# 2 Mildly decreased
# 3 Moderately decreased
# 4 Severely decreased
# 5 Kidney failure

filename = "ckd_cohort10k"
cohort = pd.read_csv(f"../eval_datasets/{filename}.csv")
data = pd.concat([
        pd.read_csv("../data/note/discharge.csv"),
        pd.read_csv("../data/note/radiology.csv")
    ])
print(cohort['icd_code'].value_counts())

In [None]:
modelname = 'simcse'
model = EmbeddingModel(model_name='../models/med_gte_simcse', pooling='mean')

In [None]:
# additional preprocessing for ckd cohort 
def clean_text(text):
    # Text Cleaning
    text = re.sub(r'[_]+', '', text)    # Remove deidentifiers 
    text = text.replace('//', ' ')    # Replace // symbol with a space
    text = re.sub(r'\n+', ' ', text)    # Replace multiple newlines with a space
    text = re.sub(r'\s\s+', ' ', text)  # Replace multiple spaces with a single space
    
    # Remove leading and trailing spaces
    text = text.strip()
    return text

def get_matching_texts(row, data):
    matching_texts = data[(data['subject_id'] == row['subject_id']) & 
                          (data['hadm_id'] == row['hadm_id'])]['text'].tolist()
    return [clean_text(text) for text in matching_texts if isinstance(text, str) and text.strip()]

cohort['texts'] = cohort.apply(lambda row: get_matching_texts(row, data), axis=1)

def chunk_text(text, max_tokens=512):
    tokens = model.tokenizer.encode(text, add_special_tokens=False)
    return [model.tokenizer.decode(tokens[i:i+max_tokens]) for i in range(0, len(tokens), max_tokens)]

def process_row(row):
    all_chunks = [chunk for text in row['texts'] for chunk in chunk_text(text)]
    #embeddings = model.get_embedding(all_chunks)   
    embeddings = []
    
    for chunk in all_chunks:
        chunk_embedding = model.get_embedding([chunk])  # Pass a single chunk as a list
        embeddings.append(chunk_embedding)
    
    # Concatenate all embeddings
    embeddings = np.concatenate(embeddings, axis=0)
    return np.mean(embeddings, axis=0)

tqdm.pandas()

cohort['embedding'] = cohort.progress_apply(process_row, axis=1)

print(f"\nNumber of rows in processed cohort: {len(cohort)}")
print(f"Number of unique ICD codes: {cohort['icd_code'].nunique()}")
print(f"Average number of texts per row: {cohort['texts'].apply(len).mean():.2f}")

print(cohort['icd_code'].value_counts())
print(cohort['egfr_category'].value_counts())

embeddings = np.stack(cohort['embedding'].values)

In [None]:
# Save embeddings
fp = f'../data/embeddings/{modelname}_{filename}.npy'
np.save(fp, embeddings)