In [1]:
import networkx as nx
import scanpy
import json
import re

In [2]:
import json
from collections import defaultdict

In [3]:
import numpy as np

In [4]:
from pprint import pprint

In [5]:
import pickle
from collections import Counter, defaultdict

In [6]:
chromosomes = {'chr1', 
               'chr2', 
               'chr3', 
               'chr4', 
               'chr5', 
               'chr6', 
               'chr7', 
               'chr8', 
               'chr9', 
               'chr10', 
               'chr11', 
               'chr12', 
               'chr13', 
               'chr14', 
               'chr15', 
               'chr16', 
               'chr17', 
               'chr18', 
               'chr19',
               'chr20',
               'chr21',
               'chr22',
               'chrX', 
              }

## Creating graph from STRING protein-protein interactions

In [7]:
string_filename = 'input/9606.protein.links.detailed.v11.5.txt'

In [8]:
len_string_file= !wc -l {string_filename}

len_string_file = int(len_string_file[0].split(' ')[0])
len_string_file

11938499

In [9]:
protein_coocurrence = nx.DiGraph()
protein_database_association = nx.DiGraph()
protein_text_mining = nx.DiGraph()
protein_experimental_association = nx.DiGraph()

string_file = open(string_filename)

header = next(string_file).strip().split(' ')
# list(enumerate(header))

i,pct = 0,0
stops = 10
step = int(len_string_file/stops)
for line in string_file:
    if i==step:
        i=0
        pct+=int(1/stops*100)
        print(f'{pct}%',flush=True)
    line=line.strip().split(' ')
    # strip off the 9606 id that STRING adds (as an id for homo sapiens)
    # to the beginning of the protein ensembl id
    p1 = ('protein',line[0][5:])
    p2 = ('protein',line[1][5:])
    labeled = zip(header, line)
    interaction_attributes = {x:int(y) for x,y in list(labeled)[2:]}
    if interaction_attributes['coexpression'] > 0:
        protein_coocurrence.add_edge(p1, p2, 
                                      weight=interaction_attributes['coexpression']/1000)
#     if interaction_attributes['database'] > 0:
#         protein_database_association.add_edge(p1, p2, 
#                                       weight=interaction_attributes['database']/1000)
#     if interaction_attributes['textmining'] > 0:
#         protein_text_mining.add_edge(p1, p2, 
#                                       weight=interaction_attributes['textmining']/1000)
    if interaction_attributes['experimental'] > 0:
        protein_experimental_association.add_edge(p1, p2, 
                                      weight=interaction_attributes['experimental']/1000)
    i+=1

10%
20%
30%
40%
50%
60%
70%
80%
90%
100%


In [10]:
print(protein_coocurrence.number_of_nodes())
print(protein_coocurrence.number_of_edges())

19026
6525628


In [11]:
print(protein_experimental_association.number_of_nodes())
print(protein_experimental_association.number_of_edges())

0
0


In [12]:
print(protein_text_mining.number_of_nodes())
print(protein_text_mining.number_of_edges())

0
0


In [13]:
print(protein_database_association.number_of_nodes())
print(protein_database_association.number_of_edges())

0
0


In [14]:
len(set(protein_coocurrence.nodes) | set(protein_experimental_association.nodes) 
    | set(protein_text_mining.nodes) | set(protein_database_association.nodes))

19026

In [15]:
pickle.dump(protein_coocurrence, open('input/protein_interaction_graph.pickle','wb'))

In [16]:
# protein_coocurrence = pickle.load(open('input/protein_interaction_graph.pickle','rb'))

## Ensembl data

In [17]:
ensembl_data = json.load(open('input/homo_sapiens.json'))

In [18]:
len(ensembl_data['genes'])

67990

In [19]:
gene_protein = defaultdict(list)
ensembl_regions = {}
for item in ensembl_data['genes']:
#     pprint(item)
    if 'transcripts' in item:
        for transcript in item['transcripts']:
            if 'translations' in transcript:
                for translation in transcript['translations']:
                    gene = ('gene',item['id'])
                    protein = ('protein', translation['id'])
                    gene_protein[gene].append(protein)
    if 'start' in item and 'end' in item and 'seq_region_name' in item:
        gene = ('gene',item['id'])
        chromosome = 'chr'+item['seq_region_name']
        ensembl_regions[gene] = (chromosome, item['start'], item['end'])
#     print('-'*80)
gene_protein = dict(gene_protein)

In [20]:
len(gene_protein)

24337

In [21]:
len(ensembl_regions)

67990

In [22]:
all_hgnc_proteins = set([item for sublist in gene_protein.values() for item in sublist])

In [23]:
len(all_hgnc_proteins)

117680

In [24]:
len(all_hgnc_proteins - set(protein_coocurrence.nodes()))

99263

In [25]:
len(set(protein_coocurrence.nodes()) - set(all_hgnc_proteins))

609

## Gene names in data

In [26]:
gene_data = scanpy.read_h5ad('output/datasets/predict_modality/openproblems_bmmc_cite_phase1_mod2/openproblems_bmmc_cite_phase1_mod2.censor_dataset.output_train_mod2.h5ad')

In [27]:
genes_in_data = gene_data.var['gene_ids'].to_list()

In [28]:
len(genes_in_data)

13953

In [29]:
not_in_ensembl = []
for gene in genes_in_data:
    gene = ('gene',gene)
    if gene not in ensembl_regions:
        not_in_ensembl.append(gene)
print(len(not_in_ensembl))

23


In [30]:
print(len(not_in_ensembl)/len(genes_in_data))

0.001648391027019279


## HGNC data

In [31]:
hgnc = json.load(open('input/hgnc_complete_set.json'))['response']['docs']

In [32]:
gene_ensembl_id = {}

no_ensembl = list()
for item in hgnc:
    symbols = []
    if 'symbol' in item:
        symbol = item['symbol']
        symbols.append(symbol.upper())
    if 'alias_symbol' in item:
        alias_symbols = item['alias_symbol']
        for symbol in alias_symbols:
            symbols.append(symbol.upper())
    
    if len(symbols)>0 and 'ensembl_gene_id' in item:
        ensembl_id = item['ensembl_gene_id']
        gene = ('gene', ensembl_id)
        for symbol in symbols:
            if symbol in gene_ensembl_id:
                gene_ensembl_id[symbol].append(gene)
            else:
                gene_ensembl_id[symbol] = [gene]
    else:
        no_ensembl.append(item)
print(len(no_ensembl))

no_symbols = []    

2913


In [33]:
ids_per_symbol = Counter()
for symbol, ids in gene_ensembl_id.items():
    ids_per_symbol[len(ids)]+=1
ids_per_symbol.items()

