In [1]:
import torch_geometric as pyg
import torch
import numpy as np
import scanpy
from collections import defaultdict
import pickle

In [2]:
import random

In [3]:
from torch_geometric.data import HeteroData

In [4]:
from collections import Counter

In [5]:
import torch.nn.functional as F

In [6]:
data_files = !find 'output/datasets/predict_modality/' -type f -name '*train*h5ad'

In [7]:
data = {}
for filename in data_files:
    data[filename.split('/')[-1]] = scanpy.read_h5ad(filename)

In [8]:
def proteins_to_idxs(data):
    indexes = []
    proteins = data.var.index.to_list()
    for protein_name in proteins:
        protein_name = protein_name.upper()
        if protein_name in node_idxs['gene_name']:
            indexes.append(node_idxs['gene_name'][protein_name])
        else:
            indexes.append(None)
    return indexes, data.X

In [9]:
def genes_to_idxs(data):
    indexes = []
    genes = data.var['gene_ids'].to_list()
    for gene_id in genes:
        if gene_id in node_idxs['gene']:
            indexes.append(node_idxs['gene'][gene_id])
        else:
            indexes.append(None)
    return indexes, data.X

In [10]:
def atac_to_idxs(data):
    indexes = {}
    regions = data.var.index.to_list()
    for region in regions:
        if region in node_idxs['atac_region']:
            indexes.append(node_idxs['atac_region'][region])
        else:
            indexes.append(None)
    return indexes, data.X

In [11]:
from torch_geometric.nn import GATConv
from torch import tensor
import torch

In [12]:
device='cuda:0'

## Protein to gene expression

In [13]:
graph = torch.load('input/graph_with_embeddings.torch')
node_idxs = pickle.load(open('input/nodes_by_type.pickle','rb'))

In [14]:
protein_data = data['openproblems_bmmc_cite_phase1_mod2.censor_dataset.output_train_mod1.h5ad']
protein_idxs, protein_expression = proteins_to_idxs(protein_data)

In [15]:
gene_data =    data['openproblems_bmmc_cite_phase1_mod2.censor_dataset.output_train_mod2.h5ad']
gene_idxs, gene_expression = genes_to_idxs(gene_data)

In [16]:
graph = graph.to('cpu')

In [17]:
graph = pyg.transforms.ToUndirected()(graph)

In [18]:
graph = graph.to(device)

In [19]:
protein_mask = torch.zeros((len(node_idxs['gene_name']),1), dtype=bool, device=device)
protein_mask[protein_idxs] = 1
graph['gene_name']['mask'] = protein_mask

In [20]:
gene_mask = torch.zeros((len(node_idxs['gene']),1), dtype=bool, device=device)
gene_mask[[idx for idx in gene_idxs if idx]] = 1
graph['gene']['mask'] = gene_mask

In [21]:
graph['gene'].mask.sum()

tensor(12437, device='cuda:0')

In [22]:
graph['gene_name'].mask.sum()

tensor(134, device='cuda:0')

## Function to create new data object with expression values

In [23]:
def append_expression(graph, cell_idx):
    newgraph = HeteroData()
    
    gene = tensor(gene_expression[cell_idx].todense())
    protein = tensor(protein_expression[cell_idx].todense())
        
    expression = dict()

    for node_type in ['gene_name', 'gene', 'atac_region']:
        expression[node_type] = torch.ones((
            len(node_idxs[node_type]),
            1
        ),device=device)*-1
    
    for i in range(gene.shape[1]):
        if gene_idxs[i]:
            expression['gene'][gene_idxs[i]] = gene[:,i]
    
    for i in range(protein.shape[1]):
        if protein_idxs[i]:
            expression['gene_name'][protein_idxs[i]] = protein[:,i]

    newgraph['gene_name'].y = expression['gene_name']
    newgraph['gene_name'].x = torch.ones((len(node_idxs['gene_name']),1),device=device)*-1

    newgraph['gene'].x = torch.cat([
        graph['gene'].x,
        expression['gene']
    ], dim=1)
        
    newgraph['atac_region'].x = torch.cat([
        graph['atac_region'].x,
        expression['atac_region']
    ], dim=1)
    
    newgraph['tad'].x = graph['tad'].x
    newgraph['protein'].x = graph['protein'].x
    
    for edge_type, store in graph._edge_store_dict.items():
        for k,v in store.items():
            newgraph[edge_type][k]=v
    
    return newgraph


## EARL = Expression and Representation Learner

✓ For each cell, create a data vector.

✓ Data level batching

Graph level batching (is this necessary?)

✓ Metapath or TransE for featureless (all) nodes?

Random masking (self supervision)

✓ Backprop loss of just unknown

✓ Make most graphs undirected. Remove incoming edges to known nodes?

✓ Create a GNN

Train GNN

In [25]:
from torch_geometric.nn import to_hetero, SAGEConv

In [26]:
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear

