In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from torch.nn.functional import normalize
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import cdist

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Dimensions of embeddings
embedding_dim = 768

cuda


In [3]:
df = pd.read_parquet("data/final_data_cleaned.parquet", engine='pyarrow')
df

Unnamed: 0,disease,drug,disease_embedding,drug_embedding
0,21-hydroxylase deficiency,17-alpha-hydroxyprogesterone,"[-0.09322444, 0.14657274, 0.2630737, -0.262951...","[-0.19820696, 0.5990576, 0.012646209, 0.000393..."
1,21-hydroxylase deficiency,dexamethasone,"[-0.09322444, 0.14657274, 0.2630737, -0.262951...","[-0.39190397, 0.13832054, -0.02711408, 0.05101..."
2,21-hydroxylase deficiency,hydrocortisone,"[-0.09322444, 0.14657274, 0.2630737, -0.262951...","[-0.31218493, -0.015504928, -0.069101125, 0.09..."
3,5q-syndrome,lenalidomide,"[0.053790092, -0.19379048, 0.11191088, -0.3218...","[-0.14175633, 0.3328832, 0.110642955, -0.09579..."
4,"ACTH Syndrome, Ectopic",corticotropin,"[-0.26318076, -0.27482936, 0.33002627, -0.5014...","[-0.14670405, 0.15948579, -0.048798714, 0.0340..."
...,...,...,...,...
18761,vasomotor symptom,fezolinetant,"[0.022215698, 0.35601804, 0.041473925, -0.2789...","[-0.035705805, 0.26630142, -0.019500958, 0.047..."
18762,vasomotor symptom,gabapentin,"[0.022215698, 0.35601804, 0.041473925, -0.2789...","[-0.08150145, -0.107544206, 0.111592196, -0.01..."
18763,vasomotor symptom,paroxetine,"[0.022215698, 0.35601804, 0.041473925, -0.2789...","[-0.04558692, 0.5095056, 0.086280674, 0.170817..."
18764,vasomotor symptom,progesterone,"[0.022215698, 0.35601804, 0.041473925, -0.2789...","[-0.26757017, 0.12531605, -0.12637241, -0.0235..."


In [4]:
def split_dataframe(df, random_seed=42):
    # For reproducibility
    np.random.seed(random_seed)

    df_train = pd.DataFrame(columns=df.columns)
    df_test = pd.DataFrame(columns=df.columns)
    
    disease_counts = df.groupby('disease').size()
    
    processed_diseases = set()
    
    # For each possible number of prescriptions (1 to 10)
    for n_prescriptions in range(1, 11):
        # Get diseases with exactly n prescriptions
        diseases_with_n = disease_counts[disease_counts == n_prescriptions].index.tolist()
        
        if len(diseases_with_n) > 0:
            # Calculate number of diseases to move to test set (10%, minimum 1)
            n_to_test = max(1, int(np.ceil(len(diseases_with_n) * 0.1)))
            
            # Randomly select diseases for test set
            test_diseases = np.random.choice(diseases_with_n, 
                                          size=n_to_test, 
                                          replace=False)
            
            # Add to test set
            test_mask = df['disease'].isin(test_diseases)
            df_test = pd.concat([df_test, df[test_mask]])
            
            # Add remaining to train set
            train_mask = df['disease'].isin(diseases_with_n) & ~df['disease'].isin(test_diseases)
            df_train = pd.concat([df_train, df[train_mask]])
            
            # Add to processed diseases
            processed_diseases.update(diseases_with_n)
    
    assert len(processed_diseases) == len(disease_counts), "Not all diseases were processed"
    assert len(df) == len(df_train) + len(df_test), "Row counts don't match"
    
    return df_train.reset_index(drop=True), df_test.reset_index(drop=True)

df, df_test_filtered = split_dataframe(df, random_seed=42)

In [5]:
# Shuffle df_train and df_test
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
df_test_filtered = df_test_filtered.sample(frac=1, random_state=42).reset_index(drop=True)

In [6]:
df

