In [None]:
pip install -r requirements.txt

In [59]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import requests
import os
from collections import defaultdict, Counter
from datetime import datetime, timedelta
import random
from scipy.stats import halfnorm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import time
from sklearn.manifold import TSNE
from sklearn.metrics import roc_curve, auc, accuracy_score, precision_score, recall_score, f1_score
import seaborn as sns
from tqdm import tqdm
import pickle

In [61]:
conditions_df = pd.read_csv('data/conditions.csv')
medications_df = pd.read_csv('data/medications.csv')
procedures_df = pd.read_csv('data/procedures.csv')


In [62]:
#map conditions (SNOMED) to diagnoses (ICD-10)
snomed_to_icd = pd.read_csv('data/snomed_icd_10_map.csv')

In [63]:
diagnoses_df = conditions_df.merge(snomed_to_icd, left_on='CODE', right_on='SNOMED ID', how='left')

#filter out rows where no mapping was found (ICD ID = "\N")
diagnoses_df = diagnoses_df[diagnoses_df['ICD ID'] != '\\N']

#strip all '?' characters from ICD ID column
diagnoses_df['ICD ID'] = diagnoses_df['ICD ID'].str.replace('?', '')
diagnoses_df = diagnoses_df[['START', 'STOP', 'PATIENT', "ENCOUNTER", "CODE", "DESCRIPTION", "SNOMED ID", "ICD ID", "ICD Name"]]

#filter out rows of procedures_df where code = 428191000124101
procedures_df = procedures_df[procedures_df['CODE'] != 428191000124101]

In [64]:
#create patient_visits
diagnoses_df['START'] = pd.to_datetime(diagnoses_df['START'])
medications_df['START'] = pd.to_datetime(medications_df['START'])
procedures_df['DATE'] = pd.to_datetime(procedures_df['DATE'])

# Diagnosis: Use ICD ID
diag_visits = diagnoses_df[['PATIENT', 'START', 'ICD ID']].copy()
diag_visits.columns = ['PATIENT', 'DATE', 'CODE']
diag_visits['CODE_TYPE'] = 'diagnosis'

# Medication: Use RxNorm CODE
med_visits = medications_df[['PATIENT', 'START', 'CODE']].copy()
med_visits.columns = ['PATIENT', 'DATE', 'CODE']
med_visits['CODE_TYPE'] = 'medication'

# Procedures: Use HCPCS CODE
proc_visits = procedures_df[['PATIENT', 'DATE', 'CODE']].copy()
proc_visits['CODE_TYPE'] = 'procedure'

In [65]:
combined_visits = pd.concat([diag_visits, med_visits, proc_visits], ignore_index=True)
combined_visits = combined_visits.sort_values(by=['PATIENT', 'DATE'])

patients_visits = defaultdict(list)
#filter out rows where CODE is nan
for _, row in combined_visits.iterrows():
  if not pd.isna(row['CODE']):
    patient = row['PATIENT']
    entry = (row['DATE'], row['CODE'], row['CODE_TYPE'])
    patients_visits[patient].append(entry)

In [None]:
#create (context_window, label_medication) training pairs
from collections import defaultdict, Counter
from datetime import datetime, timedelta
import random
from scipy.stats import halfnorm

# Flatten all codes
all_codes = [code for visits in patients_visits.values() for _, code, _ in visits]
code_freq = Counter(all_codes)
total_codes = sum(code_freq.values())

# Compute sampling probability
downsample_probs = {code: min(1.0, 1.0 / np.sqrt(freq / total_codes)) for code, freq in code_freq.items()}

def sample_context_window(std_weeks=40):
    # Draw from a half-normal distribution (weeks), convert to timedelta
    weeks = halfnorm.rvs(scale=std_weeks)
    return timedelta(weeks=int(weeks))

