This is an extra notebook to help me understand the construct_heterogeneous_graph_jihwan.py and construct_heterogeneous_graph_PyG_jihwan.py scripts, by visualizing how the graph gets constructed. In the first script the graph is constructed using the ppi (reactome3.txt) and the gene-cell dependency score (CRISPRGeneEffect.csv). The second script builds further on that graph by adding cell and gene features.

In [125]:
import torch

In [126]:
import sys
#sys.path.append('/Users/jovanafilipovic/Downloads/MSc Bioinformatics/Year 2/Thesis/Python_scripts')

from NetworkAnalysis.MultiGraph import MultiGraph
from NetworkAnalysis.UndirectedInteractionNetwork import UndirectedInteractionNetwork
from itertools import combinations

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import pickle
import os

BASE_PATH = "./Data/"
cancer_type = "Neuroblastoma"
train_ratio = 0.8
ppi = "Reactome"
remove_rpl = "_noRPL"
remove_commonE = ""  #this will not be removed
useSTD = "STD"
crispr_threshold_pos = -1.5
ppi_train_ratio = 0.8

# Read in all relevant DepMap data
ccles_ori = pd.read_csv(BASE_PATH+"Depmap/Model.csv", index_col=0)
# ccles = ccles_ori.loc[ccles_ori.PatientID.drop_duplicates().index]
ccles = ccles_ori
ccles['OncotreePrimaryDisease'].value_counts()

path = BASE_PATH+'Depmap/CRISPRGeneEffect.csv'
crispr_effect = pd.read_csv(path, header=0, index_col=0)
crispr_effect.columns = [i.split(' ')[0] for i in crispr_effect.columns]

# Read in PPI and construct prior node embeddings
# Ndex2 pcnet 
if ppi == 'PCNet':
    ppi_obj = UndirectedInteractionNetwork.from_ndex(ndex_id='c3554b4e-8c81-11ed-a157-005056ae23aa', keeplargestcomponent=False,
                                                    attributes_for_names='v', node_type=int)
else:
    ppi_ = pd.read_csv(BASE_PATH+'reactome3.txt', header=0, sep='\t')
    ppi_obj = UndirectedInteractionNetwork(ppi_, keeplargestcomponent=False) #You are keeping all components, not just the largest connected one.
    
ppi_obj.set_node_types(node_types={i: "gene" for i in ppi_obj.node_names}) #set node type to gene for each node

degreedf = ppi_obj.getDegreeDF(set_index=True) #dataframe with degree of each node
degreedf.loc[['BRIP1', 'RRM2','LCE2C']]


Continuing with Gene1 and Gene2 as columns for the nodes
14034 Nodes and 278974 interactions


Unnamed: 0,Gene,Count
BRIP1,BRIP1,84
RRM2,RRM2,42
LCE2C,LCE2C,13


In [127]:
ppi_obj

        Gene_A   Gene_B
0         A1CF  APOBEC1
1         A1CF    CELF2
2         A1CF    EP300
3         A1CF    RBM47
4          A2M    APOA1
...        ...      ...
278969    ZW10   ZWILCH
278970    ZW10    ZWINT
278971  ZWILCH    ZWINT
278972    ZXDA     ZXDB
278973    ZXDA     ZXDC

[278974 rows x 2 columns]

In [128]:
with open(BASE_PATH+f"multigraphs/{cancer_type.replace(' ', '_')}_{ppi}{remove_rpl}_{useSTD}{remove_commonE}_crispr{str(crispr_threshold_pos).replace('.','_')}.pickle", 'rb') as handle:
    mg_obj = pickle.load(handle)

In [129]:
mg_obj

            Gene_A   Gene_B      type
0             A1CF  APOBEC1  scaffold
1             A1CF    CELF2  scaffold
2             A1CF    EP300  scaffold
3             A1CF    RBM47  scaffold
4              A2M    APOA1  scaffold
...            ...      ...       ...
293068  ACH-002922     RPS5    depmap
293069  ACH-002922   TOPBP1    depmap
293070  ACH-002922    RPS20    depmap
293071  ACH-002922     NIFK    depmap
293072  ACH-002922    SF3B3    depmap

[293073 rows x 3 columns]

In [130]:
all_genes_int = mg_obj.type2nodes['gene']
all_genes_name = [mg_obj.int2gene[i] for i in all_genes_int]

print(all_genes_int[-1])
print(all_genes_name)
print(len(all_genes_name))
print(len(all_genes_int))


