In [None]:
import numpy as np
import pandas as pd
from time import time

import torch
import torch.nn.functional as F

from madrigal.utils import DATA_DIR, BASE_DIR
from madrigal.evaluate.predict import get_data_for_analysis_all_drugs
from madrigal.models.models import NovelDDIEncoder, NovelDDIMultilabel
from madrigal.utils import to_device
from madrigal.evaluate.eval_utils import get_evaluate_masks

OUTPUT_DIR = DATA_DIR + "polypharmacy_new/DrugBank/"

In [None]:
drug_metadata_nash = pd.read_pickle(BASE_DIR + "processed_data/views_features_new/combined_metadata_nash.pkl")
drug_metadata_nash.shape

(11607, 86)

In [3]:
from multiprocessing import Pool

def classwise_normalized_rank_3d_numpy(tensor):
    # flatten the tensor while maintaining the class dimension
    # print("Flattening...")
    flat_tensor = tensor.reshape(tensor.shape[0], -1)
    
    # compute the ranks
    # print("Computing ranks (via argsort), will take quite a while...")
    # start = time()
    
    if tensor.shape[0] > 1:
        flat_rank = flat_tensor.argsort(axis=1).argsort(axis=1) + 1
    else:
        temp = flat_tensor.argsort(axis=1)
        flat_rank = np.empty_like(temp)
        flat_rank[0, temp] = np.arange(flat_rank.shape[1]) + 1
        del temp

    # end = time()
    # print(f"Finished computing ranks in {(end - start):.4f} seconds.")
    
    # normalize the ranks
    normalized_rank = flat_rank / (tensor.shape[1] * (tensor.shape[2] - 1) / 2)

    # reshape back to the original shape
    return normalized_rank.reshape(tensor.shape)

In [None]:
EXAMPLE = True  # NOTE: Only generating embeddings and scores for 200 drugs. Normalization is also within pairs of these drugs, which can differ significantly from the full dataset.
data_source = 'DrugBank'
split_method = 'split_by_pairs'
repeat = None
eval_type = 'full_full'
finetune_mode = 'str_str+random_sample'
split_output_dir = BASE_DIR + f"model_output/{data_source}/{split_method}/"

# For each run

In [None]:
# checkpoint = 'drawn-grass-4'  # DrugBank (trained on all data, seed=1)
# checkpoint = 'misty-oath-5'  # DrugBank (trained on all data, seed=0)
# checkpoint = 'whole-fog-7'  # DrugBank (trained on all data, seed=99)
# checkpoint = 'snowy-serenity-8'  # DrugBank (trained on all data, seed=42)
checkpoint = 'revived-aardvark-8'  # DrugBank (trained on all data, seed=2)

checkpoint_dir = BASE_DIR + f'model_output/{data_source}/{split_method}/{checkpoint}/'

# epoch = None
epoch = 700
kg_encoder = 'hgt'
if epoch is None:
    ckpt_path = checkpoint_dir + "best_model.pt"
else:
    ckpt_path = checkpoint_dir + f"checkpoint_{epoch}.pt"

### Generate embeddings and scores

In [5]:
_, _, batch, label_map = get_data_for_analysis_all_drugs(
    data_source=data_source, 
    kg_encoder=kg_encoder, 
    split_method=split_method, 
    repeat=repeat, 
    path_base=DATA_DIR, 
    checkpoint=ckpt_path, 
    first_num_drugs=200 if EXAMPLE else drug_metadata_nash.shape[0], 
    add_specific_drugs="nash"
)

ddi_labels = batch['edge_indices']['label']
ddi_pos_neg_samples = batch['edge_indices']['pos_neg'].float()
true_ddis = ddi_pos_neg_samples
label_map_valid = np.array(label_map)[np.unique(ddi_labels)]

In [6]:
device = torch.device("cuda")
if epoch is None:
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    epoch = checkpoint["epoch"]
else:
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    
encoder = NovelDDIEncoder(**checkpoint['encoder_configs'])
model = NovelDDIMultilabel(encoder, **checkpoint['model_configs'])
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model.to(device)

batch_head = to_device(batch['head'], device)  # dict
batch_tail = to_device(batch['tail'], device)
batch_kg = to_device(batch['kg'], device)
head_masks_base = batch['head']['masks']
tail_masks_base = batch['tail']['masks']

masks_head, masks_tail = get_evaluate_masks(head_masks_base, tail_masks_base, eval_type, finetune_mode, device)

Using pretrained structure encoder
Using pretrained KG encoder
Using pretrained CV encoder
Using pretrained TX encoder
INCOMP_KEYS (make sure these contain what you expected):
%s {   'Missing keys': [   'covariates_embeddings.0.weight'],
    'Unexpected_keys': [   ]}




In [None]:
head_drugs = batch_head['drugs']
head_mol_strs = batch_head['strs']
head_cv = batch_head['cv']
head_tx_all_cell_lines = batch_head['tx']
head_masks = masks_head

