In [1]:
import numpy as np
import pandas as pd
import networkx as nx
import os
import json

# Load data

In [2]:
bp_db = pd.read_csv("../data/GOannotations_kept.csv", index_col=0)
bp_db.head()

Unnamed: 0,ENSEMBL,GO
0,ENSG00000000003,GO:0039532
1,ENSG00000000003,GO:0043123
2,ENSG00000000003,GO:1901223
3,ENSG00000000005,GO:0001886
4,ENSG00000000005,GO:0001937


In [26]:
counts1 = pd.read_csv("../data/counts1.csv", index_col=0)
counts1.head()

Unnamed: 0,ENSG00000000003,ENSG00000000005,ENSG00000000419,ENSG00000000457,ENSG00000000460,ENSG00000000938,ENSG00000000971,ENSG00000001036,ENSG00000001084,ENSG00000001167,...,ENSGR0000167393,ENSGR0000169084,ENSGR0000169093,ENSGR0000178605,ENSGR0000182378,ENSGR0000185291,ENSGR0000198223,ENSGR0000214717,ENSGR0000223511,ENSGR0000223773
089357B,14,7,103,241,72,2057,30,60,207,367,...,1,0,0,0,0,0,0,0,0,0
089366A,11,2,194,511,110,3325,36,111,186,530,...,0,0,0,0,0,0,1,0,0,1
089412B,8,0,312,450,106,3751,45,160,325,653,...,0,0,0,0,0,0,1,0,0,0
089425B,9,0,135,496,133,2758,26,93,182,620,...,0,0,0,0,0,0,0,0,0,0
089687A,4,0,89,267,49,2181,24,75,122,263,...,0,0,0,0,0,0,1,0,0,0


In [4]:
pheno1 = pd.read_csv("../data/pheno1.csv", index_col=0)
pheno1.drop(["diagnosis"], axis=1, inplace=True)
pheno1.head()

Unnamed: 0,age,sex,lithium,condition
089357B,18,F,0,Control
089366A,19,F,0,Control
089412B,23,F,0,Control
089425B,47,F,0,Control
089687A,52,F,0,Control


# process the data

In [32]:
# add age, sex, lithium of pheno1 to counts1
tmp_pheno1 = pheno1[["age", "sex", "lithium"]].apply(lambda x: x.replace("M", 0).replace("F", 1)) # chagne sex to 0, 1
counts1_merge = pd.merge(counts1, tmp_pheno1, left_index=True, right_index=True)

counts1_merge = (counts1_merge - counts1_merge.mean()) / counts1_merge.std()
counts1_merge.head()

Unnamed: 0,ENSG00000000003,ENSG00000000005,ENSG00000000419,ENSG00000000457,ENSG00000000460,ENSG00000000938,ENSG00000000971,ENSG00000001036,ENSG00000001084,ENSG00000001167,...,ENSGR0000178605,ENSGR0000182378,ENSGR0000185291,ENSGR0000198223,ENSGR0000214717,ENSGR0000223511,ENSGR0000223773,age,sex,lithium
089357B,2.110648,4.704691,-0.899272,-1.436269,-1.13286,-1.111674,-0.8044,-1.458287,-0.62945,-1.0402,...,-0.152399,-0.453101,-0.212458,-0.502577,-0.45051,-0.222561,-0.226593,-2.043569,0.879916,-0.720677
089366A,1.348549,1.129061,-0.091576,0.088378,-0.357019,-0.170067,-0.695273,-0.660205,-0.785903,-0.267949,...,-0.152399,-0.453101,-0.212458,-0.146685,-0.45051,-0.222561,2.492525,-1.972198,0.879916,-0.720677
089412B,0.58645,-0.30119,0.955766,-0.256079,-0.438687,0.146277,-0.531583,0.10658,0.249663,0.314793,...,-0.152399,-0.453101,-0.212458,-0.146685,-0.45051,-0.222561,-0.226593,-1.686712,0.879916,-0.720677
089425B,0.840483,-0.30119,-0.615247,0.003676,0.112569,-0.591117,-0.877151,-0.941881,-0.815703,0.158447,...,-0.152399,-0.453101,-0.212458,-0.502577,-0.45051,-0.222561,-0.226593,0.026202,0.879916,-0.720677
089687A,-0.429682,-0.30119,-1.023533,-1.289451,-1.602448,-1.019593,-0.913527,-1.223557,-1.26271,-1.532924,...,-0.152399,-0.453101,-0.212458,-0.146685,-0.45051,-0.222561,-0.226593,0.383059,0.879916,-0.720677