dict_items([(1, 77472), (7, 3), (3, 179), (2, 1370), (4, 41), (6, 10), (5, 15), (8, 1), (9, 1)])

In [34]:
# not_in_hgnc = list()
# all_genes_in_hgnc = set(gene_ensembl_id.values())
# for gene in genes_in_data:
#     if gene not in all_genes_in_hgnc:
#         not_in_hgnc.append(gene)
# print(len(not_in_hgnc), len(genes_in_data))

In [35]:
protein_data = scanpy.read_h5ad('output/datasets/predict_modality/openproblems_bmmc_cite_phase1_mod2/openproblems_bmmc_cite_phase1_mod2.censor_dataset.output_train_mod1.h5ad')

protein_names = protein_data.var.index.to_list()

In [36]:
protein_not_in_hgnc = []
for protein in protein_names:
    if protein.upper() not in gene_ensembl_id:
        protein_not_in_hgnc.append(protein.upper())
print(len(protein_not_in_hgnc))

25


In [37]:
manually_mapped = {
'CD3': ['ENSG00000198851', 'ENSG00000160654', 'ENSG00000167286', 'ENSG00000167286'],
'CD8': ['ENSG00000172116', 'ENSG00000172116', 'ENSG00000172116', 'ENSG00000172116', 'ENSG00000172116', 'ENSG00000153563', 'ENSG00000153563', 'ENSG00000153563', 'ENSG00000254126'],
'HLA-A-B-C': ['ENSG00000204525', 'ENSG00000234745', 'ENSG00000204525', ''],
'CD45RA': ['ENSG00000081237', 'ENSG00000081237', 'ENSG00000262418', 'ENSG00000262418'],
'CD45RO': ['ENSG00000081237', 'ENSG00000081237', 'ENSG00000262418', 'ENSG00000262418'],
'CD20': ['ENSG00000156738', 'ENSG00000156738', 'ENSG00000156738', 'ENSG00000156738', 'ENSG00000156738'],
'PODOPLANIN': ['ENSG00000162493', 'ENSG00000162493', 'ENSG00000162493', 'ENSG00000162493', 'ENSG00000162493', 'ENSG00000162493', 'ENSG00000162493'],
'IGM': ['ENSG00000211899', 'ENSG00000282657', 'ENSG00000211899'],
'HLA-DR': ['ENSG00000227993', 'ENSG00000206308', 'ENSG00000204287', 'ENSG00000234794', 'ENSG00000230726', 'ENSG00000228987', 'ENSG00000226260', 'ENSG00000196126', 'ENSG00000196101', 'ENSG00000231679'],
'CD57': ['ENSG00000109956', 'ENSG00000109956', 'ENSG00000109956'],
'CD39': ['ENSG00000138185', 'ENSG00000138185', 'ENSG00000138185', 'ENSG00000138185'],
'CD11A': ['ENSG00000005844', 'ENSG00000005844'],
'INTEGRINB7': ['ENSG00000139626'],
'TCR': ['ENSG00000146399'],
'FCERIA': ['ENSG00000179639'],
'CD13': ['ENSG00000166825', 'ENSG00000166825', 'ENSG00000166825', 'ENSG00000166825'],
'IGD': ['ENSG00000211898'],
'CD18': ['ENSG00000160255', 'ENSG00000160255', 'ENSG00000160255', 'ENSG00000160255', 'ENSG00000160255'],
'CD45': ['ENSG00000081237', 'ENSG00000081237', 'ENSG00000262418', 'ENSG00000262418'],
'CD26': ['ENSG00000197635'],
'CD158': ['ENSG00000189013', 'ENSG00000189013', 'ENSG00000189013', 'ENSG00000274232', 'ENSG00000275456', 'ENSG00000275317', 'ENSG00000277362', 'ENSG00000277750', 'ENSG00000273575', 'ENSG00000274232', 'ENSG00000274232', 'ENSG00000278271', 'ENSG00000278430', 'ENSG00000276044', 'ENSG00000278074', 'ENSG00000274609', 'ENSG00000276779', 'ENSG00000274955', 'ENSG00000277964', 'ENSG00000275317', 'ENSG00000278430', 'ENSG00000275699', 'ENSG00000284460', 'ENSG00000284365', 'ENSG00000284365', 'ENSG00000284013', 'ENSG00000283961', 'ENSG00000283869', 'ENSG00000283961', 'ENSG00000284013', 'ENSG00000284509', 'ENSG00000283961', 'ENSG00000284013', 'ENSG00000284365'],
'TCRVA7.2': ['ENSG00000256553'],
'TCRVD2': ['ENSG00000211821'],
'CD158B': ['ENSG00000243772', 'ENSG00000274830', 'ENSG00000276218', 'ENSG00000277484', 'ENSG00000276590', 'ENSG00000278327', 'ENSG00000277554', 'ENSG00000275658', 'ENSG00000273887', 'ENSG00000274952', 'ENSG00000274410', 'ENSG00000275008', 'ENSG00000274108', 'ENSG00000275623', 'ENSG00000274402', 'ENSG00000277317', 'ENSG00000277924', 'ENSG00000276459', 'ENSG00000284333', 'ENSG00000283708', 'ENSG00000283702', 'ENSG00000283790', 'ENSG00000284504', 'ENSG00000284510'],
'CD94': ['ENSG00000134539', 'ENSG00000134539', 'ENSG00000134539'],
}
for protein_name in manually_mapped:
    gene_ensembl_id[protein_name] = [('gene',ensembl_id) for ensembl_id in manually_mapped[protein_name]]

In [38]:
protein_not_in_hgnc = []
for protein in protein_names:
    if protein.upper() not in gene_ensembl_id:
        protein_not_in_hgnc.append(protein.upper())
print(len(protein_not_in_hgnc))

0


## Transcription Factor Data prep

In [39]:
tf_file = open('input/tf-target-infomation.txt')
header = next(tf_file)
tf_target = nx.DiGraph()
for line in tf_file:
    tf, target, tissue = line.strip().split('\t')
    
    if tf in gene_ensembl_id and target in gene_ensembl_id:
        tf_genes = gene_ensembl_id[tf]
        target_genes = gene_ensembl_id[target]
        for tf_gene in tf_genes:
            if tf_gene in gene_protein:
                tf_proteins = gene_protein[tf_gene]
                for tf_protein in tf_proteins:
                    for target_gene in target_genes:
                        tf_target.add_edge(tf_protein, target_gene, 
                                           interaction_type='tf_Unknown_protein_gene', weight=1.0)
#                         tf_target.add_edge(target_gene, tf_protein,
#                                            interaction_type='tf_Unknown_gene_protein', weight=1.0)

In [40]:
print(tf_target.number_of_nodes())
print(tf_target.number_of_edges())