def bin_by_2_months(events):
    # Bin events into 2-month periods
    binned = defaultdict(list)
    for date, code in events:
        bin_key = (date.year, date.month // 2)
        binned[bin_key].append((date, code))
    return binned

skip_gram_pairs = []

for patient_id, events in patients_visits.items():
    # Separate medication events from context (diagnoses + procedures)
    context_events = [(d, c) for d, c, t in events if t != 'medication']
    med_events = [(d, c) for d, c, t in events if t == 'medication']

    for med_date, med_code in med_events:
        window = sample_context_window()
        window_start = med_date - window

        # Select context events within the window
        context_in_window = [(d, c) for d, c in context_events if window_start <= d < med_date]

        if not context_in_window:
            continue

        # Bin events by 2-month period
        binned = bin_by_2_months(context_in_window)

        sampled_context_codes = []
        for bin_events in binned.values():
            sampled_date, sampled_code = random.choice(bin_events)

            # Downsampling based on frequency
            if random.random() < downsample_probs.get(sampled_code, 1.0):
                sampled_context_codes.append(sampled_code)

        for context_code in sampled_context_codes:
            skip_gram_pairs.append((str(context_code), str(med_code)))

print(skip_gram_pairs[:5])

In [None]:
# Get all unique codes from the skip-gram pairs
unique_codes = set()
for h, t in skip_gram_pairs:
    unique_codes.add(h)
    unique_codes.add(t)

# Create mappings
code_to_index = {code: idx for idx, code in enumerate(sorted(unique_codes))}
index_to_code = {idx: code for code, idx in code_to_index.items()}

vocab_size = len(code_to_index)
print(f"Vocabulary size: {vocab_size}")

In [69]:
#load skip-grams using dataloader
encoded_pairs = [(code_to_index[h], code_to_index[t]) for h, t in skip_gram_pairs
                 if h in code_to_index and t in code_to_index]

class SkipGramDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        context_idx, target_idx = self.pairs[idx]
        return torch.tensor(context_idx, dtype=torch.long), torch.tensor(target_idx, dtype=torch.long)

dataset = SkipGramDataset(encoded_pairs)
dataloader = DataLoader(dataset, batch_size=1024, shuffle=True)

#use skip_gram_pairs to create embeddings
class SkipGramDataset(Dataset):
    def __init__(self, skip_gram_pairs, code_to_index):
        self.pairs = skip_gram_pairs
        self.code_to_index = code_to_index

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        context_code, target_code = self.pairs[idx]
        return torch.tensor(self.code_to_index[context_code], dtype=torch.long), \
               torch.tensor(self.code_to_index[target_code], dtype=torch.long)

In [70]:
#setup word2vec model
class IndicationEmbeddingModel(nn.Module):
    def __init__(self, vocab_size=len(code_to_index), embed_dim=50, dropout=0.2):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.output_weights = nn.Linear(embed_dim, vocab_size)

    def forward(self, context_idxs):
        x = self.embeddings(context_idxs)
        x = self.dropout(x)
        logits = self.output_weights(x)
        return logits

In [None]:
#train model
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = IndicationEmbeddingModel(vocab_size).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
criterion = nn.CrossEntropyLoss()
losses = []
epoch_times = []
for epoch in range(epochs):
    total_loss = 0
    epoch_start_time = time.time()
    for context_batch, target_batch in dataloader:
        context_batch, target_batch = context_batch.to(device), target_batch.to(device)

        optimizer.zero_grad()
        logits = model(context_batch)
        loss = criterion(logits, target_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    losses.append(total_loss)
    epoch_end_time = time.time()
    epoch_time = epoch_end_time - epoch_start_time
    epoch_times.append(epoch_time)

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

print(f"Average epoch run time: {np.mean(epoch_times):.4f} seconds")


In [None]:
#plot training loss
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(losses) + 1), losses)
plt.title('Training Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.tight_layout()
plt.show()

In [73]:
# Normalize embeddings to unit vectors
raw_embeddings = model.embeddings.weight.data
normalized_embeddings = F.normalize(raw_embeddings, p=2, dim=1)

In [None]:
#plot embeddings
index_to_code = {idx: code for code, idx in code_to_index.items()}

code_type_dict = {}

for visits in patients_visits.values():
    for _, code, code_type in visits:
        code_str = str(code)
        if code_str not in code_type_dict:
            code_type_dict[code_str] = code_type

tsne = TSNE(n_components=2, perplexity=30, random_state=42)
embeddings_2d = tsne.fit_transform(normalized_embeddings)

type_colors = {
    'diagnosis': 'tab:blue',
    'medication': 'tab:green',
    'procedure': 'tab:red'
}

plt.figure(figsize=(16, 12))
sample_indices = np.random.choice(len(embeddings_2d), size=100, replace=False)

for idx in sample_indices:
    x, y = embeddings_2d[idx]
    code = index_to_code[idx]
    code_type = code_type_dict.get(code, 'unknown')
    color = type_colors.get(code_type, 'gray')
    plt.scatter(x, y, color=color, s=40, alpha=0.6)
    plt.text(x + 0.2, y, code, fontsize=8)

for label, color in type_colors.items():
    plt.scatter([], [], c=color, label=label)
plt.legend()
plt.title("Visualization of Indication Embeddings")
plt.xlabel("t-SNE Dim 1")
plt.ylabel("t-SNE Dim 2")
plt.grid(True)
plt.tight_layout()
plt.show()

In [76]:
#compare indication embeddings to MEDI dataset
medi_df = pd.read_csv('data/MEDI-2.csv')
medi_hps = pd.read_csv('data/MEDI-2_HPS.csv')

medi_df = medi_df.rename(columns={'RXCUI': 'rxnorm_code', 'CODE' : 'ICD10'})
medi_df['rxnorm_code'] = medi_df['rxnorm_code'].astype(str)
medi_df['ICD10'] = medi_df['ICD10'].astype(str)

medi_hps = medi_hps.rename(columns={'RXCUI': 'rxnorm_code', 'CODE' : 'ICD10'})
medi_hps['rxnorm_code'] = medi_hps['rxnorm_code'].astype(str)
medi_hps['ICD10'] = medi_hps['ICD10'].astype(str)

#filter out rows where VOCABULARY != ICD10CM
medi_df = medi_df[medi_df['VOCABULARY'] == 'ICD10CM']
medi_hps = medi_hps[medi_hps['VOCABULARY'] == 'ICD10CM']

In [None]:
positive_pairs = set([tuple(x) for x in medi_hps[['rxnorm_code', 'ICD10']].values])
drug_codes = set(medi_df['rxnorm_code'])
diagnosis_codes = set(medi_df['ICD10'])

all_possible_pairs = [(d, icd) for d in drug_codes for icd in diagnosis_codes]
negative_pairs = [pair for pair in all_possible_pairs if pair not in positive_pairs]

#sample negatives for balance
random.seed(42)
negative_pairs = random.sample(negative_pairs, k=750000)
print(f"Positives: {len(positive_pairs)} | Negatives: {len(negative_pairs)}")

In [None]:
#compute cosine similarity
emb_matrix = normalized_embeddings.numpy()
def get_similarity(code1, code2):
    idx1 = code_to_index.get(code1)
    idx2 = code_to_index.get(code2)
    if idx1 is None or idx2 is None:
        return None
    return np.dot(emb_matrix[idx1], emb_matrix[idx2])

# Score pairs
pairs = list(positive_pairs) + list(negative_pairs)
labels = [1]*len(positive_pairs) + [0]*len(negative_pairs)

scores = []
for drug, diagnosis in tqdm(pairs):
    sim = get_similarity(drug, diagnosis)
    if sim is not None:
        scores.append(sim)
    else:
        scores.append(0.0)

In [None]:
#plot roc curve
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

fpr, tpr, thresholds = roc_curve(labels, scores)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8,6))
plt.plot(fpr, tpr, label=f'Embedding Similarity (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve: Reproduced Indication Embedding Similarity')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()

