# Versa-spine Embeddings
A fully automized pipeline for information extraction from lumbar spine MRI radiology reports using GPT4

In [None]:
# Import of packages
import openai
import os

import pandas as pd
import numpy as np

import dotenv
from dotenv import load_dotenv, find_dotenv
from openai import AzureOpenAI
dotenv.load_dotenv('.env')

API_KEY = os.environ.get('API_KEY')
API_VERSION = os.environ.get('API_VERSION')
RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')


## Data selection and preprocessing

In [None]:
data_store = {
    "PathologiesLevel": {
        1: {"name": "Endplate Changes",
            "abbr": "endplate",
            "loc": "level",
            "syn": "Modic changes, Modic endplate changes, fibrovascular degenerative changes, fibrovascular changes, fibrofatty degenerative changes, fibrofatty changes, endplate sclerosis, endplate degeneration, endplate irregularity, endplate irregularities, endplate defect, endplate defects, Schmorl's node, schmorls node" 
            },
        2: {"name": "Disc Pathology",
            "abbr": "disc",
            "loc": "level",
            "syn": "disc bulge, disc bulging, disc protrusion, disc extrusion, annular fissure, disc tear, annular tear, disc herniation" 
            },
        3: {"name": "Spinal Canal Stenosis",
            "abbr": "scs",
            "loc": "level",
            "syn": "spinal canal stenosis, spinal canal narrowing, central canal stenosis, central canal narrowing, canal stenosis, canal narrowing"
            },
        4: {"name": "Facet Joint Arthropathy",
            "abbr": "fj",
            "loc": "level",
            "syn": "facet joint degeneration, facet joint arthropathy, facet degeneration, facet arthropathy, facet hypertropthy"
            }
            }, 
    "PathologiesLevelSide": {
        1: {"name": "Lateral Recess Stenosis",
            "abbr": "lrs",
            "loc": "level_side",
            "syn": "lateral recess stenosis, subarticular recess stenosis, recess stenosis, lateral recess narrowing, subarticular recess narrowing, recess narrowing, narrowing of lateral recess, narrowing of subarticular recess, stenosis of lateral recess, stenosis of subarticular recess, effacement of lateral recess, effacement of subarticular recess"
            },
        2: {"name": "Foraminal Stenosis",
            "abbr": "fs",
            "loc": "level_side",
            "syn": "neural foraminal stenosis, neural foraminal narrowing, neural foraminal effacement, neural foraminal nerve root affection, foraminal stenosis, foraminal narrowing, foraminal effacement, foraminal nerve root affection, neuroforaminal stenosis, neuroforaminal narrowing, neuroforaminal effacement, neuroforaminal nerve root affection,"
            },
    },
    "PathologiesPatient": {
        1: {"name": "Sacroiliac Joint",
            "abbr": "sij",
            "loc": "patient",
            "syn": "sacroiliac joint degeneration, degeneration of sacroiliac joints, sacro-iliac joint degeneration, degeneration of sacro-iliac joints, SIJ degeneration, degeneration of SIJ, degenerative changes of the sacroiliac joints, degenerative changes of the sacro-iliac joints"
            },  
        2: {"name": "Olisthesis",
            "abbr": "olisth",
            "loc": "patient",
            "syn": "anterolisthesis, retrolisthesis, spondylolysis, pseudo-anterolisthesis, pseudo-retrolisthesis, vertebral displacement"
            },
        3: {"name": "Curvature",
            "abbr": "curv",
            "loc": "patient",
            "syn": "scoliosis, levoconvex curvature, dextroconvex curvature, leftward convex curvature, rightward convex curvature, levocurvature, dextrocurvature, levoscoliosis, dextroscoliosis, S-shaped curvature"
            },
        4: {"name": "Fracture",
            "abbr": "frac",
            "loc": "patient",
            "syn": "fracture, osteoporotic fracture, osteoporotic deformation, wedge deformity"
            }   
    },
        "OutputFormats": {
        "level":{
            "loc":"level",
            "output": "As a result, give me a list with exactly 20 entries, grouped by pathology. It must contain five entries for each pathology, one for each of the five vertebral levels (L1-2 to L5-S1). For enplate changes, disc pathology and facet joint arthropthy, give only entries of 0 (for pathology absent) or 1 (for pathology present). For spinal canal stenosis the entry must be 0 if there is no spinal canal stenosis, 1 if it is described as mild, 2 if it is described as moderate of no further qualification of stenosis extent is given, and 3 if it is described as severe. Entries in the list must always adhere to this format. Here are three example entries: Endplate Changes L1-L2: 0, Disc Pathology L5-S1: 1, Spinal Canal Stenosis: 3. Ignore levels named ALPHANUMERICID. End the list with 'END OF LIST'."
        },
        "level_side":{
            "loc":"level_side",
            "output": "As a result, give me a list with exactly 20 entries, grouped by pathology; each entry must be on a new line, do not use commas to separate entries. It must contain ten entries for each pathology, two for each of the five vertebral levels (L1-2 to L5-S1), one for the right and one for the left side at each level. The entry must be 0 if there is no mention of a pathology at this level, 1 if the pathology is described as mild, 2 if it is described as moderate or there is no further qualification of the extent of the pathology, and 3 if it is described as severe. Entries in the list must always adhere to this format. Here are two example entries: Foraminal Stenosis L1-L2 right: 2, Lateral Recess Stenosis L5-S1 left: 0. Ignore levels named ALPHANUMERICID. End the list with 'END OF LIST'."
        },
        "patient":{
            "loc":"patient",
            "output": "As a result, give me a list with exactly 4 entries. It must contain one entry for each pathology with corresponding entries of either 1 or 0. Entries in the list must always adhere to this format. Here are two example entries: Sacroiliac joint: 0, Fracture: 1. Ignore levels named ALPHANUMERICID. End the list with 'END OF LIST'."
        },
    },
    "InterpretationGuidance": {
        "bilateral_changes":{
            "patterns": ["left greater than right", "right greater than left", "bilateral"],
            "guidance": "Consider phrases like 'bilateral','left greater than right' or 'right greater than left' as presence of changes on both sides. Please apply this rule strictly in your interpretation."
        },
        "segment_localization":{
            "patterns": ["superior endplate", "inferior endplate"],
            "guidance": "If a change is described as localized at the superior endplate, attribute it to the level above this vertebral body (e.g. superior endplate L2 belongs to the level L1-L2); conversely the inferior endplate belongs to the segment of below its vertebral body (e.g. inferior endplate L3 belongs to the Level L3-L4)."
        },
        "multilevel":{
            "patterns": ["multilevel"],
            "guidance": "If a pathology is described as 'multilevel' assume it is present in all vertebral levels."
        },
        "desiccation":{
            "patterns": ["desiccation"],
            "guidance": "Do not consider desiccation or darkening of discs a pathology."
        },
        "heightloss":{
            "patterns": ["height"],
            "guidance": "Do not consider height loss of a disc a pathology."
        },
        "straight":{
            "patterns": ["straightening"],
            "guidance": "Do not consider straightening or loss of lumbar lordosis a pathology."
        },
        "significant":{
            "patterns": ["significant"],
            "guidance": "Consider pathologies described as 'not significant' as not present."
    },
}
}