27297
7826011


In [41]:
tf_target_proteins = []
for node_type,node_id in tf_target.nodes:    
    if node_type=='protein':
        tf_target_proteins.append(node_id)
print(len(set(tf_target_proteins)))

3817


In [42]:
tf_file = open('input/trrust_rawdata.human.tsv')

trrust = nx.DiGraph()
for line in tf_file:
    tf, target, interaction_type, _ = line.strip().split('\t')
    if tf in gene_ensembl_id and target in gene_ensembl_id:
        tf_genes = gene_ensembl_id[tf]
        target_genes = gene_ensembl_id[target]
        for tf_gene in tf_genes:
            if tf_gene in gene_protein:
                tf_proteins = gene_protein[tf_gene]
                for tf_protein in tf_proteins:
                    for target_gene in target_genes:
                        trrust.add_edge(tf_protein, target_gene, 
                                           interaction_type=f'trrust_{interaction_type}_protein_gene', weight=1.0)
#                         trrust.add_edge(target_gene, tf_protein,
#                                            interaction_type=f'trrust_{interaction_type}_gene_protein', weight=1.0)


In [43]:
print(trrust.number_of_nodes())
print(trrust.number_of_edges())

8366
84470


In [44]:
len(set(trrust.nodes) - set(tf_target.nodes))

3227

In [45]:
len(set(tf_target.nodes) - set(trrust.nodes))

22158

In [46]:
len(set(tf_target.nodes) & set(trrust.nodes))

5139

In [47]:
tf_graph = nx.compose(tf_target, trrust)

## TAD data preparation

In [48]:
tad_files = !find input/TAD_annotations/TADs/ -type f -name '*10kb*'

In [49]:
tad_regions = {ch:[] for ch in chromosomes}
for filename in tad_files:
    file = open(filename)
    for line in file:
        chromosome, start, end = line.strip().split(' ')
        tad_regions[chromosome].append((int(start),int(end)))

# sort tad starts along with the index to the end
for chromosome in tad_regions:
    tad_regions[chromosome] = sorted(tad_regions[chromosome])

In [50]:
tad_regions['chr1'][:5]

[(50000, 1240000),
 (535000, 1845000),
 (565000, 1605000),
 (565000, 1645000),
 (565000, 1675000)]

In [51]:
chromosome_gene_region = {chrm:[] for chrm in chromosomes}
for gene, region in ensembl_regions.items():
    chromosome, start, end = region
    chromosome = chromosome
    if chromosome in chromosomes:
        chromosome_gene_region[chromosome].append((int(start), int(end), gene))

In [52]:
for chromosome in chromosome_gene_region:
    chromosome_gene_region[chromosome] = sorted(chromosome_gene_region[chromosome])

In [53]:
tad_gene_linkages = defaultdict(list)

for chromosome in chromosomes:
    tad_idx = 0
    gene_idx = 0
    tads = tad_regions[chromosome]
    genes = chromosome_gene_region[chromosome]
        
    while True:
        tad_start, tad_end = tads[tad_idx]
        gene_start, gene_end, gene = genes[gene_idx]
        tad_term = ('tad', f'{chromosome}-{tad_start}-{tad_end}')
        if (gene_start <= tad_end and gene_end >= tad_start):
            tad_gene_linkages[tad_term].append(gene)
            gene_idx = gene_idx+1
            if gene_idx==len(genes):
                break
        elif gene_start <= tad_start:
            gene_idx = gene_idx+1
            if gene_idx==len(genes):
                break
        elif gene_start >= tad_end:
            tad_idx = tad_idx+1
            if tad_idx == len(tads):
                break

In [54]:
tad_gene_counts = Counter()
for tad_id in tad_gene_linkages:
    tad_gene_counts[len(tad_gene_linkages[tad_id])]+=1
    

In [55]:
sum([gene_count*tad_count for gene_count, tad_count in tad_gene_counts.items()])

59732

## Get all ATAC regions

In [56]:
atac_data = scanpy.read_h5ad('output/datasets/predict_modality/openproblems_bmmc_multiome_phase1_mod2/openproblems_bmmc_multiome_phase1_mod2.censor_dataset.output_train_mod1.h5ad')

In [57]:
len(atac_data.var)

116490

In [58]:
atac_parsed = [tuple(x.split('-')) for x in atac_data.var.index.to_list()]
atac_regions = {chrm:[] for chrm in chromosomes}
for chromosome, start, end in atac_parsed:
    if chromosome in atac_regions:
        atac_regions[chromosome].append((int(start),int(end)))

for chromosome in atac_regions:
    atac_regions[chromosome] = sorted(atac_regions[chromosome])

In [59]:
len(atac_parsed)

116490

## ATAC to TAD linkage

In [60]:
tad_atac_linkages = defaultdict(list)

for chromosome in chromosomes:
    atac_idx = 0
    tad_idx = 0
    atacs = atac_regions[chromosome]
    tads = tad_regions[chromosome]
        
    while True:
        atac_start, atac_end = atacs[atac_idx]
        tad_start, tad_end = tads[tad_idx]
        tad_term = ('tad', f'{chromosome}-{tad_start}-{tad_end}')
        atac_term = ('atac_region', f'{chromosome}-{atac_start}-{atac_end}')

        if (atac_start <= tad_end and atac_end >= tad_start):
            tad_atac_linkages[tad_term].append(atac_term)
            atac_idx = atac_idx+1
            if atac_idx==len(atacs):
                break
        elif atac_start <= tad_start:
            atac_idx = atac_idx+1
            if atac_idx==len(atacs):
                break
        elif atac_start >= tad_end:
            tad_idx = tad_idx+1
            if tad_idx == len(tads):
                break


In [61]:
tad_atac_counts = Counter()
for tad_id in tad_atac_linkages:
    tad_atac_counts[len(tad_atac_linkages[tad_id])]+=1
    

In [62]:
sum([number*count for number,count in tad_atac_counts.items()])

116232

## Gene to ATAC linkage

In [63]:
atac_gene_linkages = defaultdict(list)

for chromosome in chromosomes:
    atac_idx = 0
    gene_idx = 0
    atacs = atac_regions[chromosome]
    genes = chromosome_gene_region[chromosome]
        
    while True:
        atac_start, atac_end = atacs[atac_idx]
        gene_start, gene_end, gene = genes[gene_idx]
        atac_term = ('atac_region', f'{chromosome}-{atac_start}-{atac_end}')
        if (gene_start <= atac_end and gene_end >= atac_start):
            atac_gene_linkages[atac_term].append(gene)
            gene_idx = gene_idx+1
            if gene_idx==len(genes):
                break
        elif gene_start <= atac_start:
            gene_idx = gene_idx+1
            if gene_idx==len(genes):
                break
        elif gene_start >= atac_end:
            atac_idx = atac_idx+1
            if atac_idx == len(atacs):
                break