Unnamed: 0,disease,drug,disease_embedding,drug_embedding
0,Shigellosis,ceftriaxone,"[-0.17925695, 0.093315504, 0.14118157, -0.0423...","[0.18962991, 0.13530359, 0.024884265, 0.036119..."
1,Diabetic Peripheral Neuropathy,gabapentin,"[0.27419174, 0.021408416, 0.15780602, -0.17473...","[-0.08150145, -0.107544206, 0.111592196, -0.01..."
2,Common Migraine,zavegepant,"[-0.2348871, -0.12846316, 0.22506504, -0.36667...","[0.061711952, 0.07011112, -0.10979448, 0.27578..."
3,Systemic mycosis,amphotericin B,"[0.083499566, 0.051283106, 0.1795818, -0.31725...","[0.1292809, 0.24959983, 0.037212547, 0.1430497..."
4,Malignant tumor of parathyroid gland,cinacalcet,"[-0.11646998, -0.16838503, 0.2693627, -0.24781...","[-0.08332437, 0.21648774, -0.036853496, -0.022..."
...,...,...,...,...
16863,Urethritis,erythromycin estolate,"[0.31411487, 0.3275367, 0.13985643, -0.0139704...","[-0.018034114, 0.1684173, -0.14231458, 0.11360..."
16864,Chronic pain,fentanyl,"[0.13405961, 0.14006968, 0.36103272, -0.153927...","[-0.20299998, 0.049612034, -0.20746674, 0.1787..."
16865,Strongyloidiasis,thiabendazole,"[-0.028201543, 0.2958583, 0.19486448, 0.039303...","[0.106388256, 0.24386232, 0.2922835, -0.151420..."
16866,"Mercury Poisoning, Nervous System",dimercaprol,"[0.3841246, -0.1792188, 0.5426179, -0.35204807...","[-0.26254827, 0.077304736, -0.06569935, 0.1721..."


In [7]:
df_test_filtered

Unnamed: 0,disease,drug,disease_embedding,drug_embedding
0,OPIOID TOLERANCE,fentanyl,"[-0.12380553, 0.24230061, 0.14066729, -0.01755...","[-0.20299998, 0.049612034, -0.20746674, 0.1787..."
1,Urolithiasis,tamsulosin,"[0.015423108, 0.3084482, 0.2079388, -0.2220752...","[0.17083947, 0.603384, -0.13388242, 0.05305582..."
2,"Lymphoma, Follicular",umbralisib,"[-0.16385096, -0.30995056, 0.07767179, -0.2071...","[-0.044071615, -0.013846768, 0.1444283, -0.094..."
3,Advanced Hepatocellular Carcinoma,lenvatinib,"[0.1980408, 0.09689286, -0.16902849, -0.108710...","[0.00023591526, 0.1208497, 0.20754297, -0.2127..."
4,Pericardial effusion co-occurrent and due to m...,bleomycin,"[-0.025269533, 0.008863041, -0.1388683, -0.019...","[-0.18694587, 0.13150674, 0.008723415, -0.0926..."
...,...,...,...,...
1893,Bulimia,fluoxetine hydrochloride,"[-0.12157645, -0.026235884, -0.017002245, 0.23...","[-0.12374182, 0.5025294, 0.06577464, -0.002481..."
1894,AIDS with Kaposi's sarcoma,paclitaxel,"[-0.08204432, 0.04230188, -0.05801783, -0.4442...","[-0.031103916, 0.4158533, -0.0920613, 0.080813..."
1895,Pituitary Adenoma,Endostatins,"[0.021202672, 0.120926805, -0.0037852854, -0.1...","[-0.1662487, 0.30463144, 0.11857961, 0.0526934..."
1896,"Nausea/Vomiting, Chemotherapy Induced",dexamethasone,"[0.123240806, -0.1336841, -0.008411284, -0.318...","[-0.39190397, 0.13832054, -0.02711408, 0.05101..."


In [8]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(embedding_dim, 1024),
            nn.LayerNorm(1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 512),
            nn.LayerNorm(512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, embedding_dim),
            nn.LayerNorm(embedding_dim),
            nn.Tanh()
        )
    
    def forward(self, disease_embedding):
        return self.model(disease_embedding)

In [9]:
generator = Generator().to(device)

generator.load_state_dict(torch.load('models/generator_63_17.pth', weights_only=True))

generator.eval()

Generator(
  (model): Sequential(
    (0): Linear(in_features=768, out_features=1024, bias=True)
    (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (6): LeakyReLU(negative_slope=0.2)
    (7): Dropout(p=0.3, inplace=False)
    (8): Linear(in_features=512, out_features=768, bias=True)
    (9): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (10): Tanh()
  )
)

In [10]:
# Load the tokenizer and bert_model, and move the bert_model to the specified device
model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name).to(device)