In [None]:
# Extract Pathology Terms for Text Embedding
search_topic = []
search_terms = [] 
prompt_type = []

for grouping in ["PathologiesLevel", "PathologiesLevelSide", "PathologiesPatient"]:
    for info in data_store[grouping].values():
        search_topic.extend([f"{info['abbr']}"]*len(info['syn'].split(', ')))
        prompt_type.extend(["syn_name"]*len(info['syn'].split(', ')))
        search_terms.extend(info['syn'].split(', '))

df_search_terms = pd.DataFrame()
df_search_terms['pathology'] = search_topic
df_search_terms['term_type'] = prompt_type
df_search_terms['impressions'] = search_terms

df_search_terms

## Prompting functions

Initialize output columns

In [None]:
def initialize_dataframe_columns(deployment_id, df):
    column_name = f"embed_impression_{deployment_id}"
    if column_name not in df.columns:
        df[column_name] = pd.NA
    column_name = f"embed_total_tokens_{deployment_id}"
    if column_name not in df.columns:
        df[column_name] = pd.NA
    return df

In [None]:
import json
import requests

# Helper function to send POST requests
def post_request(url, headers, body):
    response = requests.post(url, headers=headers, data=body)
    response.raise_for_status()
    return response

## Embedding pipeline - main instruction

