In [1]:
import pandas as pd
import requests
from tqdm import tqdm
import random
import matplotlib.pyplot as plt



In [2]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torch.nn as nn
from torch_geometric.data.batch import Batch
from torch_geometric.data import Data
from torch.utils.data import Dataset

  import scipy.sparse


In [3]:
scRNA_data = pd.read_csv('GSE200981_scRNAseq_processed.tsv', sep='\t')
scRNA_data.index = scRNA_data['Gene.names']
scRNA_data = scRNA_data.drop('Gene.names', axis=1)
len(scRNA_data)

26364

In [4]:
#Mapping string to protein names
string_api_url = "https://string-db.org/api"
output_format = "tsv-no-header"
method = "get_string_ids"

params = {

    "identifiers" : "\r".join(list(scRNA_data.index)), # your protein list
    "limit": 1,
    "echo_query": 1,
    "species" : 9606, # species NCBI identifier 
    "caller_identity" : "www.awesome_app.org" # your app name

}

request_url = "/".join([string_api_url, output_format, method])

results = requests.post(request_url, data=params)


protein_2_string = dict()
string_2_protein = dict()

for line in results.text.strip().split("\n"):
    l = line.split("\t")
    protein_identifier, string_identifier = l[0], l[2]
    protein_2_string[protein_identifier] = string_identifier
    string_2_protein[string_identifier] = protein_identifier

In [5]:
scRNA_data = scRNA_data.loc[list(protein_2_string.keys())]
scRNA_data

Unnamed: 0_level_0,V1_T0,V2_T0,V3_T0,V4_T0,V5_T0,V6_T0,V7_T0,V8_T0,V9_T0,V10_T0,...,V247_T7,V248_T7,V249_T7,V250_T7,V251_T7,V252_T7,V253_T7,V254_T7,V255_T7,V256_T7
Gene.names,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
OR4F5,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
OR4F3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
OR4F29,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
OR4F16,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SAMD11,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
DAZ1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
DAZ3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
DAZ2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
CDY1B,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [6]:
import gc
torch.cuda.empty_cache()
gc.collect()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
def get_all_protein_pairs(protein_list):
    protein_pairs = []
    for i in range(len(protein_list)):
        protein1 = protein_list[i]
        for j in range(i+1, len(protein_list)):
            protein2 = protein_list[j]
        
            protein_pairs.append((protein1, protein2))
            
    return protein_pairs

class Contrastive_Dataset(Dataset):
    def __init__(self, scRNA_data, string_2_protein):
        filename = '9606.protein.links.v12.0.txt'

        file = open(filename, 'r')
        lines = file.readlines()
        lines.pop(0)

        string_2_index = dict()
        counter = 0
        for string_id in string_2_protein:
            string_2_index[string_id] = counter
            counter += 1

        list_network = list()
        
        """self.train_node_features = list()
        self.train_list_outputs = list()
        
        self.test_node_features = list()
        self.test_list_outputs = list()
        
        self.val_node_features = list()
        self.val_list_outputs = list()"""
        
        self.e_nodes = list()
        self.m_nodes = list()

        print('Getting network tensor...')
        for line in tqdm(lines):
            line = line.strip().split(' ')

            if int(line[2]) >= 999:

                try:
                    id1 = string_2_index[line[0]]
                    id2 = string_2_index[line[1]]
                    list_network.append([id1, id2])
                    list_network.append([id2, id1])

                except KeyError:
                    continue

        print('Getting node features tensor...')
        T0_column_vals = [column for column in scRNA_data.columns if 'T0' in column]
        T8_column_vals = [column for column in scRNA_data.columns if 'T7' in column]
        
        proteins = set([string_2_protein[string_id] for string_id in string_2_index])
        #train_proteins, test_proteins, validate_proteins = np.split(proteins, [int(len(proteins)*0.7), int(len(proteins*0.9))])
        
        #train_protein_pairs = get_all_protein_pairs(train_proteins)
        #testing_protein_pairs = get_all_protein_pairs(test_proteins)
        #validate_protein_pairs = get_all_protein_pairs(validate_proteins)
        
        #for 
        
        for column in T0_column_vals:   
            self.e_nodes.append(torch.tensor([scRNA_data.loc[protein, column] for protein in proteins], dtype=torch.float32).to(device))
                                        
        for column in T8_column_vals:
            self.m_nodes.append(torch.tensor([scRNA_data.loc[protein, column] for protein in proteins], dtype=torch.float32).to(device))
        
        self.edge_index = torch.tensor(list_network).t().contiguous()
        self.edge_index = self.edge_index.to(device)
        #print(self.e_nodes[0].view(len(self.e_nodes[0]), 1))
        self.e_nodes = [i.view(len(i), 1) for i in self.e_nodes]
        self.m_nodes = [i.view(len(i), 1) for i in self.m_nodes]
        #self.node_features = torch.tensor(self.node_features, dtype=torch.float)
        #self.list_outputs = torch.tensor(self.list_outputs, dtype=torch.float)
    
    def get_e_nodes(self):
        return self.e_nodes[0:5]
    
    def get_m_nodes(self):
        return self.m_nodes[0:5]
    
    def get_edge_index(self):
        return self.edge_index