def compute_embeddings(text, tokenizer, bert_model, device):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=64).to(device)
    
    with torch.no_grad():
        outputs = bert_model(**inputs)
    
    # Get the embeddings (the [CLS] token), then move them back to CPU and convert to numpy
    embeddings = outputs.last_hidden_state[:, 0, :].cpu().squeeze().numpy()
    return embeddings

In [11]:
def generate_drug(disease_embedding):
    if not isinstance(disease_embedding, torch.Tensor):
        disease_embedding = torch.tensor(
            disease_embedding, 
            dtype=torch.float32
        ).to(device)
    
    if len(disease_embedding.shape) == 1:
        disease_embedding = disease_embedding.unsqueeze(0)
    
    # Normalize input
    disease_embedding = nn.functional.normalize(disease_embedding, dim=1)
    
    with torch.no_grad():
        generated_drug = generator(disease_embedding)
        generated_drug = nn.functional.normalize(generated_drug, dim=1)
    
    return generated_drug

In [12]:
def find_most_similar_drugs(disease_name, combined_df, tokenizer, bert_model, device, top_n=5):
    # Drop duplicate drugs
    combined_df = combined_df.drop_duplicates(subset=["drug"])

    # Generate the disease embedding
    disease_embedding = compute_embeddings(disease_name, tokenizer, bert_model, device)
    disease_embedding = torch.tensor(disease_embedding, dtype=torch.float32).to(device)
    disease_embedding = disease_embedding.unsqueeze(0)
    
    # Generate drug embedding
    generated_drug_embedding = generate_drug(disease_embedding).cpu().numpy()
    
    # Normalize all drug embeddings
    drug_embeddings = np.vstack(combined_df["drug_embedding"].values)
    drug_embeddings = normalize(torch.tensor(drug_embeddings), dim=1).numpy()
    
    # Calculate cosine similarity and rank
    similarities = cosine_similarity(generated_drug_embedding, drug_embeddings)
    cosine_ranks = np.argsort(similarities[0])[::-1]  # Sort descending (higher similarity is better)
    
    # Calculate Euclidean distance and rank
    euclidean_distances = cdist(generated_drug_embedding, drug_embeddings, metric="euclidean")[0]
    euclidean_ranks = np.argsort(euclidean_distances)  # Sort ascending (lower distance is better)
    
    # Calculate Manhattan distance and rank
    manhattan_distances = cdist(generated_drug_embedding, drug_embeddings, metric="cityblock")[0]
    manhattan_ranks = np.argsort(manhattan_distances)  # Sort ascending (lower distance is better)
    
    # Calculate total ranks for each drug
    total_ranks = {i: (np.where(cosine_ranks == i)[0][0] + 
                       np.where(euclidean_ranks == i)[0][0] + 
                       np.where(manhattan_ranks == i)[0][0]) for i in range(len(combined_df))}
    
    # Sort drugs by total rank and select top-n
    sorted_drug_indices = sorted(total_ranks, key=total_ranks.get)[:top_n]
    top_drugs = combined_df.iloc[sorted_drug_indices]["drug"].tolist()
    
    return top_drugs

In [13]:
def evaluate_multiple_accuracy(df_eval, df_combined, disease_drug_map, tokenizer, bert_model, device, top_n=5):
    """
    Evaluate accuracy by checking if any of the top N predicted drugs match
    the actual drugs that treat the disease.
    """
    correct_predictions = 0
    
    for _, row in df_eval.iterrows():
        disease_name = row['disease']
        
        predicted_drugs = find_most_similar_drugs(
            disease_name,
            df_combined,
            tokenizer,
            bert_model,
            device,
            top_n=top_n
        )
        
        # Check if any of the predicted drugs are in the set of valid drugs for this disease
        if any(predicted_drug in disease_drug_map[disease_name] for predicted_drug in predicted_drugs):
            correct_predictions += 1
            
    accuracy = correct_predictions / len(df_eval)
    return accuracy

In [14]:
df_combined = pd.concat([df, df_test_filtered], ignore_index=True)
print(df_combined.shape)

(18766, 4)


In [15]:
drug_dataset = pd.read_csv("data/final_data.csv")
print(drug_dataset.shape)

(56214, 2)


In [16]:
disease_drug_map = {}