In [None]:
# Extract embeddings
deployment_id = "text-embedding-ada-002" #Options: text-embedding-ada-002, text-embedding-3-large, text-embedding-3-small
df_search_terms = initialize_dataframe_columns(deployment_id, df_search_terms)

for index, row in df_search_terms.iterrows():
    report_text = row['impressions']

    embeddings_url = f"{RESOURCE_ENDPOINT}/openai/deployments/{deployment_id}/embeddings?api-version={API_VERSION}"
    body = json.dumps({
        "input": report_text,
    })
    headers = {
        'Content-Type': 'application/json',
        'api-key': API_KEY
    }
    response = post_request(embeddings_url, headers, body)
    embedding = json.loads(response.text)['data'][0]['embedding']
    
    total_tokens = json.loads(response.text)['usage']['total_tokens']

    # Update the DataFrame
    df_search_terms.at[index, f"embed_impression_{deployment_id}"] = embedding
    df_search_terms.at[index, f"embed_total_tokens_{deployment_id}"] = total_tokens
    

In [None]:
save_path = # TODO: define .xlsx path to save data
df_search_terms.to_excel(save_path, index=False)

## Embedding Analysis

In [None]:
# --- Embedding functions ---
def string_to_array(s):
    array = []
    for x in s.replace('[','').replace(']','').split(', '):
        x = x.replace(',','')
        if x != '' and x != '-' and x!= ' ':
            array.append(float(x))
    return array
    
def get_embeddings(df, col):
    # Extract embeddings and reformat them
    embeddings = []
    for index, row in df[:].iterrows():
        embedding = string_to_array(row[col])
        #print(len(embedding))
        embeddings.append(embedding)
    
    return np.array(embeddings)

# --- Similarity functions --- 
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import silhouette_score

# --- Clustering functions --- 
import umap.umap_ as umap

# --- Plotting functions --- 
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colorbar import Colorbar

### Cosine similaritiy of each pathology term embedding with agglomerative clustering

In [None]:
# Load embeddings from .xlsx
save_path = #TODO
df_syn = pd.read_excel(save_path)

# Extract embeddings from a dataframe to a matrix
embed_syns = get_embeddings(df_syn, 'embed_impression_text-embedding-ada-002')

# Calculate cosine similarity
similarity_matrix = cosine_similarity(embed_syns)

# Create a unique color map for the categories
categories = ['scs', 'lrs', 'fs',  'fj', 'sij', 'endplate', 'disc', 'curv', 'olisth', 'frac'] 
colors = sns.color_palette('CMRmap', len(categories))  # Choose a color palette
category_colors = {category: color for category, color in zip(categories, colors)}

# Map categories to colors for tick labels
row_colors = df_syn['pathology'].map(category_colors).to_numpy()
col_colors = df_syn['pathology'].map(category_colors).to_numpy()

# Create the cluster map
plt.figure(figsize=(12, 10))  
cluster_map = sns.clustermap(
    similarity_matrix, 
    annot=False, 
    method='average', # 'single','complete','average','weighted','centroid','median','ward'
    cmap='coolwarm', 
    xticklabels=df_syn['impressions'],
    yticklabels=df_syn['impressions'],
    row_colors=row_colors, 
    col_colors=col_colors, # option 1
    #col_cluster=False, # option 2
    cbar_pos=None,
    #orientation='horizontal',
    figsize=(15, 15)  
)

# Rotate the axis labels for better readability
plt.setp(cluster_map.ax_heatmap.xaxis.get_majorticklabels(),  ha='right', fontsize=10) #rotation=45,
plt.setp(cluster_map.ax_heatmap.yaxis.get_majorticklabels(), fontsize=10)

