In [3]:
pip install google-generativeai

Note: you may need to restart the kernel to use updated packages.


In [4]:
pip install pandas

Note: you may need to restart the kernel to use updated packages.


In [5]:
pip install datasets

Note: you may need to restart the kernel to use updated packages.


In [6]:
from datasets import load_dataset

ds = load_dataset("FiscaAI/synth-ehr-icd10cm-prompt")

README.md:   0%|          | 0.00/392 [00:00<?, ?B/s]

train-00000-of-00002.parquet:   0%|          | 0.00/70.5M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/70.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/366120 [00:00<?, ? examples/s]

In [7]:
# Convert the 'train' split to a pandas DataFrame
df = ds['train'].to_pandas()

# Split the 'assistant' column into two new columns: 'icd10_code' and 'explanation'
df[['icd10_code', 'explanation']] = df['assistant'].str.split('\n', n=1, expand=True)

# Remove 'Explanation:' from the 'explanation' column
df['explanation'] = df['explanation'].str.replace('Explanation: ', '', regex=False)

# Remove the ICD10 code (before the first colon) from the 'explanation' column
df['explanation'] = df['explanation'].str.split(':', n=1).str[1].str.strip()

# Optionally, remove any leading/trailing whitespace from 'icd10_code'
df['icd10_code'] = df['icd10_code'].str.strip()


In [8]:
import numpy as np
df['codes'] = df['codes'].apply(lambda x: str(x[0]) if isinstance(x, np.ndarray) else x)

In [9]:
import google.generativeai as genai
import os
from tqdm import tqdm

In [10]:
# Filter rows where 'explanation' or 'user' columns are empty or NaN
filtered_df = df[(df["explanation"].notna()) & (df["explanation"] != "") &
                 (df["user"].notna()) & (df["user"] != "")]

# Display the filtered DataFrame
filtered_df

Unnamed: 0,system,user,assistant,codes,icd10_code,explanation
0,You are a medical coding assistant. Your task ...,The patient has a history of right hand injury...,List of ICD10-CM codes: M24.541\n\n Explanatio...,M24.541,List of ICD10-CM codes: M24.541,"Contracture, right hand."
1,You are a medical coding assistant. Your task ...,The patient has a history of left hand injury ...,List of ICD10-CM codes: M24.542\n\n Explanatio...,M24.542,List of ICD10-CM codes: M24.542,"Contracture, left hand."
2,You are a medical coding assistant. Your task ...,The patient has a history of osteoarthritis in...,List of ICD10-CM codes: M24.561\n\n Explanatio...,M24.561,List of ICD10-CM codes: M24.561,"Contracture, right knee."
3,You are a medical coding assistant. Your task ...,The patient has a history of ankylosing spondy...,List of ICD10-CM codes: M24.652\n\n Explanatio...,M24.652,List of ICD10-CM codes: M24.652,"Ankylosis, left hip."
4,You are a medical coding assistant. Your task ...,The patient has a history of osteoarthritis in...,List of ICD10-CM codes: M24.661\n\n Explanatio...,M24.661,List of ICD10-CM codes: M24.661,"Ankylosis, right knee."
...,...,...,...,...,...,...
366115,You are a medical coding assistant. Your task ...,The patient has a history of progressive muscl...,List of ICD10-CM codes: M33.20\n\n Explanation...,M33.20,List of ICD10-CM codes: M33.20,"Polymyositis, organ involvement unspecified."
366116,You are a medical coding assistant. Your task ...,The patient has a history of progressive muscl...,List of ICD10-CM codes: M33.90\n\n Explanation...,M33.90,List of ICD10-CM codes: M33.90,"Dermatopolymyositis, unspecified, organ involv..."
366117,You are a medical coding assistant. Your task ...,The patient has a history of limited cutaneous...,List of ICD10-CM codes: M34.1\n\n Explanation:...,M34.1,List of ICD10-CM codes: M34.1,CR(E)ST syndrome.
366118,You are a medical coding assistant. Your task ...,The patient has a history of Raynaud's phenome...,List of ICD10-CM codes: M34.81\n\n Explanation...,M34.81,List of ICD10-CM codes: M34.81,Systemic sclerosis with lung involvement.