In [33]:
mask_features = counts1.columns.isin(bp_db["ENSEMBL"])
mask_features

array([ True,  True,  True, ..., False, False, False])

In [29]:
gnn_dataset = {}
gnn_dataset['x'] = counts1.to_numpy()
gnn_dataset['y'] = pheno1['condition'].to_numpy()

Unnamed: 0,ENSG00000000003,ENSG00000000005,ENSG00000000419,ENSG00000000457,ENSG00000000460,ENSG00000000938,ENSG00000000971,ENSG00000001036,ENSG00000001084,ENSG00000001167,...,ENSGR0000178605,ENSGR0000182378,ENSGR0000185291,ENSGR0000198223,ENSGR0000214717,ENSGR0000223511,ENSGR0000223773,age,sex,lithium
089357B,2.110648,4.704691,-0.899272,-1.436269,-1.13286,-1.111674,-0.8044,-1.458287,-0.62945,-1.0402,...,-0.152399,-0.453101,-0.212458,-0.502577,-0.45051,-0.222561,-0.226593,18,F,0
089366A,1.348549,1.129061,-0.091576,0.088378,-0.357019,-0.170067,-0.695273,-0.660205,-0.785903,-0.267949,...,-0.152399,-0.453101,-0.212458,-0.146685,-0.45051,-0.222561,2.492525,19,F,0
089412B,0.58645,-0.30119,0.955766,-0.256079,-0.438687,0.146277,-0.531583,0.10658,0.249663,0.314793,...,-0.152399,-0.453101,-0.212458,-0.146685,-0.45051,-0.222561,-0.226593,23,F,0
089425B,0.840483,-0.30119,-0.615247,0.003676,0.112569,-0.591117,-0.877151,-0.941881,-0.815703,0.158447,...,-0.152399,-0.453101,-0.212458,-0.502577,-0.45051,-0.222561,-0.226593,47,F,0
089687A,-0.429682,-0.30119,-1.023533,-1.289451,-1.602448,-1.019593,-0.913527,-1.223557,-1.26271,-1.532924,...,-0.152399,-0.453101,-0.212458,-0.146685,-0.45051,-0.222561,-0.226593,52,F,0


In [8]:
bp_db_genes = set(bp_db["ENSEMBL"])
list_genes = [x for x in counts1.columns if x not in bp_db_genes]
gp_go = bp_db.groupby("GO")
list_go = list(gp_go.groups.keys())

index_genes = {gene: idx for idx, gene in enumerate(list_genes)}
index_go = {go: idx for idx, go in enumerate(list_go)}

matrix_connection = torch.tensor(np.zeros((len(list_genes), len(list_go)), dtype=np.float32))
for idx, row in bp_db.iterrows():
    gene = row["ENSEMBL"]
    go = row["GO"]
    if gene in index_genes and go in index_go:
        matrix_connection[index_genes[gene], index_go[go]] = 1

# Load intermediate data from pre-processing

In [18]:
graph = nx.read_gml("../data/bp_graph.gml")
print(graph)

Graph with 19790 nodes and 44166 edges


In [20]:
df_go_level = pd.read_csv("../data/go_to_level.csv", index_col=0)
df_go_level.head()

Unnamed: 0,root,d+,d-
GO:1904355,6,3,0
GO:0071243,4,2,3
GO:0075201,4,1,1
GO:0007026,5,5,0
GO:0150070,5,1,1


In [30]:
with open("../data/map_int_go.txt", 'r') as fp:
    map_int_go = json.load(fp)
map_int_go = {int(idx): go for idx, go in map_int_go.items()}
[x for x in map_int_go.items()][:5]

[(0, 'GO:0000002'),
 (1, 'GO:0000012'),
 (2, 'GO:0000017'),
 (3, 'GO:0000018'),
 (4, 'GO:0000019')]

In [27]:
map_go_int = {go: idx for idx, go in map_int_go.items()}
[x for x in map_go_int.items()][:5]

[('GO:0000002', 0),
 ('GO:0000012', 1),
 ('GO:0000017', 2),
 ('GO:0000018', 3),
 ('GO:0000019', 4)]