In [255]:
import torch
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.nn.inits import glorot, zeros


# optimized version of the standard GCN model
class MyGCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_index, num_nodes, improved=False, cached=False,
                 bias=True, normalize=True, **kwargs):
        super(MyGCNConv, self).__init__(aggr='add', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.normalize = normalize

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

        self.num_nodes = num_nodes
        self.edge_index = edge_index
        self.edge_index, self.norm = self.norm(edge_index, self.num_nodes, improved=True)

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)
        self.cached_result = None
        self.cached_num_edges = None

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight=None, improved=False,
             dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)

        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_weight=None):
        """"""
        x = torch.matmul(x, self.weight)
        return self.propagate(self.edge_index, x=x, norm=self.norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j if norm is not None else x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)
    

In [7]:
# read protein interactions

import networkx as nx
import csv

G=nx.Graph()

node2id = {}
count = 0

file = open("string/9606.protein.actions.v11.0.txt")
reader = csv.reader(file, delimiter='\t')
firstline = True
for column in reader:
    if firstline:
        firstline = False
    else:
        n1 = -1
        n2 = -1
        if column[0] in node2id:
            n1 = node2id[column[0]]
        else:
            node2id[column[0]] = count
            n1 = count
            count += 1

        if column[1] in node2id:
            n2 = node2id[column[1]]
        else:
            node2id[column[1]] = count
            n2 = count
            count += 1

        G.add_node(n1)
        G.add_node(n2)
        score = int(column[6])
        if score>=700:
            G.add_edge(n1, n2)


In [8]:
# read map of gene ids to protein ids, map to nodes
file = open("string/9606.protein.aliases.v11.0.txt")
reader = csv.reader(file, delimiter='\t')
firstline = True
gene2protein = {}
for col in reader:
    if firstline:
        firstline = False
    else:
        prot = col[0]
        gene = col[1]
        if gene.find("ENSG")>-1 and prot in node2id:
            gene2protein[gene] = prot


In [257]:
# generate PyTorch datastructure from networkx

import torch_geometric.utils as util
import gzip
import csv
import os
import torch

network = util.from_networkx(G)



In [302]:
# Read the cancer data -> read in tensor of shape NumSamples x NumNodes x NumFeatures [NumFeatures == 1]

from torch_geometric.data import Data

dataset = []
dataset_plain = []
count = 0
for i in [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser("cancer/")) for f in fn]:
    if str(i).find("FPKM-UQ") > -1 and count < 20:
        count += 1
        file = gzip.open(i, mode="rt")
        csvobj = csv.reader(file, delimiter = '\t')
        localX = torch.zeros(network.num_nodes,1)
        for line in csvobj:
            gene = line[0].split(".")[0]
            if gene in gene2protein:
                protein = gene2protein[gene]
                nodeid = node2id[protein]
                evalue = float(line[1])
                #print(gene + " is not missing: " + protein + " with id " + str(nodeid))
                localX[nodeid] = evalue
            else:
                continue
        data = Data(x = localX, y = localX)
        dataset.append(data)
        dataset_plain.append(localX)

# normalize
maxval = torch.max(torch.stack(dataset_plain))
dataset_plain = [t / maxval for t in dataset_plain]


In [303]:
# define the model layout -> use custom GCN model defined above

import torch
import torch.nn.functional as F

class Net(torch.nn.Module):
    def __init__(self, edge_index, num_nodes):
        super(Net, self).__init__()
        self.conv1 = MyGCNConv(1, 8, edge_index, num_nodes, node_dim = 1)
        self.conv2 = MyGCNConv(8, 1, edge_index, num_nodes, node_dim = 1 )

    def forward(self, data):
        x = data
        x = self.conv1(x)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x)
        out = torch.sigmoid(x)
        return out


In [306]:
from torch.utils.data import DataLoader

BATCHSIZE = 10

trainmask = torch.zeros(BATCHSIZE, data.num_nodes).bool().random_(0, 10)
testmask = torch.ones(BATCHSIZE, data.num_nodes).bool()
validationmask = torch.ones(BATCHSIZE, data.num_nodes).bool()
for i in range(0, len(data) - 1):
    for j in range(0, data.num_nodes):
        if trainmask[i][j]:
            testmask[i][j] = False
        else:
            testmask[i][j] = True

loader = DataLoader(dataset_plain, batch_size = BATCHSIZE, shuffle = True)

model = Net(network.edge_index, network.num_nodes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(10):
    for batch in loader:
        model.train()
        optimizer.zero_grad()
        out = model(batch)
        loss = F.mse_loss(out[trainmask], batch[trainmask])
        loss.backward()
        optimizer.step()
    if epoch % 2 == 0:
        out = model(batch)
        testloss = F.mse_loss(out[testmask], batch[testmask])
        print("Epoch " + str(epoch) + ": " + str(testloss))


Epoch 0: tensor(0.2438, grad_fn=<MseLossBackward>)


Epoch 2: tensor(0.2310, grad_fn=<MseLossBackward>)


Epoch 4: tensor(0.2145, grad_fn=<MseLossBackward>)


Epoch 6: tensor(0.1968, grad_fn=<MseLossBackward>)


Epoch 8: tensor(0.1785, grad_fn=<MseLossBackward>)