for disease in drug_dataset['disease'].unique():
    disease_drugs = drug_dataset[drug_dataset['disease'] == disease]['drug'].unique()
    disease_drug_map[disease] = set(disease_drugs)

In [17]:
df_test_filtered = df_test_filtered.drop_duplicates(subset=["disease"]).reset_index(drop=True)

test_accuracy = evaluate_multiple_accuracy(
    df_test_filtered,
    df_combined,
    disease_drug_map,
    tokenizer,
    bert_model,
    device,
    top_n=5
)
print(f"Test accuracy: {test_accuracy:.2%}")

Test accuracy: 49.11%


In [18]:
def calculate_mrr(df_eval, df_combined, disease_drug_map, tokenizer, bert_model, device, top_n=5):
    """
    Calculate the Mean Reciprocal Rank (MRR) for the predicted drugs.
    """
    reciprocal_ranks = []

    for _, row in df_eval.iterrows():
        disease_name = row['disease']
        
        predicted_drugs = find_most_similar_drugs(
            disease_name,
            df_combined,
            tokenizer,
            bert_model,
            device,
            top_n=top_n
        )
        
        # Find the rank of the first correct drug in the predicted list
        correct_drugs = disease_drug_map[disease_name]
        rank = 0
        
        # Check for the first correct drug in the top N predictions
        for idx, predicted_drug in enumerate(predicted_drugs):
            if predicted_drug in correct_drugs:
                rank = idx + 1  # 1-based index for rank
                break
        
        # If no correct drug is found, the rank remains 0
        if rank > 0:
            reciprocal_ranks.append(1 / rank)
        else:
            reciprocal_ranks.append(0)  # No correct drug found in the top-N list
    
    mrr = sum(reciprocal_ranks) / len(df_eval)
    return mrr

In [19]:
mrr = calculate_mrr(
    df_test_filtered,
    df_combined,
    disease_drug_map,
    tokenizer,
    bert_model,
    device,
    top_n=3101
)
print(f"Mean Reciprocal Rank (MRR): {mrr:.3f}")

Mean Reciprocal Rank (MRR): 0.367


In [20]:
disease_name = "Leukemia"
top_drugs = find_most_similar_drugs(disease_name, df_combined, tokenizer, bert_model, device, top_n=25)
print(top_drugs)

['methotrexate', 'fludarabine', 'busulfan', 'cladribine', 'cytarabine', 'flutamide', 'topotecan', 'clofarabine', 'dactinomycin', 'idarubicin', 'sirolimus', 'epirubicin', 'azathioprine', 'vincristine sulfate', 'mizoribine', 'ifosfamide', 'mitoxantrone', 'dacarbazine', 'vincristine', 'vinorelbine', 'methylprednisolone', 'prednisone', 'tiagabine', 'melphalan', 'procarbazine']


In [21]:
disease_name = "Kidney Cancer"
top_drugs = find_most_similar_drugs(disease_name, df_combined, tokenizer, bert_model, device, top_n=25)
print(top_drugs)

['oxaliplatin', 'vincristine', 'capecitabine', 'sorafenib', 'sunitinib', 'dasatinib', 'lenalidomide', 'gemcitabine', 'irinotecan', 'clofarabine', 'temozolomide', 'bortezomib', 'methotrexate', 'etoposide', 'doxorubicin', 'docetaxel', 'everolimus', 'letrozole', 'ibrutinib', 'thalidomide', 'tamoxifen', 'sirolimus', 'sunitinib malate', 'belinostat', 'cisplatin']


In [22]:
disease_name = "Renal Cell Carcinoma"
top_drugs = find_most_similar_drugs(disease_name, df_combined, tokenizer, bert_model, device, top_n=25)
print(top_drugs)

['brigatinib', 'selumetinib', 'pemigatinib', 'pazopanib', 'lenvatinib', 'sorafenib', 'encorafenib', 'axitinib', 'regorafenib', 'sunitinib malate', 'tofacitinib', 'ponatinib', 'dasatinib', 'vandetanib', 'infigratinib', 'ruxolitinib', 'capecitabine', 'cobimetinib', 'afatinib', 'osimertinib', 'gemcitabine', 'ibrutinib', 'oxaliplatin', 'dabrafenib', 'clofarabine']


In [23]:
disease_name = "Melanoma"
top_drugs = find_most_similar_drugs(disease_name, df_combined, tokenizer, bert_model, device, top_n=25)
print(top_drugs)

