In [1]:
import pandas as pd
import requests
from tqdm import tqdm
import random
import matplotlib.pyplot as plt

In [2]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torch.nn as nn
from torch_geometric.data.batch import Batch
from torch_geometric.data import Data
from torch.utils.data import Dataset

In [3]:
scRNA_data = pd.read_csv('GSE200981_scRNAseq_processed.tsv', sep='\t')
scRNA_data.index = scRNA_data['Gene.names']
scRNA_data = scRNA_data.drop('Gene.names', axis=1)
len(scRNA_data)

26364

In [4]:
#Mapping string to protein names
string_api_url = "https://string-db.org/api"
output_format = "tsv-no-header"
method = "get_string_ids"

params = {

    "identifiers" : "\r".join(list(scRNA_data.index)), # your protein list
    "limit": 1,
    "echo_query": 1,
    "species" : 9606, # species NCBI identifier 
    "caller_identity" : "www.awesome_app.org" # your app name

}

request_url = "/".join([string_api_url, output_format, method])

results = requests.post(request_url, data=params)


protein_2_string = dict()
string_2_protein = dict()

for line in results.text.strip().split("\n"):
    l = line.split("\t")
    protein_identifier, string_identifier = l[0], l[2]
    protein_2_string[protein_identifier] = string_identifier
    string_2_protein[string_identifier] = protein_identifier

In [5]:
scRNA_data = scRNA_data.loc[list(protein_2_string.keys())]
scRNA_data

Unnamed: 0_level_0,V1_T0,V2_T0,V3_T0,V4_T0,V5_T0,V6_T0,V7_T0,V8_T0,V9_T0,V10_T0,...,V247_T7,V248_T7,V249_T7,V250_T7,V251_T7,V252_T7,V253_T7,V254_T7,V255_T7,V256_T7
Gene.names,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
OR4F5,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
OR4F3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
OR4F29,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
OR4F16,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SAMD11,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
DAZ1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
DAZ3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
DAZ2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
CDY1B,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [6]:
import gc
torch.cuda.empty_cache()
gc.collect()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'

In [7]:
device

device(type='cuda')

In [18]:
def get_all_protein_pairs(protein_list):
    protein_pairs = []
    for i in range(len(protein_list)):
        protein1 = protein_list[i]
        for j in range(i+1, len(protein_list)):
            protein2 = protein_list[j]
        
            protein_pairs.append((protein1, protein2))
            
    return protein_pairs

class Contrastive_Dataset(Dataset):
    def __init__(self, scRNA_data, string_2_protein, batch_size):
        self.counter = 0
        self.e_counter = 0
        self.m_counter = 0
        self.batch_size = batch_size
        filename = '9606.protein.links.v12.0.txt'

        file = open(filename, 'r')
        lines = file.readlines()
        lines.pop(0)

        string_2_index = dict()
        counter = 0
        for string_id in string_2_protein:
            string_2_index[string_id] = counter
            counter += 1

        list_network = list()
        
        """self.train_node_features = list()
        self.train_list_outputs = list()
        
        self.test_node_features = list()
        self.test_list_outputs = list()
        
        self.val_node_features = list()
        self.val_list_outputs = list()"""
        
        self.e_nodes = list()
        self.m_nodes = list()

        print('Getting network tensor...')
        for line in tqdm(lines):
            line = line.strip().split(' ')

            if int(line[2]) >= 999:

                try:
                    id1 = string_2_index[line[0]]
                    id2 = string_2_index[line[1]]
                    list_network.append([id1, id2])
                    list_network.append([id2, id1])

                except KeyError:
                    continue

        print('Getting node features tensor...')
        T0_column_vals = [column for column in scRNA_data.columns if 'T0' in column]
        T8_column_vals = [column for column in scRNA_data.columns if 'T7' in column]
        
        proteins = set([string_2_protein[string_id] for string_id in string_2_index])
        #train_proteins, test_proteins, validate_proteins = np.split(proteins, [int(len(proteins)*0.7), int(len(proteins*0.9))])
        
        #train_protein_pairs = get_all_protein_pairs(train_proteins)
        #testing_protein_pairs = get_all_protein_pairs(test_proteins)
        #validate_protein_pairs = get_all_protein_pairs(validate_proteins)
        
        #for 
        
        for column in T0_column_vals:   
            self.e_nodes.append(torch.tensor([scRNA_data.loc[protein, column] for protein in proteins], dtype=torch.float32).to(device))
                                        
        for column in T8_column_vals:
            self.m_nodes.append(torch.tensor([scRNA_data.loc[protein, column] for protein in proteins], dtype=torch.float32).to(device))
        
        self.edge_index = torch.tensor(list_network).t().contiguous()
        self.edge_index = self.edge_index.to(device)
        
        self.e_nodes = [i.view(len(i), 1) for i in self.e_nodes]
        self.m_nodes = [i.view(len(i), 1) for i in self.m_nodes]
        #self.len = len(self.e_nodes)
        #print(len(self.e_nodes), len(self.m_nodes))
    
    def __getitem__(self, idx):
        return self.e_nodes[idx], self.m_nodes[idx], self.edge_index
        """if self.counter == self.batch_size:
            self.counter = 0
        if self.counter < self.batch_size//2:
            self.counter += 1
            self.e_counter += 1
            return self.e_nodes[self.e_counter - 1], self.edge_index
        elif self.counter >= self.batch_size//2:
            self.counter += 1
            self.m_counter += 1
            return self.m_nodes[self.m_counter-1], self.edge_index"""
            
    def get_e_nodes(self):
        return self.e_nodes
    
    def get_m_nodes(self):
        return self.m_nodes#[0:10]
    
    def get_edge_index(self):
        return self.edge_index

    def shuffle(self):
        random.shuffle(self.e_nodes)
        random.shuffle(self.m_nodes)

    def __len__(self):
        return min(len(self.e_nodes), len(self.m_nodes))