The code in the following section borrows heavily from the author of the original paper, who made their script available on Github.

In [85]:
evemb = pd.read_table('data/indication_embedding.csv', index_col=0)
voc = pd.read_table('data/indications_vocab.tsv', sep="\t",index_col=0)

#drug to theraputic group mapping
(g2thrgrds, thrgrds) = pickle.load(open("data/g2thrgrds.pkl",'rb'))
(icd2ccs, ccsdo) = pickle.load(open("data/icd2ccs.pkl", 'rb'))
icd2ccs_dict = dict(zip(*tuple((icd2ccs.index, icd2ccs['ccs']))))
ccs2icd = {p:[i for i,pc in icd2ccs_dict.items() if p in pc] for p in set(icd2ccs_dict.values())}

icd2phe = pickle.load(open("data/icd2phe.03.18.pkl",'rb'))
phe2icd = {p:[i for i,pc in icd2phe.items() if p in pc] for p in set(icd2phe.values())}

cut = 20000
## filter to keep Dx and Rx that are in more than 20000 patients
dxvoc = voc.loc[(voc['type']=='dx') & (voc['id']>0) & (voc['ct'] > cut),:]
rxvoc = voc.loc[(voc['type']=='rx') & (voc['id']>0) & (voc['ct'] > cut),:]