In [11]:
# Create a unique key by combining the ICD-10 code with the row index
filtered_df['unique_key'] = filtered_df['codes'] + '_' + filtered_df.index.astype(str)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df['unique_key'] = filtered_df['codes'] + '_' + filtered_df.index.astype(str)


In [12]:
raw_df=filtered_df.copy()

In [13]:
# Filter out rows with empty explanations
filtered_df = raw_df[raw_df['explanation'].notna() & raw_df['explanation'].str.strip() != '']

In [14]:
import google.generativeai as genai
import os
from tqdm import tqdm  

# Configure Gemini API
genai.configure(api_key=api_key)  # Use environment variables instead

In [15]:
import time
from tqdm import tqdm

# Function to get embeddings in batches
def get_embeddings_in_batch(text_list):
    try:
        result = genai.embed_content(
            model="models/text-embedding-004",
            content=text_list  # Pass a list of texts for batch processing
        )
        if 'embedding' in result and result['embedding']:
            return result['embedding']
        else:
            return [None] * len(text_list)  # Return a list of None if embeddings are missing
    except Exception as e:
        return [None] * len(text_list)  # Handle errors gracefully by returning None for the batch

In [16]:
import pandas as pd
from tqdm import tqdm
import time
import gc  # For garbage collection
import logging

batch_size = 50  # Define batch size for processing
combined_embeddings_dict = {}

# Set up logging to capture potential issues
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

print("Combining explanations and symptoms into embeddings with batch processing...")
with tqdm(total=len(filtered_df), desc="Processing batches") as pbar:
    for batch_start in range(0, len(filtered_df), batch_size):
        batch_end = min(batch_start + batch_size, len(filtered_df))
        batch_df = filtered_df.iloc[batch_start:batch_end]  # Extract batch
        
        explanations = batch_df['explanation'].tolist()  # Extract explanation texts
        symptoms = batch_df['user'].tolist()  # Extract symptoms texts
        
        try:
            explanation_embeddings = get_embeddings_in_batch(explanations)
            symptom_embeddings = get_embeddings_in_batch(symptoms)
        except Exception as e:
            logging.error(f"Error in processing batch starting at row {batch_start}: {e}")
            continue
        
        for i, (explanation_emb, symptom_emb) in enumerate(zip(explanation_embeddings, symptom_embeddings)):
            combined_embeddings_dict[batch_start + i] = (explanation_emb, symptom_emb)
        
        pbar.update(batch_size)  # Update progress for the processed batch
        
        # Optional: Pause after processing a large number of batches
        if batch_start > 0 and batch_start % (10 * batch_size) == 0:
            logging.info(f"Processed {batch_start} rows. Taking a short break...")
            time.sleep(10)
            
        # Clear variables to free memory
        del explanations, symptoms, explanation_embeddings, symptom_embeddings
        gc.collect()

print("Combined embeddings created successfully.")

Combining explanations and symptoms into embeddings with batch processing...


Processing batches: 366150it [9:41:11, 10.50it/s]                              

Combined embeddings created successfully.





In [21]:
import pickle

# Define the output file name
pickle_file_path = "combined_embeddings.pkl"

# Save the dictionary to a pickle file
with open(pickle_file_path, "wb") as pickle_file:
    pickle.dump(combined_embeddings_dict, pickle_file)

print(f"Combined embeddings saved successfully to {pickle_file_path}.")

Combined embeddings saved successfully to combined_embeddings.pkl.


In [22]:
import os
from IPython.display import FileLink

# Set the working directory to /kaggle/working
os.chdir(r'/kaggle/working')

# Display a download link for the combined_embeddings.pkl file
display(FileLink(r'combined_embeddings.pkl'))

In [24]:
# Extract the embeddings directly from the dictionary
filtered_df['explanation_embedding'] = filtered_df.index.map(
    lambda idx: combined_embeddings_dict[idx][0] if idx in combined_embeddings_dict else None
)
filtered_df['symptom_embedding'] = filtered_df.index.map(
    lambda idx: combined_embeddings_dict[idx][1] if idx in combined_embeddings_dict else None
)