In [19]:
def graph_collate_fn(batch):
    m_node_features = []
    e_node_features = []
    graph_list = []
    counter = 0
    batch_size = len(batch)
    
    for ex in batch:
        e_node, m_node, graph = ex
        num_nodes = e_node.shape[0]
        m_node_features.append(m_node)
        e_node_features.append(e_node)

    for i in range(batch_size):
        graph_list.append(graph + num_nodes*i)
    graphs_2 = graph_list[0:2]

    e_node_features = torch.stack(e_node_features, dim=0)
    e_node_features = torch.reshape(e_node_features, (e_node_features.shape[0]*e_node_features.shape[1], e_node_features.shape[2]))
    
    graphs = torch.cat(graph_list, 1)
    graphs_2 = torch.cat(graphs_2, 1)

    m_node_features = torch.stack(m_node_features, dim=0)
    m_node_features = torch.reshape(m_node_features, (m_node_features.shape[0]*m_node_features.shape[1], m_node_features.shape[2]))
    
    return e_node_features, m_node_features, graphs_2, graphs

In [20]:
batch_size = 20
contrastive_dataset = Contrastive_Dataset(scRNA_data, string_2_protein, batch_size)

Getting network tensor...


100%|██████████| 13715404/13715404 [00:06<00:00, 1971411.62it/s]


Getting node features tensor...
383 256


In [11]:
#contrastive_dataloader = DataLoader(contrastive_dataset, batch_size=batch_size, shuffle=True, collate_fn=graph_collate_fn)

In [12]:
def cosine_similarity(x1, x2, temp):
    return (torch.dot(x1.reshape(x1.shape[1]), x2.reshape(x2.shape[1]))/(torch.norm(x1)*torch.norm(x2)))/temp
    
class ContrastiveLoss(nn.Module):
    def __init__(self):
        super(ContrastiveLoss, self).__init__()
        self.temp = 0.05
    
    def forward(self, zn, zd):
        zi = zn[0, :, :]
        zj = zn[1, :, :]
        num = cosine_similarity(zi, zj, self.temp)

        denom = sum([torch.exp(cosine_similarity(zi, zd[i, :, :], self.temp)) for i in range(zd.shape[0])])
        return torch.multiply(torch.log(torch.divide(num, denom)), -1)