atac_gene_linkages = dict(atac_gene_linkages)

In [64]:
atac_gene_counts = Counter()
for atac_term in atac_gene_linkages:
    atac_gene_counts[len(atac_gene_linkages[atac_term])]+=1
    

In [65]:
print(sum([number*count for number,count in atac_gene_counts.items()]))
print(len(set([item for linked in atac_gene_linkages for item in linked])))

27233
23885


## Enhancer ATAC overlap 

In [66]:
enhancer_file = open('input/hacer_enhancers.txt')

enhancer_regions = {chromosome:[] 
                    for chromosome in chromosomes}
enhancer_genes = defaultdict(list)

for line in enhancer_file:
    fields = line.strip().split('\t')
    enhancer_id = fields[0]
    chromosome = fields[1]
    start = int(fields[2])
    end = int(fields[3])
    fantom5_gene = fields[6]
    fiftykb_gene = fields[7]
    fourd_gene = fields[8]
    closest_gene = fields[12]

    genes = ','.join([fantom5_gene, fiftykb_gene, 
                      fourd_gene, closest_gene])
    
    if chromosome in enhancer_regions:
        enhancer_regions[chromosome].append((start,end))
    
    enhancer_term = ('enhancer', f'{chromosome}-{start}-{end}')

    for gene in re.split('[,;]', genes):
        if gene != 'NA':
            enhancer_genes[enhancer_term].append(gene)

In [67]:
enhancer_regions = {chromosome: sorted(regions) 
                    for chromosome, regions in enhancer_regions.items()}

In [68]:
enhancer_atac_graph = nx.DiGraph()
atac_enhancer_linkages = defaultdict(list)

for chromosome in chromosomes:
    atac_idx = 0
    enhancer_idx = 0
    atacs = atac_regions[chromosome]
    enhancers = enhancer_regions[chromosome]
        
    while True:
        atac_start, atac_end = atacs[atac_idx]
        enhancer_start, enhancer_end = enhancers[enhancer_idx]
        atac_term = ('atac_region', f'{chromosome}-{atac_start}-{atac_end}')
        enhancer_term = ('enhancer', f'{chromosome}-{enhancer_start}-{enhancer_end}')
        if (enhancer_start <= atac_end and enhancer_end >= atac_start):
            atac_enhancer_linkages[atac_term].append(enhancer_term)
            enhancer_idx = enhancer_idx+1
            if enhancer_idx==len(enhancers):
                break
        elif enhancer_start <= atac_start:
            enhancer_idx = enhancer_idx+1
            if enhancer_idx==len(enhancers):
                break
        elif enhancer_start >= atac_end:
            atac_idx = atac_idx+1
            if atac_idx == len(atacs):
                break
                
atac_enhancer_linkages = dict(atac_enhancer_linkages)    

In [69]:
atac_enhancer_counts = Counter()
for atac, enhancers in atac_enhancer_linkages.items():
    atac_enhancer_counts[len(enhancers)]+=1
atac_enhancer_counts.most_common(10)

# Number of unique atac regions
print(len(set([item for linked in atac_enhancer_linkages for item in linked])))

15414


## ATAC - k nearest genes

In [70]:
from sklearn.neighbors import KDTree

In [71]:
k=6

atac_neighbor_genes = nx.DiGraph()

for chromosome, atacs in atac_regions.items():
    starts, ends, genes = zip(*chromosome_gene_region[chromosome])
    start_tree = KDTree(np.array(starts).reshape(-1,1))
    end_tree = KDTree(np.array(ends).reshape(-1,1))
    atac_starts, atac_ends = zip(*atacs)
    atac_starts = np.array(atac_starts).reshape(-1,1)
    atac_ends = np.array(atac_ends).reshape(-1,1)
    end_dists, end_idxs = end_tree.query(atac_starts,k=k)
    start_dists, start_idxs = start_tree.query(atac_ends,k=k)
    dists = np.concatenate((start_dists, end_dists),axis=1)
    idxs = np.concatenate((start_idxs, end_idxs), axis=1)
    sorted_dists = np.argsort(dists, axis=1)
    closest_idxs = np.take_along_axis(idxs, sorted_dists, axis=1)
    closest_dists = np.take_along_axis(dists, sorted_dists, axis=1)
    for i,row in enumerate(closest_idxs):
        row = np.unique(row)[:k]
        for j,idx in enumerate(row):
            atac = atacs[i]
            atac_term = ('atac_region', f'{chromosome}-{atac[0]}-{atac[1]}')
            gene = genes[idx]
            dist = closest_dists[i,j]
            atac_neighbor_genes.add_edge(atac_term, gene, distance=dist)


### Check that each node has the right number (k) neighbors

In [72]:
c=Counter()
for chromosome, atacs in atac_regions.items():
    for atac in atacs:
        atac_term = ('atac_region', f'{chromosome}-{atac[0]}-{atac[1]}')
        c[len(atac_neighbor_genes[atac_term])] +=1
c.most_common()

[(6, 116465)]

## Link gene products with structure

In [73]:
gene_protein_graph = nx.DiGraph()

for gene, proteins in gene_protein.items():
    for protein in proteins:
        gene_protein_graph.add_edge(gene, protein, interaction_type='gene_protein', weight=1.0)
#         gene_protein_graph.add_edge(protein, gene, interaction_type='protein_gene', weight=1.0)

In [74]:
pickle.dump(gene_protein, open('input/gene_protein_association.pickle','wb'))

In [75]:
print(gene_protein_graph.number_of_nodes())
print(gene_protein_graph.number_of_edges())

142017
117680


## Link Protein Names with ID's

In [76]:
protein_name_proteins = dict()
for protein_name, genes in gene_ensembl_id.items():
    protein_name_proteins[protein_name] = []
    for gene in genes:
        if gene in gene_protein:
            proteins = gene_protein[gene]
            protein_name_proteins[protein_name] += proteins

In [77]:
protein_name_protein_graph = nx.DiGraph()
for protein_name, protein_ids in protein_name_proteins.items():
    for protein_id in protein_ids:
        protein_name_protein_graph.add_edge(protein_id, ('protein_name',protein_name))

In [78]:
print(protein_name_protein_graph.number_of_nodes())
print(protein_name_protein_graph.number_of_edges())

158258
317549


## Count Intersections for quality control

Make sure that species names overlap between graphs

In [79]:
print(protein_coocurrence.number_of_nodes())
print(protein_coocurrence.number_of_edges())

19026
6525628