['dacarbazine', 'procarbazine', 'pazopanib', 'clofarabine', 'topotecan', 'ibrutinib', 'melphalan', 'tegafur-uracil', 'busulfan', 'abiraterone acetate', 'idarubicin', 'cabazitaxel', 'vemurafenib', 'cytarabine', 'ruxolitinib', 'interferon alfa-2b', 'gemcitabine', 'selumetinib', 'encorafenib', 'vandetanib', 'eribulin', 'Adriamycin-Bleomycin-Vinblastine-Dacarbazine Regimen', 'carmustine', 'brigatinib', 'cladribine']


In [24]:
disease_name = "Non-Hodgkin's Lymphoma"
top_drugs = find_most_similar_drugs(disease_name, df_combined, tokenizer, bert_model, device, top_n=25)
print(top_drugs)

['vincristine', 'oxaliplatin', 'dacarbazine', 'clofarabine', 'methotrexate', 'procarbazine', 'gemcitabine', 'etoposide', 'temozolomide', 'topotecan', 'cisplatin', 'tiagabine', 'busulfan', 'belinostat', 'idarubicin', 'bortezomib', 'cytarabine', 'cabazitaxel', 'carmustine', 'cladribine', 'ibrutinib', 'fludarabine', 'nelarabine', 'doxorubicin', 'epirubicin']


In [25]:
disease_name = "Invasive carcinoma of breast"
top_drugs = find_most_similar_drugs(disease_name, df_combined, tokenizer, bert_model, device, top_n=25)
print(top_drugs)

['tamoxifen', 'docetaxel', 'BEP regimen', 'oxaliplatin', 'capecitabine', 'vincristine', 'carboplatin', 'etoposide', 'sunitinib', 'TAC Regimen', 'letrozole', 'Folfox protocol', 'abiraterone', 'anastrozole', 'irinotecan', 'cisplatin', 'FOLFOX Regimen', 'thalidomide', 'cabazitaxel', 'lenalidomide', 'folfirinox', 'cyclophosphamide', 'lapatinib', 'etoposide phosphate', 'docetaxel anhydrous']


In [26]:
disease_name = "Lung Cancer"
top_drugs = find_most_similar_drugs(disease_name, df_combined, tokenizer, bert_model, device, top_n=25)
print(top_drugs)

['oxaliplatin', 'cisplatin', 'vincristine', 'capecitabine', 'temozolomide', 'etoposide', 'docetaxel', 'gemcitabine', 'methotrexate', 'doxorubicin', 'sorafenib', 'dasatinib', 'irinotecan', 'dacarbazine', 'clofarabine', 'lenalidomide', 'bortezomib', 'paclitaxel', 'procarbazine', 'cabazitaxel', 'cyclophosphamide', 'belinostat', 'cytarabine', 'gefitinib', 'irinotecan hydrochloride']


In [27]:
disease_name = "Adenocarcinoma"
top_drugs = find_most_similar_drugs(disease_name, df_combined, tokenizer, bert_model, device, top_n=25)
print(top_drugs)

['sunitinib', 'oxaliplatin', 'vincristine', 'carboplatin', 'capecitabine', 'etoposide', 'irinotecan', 'temozolomide', 'cisplatin', 'gemcitabine', 'docetaxel', 'doxorubicin', 'tamoxifen', 'irinotecan hydrochloride', 'bendamustine', 'cabazitaxel', 'gefitinib', 'BEP regimen', 'epirubicin', 'thalidomide', 'sorafenib', 'lapatinib', 'capmatinib', 'macimorelin', 'cyclophosphamide']


In [28]:
disease_name = "Colorectal Cancer"
top_drugs = find_most_similar_drugs(disease_name, df_combined, tokenizer, bert_model, device, top_n=25)
print(top_drugs)

['capecitabine', 'oxaliplatin', 'gemcitabine', 'irinotecan', 'temozolomide', 'sorafenib', 'docetaxel', 'cisplatin', 'irinotecan hydrochloride', 'clofarabine', 'dasatinib', 'vincristine', 'procarbazine', 'doxorubicin', 'dacarbazine', 'cabazitaxel', 'eribulin', 'topotecan', 'etoposide', 'ibrutinib', 'bortezomib', 'tegafur-uracil', 'lenalidomide', 'idarubicin', 'belinostat']
