In [122]:
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch

In [123]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [124]:
drug_gene_df = pd.read_csv('./Data/preprocessed_data/drug_gene.csv')
drugs = list(drug_gene_df.drug_name.unique())
genes = list(drug_gene_df.gene_name.unique())
disease_gene_df = pd.read_csv('./Data/preprocessed_data/disease_gene.csv')
diseases = list(disease_gene_df.disease_name.unique())
genes.extend(list(disease_gene_df.gene_name.unique()))
disease_drug_df = pd.read_csv('./Data/preprocessed_data/disease_drug.csv')
diseases.extend(list(disease_drug_df.disease_name.unique()))
drugs.extend(list(disease_drug_df.drug_name.unique()))
drug_drug_df = pd.read_csv('./Data/preprocessed_data/drug_drug.csv')
drugs.extend(list(drug_drug_df.drug_1_name.unique()))
drugs.extend(list(drug_drug_df.drug_2_name.unique()))

In [125]:
drugs = list(set(drugs))
genes = list(set(genes))
diseases = list(set(diseases))

In [126]:
len(drugs), len(genes), len(diseases)

(5267, 17682, 5484)

In [127]:
def init_model(model_name, checkpoint=None):
    if checkpoint is None:
        checkpoint = model_name
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(checkpoint).to(device)
    return tokenizer, model

In [128]:
tokenizer, model = init_model('bert-base-uncased')
cutomized_tokenizer, customized_model = init_model('bert-base-uncased','checkpoint-44000')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at checkpoint-44000 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 

In [129]:
from tqdm.notebook import tqdm

def get_embeddings(text_list, tokenizer, model, batch_size=32):
    embeddings = np.array([])
    for i in tqdm(range(0, len(text_list), batch_size)):
        batch = text_list[i:i+batch_size]
        encoded_input = tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device)
        attn_mask = encoded_input['attention_mask'].cpu()
        with torch.no_grad():
            model_output = model(**encoded_input)[0].squeeze(0)
        embeds = model_output.cpu()
        # take mean of embeddings of all tokens in each sentence in the batch use attn_mask to ignore padding
        embeds = (embeds * attn_mask.unsqueeze(2)).sum(1) / attn_mask.sum(1).unsqueeze(1)
        if embeddings.size == 0:
            embeddings = embeds
        else:
            embeddings = np.concatenate((embeddings, embeds), axis=0)
    return embeddings

In [130]:
# drug_embeddings = get_embeddings(drugs, tokenizer, model, batch_size=64)
custom_drug_embeddings = get_embeddings(drugs, cutomized_tokenizer, customized_model, batch_size=64)

  0%|          | 0/83 [00:00<?, ?it/s]

In [131]:
# gene_embeddings = get_embeddings(genes, tokenizer, model, batch_size=64)
custom_gene_embeddings = get_embeddings(genes, cutomized_tokenizer, customized_model, batch_size=64)
# disease_embeddings = get_embeddings(diseases, tokenizer, model, batch_size=64)
custom_disease_embeddings = get_embeddings(diseases, cutomized_tokenizer, customized_model, batch_size=64)

  0%|          | 0/277 [00:00<?, ?it/s]

  0%|          | 0/86 [00:00<?, ?it/s]

In [132]:
# np.save('./Data/embeddings/drug_embeddings.npy', drug_embeddings)
# np.save('./Data/embeddings/gene_embeddings.npy', gene_embeddings)
# np.save('./Data/embeddings/disease_embeddings.npy', disease_embeddings)

np.save('./Data/embeddings/custom_drug_embeddings_2.npy', custom_drug_embeddings)
np.save('./Data/embeddings/custom_gene_embeddings_2.npy', custom_gene_embeddings)
np.save('./Data/embeddings/custom_disease_embeddings_2.npy', custom_disease_embeddings)


In [133]:
gene_embeddings.shape, custom_gene_embeddings.shape

((17682, 768), (17682, 768))

In [99]:
gene_idx_map = {gene: i for i, gene in enumerate(genes)}
drug_idx_map = {drug: i for i, drug in enumerate(drugs)}
disease_idx_map = {disease: i for i, disease in enumerate(diseases)}

In [100]:
import json

with open('./Data/embeddings/gene_idx_map.json', 'w') as f:
    json.dump(gene_idx_map, f)
with open('./Data/embeddings/drug_idx_map.json', 'w') as f:
    json.dump(drug_idx_map, f)
with open('./Data/embeddings/disease_idx_map.json', 'w') as f:
    json.dump(disease_idx_map, f)

In [119]:
np.random.choice(diseases, 5, replace=False)

array(['Horse Diseases', 'Bone Diseases, Endocrine', 'Hemifacial Spasm',
       'Splenomegaly', 'Mental Retardation, X-Linked 72'], dtype='<U115')

In [121]:
drug_drug_df.head(20)

Unnamed: 0,drug_1_name,drug_2_name
0,Vardenafil,Telmisartan
1,Clonidine,Pentoxifylline
2,Clomipramine,Mirabegron
3,Desipramine,Perampanel
4,L-DOPA,Hydralazine
5,Interferon alfa-n3,Methadone
6,Caffeine,Deferasirox
7,Flurbiprofen,Acenocoumarol
8,Dextroamphetamine,Fluspirilene
9,Deferasirox,Tolvaptan