In [80]:
len(set(protein_coocurrence.nodes) & set(gene_protein_graph.nodes))

18417

In [81]:
# len(set(tad_genes_graph.nodes) & set(tad_atac_graph.nodes))

In [82]:
# len(set(atac_gene_graph.nodes) & set(tad_atac_graph.nodes))

In [83]:
len(set(tf_target.nodes) & set(protein_coocurrence.nodes))

487

In [84]:
len(set([node for node_type,node in gene_protein_graph.nodes if node_type=='protein']) & 
    set([node for node_type,node in tf_graph.nodes if node_type=='protein']))

6768

In [85]:
len(set([node for node_type,node in tf_graph.nodes if node_type=='protein']))

6768

## Relabel proteins in graphs to unified synonyms 

proteins associated with genes - from HGNC most complete

proteins in data - gene association

proteins in co-occurrence - some not in data

proteins in transcription factors - some not in data

In [86]:
protein_gene = dict()
for gene,proteins in gene_protein.items():
    for protein in proteins:
        protein_gene[protein] = gene

In [87]:
synonym_mapping = dict()
# not_mapped
for protein in protein_gene:
    gene = protein_gene[protein]
    synonym_mapping[protein] = gene_protein[gene][0]
print(len(synonym_mapping))

117680


In [88]:
protein_coocurrence = nx.relabel.relabel_nodes(protein_coocurrence, synonym_mapping)

In [89]:
tf_graph = nx.relabel.relabel_nodes(tf_graph, synonym_mapping)

In [90]:
protein_name_protein_graph = nx.relabel.relabel_nodes(protein_name_protein_graph, synonym_mapping)

In [91]:
gene_protein_graph = nx.relabel.relabel_nodes(gene_protein_graph, synonym_mapping)

In [92]:
print(len(gene_protein_graph.nodes))
print(len(gene_protein_graph.edges))

48674
24337


In [93]:
print(len(protein_name_protein_graph.nodes))
print(len(protein_name_protein_graph.edges))

72821
54792


In [94]:
print(len(tf_graph.nodes))
print(len(tf_graph.edges))

24729
983464


In [95]:
len(set(tf_graph.nodes) & set(synonym_mapping.keys()))

973

In [96]:
len(set([node for ntype, node in tf_graph.nodes if ntype=='protein']))

973

In [97]:
len(set(tf_graph.nodes) & set(protein_coocurrence.nodes))

946

In [98]:
len(set(gene_protein_graph.nodes) & set(protein_coocurrence.nodes))

18380

In [99]:
len(set([node for node_type,node in gene_protein_graph.nodes if node_type=='protein']) & 
    set([node for node_type,node in tf_graph.nodes if node_type=='protein']))

973

In [100]:
len(set([node for node_type,node in tf_graph.nodes if node_type=='protein']))

973

## Match genes to proteins

Genes in data

Genes corresponding to proteins in graph

In [101]:
gene_data1 = scanpy.read_h5ad('output/datasets/predict_modality/' + 
                              'openproblems_bmmc_cite_phase1_mod2/'+
                              'openproblems_bmmc_cite_phase1_mod2.censor_dataset.output_train_mod2.h5ad')

In [102]:
gene_data2 = scanpy.read_h5ad('output/datasets/predict_modality/' + 
                              'openproblems_bmmc_multiome_phase1_rna' + 
                              '/openproblems_bmmc_multiome_phase1_rna.censor_dataset.output_train_mod1.h5ad')

In [103]:
genes1 = set(gene_data1.var['gene_ids'])

In [104]:
genes2 = set(gene_data2.var['gene_ids'])

In [105]:
len(genes1)

13953

In [106]:
len(genes2)

13431

In [107]:
all_genes_in_data = set([('gene', gene) for gene in (genes1 | genes2)])

In [108]:
len(all_genes_in_data)

15325

In [109]:
proteins_in_data = protein_data.var.index.to_list()

In [110]:
proteins_in_data_to_genes = set()
for protein_name in proteins_in_data:
    proteins = protein_name_proteins[protein_name.upper()]
    for protein in proteins:
        proteins_in_data_to_genes.add(protein_gene[protein])
#     synonym_mapping[] 

In [111]:
len(proteins_in_data_to_genes)

195

In [112]:
all_proteins_from_external = set(tf_graph.nodes) |  set(protein_coocurrence.nodes) | set(proteins_in_data_to_genes)

In [113]:
all_proteins_from_external = [(ntype, node) for ntype,node in all_proteins_from_external if ntype=='protein']

In [114]:
genes_from_external_proteins = set()
unmapped_external_proteins = list()
for protein in all_proteins_from_external:
    if protein in synonym_mapping:
        protein = synonym_mapping[protein]
        genes_from_external_proteins.add(protein_gene[protein])
    else:
        unmapped_external_proteins.append(protein)

In [115]:
len(genes_from_external_proteins)

18407

In [116]:
len(set(unmapped_external_proteins) & set(protein_coocurrence.nodes))

609

In [117]:
all_genes = genes_from_external_proteins | proteins_in_data_to_genes 

In [118]:
len(all_genes)

18470

## Enhancer - genes Graph

In [119]:
enhancer_gene_graph = nx.DiGraph()
for enhancer, genes in enhancer_genes.items():
    for gene in genes:
        if gene in gene_ensembl_id:
            ensembl_ids = gene_ensembl_id[gene]
            for ensembl_id in ensembl_ids:
                enhancer_gene_graph.add_edge(enhancer, ensembl_id)


In [120]:
print(len(enhancer_gene_graph.nodes))
print(len(enhancer_gene_graph.edges))

209170
494411


## Enhancer - ATAC graph

In [121]:
enhancer_atac_graph = nx.DiGraph()
for atac,enhancers in atac_enhancer_linkages.items():
    for enhancer in enhancers:
        enhancer_atac_graph.add_edge(enhancer, atac)

In [122]:
print(len(enhancer_atac_graph.nodes))
print(len(enhancer_atac_graph.edges))

47363
31950


## Graph with Transcripts/TADs/ATAC

In [123]:
from itertools import combinations

In [124]:
atac_gene_graph = nx.DiGraph()
# TODO should we add links from all the chromosomes
# to the genes? 
for atac, genes in atac_gene_linkages.items():
    for gene in genes:
        atac_gene_graph.add_edge(atac, gene, interaction_type='gene_atac', weight=1.0)
#             gene_structure_graph.add_edge(chromosome, gene)


In [125]:
tad_genes_graph = nx.DiGraph()
for tad, genes in tad_gene_linkages.items():
    for gene in genes:
        tad_genes_graph.add_edge(tad, gene, interaction_type='tad_gene', weight=1.0)
#             tad_genes_graph.add_edge(gene, tad, interaction_type='gene_tad', weight=1.0)