14070
['A1CF', 'A2M', 'A4GNT', 'AAAS', 'AADAT', 'AAG1', 'AAK1', 'AAMP', 'AAR2', 'AARS1', 'AARS2', 'AARSD1', 'AASDHPPT', 'AATF', 'ABAT', 'ABCA1', 'ABCA10', 'ABCA12', 'ABCA13', 'ABCA2', 'ABCA3', 'ABCA4', 'ABCA5', 'ABCA6', 'ABCA7', 'ABCA8', 'ABCA9', 'ABCB1', 'ABCB10', 'ABCB11', 'ABCB4', 'ABCB7', 'ABCB8', 'ABCB9', 'ABCC1', 'ABCC10', 'ABCC11', 'ABCC12', 'ABCC2', 'ABCC3', 'ABCC4', 'ABCC6', 'ABCC8', 'ABCC9', 'ABCD1', 'ABCD2', 'ABCD3', 'ABCD4', 'ABCE1', 'ABCF1', 'ABCF2', 'ABCG1', 'ABCG4', 'ABCG5', 'ABCG8', 'ABHD17A', 'ABHD17B', 'ABHD17C', 'ABHD5', 'ABI1', 'ABI2', 'ABL1', 'ABL2', 'ABLIM1', 'ABLIM2', 'ABLIM3', 'ABR', 'ABRAXAS1', 'ABRAXAS2', 'ABT1', 'ACAA1', 'ACAA2', 'ACACA', 'ACACB', 'ACAD11', 'ACAD9', 'ACADL', 'ACADM', 'ACADS', 'ACADVL', 'ACAN', 'ACAP1', 'ACAP2', 'ACAT1', 'ACAT2', 'ACBD3', 'ACBD4', 'ACBD5', 'ACD', 'ACE', 'ACE2', 'ACHE', 'ACIN1', 'ACKR1', 'ACKR2', 'ACKR3', 'ACKR4', 'ACLY', 'ACO1', 'ACO2', 'ACOT1', 'ACOT11', 'ACOT12', 'ACOT13', 'ACOT2', 'ACOT4', 'ACOT7', 'ACOT7L', 'ACOT8', 'ACOX1

In [131]:
#PPI
ppi_obj = mg_obj.getEdgeType_subset(edge_type='scaffold') #gene-gene interactions
ppi_obj_new_gene2int = {n:i for i, n in enumerate(all_genes_name)}
ppi_obj_new_int2gene = {v:k for k, v in ppi_obj_new_gene2int.items()}
ppi_interactions = ppi_obj.getInteractionNamed() #function defined in multigraph.py, it returns a dataframe with 2 columns "Gene_A", "Gene_B"
print(ppi_interactions)
ppi_interactions = ppi_interactions.map(lambda x: ppi_obj_new_gene2int[x]) #replace gene names with their corresponding integers
print(ppi_interactions)

print(len(ppi_obj_new_gene2int), len(ppi_obj_new_int2gene))

Returning UndirectedInteractionNetwork object.
Continuing with Gene_A and Gene_B as columns for the nodes
14034 Nodes and 278974 interactions
        Gene_A   Gene_B
0         A1CF  APOBEC1
1         A1CF    CELF2
2         A1CF    EP300
3         A1CF    RBM47
4          A2M    APOA1
...        ...      ...
278969    ZW10   ZWILCH
278970    ZW10    ZWINT
278971  ZWILCH    ZWINT
278972    ZXDA     ZXDB
278973    ZXDA     ZXDC

[278974 rows x 2 columns]
        Gene_A  Gene_B
0            0     579
1            0    1828
2            0    3466
3            0    9946
4            1     574
...        ...     ...
278969   14026   14027
278970   14026   14028
278971   14027   14028
278972   14029   14030
278973   14029   14031

[278974 rows x 2 columns]
14034 14034


In [132]:

# DEP obj
dep_obj = mg_obj.getEdgeType_subset(edge_type='depmap')
cells = [k for k, v in mg_obj.node_type_names.items() if v == 'cell']
cell2int = {c:i for i, c in enumerate(cells)}
int2cell = {v:k for k, v in cell2int.items()}
dep_interactions = dep_obj.getInteractionNamed() #2 columns
print(dep_interactions)
dep_genes = [dep_obj.int2gene[i] for i in dep_obj.type2nodes['gene']]
print(dep_genes)

dep_interactions.loc[~dep_interactions.Gene_A.isin(cells), ['Gene_A', 'Gene_B']] = \
    dep_interactions.loc[~dep_interactions.Gene_A.isin(cells), ['Gene_B', 'Gene_A']].values # assure that all values in Gene_A are cells, otherwise switch with the gene column

assert dep_interactions.Gene_A.isin(cells).sum() == dep_interactions.shape[0] #all Gene_A should be cell lines
dep_interactions = dep_interactions.map(lambda x: cell2int[x] if x in cell2int else ppi_obj_new_gene2int[x]) #map cell lines and genes to their resp. integers
print(dep_interactions)
dep_interactions = dep_interactions[['Gene_B', 'Gene_A']]
print(dep_interactions)
print(dep_interactions.shape)

Returning UndirectedInteractionNetwork object.
Continuing with Gene_A and Gene_B as columns for the nodes
1037 Nodes and 14099 interactions
            Gene_A  Gene_B
278974  ACH-000078   HSPD1
278975  ACH-000078   PRPF6
278976  ACH-000078   RPS4X
278977  ACH-000078   CTDP1
278978  ACH-000078    BUB3
...            ...     ...
293068  ACH-002922    RPS5
293069  ACH-002922  TOPBP1
293070  ACH-002922   RPS20
293071  ACH-002922    NIFK
293072  ACH-002922   SF3B3

[14099 rows x 2 columns]
['AAMP', 'AARS1', 'ABCB7', 'ABCE1', 'ABT1', 'ACO2', 'ACTL6A', 'ACTR10', 'ACTR1A', 'ACTR2', 'ADSL', 'AFG3L2', 'AHCTF1', 'AK6', 'AKIRIN2', 'ALDOA', 'ALG1', 'ALG11', 'ALG13', 'ALG2', 'ALK', 'ALYREF', 'ANAPC1', 'ANAPC10', 'ANAPC11', 'ANAPC2', 'ANAPC4', 'ANAPC5', 'ANKLE2', 'AQR', 'ARCN1', 'ARF4', 'ARFRP1', 'ARIH1', 'ARL2', 'ARPC4', 'ASCL1', 'ASPM', 'ATL2', 'ATP2A2', 'ATP5F1A', 'ATP5F1B', 'ATP5F1D', 'ATP5F1E', 'ATP5ME', 'ATP5PB', 'ATP6AP1', 'ATP6V0C', 'ATP6V0D1', 'ATP6V1A', 'ATP6V1B2', 'ATP6V1C1', 'ATP6V1D', 'A

In [153]:
def read_gmt_file(fp, nw_obj): #read gmt files: generates a dictionary with gene set names and corresponding genes
    genes_per_DB = {}
    if isinstance(nw_obj, list):
        focus_genes = set(nw_obj)
    else:
        focus_genes = set(nw_obj.node_names)
    with open(fp) as f:
        lines = f.readlines()
        for line in lines:
            temp = line.strip('\n').split('\t')
            genes_per_DB[temp[0]] = set(gene for gene in temp[2:]) & focus_genes
    return genes_per_DB

gene_feat_name = 'cgp'
# Gene features
if gene_feat_name == 'cgp':
    cgn = read_gmt_file(BASE_PATH+"MsigDB/c2.cgp.v2023.2.Hs.symbols.gmt", ppi_obj) #ppi_obj is used to filter out genes that are not in the network
elif gene_feat_name == 'bp':
    cgn = read_gmt_file(BASE_PATH+"MsigDB/c5.go.bp.v2023.2.Hs.symbols.gmt", ppi_obj)
elif gene_feat_name == 'go':    
    cgn = read_gmt_file(BASE_PATH+"MsigDB/c5.go.v2023.2.Hs.symbols.gmt", ppi_obj)
elif gene_feat_name == 'cp':  
    cgn = read_gmt_file(BASE_PATH+"MsigDB/c2.cp.v2023.2.Hs.symbols.gmt", ppi_obj)

#print(cgn)

# Create a dataframe with all genes (rows) and their corresponding gene sets (columns)
cgn_df = pd.DataFrame(np.zeros((len(all_genes_name), len(cgn))), index=all_genes_name, columns=list(cgn.keys()))
for k, v in cgn.items():
    cgn_df.loc[list(v), k] = 1 #set 1 if gene is in the gene set, 0 otherwise

print(cgn_df)




       ABBUD_LIF_SIGNALING_1_DN  ABBUD_LIF_SIGNALING_1_UP  \
A1CF                        0.0                       0.0   
A2M                         0.0                       0.0   
A4GNT                       0.0                       0.0   
AAAS                        0.0                       0.0   
AADAT                       0.0                       0.0   
...                         ...                       ...   
ZXDA                        0.0                       0.0   
ZXDB                        0.0                       0.0   
ZXDC                        0.0                       0.0   
ZYX                         0.0                       0.0   
ZZZ3                        0.0                       0.0   

       ABBUD_LIF_SIGNALING_2_DN  ABBUD_LIF_SIGNALING_2_UP  \
A1CF                        0.0                       0.0   
A2M                         0.0                       0.0   
A4GNT                       0.0                       0.0   
AAAS                   

In [134]:
# Create a dataframe with all genes (rows) and their corresponding gene sets (columns)
cgn_df = pd.DataFrame(np.zeros((len(all_genes_name), len(cgn))), index=all_genes_name, columns=list(cgn.keys()))
for k, v in cgn.items():
    cgn_df.loc[list(v), k] = 1  #set 1 if gene is in the gene set, 0 otherwise
zero_gene_feat = cgn_df.index[cgn_df.sum(axis=1) == 0] # This is not allowed because all genes must have features
# Check how many of the dep genes are in that all 0, otherwise this is basically of no use
zero_depgenes = set(zero_gene_feat) & set(dep_genes) #genes that are in the dep_genes and have no features
print(len(zero_depgenes))
cgn_df = cgn_df.drop(zero_depgenes) #remove genes that have no features
zero_depgenes_ids = [ppi_obj_new_gene2int[gene] for gene in zero_depgenes]

# Filter ppi_obj_new_gene2int to only include genes present in the updated cgn_df
print(len(ppi_obj_new_gene2int))
ppi_obj_new_gene2int = {k:v for k, v in ppi_obj_new_gene2int.items() if k in cgn_df.index}
print(len(ppi_obj_new_gene2int))

# Delete zero_depgenes from ppi_interactions and dep_interactions
ppi_interactions = ppi_interactions[~ppi_interactions.isin(zero_depgenes_ids).any(axis=1)]
dep_interactions = dep_interactions[~dep_interactions.isin(zero_depgenes_ids).any(axis=1)]

#gene featur matrix (rows=genes) 
gene_feat = torch.from_numpy(cgn_df.values).to(torch.float) ##why not filtering???? 크키맞출라고


1
14034
14033


In [135]:
print(ppi_interactions)

        Gene_A  Gene_B
0            0     579
1            0    1828
2            0    3466
3            0    9946
4            1     574
...        ...     ...
278969   14026   14027
278970   14026   14028
278971   14027   14028
278972   14029   14030
278973   14029   14031

[278961 rows x 2 columns]


In [136]:
print(len(ppi_obj_new_gene2int))

ppi_obj_new_gene2int = {gene: idx for idx, gene in enumerate(cgn_df.index)}

print(len(ppi_obj_new_gene2int))

14033
14033


In [137]:
gene_feat = torch.from_numpy(cgn_df.values).to(torch.float) ##why not filtering???? 크키맞출라고
gene_feat

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [138]:
path = BASE_PATH+'Depmap/OmicsCNGene.csv'
ccle_cnv = pd.read_csv(path, header=0, index_col=0)
ccle_cnv.columns = [i.split(' ')[0] for i in ccle_cnv.columns] 
print(ccle_cnv.shape)
ccle_cnv = ccle_cnv[ccle_cnv.columns[ccle_cnv.isna().sum() == 0]] #remove columns with missing values
print(ccle_cnv.shape)

ccle_cnv = ccle_cnv.loc[list(set(cells) & set(ccle_cnv.index))] #filter only the cells that are in the cell lines
print(ccle_cnv)

hvg_q = ccle_cnv.std().quantile(q=0.95)  #compute the 95th percentile of the standard deviation per gene accross all cells
hvg_final = ccle_cnv.std()[ccle_cnv.std() >= hvg_q].index #select genes with high variance (top 5%)

ccle_cnv_hvg = ccle_cnv[hvg_final]
cell_feat = torch.from_numpy(ccle_cnv_hvg.loc[cell2int.keys()].values).to(torch.float) #cell feature matrix
cell_feat

(1788, 24383)
(1788, 24352)
              FAM87B  LINC01128  AL669831.7    FAM41C  LINC02593    SAMD11  \
ACH-001366  1.051869   1.051869    1.051869  1.051869   1.051869  1.051869   
ACH-002278  0.905035   0.905035    0.905035  0.905035   0.905035  0.905035   
ACH-000312  0.617370   0.617370    0.617370  0.617370   0.617370  0.617370   
ACH-000259  1.418187   1.418187    1.418187  1.418187   1.418187  1.418187   
ACH-001603  0.543102   0.543102    0.543102  0.543102   0.543102  0.543102   
ACH-000227  0.548441   0.548441    0.548441  0.548441   0.548441  0.548441   
ACH-002280  1.006563   1.006563    1.006563  1.006563   1.006563  1.006563   
ACH-002261  0.962287   0.962287    0.962287  0.962287   0.962287  0.962287   
ACH-001302  0.539973   0.539973    0.539973  0.539973   0.539973  0.539973   
ACH-000366  0.791845   0.791845    0.791845  0.791845   0.791845  0.791845   
ACH-002083  0.722578   0.722578    0.722578  0.722578   0.722578  0.722578   
ACH-000078  0.820274   0.820274    0

tensor([[0.8203, 0.8203, 0.8203,  ..., 1.2385, 1.2385, 1.2385],
        [0.5278, 0.5278, 0.5278,  ..., 1.0132, 1.0132, 1.0132],
        [0.4920, 0.4920, 0.4920,  ..., 0.9906, 0.9906, 0.9906],
        ...,
        [0.9981, 0.9981, 0.9981,  ..., 1.0062, 1.0062, 1.0062],
        [0.6785, 0.6785, 0.6785,  ..., 1.0461, 1.0461, 1.0461],
        [1.8529, 1.3213, 2.2529,  ..., 1.2966, 1.7596, 1.7596]])

In [141]:
#  heterodata_obj = torch.load(BASE_PATH+f"multigraphs/"\
#                                 f"heteroData_gene_cell_{cancer_type.replace(' ', '_')}_{ppi}"\
#                                     f"_crispr{str(crp_pos).replace('.','_')}{drugs}_{gene_feat}_{args.cell_feat}.pt")
    
heterodata_obj = torch.load('Data/multigraphs/heteroData_gene_cell_Neuroblastoma_Reactome_crispr-1_5_cgp_cnv.pt')
heterodata_obj

HeteroData(
  gene={
    node_id=[14034],
    names=[14034],
    x=[14034, 3438],
  },
  cell={
    node_id=[37],
    names=[37],
    x=[37, 1218],
  },
  (gene, interacts_with, gene)={ edge_index=[2, 278974] },
  (gene, dependency_of, cell)={ edge_index=[2, 14099] },
  (gene, rev_interacts_with, gene)={ edge_index=[2, 278974] },
  (cell, rev_dependency_of, gene)={ edge_index=[2, 14099] }
)

In [145]:
    cell2int = dict(zip(heterodata_obj['cell'].names, heterodata_obj['cell'].node_id.numpy()))
    gene2int = dict(zip(heterodata_obj['gene'].names, heterodata_obj['gene'].node_id.numpy()))
    dep_genes = list(set(heterodata_obj['gene', 'dependency_of', 'cell'].edge_index[0].numpy())) # all genes that have a dependency edge


In [151]:
# Define the full probability matrix for validation
cls_int = heterodata_obj['cell'].node_id
cl_probs = torch.zeros((2, len(cls_int)*len(dep_genes)), dtype=torch.long)

for i, cl in enumerate(cls_int):
        # cl = 20
    x_ = torch.stack((torch.tensor(dep_genes), 
                    torch.tensor([cl]*len(dep_genes))), dim=0)
                        
    cl_probs[:, i*len(dep_genes):(i+1)*len(dep_genes)] = x_
full_pred_data = heterodata_obj.clone()
full_pred_data['gene', 'dependency_of', 'cell'].edge_label_index = cl_probs

print(full_pred_data)
print(cl_probs)
#print(full_pred_data['gene', 'dependency_of', 'cell'].edge_label_index)

HeteroData(
  gene={
    node_id=[14034],
    names=[14034],
    x=[14034, 3438],
  },
  cell={
    node_id=[37],
    names=[37],
    x=[37, 1218],
  },
  (gene, interacts_with, gene)={ edge_index=[2, 278974] },
  (gene, dependency_of, cell)={
    edge_index=[2, 14099],
    edge_label_index=[2, 37000],
  },
  (gene, rev_interacts_with, gene)={ edge_index=[2, 278974] },
  (cell, rev_dependency_of, gene)={ edge_index=[2, 14099] }
)
tensor([[10241,     7,     9,  ..., 10237,  2046, 10239],
        [    0,     0,     0,  ...,    36,    36,    36]])