tail_drugs = batch_tail['drugs']
tail_mol_strs = batch_tail['strs']
tail_cv = batch_tail['cv']
tail_tx_all_cell_lines = batch_tail['tx']
tail_masks = masks_tail

with torch.no_grad():
    z_full = model.encoder(head_drugs, head_masks, head_mol_strs, batch_kg, head_cv, head_tx_all_cell_lines)
    if model.normalize:
        z_full = F.normalize(z_full)

torch.save(z_full.detach().cpu(), f"{checkpoint_dir}/{data_source}_drug_embeddings_full.pt")

In [15]:
fp = np.memmap(
    f"{checkpoint_dir}/{data_source}_drugs_raw_scores_{epoch}.raw", 
    dtype=np.float32, mode="w+", shape=(label_map.shape[0], z_full.shape[0], z_full.shape[0])
)

start_idx = 0
for start, end in zip(np.arange(0, len(label_map), 30), np.arange(0, len(label_map), 30)[1:].tolist() + [len(label_map)]):
    print(start)
    label_range = (start, end)
    with torch.no_grad():
        pred_scores = model.decoder(z_full, z_full, label_range).detach().cpu().numpy()
    fp[start:end, :, :] = pred_scores

with open(f"{checkpoint_dir}/{data_source}_drugs_raw_scores_{epoch}.npy", "wb") as f:
    np.save(f, fp)
    
fp.flush()

0
30
60
90
120
150


## Convert scores to normalized ranks

Normalize scores to ranks, normalized to [0, 1]. This step can take a few hours to run for a large number of drugs.

Script version can be found at `notebooks/normalize_scores.py`

In [None]:
raw_scores = np.load(f"{checkpoint_dir}/{data_source}_drugs_raw_scores_{epoch}.npy", mmap_mode="r")
raw_scores_norm = np.memmap(f"{checkpoint_dir}/{data_source}_drugs_normalized_ranks_{epoch}.raw", mode="w+", dtype=np.float32, shape=raw_scores.shape)

mask_indices = np.vstack(np.triu_indices(raw_scores.shape[1], k=0, m=raw_scores.shape[2]))
interval = 1
def run_slice(tup):
    st = time()
    start, end = tup
    raw_scores_slice = raw_scores[start:end, :, :]
    raw_scores_slice = raw_scores_slice.copy()
    raw_scores_slice[:, mask_indices[0], mask_indices[1]] = 1e7
    raw_scores_slice_norm = classwise_normalized_rank_3d_numpy(raw_scores_slice)
    raw_scores_slice_norm[:, mask_indices[0], mask_indices[1]] = 0
    raw_scores_slice_norm = raw_scores_slice_norm + raw_scores_slice_norm.swapaxes(1, 2)
    # assert raw_scores_slice_norm.max() < 1e7
    raw_scores_norm[start:end, :, :] = raw_scores_slice_norm
    e = time()
    # print(f"Finished normalizing class {start} in {((e - st) / 60):.4f} minutes.")

In [28]:
np.isnan(raw_scores).sum()

0

In [None]:
st = time()
with Pool() as pool:
    pool.map(run_slice, zip(np.arange(0, raw_scores.shape[0], 1), np.arange(0, raw_scores.shape[0], 1)[1:].tolist() + [raw_scores.shape[0]]))
e = time()
print(f"Takes {(e-st):.4f} seconds to run the score normalization.")

with open(f"{checkpoint_dir}/{data_source}_drugs_normalized_ranks_{epoch}.npy", "wb") as f:
    np.save(f, raw_scores_norm)

In [25]:
raw_scores_norm.flush()

# Across runs

## Geometric mean-aggregate normalized ranks

In [None]:
from scipy.stats.mstats import gmean

checkpoints = [
    "drawn-grass-4",  # DrugBank (trained on all data, seed=1)
    "misty-oath-5",  # DrugBank (trained on all data, seed=0)
    "whole-fog-7",  # DrugBank (trained on all data, seed=99)
    "snowy-serenity-8",  # DrugBank (trained on all data, seed=42)
    "revived-aardvark-8",  # DrugBank (trained on all data, seed=2)
]

normalized_ranks_list = []
for checkpoint in checkpoints:
    eval_type = 'full_full'
    checkpoint_dir = BASE_DIR + f'model_output/{data_source}/{split_method}/{checkpoint}/'
    normalized_ranks = np.load(f"{checkpoint_dir}/{data_source}_drugs_normalized_ranks_{epoch}.npy", mmap_mode="r")
    normalized_ranks_list.append(normalized_ranks)

gmean_fp = np.memmap(
    split_output_dir + f"{data_source}_drugs_normalized_ranks_{epoch}_gmean.raw", 
    dtype=np.float32, mode="w+", shape=(normalized_ranks.shape[0], normalized_ranks.shape[1], normalized_ranks.shape[2])
)

start_idx = 0
interval = 10