In [126]:
tad_atac_graph = nx.DiGraph()
for tad, atacs in tad_atac_linkages.items():
    for atac in atacs:
        tad_atac_graph.add_edge(tad, atac, interaction_type='tad_atac', weight=1.0)
#             tad_atac_graph.add_edge(atac_term, tad, interaction_type='atac_tad', weight=1.0)


In [127]:
print(Counter([ntype for ntype,node in tad_genes_graph.nodes]).most_common())
print(tad_genes_graph.number_of_nodes())
print(tad_genes_graph.number_of_edges())

[('gene', 59732), ('tad', 2611)]
62343
59732


In [128]:
print(Counter([ntype for ntype,node in tad_atac_graph.nodes]).most_common())
print(tad_atac_graph.number_of_nodes())
print(tad_atac_graph.number_of_edges())

[('atac_region', 116232), ('tad', 2497)]
118729
116232


In [129]:
print(Counter([ntype for ntype,node in atac_gene_graph.nodes]).most_common())
print(atac_gene_graph.number_of_nodes())
print(atac_gene_graph.number_of_edges())

[('gene', 27233), ('atac_region', 23884)]
51117
27233


## Convert to PyTorch Geometric (pyg) graph representation

In [130]:
# import scipy
import torch_geometric as pyg
import torch

In [131]:
all_graphs = [tad_atac_graph,
              gene_protein_graph, 
              atac_gene_graph, 
              protein_coocurrence, 
              tf_graph,
              tad_genes_graph,
              protein_name_protein_graph,
              enhancer_atac_graph,
              enhancer_gene_graph,
              atac_neighbor_genes
             ]

In [132]:
node_lens = Counter()
g = set()
for graph in [tad_atac_graph, gene_protein_graph, atac_gene_graph, tf_graph,tad_genes_graph,enhancer_gene_graph,atac_neighbor_genes]:
    for node in graph.nodes:
        if node[0]=='gene':
            g.add(node)
print(len(g))

63995


In [133]:
len(all_genes)

18470

In [134]:
for graph in all_graphs:
    for node in list(graph.nodes):
        node_type, name = node 
        if node not in all_genes and node_type == 'gene':
            graph.remove_node(node)

In [135]:
combined_graph = nx.Graph()
for graph in all_graphs:
    combined_graph = nx.compose(combined_graph, graph.to_undirected())

In [136]:
ccs = list(nx.components.connected_components(combined_graph))

In [137]:
len_ccs = Counter()
for cc in ccs:
    len_ccs[len(cc)]+=1
print('cc size, count')
for sz, cnt in len_ccs.most_common():
    print(f'{sz:8d} {cnt:<7d} {sz*cnt}')

cc size, count
       1 10691   10691
       2 840     1680
       3 367     1101
       4 183     732
       5 91      455
       6 33      198
       7 13      91
       8 11      88
       9 4       36
  392443 1       392443
      13 1       13


In [138]:
gene_degree = Counter()
nodes=set()
for node in list(combined_graph.nodes):
    node_type, name = node
    nodes.add(node)
    if node_type == 'gene':
        gene_degree[combined_graph.degree[node]]+=1

In [139]:
sorted(gene_degree.most_common())[:10]

[(1, 73),
 (2, 295),
 (3, 177),
 (4, 167),
 (5, 157),
 (6, 151),
 (7, 176),
 (8, 147),
 (9, 166),
 (10, 177)]

In [140]:
nodes = list(nodes)

In [141]:
nodes_by_type = defaultdict(set)
for node_type, node_id in nodes:
    nodes_by_type[node_type].add(node_id)
nodes_by_type = {node_type: {node_id:i for i, node_id in enumerate(node_ids)} for node_type, node_ids in nodes_by_type.items()}

In [142]:
def to_pyg_from_networkx(graph, use_weight=False, use_interaction_type=False):
    N = graph.number_of_nodes()
    
    edge_attr = None
    if use_weight:
        weight_list = []
        for src, dst, edge_data in iter(graph.edges(data=True)):
            weight_list.append(edge_data['weight'])
    if use_interaction_type:
        interaction_list = []
        for src, dst, edge_data in iter(graph.edges(data=True)):
            interaction_list.append(edge_data['interaction_type'])
        
        interaction_idx = {x:i for i,x in enumerate(set(interaction_list))}
        interaction_idxs = [interaction_idx[feature] for feature in interaction_list]
        # TODO make the edge feature matrix sparse
        rows = np.arange(len(interaction_idxs))
        interaction_vector = np.zeros((len(rows),len(interaction_idx)), dtype=float)
        interaction_vector[rows, interaction_idxs] = 1
    
    if use_weight and use_interaction_type:
        edge_attr = torch.tensor(np.hstack((interaction_vector, np.array(weight_list, ndmin=2).T)))
    elif use_weight:
        edge_attr = torch.tensor(np.array(weight_list, ndmin=2).T)
    elif use_interaction_type:
        edge_attr = torch.tensor(interaction_vector)
    
    node_idx_map = dict()
    for node in graph.nodes:
        node_type, node_id = node
        node_idx_map[node] = nodes_by_type[node_type][node_id]
    graph = nx.relabel.relabel_nodes(graph, node_idx_map)

    edge_index = torch.LongTensor(list(graph.edges)).t().contiguous().view(2,-1)
    
    return edge_index, edge_attr

In [143]:
from torch_geometric.data import HeteroData

In [144]:
# Graphs that don't have edge or 
bare_graphs = [(tad_atac_graph, ('tad','overlaps','atac_region')),
               (tad_genes_graph, ('tad','overlaps','gene')),
               (gene_protein_graph, ('gene', 'associated', 'protein')), 
               (atac_gene_graph, ('atac_region', 'overlaps', 'gene')), 
               (protein_coocurrence, ('protein', 'coexpressed', 'protein')), 
               (tf_graph, ('protein','tf_interacts','gene')),
               (protein_name_protein_graph, ('protein','is_named','protein_name')),
               (enhancer_atac_graph, ('enhancer','overlaps','atac_region')),
               (enhancer_gene_graph, ('enhancer','associated','gene')),
               (atac_neighbor_genes, ('atac_region','neighbors','gene')),
              ]

data = HeteroData()
for graph, relation in bare_graphs:
    print(relation)
    edge_index, edge_attr = to_pyg_from_networkx(graph)

    data[relation]['edge_index'] = edge_index
    data[relation].num_nodes = graph.number_of_nodes()
    