class EaRL(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super().__init__()
        self.hidden_channels = hidden_channels

        self.convs = torch.nn.ModuleList()
        self.linear = Linear(hidden_channels,1)
        for _ in range(num_layers):
            conv = HeteroConv({
                ('tad', 'overlaps', 'atac_region'): SAGEConv((-1, -1), hidden_channels),
                ('tad', 'overlaps', 'gene'): SAGEConv((-1, -1), hidden_channels),
                ('atac_region', 'rev_overlaps', 'tad'): SAGEConv((-1, -1), hidden_channels),
                ('atac_region', 'overlaps', 'gene'): SAGEConv((-1, -1), hidden_channels),
                ('protein', 'coexpressed', 'protein'): SAGEConv((-1, -1), hidden_channels),
                ('protein', 'tf_interacts', 'gene'): SAGEConv((-1, -1), hidden_channels),
                ('protein', 'trrust_interacts', 'gene'): SAGEConv((-1, -1), hidden_channels),
                ('gene', 'rev_overlaps', 'tad'): SAGEConv((-1, -1), hidden_channels),
                ('gene', 'rev_overlaps', 'atac_region'): SAGEConv((-1, -1), hidden_channels),
                ('gene', 'rev_trrust_interacts', 'protein'): SAGEConv((-1, -1), hidden_channels),
                ('gene', 'rev_tf_interacts', 'protein'): SAGEConv((-1, -1), hidden_channels),
                ('protein', 'rev_associated', 'gene'): SAGEConv((-1, -1), hidden_channels),
                ('gene', 'associated', 'protein'): SAGEConv((-1, -1), hidden_channels),
                ('protein', 'is_named', 'gene_name'): SAGEConv((-1, -1), hidden_channels),
                ('gene_name', 'rev_is_named', 'protein'): SAGEConv((-1, -1), hidden_channels)
            })

            self.convs.append(conv)
        self.name_conv = HeteroConv({('protein', 'is_named', 'gene_name'): SAGEConv((-1, -1), 1)})

    def forward(self, x_dict, edge_index_dict):
        #         gene_names = x_dict['gene_name']
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
            x_dict['gene_name'] = self.linear(x_dict['gene_name'])
            #         x_dict['gene_name'] = gene_names
            #         names = self.name_conv(x_dict, edge_index_dict)
            #         x_dict['gene_name'] = names['gene_name']

            return x_dict


## Training

## TODO trim the gene names, we have way more gene names than we have in the data

In [27]:
num_cells = gene_expression.shape[0]

In [28]:
train_set_size = int(num_cells*.7)

In [29]:
val_set_size = num_cells - train_set_size

In [30]:
# earl = to_hetero(EaRL(64), graph.metadata(), aggr='sum')
earl = EaRL(hidden_channels=64, num_layers=3)
earl = earl.to('cuda:0')
optimizer = torch.optim.Adam(params=earl.parameters(), lr=.001)
earl.train()

n_epochs = 10

cell_idxs = list(range(gene_expression.shape[0]))
random.shuffle(cell_idxs)
cell_idxs = cell_idxs[:train_set_size]

batch_size = 20
for epoch in range(n_epochs):
    mask = graph['gene_name']['mask']
    total_loss = 0
    batch_start = 0
    batch_end = batch_size
    
    num_predictions = min(batch_size-batch_end+len(cell_idxs), batch_size)
    
    batch_idx = 0
    while batch_end < len(cell_idxs)+batch_size:
        optimizer.zero_grad()
        predictions = torch.zeros((
            num_predictions,
            mask.sum()
        ), device=device)
        for i,idx in enumerate(cell_idxs[batch_start:batch_end]):
            newgraph = append_expression(graph, idx)
            predictions[i] = earl(newgraph.x_dict, newgraph.edge_index_dict)['gene_name'][mask].flatten()

        y = newgraph['gene_name'].y[mask]

        loss = ((predictions - y)**2).sum()
        print(f'Batch: {batch_idx}, Loss: {float(loss)}')
        loss.backward()
        optimizer.step()
        batch_start += batch_size
        batch_end += batch_size
        batch_idx += 1
        
    print({'Epoch: {epoch}, Epoch loss: {float(loss)}'})

Batch: 0, Loss: 2065.065185546875
Batch: 1, Loss: 2044.35107421875
Batch: 2, Loss: 1143.814208984375
Batch: 3, Loss: 1375.7183837890625
Batch: 4, Loss: 1414.009033203125
Batch: 5, Loss: 780.3236694335938
Batch: 6, Loss: 626.9471435546875
Batch: 7, Loss: 681.8627319335938
Batch: 8, Loss: 2575.94775390625
Batch: 9, Loss: 621.9063720703125
Batch: 10, Loss: 1047.905517578125
Batch: 11, Loss: 815.927001953125
Batch: 12, Loss: 1354.922607421875
Batch: 13, Loss: 681.4703369140625
Batch: 14, Loss: 1164.22314453125
Batch: 15, Loss: 879.9378051757812
Batch: 16, Loss: 685.1246337890625
Batch: 17, Loss: 624.9552001953125
Batch: 18, Loss: 1130.24853515625
Batch: 19, Loss: 1068.029296875
Batch: 20, Loss: 1232.4345703125
Batch: 21, Loss: 2133.94970703125
Batch: 22, Loss: 1174.459228515625
Batch: 23, Loss: 1549.676513671875
Batch: 24, Loss: 549.6134033203125
Batch: 25, Loss: 1190.075927734375
Batch: 26, Loss: 914.02001953125
Batch: 27, Loss: 945.0771484375
Batch: 28, Loss: 1129.7239990234375
Batch: 29

KeyboardInterrupt: 

## Maybe try sampling zero/one for gene dropouts

In [None]:
from torch.distributions import Bernoulli

In [None]:
bernoulli = Bernoulli(torch.tensor([.5,.1]))

In [None]:
bernoulli.sample((10,))