rxemb = evemb.loc[rxvoc['code'],:]
dxemb = evemb.loc[dxvoc['code'],:]

In [None]:
rxdotdx_sel = rxemb.dot(dxemb.transpose()).transpose()

hps = pickle.load(open("data/MEDI_01212013_HPS.pkl",'rb'))
medi = pickle.load(open("data/MEDI_01212013.pkl",'rb'))

isind = []
for row in rxdotdx_sel.index:
    isind.append(rxdotdx_sel.columns.isin(hps[row]))
isind = pd.DataFrame(isind,index=rxdotdx_sel.index,columns = rxdotdx_sel.columns)

## filter to ICD and Drug that have at least one indication relationship in the High Precision Set
## after removing those drugs that are nonspecific (prescribed for more than 2% of diseases)
selcol =(isind.sum(axis=0) >= 1)
selrow =(isind.sum(axis=1) >= 1)

isind = isind.loc[selrow,selcol]
rxdotdx_sel = rxdotdx_sel.loc[selrow,selcol]

### full set -- "low precision"
mediisind = []
for row in rxdotdx_sel.index:
    mediisind.append(rxdotdx_sel.columns.isin(medi[row]))
mediisind = pd.DataFrame(mediisind,index=rxdotdx_sel.index,columns = rxdotdx_sel.columns)

### negative = not in High precision set OR in full set of indications
neg = rxdotdx_sel.values.reshape(-1,1)[(~isind & ~mediisind).values.reshape(-1,1)==True]

### positive = high precision set
pos = rxdotdx_sel.values.reshape(-1,1)[isind.values.reshape(-1,1)==True]
original_roc = roc_auc_score(np.append(np.ones(pos.shape),np.zeros(neg.shape)), np.append(pos,neg))
print("ROC AUC: {:1.2f}".format(roc_auc_score(np.append(np.ones(pos.shape),np.zeros(neg.shape)), np.append(pos,neg))))

y_scores = np.append(pos, neg)
y_true = np.append(np.ones(pos.shape), np.zeros(neg.shape))

# Compute ROC
fpr, tpr, thresholds = roc_curve(y_true, y_scores)

# Plot ROC Curve
plt.figure(figsize=(8,6))
plt.plot(fpr, tpr, label=f'ROC AUC = {original_roc:.2f}', color='darkblue', linewidth=2)
plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve: Original Indication Embedding Similaririty')
plt.legend(loc='lower right')
plt.grid(True)
plt.tight_layout()
plt.show()