In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#from bioblp import train
from pykeen.pipeline import pipeline
from pykeen.models import TransE
import pandas as pd 
from pathlib import Path
import toml

  from .autonotebook import tqdm as notebook_tqdm


In [93]:
from bioblp.data import COL_EDGE, COL_SOURCE, COL_TARGET
DATA_DIR = Path("../data")


In [9]:
from bioblp.data import load_splits

In [96]:
# dataset_name = 'biokg_mini_random_900505'
dataset_name = 'biokg_random_900505'

train, valid, test = load_splits(dataset=dataset_name,
                                 data_path=DATA_DIR.joinpath('raw/biokg_full_splits'))


In [10]:
train_df = pd.DataFrame(train.triples, columns=[COL_SOURCE, COL_EDGE, COL_TARGET])
train_df.edg.value_counts()

Reconstructing all label-based triples. This is expensive and rarely needed.


DDI                            1194699
PROTEIN_PATHWAY_ASSOCIATION     229631
PPI                             105482
PROTEIN_DISEASE_ASSOCIATION      99608
MEMBER_OF_COMPLEX                79244
DRUG_DISEASE_ASSOCIATION         60319
DPI                              25553
COMPLEX_IN_PATHWAY               20458
COMPLEX_TOP_LEVEL_PATHWAY        14097
DRUG_TARGET                      13670
DRUG_PATHWAY_ASSOCIATION          4646
DISEASE_GENETIC_DISORDER          4594
DRUG_ENZYME                       4508
RELATED_GENETIC_DISORDER          3716
DISEASE_PATHWAY_ASSOCIATION       3258
DRUG_TRANSPORTER                  2730
DRUG_CARRIER                       698
Name: edg, dtype: int64

## Load benchmark

In [14]:
dpi_benchmark_path = DATA_DIR.joinpath('benchmarks/dpi_fda.tsv')

In [15]:
dpi_bm = pd.read_csv(dpi_benchmark_path, sep='\t', names=[COL_SOURCE, COL_EDGE, COL_TARGET])

In [16]:
dpi_bm.edg.value_counts()

DPI    19161
Name: edg, dtype: int64


* [DB01079; Tegaserod](https://go.drugbank.com/drugs/DB01079)
Tegaserod is a serotonin-4 (5-HT4) receptor agonist indicated for the treatment of constipation predominant irritable bowel syndrome (IBS-C) specifically in women under the age of 65. There is currently no safety or efficacy data for use of tegaserol in men.

* https://www.uniprot.org/uniprotkb/Q13639/entry

In [17]:
dpi_bm.head()

Unnamed: 0,src,edg,tgt
0,DB01079,DPI,Q13639
1,DB00114,DPI,P20711
2,DB01158,DPI,P13637
3,DB01069,DPI,P18825
4,DB01186,DPI,P08684


In [97]:
from typing import List

ent2id_map = train.entity_to_id
rel2id_map = train.relation_to_id

def get_ent_ids_for_entity_list(entity_list: List[str], ent2id_map):
    ids = [ent2id_map.get(ent) for ent in entity_list]
    return ids


Load pretrained KGE model

In [102]:
import torch

model_dir = Path('/home/jovyan/BioBLP/models/1baon0eg')
print(f'Loading trained model from {model_dir}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load(model_dir.joinpath(f"trained_model.pkl"), map_location=device)
#if not torch.cuda.is_available():
#    model.device = 'cpu'

Loading trained model from /home/jovyan/BioBLP/models/1baon0eg


#### Retrieve KG embeddings

In [76]:
entity_representation = model.entity_representations[0]._embeddings
relation_representation = model.relation_representations[0]._embeddings
entity_representation

Embedding(106339, 512)

In [98]:
dpi_bm_drugs = list(dpi_bm.src.values)
drug_ids = get_ent_ids_for_entity_list(dpi_bm_drugs, ent2id_map)    
drug_ids = torch.LongTensor(drug_ids)
#drug_embs = model.entity_representations[0](drug_ids)
drug_embs = model.entity_representations[0]._embeddings(drug_ids)

In [99]:
dpi_bm_prots = list(dpi_bm.tgt.values)
prot_ids = get_ent_ids_for_entity_list(dpi_bm_prots, ent2id_map)   
prot_ids = torch.LongTensor(prot_ids)
prot_embs = model.entity_representations[0]._embeddings(prot_ids)

In [87]:
drug_embs.shape, prot_embs.shape

(torch.Size([19161, 512]), torch.Size([19161, 512]))

#### encode pairs of entities

In [130]:
from collections.abc import Callable

def concatenate(emb1, emb2):
    out = torch.cat((emb1, emb2), dim=0).view(1, -1)
    return out

def average(emb1, emb2):
    concat = torch.cat((emb1, emb2), dim=0).view(2, -1)
    out = torch.stack((emb1, emb2)).mean(dim=0).view(1,-1)
    return out

def encode_entity_pair(emb1, emb2, transform:Callable):
    return transform(emb1, emb2)
    

In [131]:
out = encode_entity_pair(emb1=drug_embs[0, :], emb2=prot_embs[0, :], transform=average)

In [132]:
out.shape

torch.Size([1, 512])