In [13]:
def train_model(epochs, dataset, gat, mlp, optimizer):
    torch.cuda.empty_cache()
    criterion = ContrastiveLoss()
    plot_losses = []
    for _ in tqdm(range(epochs)):
        dataset.shuffle()
        dataloader = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=graph_collate_fn)
        torch.cuda.empty_cache()
        
        for batch in tqdm(dataloader):
            batch_losses = []
            e_nodes, m_nodes, graphs_2, graphs = batch
            batch_size = e_nodes.shape[0]//18840
            e_nodes_unstacked = torch.reshape(e_nodes, (batch_size, 1, -1))

            zd = gat(m_nodes, graphs)
            zd = mlp(zd)

            all_losses = []
            for i in range(e_nodes_unstacked.shape[0]):
                xi = e_nodes_unstacked[i]
                #print(xi)
                for j in range(e_nodes_unstacked.shape[0]):
                    if i == j: continue
                    xj = e_nodes_unstacked[j]
                    xn = torch.stack([xi, xj], dim=0)
                    xn = torch.reshape(xn, (xn.shape[0]*xn.shape[2], xn.shape[1]))
                    zn = gat(xn, graphs_2)
                    zn = mlp(zn)

                    loss = criterion(zn, zd)
                    all_losses.append(loss)
                    
            final_loss = sum(all_losses)/len(all_losses)
            plot_losses.append(final_loss.item())
            #print(final_loss)
            final_loss.backward()
            optimizer.step()
            #final_loss.detach()

            m_nodes_unstacked = torch.reshape(m_nodes, (batch_size, 1, -1))

            zd = gat(e_nodes, graphs)
            zd = mlp(zd)

            all_losses = []
            for i in range(m_nodes_unstacked.shape[0]):
                xi = m_nodes_unstacked[i]
                #print(xi)
                for j in range(m_nodes_unstacked.shape[0]):
                    if i == j: continue
                    xj = m_nodes_unstacked[j]
                    xn = torch.stack([xi, xj], dim=0)
                    xn = torch.reshape(xn, (xn.shape[0]*xn.shape[2], xn.shape[1]))
                    zn = gat(xn, graphs_2)
                    zn = mlp(zn)

                    loss = criterion(zn, zd)
                    all_losses.append(loss)
                    
            final_loss = sum(all_losses)/len(all_losses)
            plot_losses.append(final_loss.item())
            #print(final_loss)
            final_loss.backward()
            optimizer.step()
            #final_loss.detach()
        #plot_losses.append(sum(final_loss)/len(final_loss)) 
    return plot_losses

In [14]:
class GAT_Contrast(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super().__init__()

        self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout = 0.6)
        #self.output_layer = nn.Linear(out_channels*heads, 2)
        #self.conv2 = GATConv(hidden_channels*heads, out_channels, heads, dropout=0.6)
        self.num_nodes = 18840
        #self.columns = out_channels*heads

    def forward(self, x, edge_index):
        batch_size = x.shape[0]//self.num_nodes
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        
        return x

class MLP(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.linear1 = nn.Linear(in_channels, hidden_channels)
        self.linear2 = nn.Linear(hidden_channels, out_channels)
        self.num_nodes = 18840

    def forward(self, x):
        batch_size = x.shape[0]//self.num_nodes
        x = torch.reshape(x, (batch_size, 1, -1))
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        return x

In [15]:
import torch_geometric.utils
from torch_geometric.nn import GATConv

def GAT_contrast_train(gat, mlp, dataset, epochs, num_nodes, lr = 1e-4, weight_decay = 5e-4, temp=0.05):
    torch.cuda.empty_cache()
    params = list(gat.parameters()) + list(mlp.parameters())
    optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
    criterion = ContrastiveLoss()
    gat.train()
    mlp.train()
    optimizer.zero_grad()
    
    e_nodes = dataset.get_e_nodes()
    m_nodes = dataset.get_m_nodes()
    edge_index = dataset.get_edge_index()

    losses = train_model(epochs, dataset, gat, mlp, optimizer)
    return losses

In [16]:
#Need to put in tensor batches again
gat_contrast = GAT_Contrast(1, 16, 2, 1).to(device)
mlp_contrast = MLP(301440, 64, 32).to(device)

In [17]:
losses = GAT_contrast_train(gat_contrast, mlp_contrast, contrastive_dataset, 500, scRNA_data.shape[0])

  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/39 [00:00<?, ?it/s][A
  3%|▎         | 1/39 [00:01<00:49,  1.29s/it][A
  5%|▌         | 2/39 [00:02<00:46,  1.24s/it][A
  8%|▊         | 3/39 [00:03<00:42,  1.17s/it][A
 10%|█         | 4/39 [00:04<00:39,  1.14s/it][A
 13%|█▎        | 5/39 [00:05<00:38,  1.12s/it][A
 15%|█▌        | 6/39 [00:06<00:36,  1.11s/it][A
 18%|█▊        | 7/39 [00:07<00:35,  1.11s/it][A
 21%|██        | 8/39 [00:09<00:34,  1.10s/it][A
 23%|██▎       | 9/39 [00:10<00:32,  1.10s/it][A
 26%|██▌       | 10/39 [00:11<00:32,  1.14s/it][A
 28%|██▊       | 11/39 [00:12<00:31,  1.12s/it][A
 31%|███       | 12/39 [00:13<00:30,  1.11s/it][A
 33%|███▎      | 13/39 [00:14<00:28,  1.11s/it][A
 36%|███▌      | 14/39 [00:15<00:27,  1.10s/it][A
 38%|███▊      | 15/39 [00:16<00:26,  1.10s/it][A
 41%|████      | 16/39 [00:17<00:25,  1.10s/it][A
 44%|████▎     | 17/39 [00:18<00:24,  1.10s/it][A
 46%|████▌     | 18/39 [00:20<00:22,  1.09s/it][A
 49%|████

IndexError: list index out of range