st = time()
for start, end in zip(
    np.arange(0, normalized_ranks.shape[0], interval), 
    np.arange(0, normalized_ranks.shape[0], interval)[1:].tolist() + [normalized_ranks.shape[0]]
):
    print(start)
    gmean_fp[start:end, :, :] = gmean(np.stack([ranks[start:end, :, :] for ranks in normalized_ranks_list], axis=-1), axis=-1)
e = time()
print(f"Takes {(e-st):.4f} seconds to run gmean.")

with open(split_output_dir + f"{data_source}_drugs_normalized_ranks_{epoch}_gmean.npy", "wb") as f:
    np.save(f, gmean_fp)
    
gmean_fp.flush()

0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
Takes 3015.0882 seconds to run gmean for 5 normalized ranks.


## Re-normalize (normalized) ranks

In [None]:
gmean_ranks = np.load(split_output_dir + f"{data_source}_drugs_normalized_ranks_{epoch}_gmean.npy", mmap_mode="r")

gmean_ranks_norm = np.memmap(
    split_output_dir + f"{data_source}_drugs_normalized_ranks.raw", 
    mode="w+", dtype=np.float32, shape=gmean_ranks.shape
)
mask_indices = np.vstack(np.triu_indices(gmean_ranks.shape[1], k=0, m=gmean_ranks.shape[2]))

def run_slice(tup):
    st = time()
    start, end = tup
    gmean_ranks_slice = gmean_ranks[start:end, :, :]
    gmean_ranks_slice = gmean_ranks_slice.copy()
    gmean_ranks_slice[:, mask_indices[0], mask_indices[1]] = 1e7
    gmean_ranks_slice_norm = classwise_normalized_rank_3d_numpy(gmean_ranks_slice)
    gmean_ranks_slice_norm[:, mask_indices[0], mask_indices[1]] = 0
    gmean_ranks_slice_norm = gmean_ranks_slice_norm + gmean_ranks_slice_norm.swapaxes(1, 2)
    # assert gmean_ranks_slice_norm.max() < 1e7
    gmean_ranks_norm[start:end, :, :] = gmean_ranks_slice_norm
    e = time()
    # print(f"Finished normalizing class {start} in {((e - st) / 60):.4f} minutes.")

st = time()
with Pool() as pool:
    pool.map(run_slice, zip(np.arange(0, gmean_ranks.shape[0], 1), np.arange(0, gmean_ranks.shape[0], 1)[1:].tolist() + [gmean_ranks.shape[0]]))
e = time()
print(f"Takes {(e-st):.4f} seconds to run score normalization.")

with open(split_output_dir + f"{data_source}_drugs_normalized_ranks.npy", "wb") as f:
    np.save(f, gmean_ranks_norm)

Finished normalizing class 28 in 1.7004 minutes.
Finished normalizing class 60 in 1.7130 minutes.
Finished normalizing class 58 in 1.7616 minutes.
Finished normalizing class 62 in 1.7917 minutes.
Finished normalizing class 56 in 2.2525 minutes.
Finished normalizing class 38 in 2.7902 minutes.
Finished normalizing class 14 in 2.7951 minutes.
Finished normalizing class 32 in 2.8403 minutes.
Finished normalizing class 44 in 2.8908 minutes.
Finished normalizing class 18 in 2.9051 minutes.
Finished normalizing class 0 in 2.9185 minutes.
Finished normalizing class 2 in 2.9214 minutes.
Finished normalizing class 20 in 2.9364 minutes.
Finished normalizing class 16 in 3.0938 minutes.
Finished normalizing class 36 in 3.2349 minutes.
Finished normalizing class 22 in 3.3714 minutes.
Finished normalizing class 40 in 3.3913 minutes.
Finished normalizing class 24 in 3.4244 minutes.
Finished normalizing class 4 in 3.4342 minutes.
Finished normalizing class 8 in 3.4389 minutes.
Finished normalizing cla

In [None]:
gmean_ranks_norm

memmap([[[0.        , 0.6872362 , 0.12949748, ..., 0.30432162,
          0.37869346, 0.45221105],
         [0.6872362 , 0.        , 0.27125627, ..., 0.37366834,
          0.46492463, 0.49110553],
         [0.12949748, 0.27125627, 0.        , ..., 0.30708542,
          0.598995  , 0.330201  ],
         ...,
         [0.30432162, 0.37366834, 0.30708542, ..., 0.        ,
          0.44246233, 0.41954774],
         [0.37869346, 0.46492463, 0.598995  , ..., 0.44246233,
          0.        , 0.680603  ],
         [0.45221105, 0.49110553, 0.330201  , ..., 0.41954774,
          0.680603  , 0.        ]],

        [[0.        , 0.33954775, 0.68708545, ..., 0.67577887,
          0.7384925 , 0.8203518 ],
         [0.33954775, 0.        , 0.528995  , ..., 0.2919598 ,
          0.41592965, 0.49537688],
         [0.68708545, 0.528995  , 0.        , ..., 0.8427136 ,
          0.8675879 , 0.8438191 ],
         ...,
         [0.67577887, 0.2919598 , 0.8427136 , ..., 0.        ,
          0.82798994, 0.9