# Add a horizontal colorbar at the top
cbar_ax = cluster_map.fig.add_axes([0.19, 0.86, 0.58, 0.018])  # [left, bottom, width, height]
norm = plt.Normalize(vmin=min(0.7,similarity_matrix.min()), vmax=similarity_matrix.max())
sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=norm)
sm.set_array([])
cbar = Colorbar(ax=cbar_ax, mappable=sm, orientation='horizontal')
#cbar.set_label('Cosine Similarity', fontsize=12)
cbar.ax.tick_params(length=5, labelsize=10)
cbar.outline.set_visible(False) # Remove the border around the colorbar

# Add a legend for the category colors
for label, color in category_colors.items():
    cluster_map.ax_col_dendrogram.bar(0, 0, color=color, label=label, linewidth=0)
cluster_map.ax_col_dendrogram.legend(loc="upper left", ncol=10, bbox_to_anchor=(0, .6))

# Save the figure
save_path = f"cos_similarity.png"
plt.savefig(save_path, dpi=300)

plt.show()

### Unsupervised UMAP clustering of pathology term embeddings

In [None]:
#Unsupervised UMAP hyperparameter sweep
best_score = 0
for n_neighbors in [5,10,15,20,25]:
    for min_dist in [0.01, 0.05, 0.1, 0.15]:
        for negative_sample_rate in [2,4,6]:
            for local_connectivity in [1,2,3]:
                umap_result = umap.UMAP(n_neighbors = n_neighbors, min_dist = min_dist, negative_sample_rate=negative_sample_rate, local_connectivity=local_connectivity, random_state=11).fit_transform(embed_syns)
            
                silhouette_avg = silhouette_score(umap_result, df_syn['pathology'])
                print(silhouette_avg)

                if silhouette_avg>best_score:
                    best_params = {
                        'n_neighbors':n_neighbors,
                        'min_dist':min_dist,
                        'negative_sample_rate':negative_sample_rate,
                        'local_connectivity':local_connectivity,
                    }
                    print(best_params)
                    best_score = np.copy(silhouette_avg)
                    


In [None]:
print(best_params)
#{'n_neighbors': 10, 'min_dist': 0.01, 'negative_sample_rate': 6, 'local_connectivity': 3}

In [None]:
# Unsupervised umap with the hyperparameters for the lowest silhouette score
umap_result = umap.UMAP(n_neighbors = best_params['n_neighbors'], 
                        min_dist = best_params['min_dist'], 
                        negative_sample_rate=best_params['negative_sample_rate'], 
                        local_connectivity=best_params['local_connectivity'], 
                        random_state=11).fit_transform(embed_syns)
# Calculate the silhouette score
silhouette_avg = silhouette_score(umap_result, df_syn['pathology'])
print(silhouette_avg)

# Create one figure with two subplots in a row
fig, axes = plt.subplots(1, 2, figsize=(25,6))
axes = axes.flatten()

for ii in range(2):
    # Define the labels for both plots
    if ii==0:
        labels = list(df_syn['pathology'])
        palette = 'tab10' 
    elif ii==1:
        labels = list(df_syn['impressions'])
        # Use a unique color palette
        unique_labels = sorted(set(labels))
        palette = sns.color_palette("nipy_spectral", len(unique_labels)+2) 
    

    #Scatter plot of umap projection
    scatter = sns.scatterplot(ax=axes[ii], x=umap_result[:, 0], y=umap_result[:, 1],
                        hue=labels, palette=palette, s=30, alpha=0.7, edgecolor='k')

    # Configure legend for the second subplot
    if ii == 1:
        handles, legend_labels = scatter.get_legend_handles_labels()
        axes[ii].legend(
            handles=handles, labels=legend_labels, 
            loc='upper left', bbox_to_anchor=(1, 1),  # Move legend outside the plot
            ncol=3, fontsize='small', frameon=False
        )
        # Ensure legend fits fully in the saved plot
        fig.subplots_adjust(left=0.2, right=0.5, top=0.9, bottom=0.1)
        
    plt.setp(axes[ii], xticks=[], yticks=[])

plt.suptitle('Unsupervised UMAP of Embeddings');
# Save the figure
save_path = f"unsupervised_UMAP_syn_embeddings.png"
plt.savefig(save_path, dpi=300)

plt.show()