In [18]:
from lincs_gsnn.proc.get_bio_interactions import get_bio_interactions 

import torch 
import numpy as np
import pandas as pd
from lincs_gsnn.proc.subset import filter_func_nodes
import torch_geometric as pyg

%load_ext autoreload
%autoreload 2



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# model output gene names 
gene_names = pd.read_csv('/home/teddy/local/lincs-traj/workflow/runs/exp/default_v02/output/predict_grid/gene_names.csv')['gene_names'].values.astype(str)
landmark_mapping = pd.read_csv('../extdata/landmark_gene2uni.tsv', sep='\t')

In [3]:
landmark_mapping = landmark_mapping.rename({'From': 'gene_symbol', 'Entry': 'uniprot_id'}, axis=1)[['gene_symbol', 'uniprot_id']]
landmark_mapping.head()

Unnamed: 0,gene_symbol,uniprot_id
0,GNPDA1,P46926
1,GNPDA1,D6R917
2,GNPDA1,D6R9P4
3,GNPDA1,D6RAY7
4,GNPDA1,D6RB13


In [4]:
func_names, func_df = get_bio_interactions(include_extra=True, include_mirna=True, dorothea_levels=['A', 'B', 'C', 'D'],)

# of translation (RNA->PROTEIN) edges: 10000


In [5]:
rna_uniprots = [x.split('__')[1] for x in func_names if x.startswith('RNA__')] 
prot_uniprots = [x.split('__')[1] for x in func_names if x.startswith('PROTEIN__')]

print('# RNA uniprots:', len(rna_uniprots))
print('# PROT uniprots:', len(prot_uniprots))

# RNA uniprots: 19584
# PROT uniprots: 11272


In [6]:
landmark_mapping = landmark_mapping[lambda x: x.uniprot_id.isin(rna_uniprots)] 
landmark_mapping.shape

(1013, 2)

In [21]:
# are there any duplicate gene->uni mappings?
print(landmark_mapping.duplicated(subset=['gene_symbol']).sum())

# there are duplicates, for now, we will just take the first one
landmark_mapping = landmark_mapping.drop_duplicates(subset=['gene_symbol'])
landmark_mapping.shape

# convert to dict 
landmark_map_dict = landmark_mapping.set_index('gene_symbol')['uniprot_id'].to_dict() 

0


In [None]:
drugs = pd.read_csv('/home/teddy/local/lincs-traj/workflow/runs/exp/default_v02/output/predict_grid/pert_ids.csv').pert_id.values.tolist() 
lines = pd.read_csv('/home/teddy/local/lincs-traj/workflow/runs/exp/default_v02/output/predict_grid/cell_inames.csv').cell_iname.values.tolist()
print(len(drugs))

clue_mapping = pd.read_csv('../../data/compoundinfo_beta.txt', sep='\t')[['inchi_key', 'pert_id']].drop_duplicates() 


8


In [None]:
tge = pd.read_csv('../../data/targetome_extended-01-23-25.csv').merge(clue_mapping, on='inchi_key', how='inner')


tge = tge[lambda x: x.assay_type.isin(['Kd', 'Ki'])] # direct targets only 
tge = tge[lambda x: x.assay_relation.isin(['=', '<', '<='])] # exclude ">" relations 
tge = tge[lambda x: x.assay_value <= 1000] # only targets with affinity <= 1000 nM 
tge = tge[lambda x: x.pert_id.isin(drugs)] # only drugs in our list
tge = tge[lambda x: x.uniprot_id.isin(prot_uniprots)]

dtis = tge[['pert_id', 'uniprot_id']].drop_duplicates().rename({'uniprot_id': 'target'}, axis=1)
dtis = dtis.assign(target_name = ['PROTEIN__' + x for x in dtis.target.values]) # add prefix to match func_names

print(tge.shape)
tge.head()

(1185, 9)


Unnamed: 0,pubchem_cid,inchi_key,uniprot_id,pubmed_id,database,assay_type,assay_relation,assay_value,pert_id
23067,5291,KTUFNOKKBVMGRW-UHFFFAOYSA-N,O00571,29191878.0,pubchem_bioassay,Kd,=,435.0,BRD-K92723993
23086,5291,KTUFNOKKBVMGRW-UHFFFAOYSA-N,O14976,18183025.0,pubchem_bioassay,Kd,=,1000.0,BRD-K92723993
23087,5291,KTUFNOKKBVMGRW-UHFFFAOYSA-N,O14976,22037378.0,pubchem_bioassay,Kd,=,1000.0,BRD-K92723993
23113,5291,KTUFNOKKBVMGRW-UHFFFAOYSA-N,O43570,19527930.0,pubchem_bioassay,Ki,=,980.0,BRD-K92723993
23173,5291,KTUFNOKKBVMGRW-UHFFFAOYSA-N,P00519,15711537.0,pubchem_bioassay,Kd,=,44.0,BRD-K92723993


In [10]:
# num targets per drug 
cnts = dtis.groupby('pert_id').size().sort_values(ascending=False)
cnts