In [15]:
contrastive_dataset = Contrastive_Dataset(scRNA_data, string_2_protein)

Getting network tensor...


100%|██████████████████████████| 13715404/13715404 [00:05<00:00, 2397328.15it/s]


Getting node features tensor...


In [16]:
class ContrastiveLoss(nn.Module):
    def __init__(self):
        super(ContrastiveLoss, self).__init__()
        self.temp = 0.05
        self.cos = nn.CosineSimilarity(dim=1)
    
    def forward(self, zi, zj, xneg, edge_index, model):
        graph_list = [edge_index + xneg.shape[1]*i for i in range(xneg.shape[0])]
        graphs = torch.cat(graph_list, 1).to(device)
        num = torch.exp(self.cos(xi, xj)/self.temp)
        
        zneg = model(xneg, graphs)
        denom = torch.exp(torch.divide(self.cos(zi, zneg),self.temp))
        denom = torch.sum(denom)
        
        return torch.multiply(torch.log(torch.divide(num, denom)), -1)

In [17]:
class GAT_Contrast(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super().__init__()

        self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout = 0.6)
        self.output_layer = nn.Linear(out_channels*heads, 2)
        self.conv2 = GATConv(hidden_channels*heads, out_channels, heads, dropout=0.6)
        self.num_nodes = 18840
        self.columns = out_channels*heads

    def forward(self, x, edge_index):
        batch_size = x.shape[0]//self.num_nodes
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        x = torch.reshape(x, (batch_size, self.num_nodes, self.columns))
        x = torch.mean(x, dim=2)
        return x

In [18]:
from torch_geometric.nn import GATConv

def GAT_contrast_train(model, dataset, epochs, num_nodes, lr = 1e-6, weight_decay = 5e-4, temp=0.05):
    torch.cuda.empty_cache()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = ContrastiveLoss()
    model.train()
    optimizer.zero_grad()
    all_losses = []
    
    e_nodes = dataset.get_e_nodes()
    m_nodes = dataset.get_m_nodes()
    edge_index = dataset.get_edge_index()
    
    for _ in tqdm(range(epochs)):
        losses = []
        
        print('Training on E cells...')
        for i in range(len(e_nodes)):
            tensori = e_nodes[i]
            #print(xi.shape)
            
            for j in range(len(e_nodes)):
                print('Hello')
                tensorj = e_nodes[j]
                
                if i == j:
                    continue
            
                xi = model(tensori, edge_index)
                xj = model(tensorj, edge_index)
                
                loss = criterion(xi, xj, m_nodes, edge_index, model)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #print(loss.item())
                losses.append(loss.item())
                
            all_losses.append(sum(losses)/len(losses))
            print(all_losses)
        
    plt.plot([i for i in range(1, (len(all_losses)+1))], all_losses)
    plt.xlabel('Number of Epochs')
    plt.ylabel('Loss')
    plt.title('GATConv Contrastive Training')

In [19]:
gat_contrast = GAT_Contrast(1, 16, 18840, 4).to(device)

In [20]:
GAT_contrast_train(gat_contrast, contrastive_dataset, 1, scRNA_data.shape[0])

  0%|                                                     | 0/1 [00:00<?, ?it/s]

Training on E cells...
Hello
Hello





OutOfMemoryError: CUDA out of memory. Tried to allocate 5.29 GiB. GPU 0 has a total capacity of 7.78 GiB of which 1.89 GiB is free. Including non-PyTorch memory, this process has 5.70 GiB memory in use. Of the allocated memory 5.49 GiB is allocated by PyTorch, and 14.24 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)