('tad', 'overlaps', 'atac_region')
('tad', 'overlaps', 'gene')
('gene', 'associated', 'protein')
('atac_region', 'overlaps', 'gene')
('protein', 'coexpressed', 'protein')
('protein', 'tf_interacts', 'gene')
('protein', 'is_named', 'protein_name')
('enhancer', 'overlaps', 'atac_region')
('enhancer', 'associated', 'gene')
('atac_region', 'neighbors', 'gene')


In [145]:
data

HeteroData(
  [1m(tad, overlaps, atac_region)[0m={
    edge_index=[2, 116232],
    num_nodes=118729
  },
  [1m(tad, overlaps, gene)[0m={
    edge_index=[2, 18328],
    num_nodes=20939
  },
  [1m(gene, associated, protein)[0m={
    edge_index=[2, 18470],
    num_nodes=42807
  },
  [1m(atac_region, overlaps, gene)[0m={
    edge_index=[2, 15951],
    num_nodes=39835
  },
  [1m(protein, coexpressed, protein)[0m={
    edge_index=[2, 6509230],
    num_nodes=18989
  },
  [1m(protein, tf_interacts, gene)[0m={
    edge_index=[2, 777242],
    num_nodes=16266
  },
  [1m(protein, is_named, protein_name)[0m={
    edge_index=[2, 54792],
    num_nodes=72821
  },
  [1m(enhancer, overlaps, atac_region)[0m={
    edge_index=[2, 31950],
    num_nodes=47363
  },
  [1m(enhancer, associated, gene)[0m={
    edge_index=[2, 448589],
    num_nodes=206282
  },
  [1m(atac_region, neighbors, gene)[0m={
    edge_index=[2, 276738],
    num_nodes=134164
  }
)

#### Attributed graphs, not using edge attributes at the moment

In [146]:
# for graph, relation in attributed_graphs:
#     edge_index, edge_attr = to_pyg_from_networkx(graph, use_weight=True, use_interaction_type=True)
    
#     data[relation]['edge_index'] = edge_index
#     data[relation]['edge_attr'] = edge_attr
#     data[relation].num_nodes = graph.number_of_nodes()

# data['gene_ensembl_id'] = gene_ensembl_id

In [147]:
data

HeteroData(
  [1m(tad, overlaps, atac_region)[0m={
    edge_index=[2, 116232],
    num_nodes=118729
  },
  [1m(tad, overlaps, gene)[0m={
    edge_index=[2, 18328],
    num_nodes=20939
  },
  [1m(gene, associated, protein)[0m={
    edge_index=[2, 18470],
    num_nodes=42807
  },
  [1m(atac_region, overlaps, gene)[0m={
    edge_index=[2, 15951],
    num_nodes=39835
  },
  [1m(protein, coexpressed, protein)[0m={
    edge_index=[2, 6509230],
    num_nodes=18989
  },
  [1m(protein, tf_interacts, gene)[0m={
    edge_index=[2, 777242],
    num_nodes=16266
  },
  [1m(protein, is_named, protein_name)[0m={
    edge_index=[2, 54792],
    num_nodes=72821
  },
  [1m(enhancer, overlaps, atac_region)[0m={
    edge_index=[2, 31950],
    num_nodes=47363
  },
  [1m(enhancer, associated, gene)[0m={
    edge_index=[2, 448589],
    num_nodes=206282
  },
  [1m(atac_region, neighbors, gene)[0m={
    edge_index=[2, 276738],
    num_nodes=134164
  }
)

## Save the graphs and metadata

In [148]:
torch.save(data, 'input/pyg_graph.torch')

In [149]:
import pickle

In [150]:
pickle.dump(nodes_by_type, open('input/nodes_by_type.pickle','wb'))

In [151]:
pickle.dump(gene_ensembl_id, open('input/gene_ensembl_id.pickle','wb'))

In [152]:
pickle.dump(protein_name_proteins, open('input/protein_name_proteins.pickle','wb'))

In [153]:
pickle.dump(bare_graphs, open('input/graphs.pickle','wb'))

## Count node types

In [154]:
number_of_nodes_by_type = {node_type:len(node_ids) for node_type,node_ids in nodes_by_type.items()}

In [155]:
number_of_nodes_by_type

{'atac_region': 116465,
 'enhancer': 191827,
 'protein': 24946,
 'protein_name': 52958,
 'tad': 2862,
 'gene': 18470}

In [156]:
sum(number_of_nodes_by_type.values())-number_of_nodes_by_type['tad']-number_of_nodes_by_type['protein_name']

351708

## Add the node2vec embeddings

In [157]:
device='cpu'

In [158]:
data = torch.load('input/pyg_graph.torch')

In [159]:
data = data.to(device)

In [160]:
embedding = torch.load('input/node2vec_embeddings.torch').to(device)

In [161]:
subgraph = torch.load('input/node2vec_subgraph.torch').to(device)

In [162]:
subgraph._node_type_names

['protein', 'gene', 'enhancer', 'atac_region']

In [163]:
node_types = subgraph.node_type.unique()

In [164]:
edge_types = subgraph.edge_type.unique()

In [165]:
subgraph.node_type.shape

torch.Size([351708])

In [166]:
embedding.shape

torch.Size([351708, 128])

In [167]:
for node_type in node_types:
    idxs = (subgraph.node_type==node_type).nonzero().view((-1,))
    node_name = subgraph._node_type_names[node_type]
    data[node_name].x = embedding[idxs]

In [168]:
import torch_scatter

In [169]:
tad_gene_idxs = data[('tad','overlaps','gene')].edge_index[1]
tad_gene_embeddings = data['gene'].x[tad_gene_idxs]

tad_summed_gene_embeddings = torch.zeros((len(nodes_by_type['tad']),
                                          tad_gene_embeddings.shape[1]),
                                         device=device)

tad_summed_gene_embeddings = torch_scatter.scatter(src=tad_gene_embeddings,
                                                   index=data[('tad','overlaps','gene')].edge_index[0],
                                                   dim=0
                                                  )

In [170]:
tad_summed_gene_embeddings.shape

torch.Size([2860, 128])

In [171]:
(torch.abs(tad_summed_gene_embeddings.sum(axis=1)) > 0).sum()

tensor(2019)

In [172]:
# atac_gene_idxs = data[('atac_region','overlaps','gene')].edge_index[1]
# atac_gene_embeddings = data['gene'].x[atac_gene_idxs]

# atac_summed_gene_embeddings = torch.zeros((len(nodes_by_type['atac_region']),
#                                            atac_gene_embeddings.shape[1]),
#                                           device=device)

# atac_summed_gene_embeddings = torch_scatter.scatter(src=atac_gene_embeddings,
#                                                     index=data[('atac_region','overlaps','gene')].edge_index[0],
#                                                     dim=0,
#                                                    )

# atac_summed_gene_embeddings.shape

# (torch.abs(atac_summed_gene_embeddings.sum(axis=1)) > 0).sum()

# data['atac_region'].x = atac_summed_gene_embeddings

# len(nodes_by_type['atac_region'])

In [173]:
data['tad'].x = tad_summed_gene_embeddings

In [174]:
data

HeteroData(
  [1mprotein[0m={ x=[24946, 128] },
  [1mgene[0m={ x=[18470, 128] },
  [1menhancer[0m={ x=[191827, 128] },
  [1matac_region[0m={ x=[116465, 128] },
  [1mtad[0m={ x=[2860, 128] },
  [1m(tad, overlaps, atac_region)[0m={
    edge_index=[2, 116232],
    num_nodes=118729
  },
  [1m(tad, overlaps, gene)[0m={
    edge_index=[2, 18328],
    num_nodes=20939
  },
  [1m(gene, associated, protein)[0m={
    edge_index=[2, 18470],
    num_nodes=42807
  },
  [1m(atac_region, overlaps, gene)[0m={
    edge_index=[2, 15951],
    num_nodes=39835
  },
  [1m(protein, coexpressed, protein)[0m={
    edge_index=[2, 6509230],
    num_nodes=18989
  },
  [1m(protein, tf_interacts, gene)[0m={
    edge_index=[2, 777242],
    num_nodes=16266
  },
  [1m(protein, is_named, protein_name)[0m={
    edge_index=[2, 54792],
    num_nodes=72821
  },
  [1m(enhancer, overlaps, atac_region)[0m={
    edge_index=[2, 31950],
    num_nodes=47363
  },
  [1m(enhancer, associated, gene)[0m={
  

## Add ATAC DNABERT embeddings

In [175]:
atac_embedding = torch.load('input/projected_atac_sequences.torch')

In [176]:
atac_embedding.shape

torch.Size([116468, 128])

In [177]:
data['atac_region'].x.shape

torch.Size([116465, 128])

In [178]:
embedded_atac_ids = [x.strip() for x in open('input/atac_list.txt')]

In [179]:
len(nodes_by_type['atac_region'])

116465

In [180]:
matched_atac_idxs = []
for i in range(len(embedded_atac_ids)):
    if embedded_atac_ids[i] in nodes_by_type['atac_region']:
        matched_atac_idxs.append(i)
    else:
        print(embedded_atac_ids[i])

chrY-11295162-11295942
chrY-11333659-11334343
chrY-56836454-56837350


In [181]:
len(matched_atac_idxs)

116465

In [182]:
matched_atac_embeddings = atac_embedding[matched_atac_idxs]

In [183]:
matched_atac_embeddings.shape

torch.Size([116465, 128])

In [184]:
data['atac_region'].x = torch.cat([data['atac_region'].x, matched_atac_embeddings], axis=1)

In [185]:
data

HeteroData(
  [1mprotein[0m={ x=[24946, 128] },
  [1mgene[0m={ x=[18470, 128] },
  [1menhancer[0m={ x=[191827, 128] },
  [1matac_region[0m={ x=[116465, 256] },
  [1mtad[0m={ x=[2860, 128] },
  [1m(tad, overlaps, atac_region)[0m={
    edge_index=[2, 116232],
    num_nodes=118729
  },
  [1m(tad, overlaps, gene)[0m={
    edge_index=[2, 18328],
    num_nodes=20939
  },
  [1m(gene, associated, protein)[0m={
    edge_index=[2, 18470],
    num_nodes=42807
  },
  [1m(atac_region, overlaps, gene)[0m={
    edge_index=[2, 15951],
    num_nodes=39835
  },
  [1m(protein, coexpressed, protein)[0m={
    edge_index=[2, 6509230],
    num_nodes=18989
  },
  [1m(protein, tf_interacts, gene)[0m={
    edge_index=[2, 777242],
    num_nodes=16266
  },
  [1m(protein, is_named, protein_name)[0m={
    edge_index=[2, 54792],
    num_nodes=72821
  },
  [1m(enhancer, overlaps, atac_region)[0m={
    edge_index=[2, 31950],
    num_nodes=47363
  },
  [1m(enhancer, associated, gene)[0m={
  

## Initialize protein_names

In [186]:
protein_idxs = data[('protein','is_named','protein_name')].edge_index
protein_embeddings = data['protein'].x[protein_idxs[0]]

name_summed_protein_embeddings = torch.zeros((len(nodes_by_type['protein_name']),
                                              protein_embeddings.shape[1]),
                                          device=device)

name_summed_protein_embeddings = torch_scatter.scatter(src=protein_embeddings,
                                                       index=protein_idxs[1],
                                                       dim=0)

data['protein_name'].x = name_summed_protein_embeddings

len(nodes_by_type['protein_name'])

52958

In [187]:
(torch.abs(name_summed_protein_embeddings.sum(axis=1)) > 0).sum()

tensor(52958)

In [188]:
data['protein_name'].x.shape

torch.Size([52958, 128])

In [189]:
len(nodes_by_type['protein_name'])

52958

In [190]:
len(nodes_by_type['protein'])

24946

In [191]:
data[('protein','is_named','protein_name')].edge_index[0].unique().shape

torch.Size([19863])

In [192]:
data[('protein','is_named','protein_name')].edge_index[1].unique().shape

torch.Size([52958])

In [193]:
data

HeteroData(
  [1mprotein[0m={ x=[24946, 128] },
  [1mgene[0m={ x=[18470, 128] },
  [1menhancer[0m={ x=[191827, 128] },
  [1matac_region[0m={ x=[116465, 256] },
  [1mtad[0m={ x=[2860, 128] },
  [1mprotein_name[0m={ x=[52958, 128] },
  [1m(tad, overlaps, atac_region)[0m={
    edge_index=[2, 116232],
    num_nodes=118729
  },
  [1m(tad, overlaps, gene)[0m={
    edge_index=[2, 18328],
    num_nodes=20939
  },
  [1m(gene, associated, protein)[0m={
    edge_index=[2, 18470],
    num_nodes=42807
  },
  [1m(atac_region, overlaps, gene)[0m={
    edge_index=[2, 15951],
    num_nodes=39835
  },
  [1m(protein, coexpressed, protein)[0m={
    edge_index=[2, 6509230],
    num_nodes=18989
  },
  [1m(protein, tf_interacts, gene)[0m={
    edge_index=[2, 777242],
    num_nodes=16266
  },
  [1m(protein, is_named, protein_name)[0m={
    edge_index=[2, 54792],
    num_nodes=72821
  },
  [1m(enhancer, overlaps, atac_region)[0m={
    edge_index=[2, 31950],
    num_nodes=47363
  },


## Save the data

In [194]:
torch.save(data, 'input/graph_with_embeddings.torch')