pert_id
BRD-K42828737    188
BRD-K99964838    117
BRD-K49328571     89
BRD-K81528515     46
BRD-K44227013     37
BRD-K92723993     22
BRD-K51544265      9
BRD-K33379087      4
dtype: int64

In [11]:
set(drugs) - set(dtis.pert_id.unique()) # drugs not in targetome

set()

In [12]:
dtis.head()

Unnamed: 0,pert_id,target,target_name
23067,BRD-K92723993,O00571,PROTEIN__O00571
23086,BRD-K92723993,O14976,PROTEIN__O14976
23113,BRD-K92723993,O43570,PROTEIN__O43570
23173,BRD-K92723993,P00519,PROTEIN__P00519
23256,BRD-K92723993,P00915,PROTEIN__P00915


In [13]:

lincs = landmark_mapping.uniprot_id.values.tolist() # lincs genes 

# filter nodes that are not downstream of a drug AND do not have downstream LINCS genes 
# also filter targets that are no longer relevant 
func_names2, func_df2, targets2, drugs2, lincs2 = filter_func_nodes(func_names, func_df, dtis, lincs, drugs, filter_depth=4)


filtering function nodes...
function nodes retained: 1440.0/30856
drug nodes retained: 8.0/8
lincs nodes retained: 972.0/974


In [14]:
func2idx = {f:i for i,f in enumerate(func_names2)}
func_df2 = func_df2.assign(src_idx = [func2idx[f] for f in func_df2.source.values])
func_df2 = func_df2.assign(dst_idx = [func2idx[f] for f in func_df2.target.values])

In [24]:
input_names = ['DRUG__' + p for p in drugs2] + ['LINE__' + c for c in lines] + ['GENE__' + g for g in gene_names] 
function_names = func_names2
output_names = ['GENE__' + g for g in gene_names]

input_name2idx = {n:i for i,n in enumerate(input_names)}
output_name2idx = {n:i for i,n in enumerate(output_names)}
function_name2idx = {n:i for i,n in enumerate(function_names)}

In [31]:
input_edge_list = [] 

# add drug edges 
for i, row in targets2.iterrows(): 
    if row.target_name in function_names: 
        input_edge_list.append(['DRUG__' + row.pert_id, row.target_name]) 
    else:
        print(f'No function node for target {row.target} in drug {row.pert_id}')

# add cell line edges (one edge to every function node)
for line in lines: 
    for f in function_names: 
        input_edge_list.append(['LINE__' + line, f])

for g in gene_names: 
    targ = landmark_map_dict.get(g, None) 

    if (targ is not None) and ('RNA__' + targ in function_names): 
        input_edge_list.append(['GENE__' + g, 'RNA__' + targ]) # gene to uniprot target edge
    else: 
        print(f'No uniprot mapping for gene {g}') 
        # self loop as placeholder? 

row,col = np.array(input_edge_list).T
row = [input_name2idx[r] for r in row]
col = [function_name2idx[c] for c in col] 

input_edge_index = torch.tensor([row, col], dtype=torch.long) 
input_edge_index.shape
    

No uniprot mapping for gene B4GAT1
No uniprot mapping for gene CHP1
No uniprot mapping for gene HDGFL3
No uniprot mapping for gene ABCF3
No uniprot mapping for gene PLSCR3
No uniprot mapping for gene JPT2


torch.Size([2, 17170])

In [33]:
func_edge_list = [] 

func_edge_index = torch.stack([torch.tensor(func_df2.src_idx.values, dtype=torch.long),
                              torch.tensor(func_df2.dst_idx.values, dtype=torch.long)], dim=0) 

func_edge_index.shape

torch.Size([2, 22982])

In [34]:
output_edge_list = [] 

for g in gene_names:
    targ = landmark_map_dict.get(g, None) 

    if (targ is not None) and ('RNA__' + targ in function_names): 
        output_edge_list.append(['RNA__' + targ, 'GENE__' + g]) # gene to uniprot target edge
    else: 
        print(f'No uniprot mapping for gene {g}') 
        # self loop as placeholder?

row,col = np.array(output_edge_list).T
row = [function_name2idx[r] for r in row]
col = [output_name2idx[c] for c in col]
output_edge_index = torch.tensor([row, col], dtype=torch.long)
output_edge_index.shape

No uniprot mapping for gene B4GAT1
No uniprot mapping for gene CHP1
No uniprot mapping for gene HDGFL3
No uniprot mapping for gene ABCF3
No uniprot mapping for gene PLSCR3
No uniprot mapping for gene JPT2


torch.Size([2, 972])

In [35]:
data = pyg.data.HeteroData() 

# create data 
data['edge_index_dict'] = {
    ('input',       'to',           'function')     : input_edge_index, 
    ('function',    'to',           'function')     : func_edge_index, 
    ('function',    'to',           'output')       : output_edge_index, 
}


data['node_names_dict'] = {'input':input_names,
                            'function':func_names,
                            'output':output_names}

In [36]:
torch.save(data, '../extdata/bionetwork.pt')