In [None]:
import os.path as osp
from math import ceil

import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
import torch_geometric.transforms as T
from torch_geometric.data import DenseDataLoader, Data
from torch_geometric.nn import DenseSAGEConv, dense_diff_pool

from utils.loaders import load_data, get_onehots
from utils.evaluation_metrics import SRR, auprc_auroc_ap

import numpy as np
from sklearn.utils import shuffle

## You need to make the DiffPool encoder 

In [None]:

max_nodes = 150

class MyFilter(object):
    def __call__(self, data):
        return data.num_nodes <= max_nodes


path = osp.join(osp.dirname(osp.realpath('__file__')), '..', 'data',
                'PROTEINS_dense')
dataset = TUDataset(path, name='PROTEINS', transform=T.ToDense(max_nodes),
                    pre_filter=MyFilter())
dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
test_loader = DenseDataLoader(test_dataset, batch_size=20)
val_loader = DenseDataLoader(val_dataset, batch_size=20)
train_loader = DenseDataLoader(train_dataset, batch_size=20)

In [None]:
from utils.path_manage import get_files

data, lookup, ASD_dictionary, BCE_dictionary, Edge_list, Edge_features, Drug_graph_list, Protein_graph_list = get_files()
entities = int(len(lookup)/2)

In [None]:
max_protein_nodes = 150
max_drug_nodes = 150

In [None]:
Drug_list = list(set(data[:,0]))
Protein_list = list(set(data[:,2]))

Drug_graph_dict = {x : y for x, y in zip(Drug_list, Drug_graph_list)}
Protein_graph_dict = {x : y for x, y in zip(Protein_list, Protein_graph_list)}

filtered_data = [x for x in data if not isinstance(Drug_graph_dict[x[0]], str)] 
filtered_data = [x for x in filtered_data if not isinstance(Protein_graph_dict[x[2]], str)] 

filtered_data = [x for x in filtered_data if Drug_graph_dict[x[0]].num_nodes <= max_drug_nodes]
filtered_data = [x for x in filtered_data if Protein_graph_dict[x[2]].num_nodes <= max_protein_nodes]


In [None]:
filtered_data = np.stack(filtered_data)

In [None]:
protien_ids = list(set(filtered_data[:,2]))
protien_ids = torch.LongTensor(protien_ids)

In [None]:
number_of_batches = 5
number_of_epochs = 20
x = shuffle(filtered_data)
dataset = x[:50]

In [None]:
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]


In [None]:
test_dataset

In [None]:
def get_adj_mask(max_nodes, graph):
    num_nodes = graph.num_nodes
    num_features = graph.x.shape[1]

    mask = np.zeros([max_nodes,max_nodes], dtype = bool)
    mask[0:num_nodes][0:num_nodes] = True
    mask = torch.DoubleTensor(mask)

    node_mask = torch.FloatTensor(np.zeros([max_nodes - num_nodes, num_features]))
    nodes = torch.cat([graph.x, node_mask]).double()
    
    adjacency = np.zeros([max_nodes,max_nodes]) # Check if Dtype int is needed! 
    edges = graph.edge_index.T
    for edge in edges:
        adjacency[edge[0]][edge[1]] = 1
        adjacency[edge[1]][edge[0]] = 1
        # should add weighting here!
    adjacency = torch.DoubleTensor(adjacency)

    return Data(x =  nodes, adj = adjacency, mask = mask)
    

In [None]:
# here you need to import your graph lists!!! 

In [None]:
protein_batch = DenseDataLoader([get_adj_mask(max_protein_nodes, Protein_graph_dict[data[2]]) for data in test_dataset], 5)
drug_batch = DenseDataLoader([get_adj_mask(max_drug_nodes, Drug_graph_dict[data[0]]) for data in test_dataset], 5)
relations = torch.LongTensor(test_dataset[:,1])

In [None]:
testo = get_adj_mask(max_protein_nodes, Protein_graph_dict[test_dataset[0,2]])

In [None]:
testo.x.dtype

In [None]:
from models.DiffPool import *

In [None]:
protein_encoder = Diff_Pool_Encoder(max_nodes = 150)
drug_encoder = Diff_Pool_Encoder(max_nodes = 150)
decoder = DistMult_Decoder()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model = Encoder_Decoder(protein_encoder = protein_encoder, drug_encoder= drug_encoder, decoder = decoder, num_relationships= 4).to(device).float()

In [None]:

for p, d in zip(protein_batch , drug_batch):
    print(d.x.dtype, d.adj.dtype, d.mask.dtype)
    print(p.x.dtype, p.adj.dtype, p.mask.dtype)
    prediction = model.forward(rel=relations, d_graph=d.x.double(), d_adj=d.adj.double(), d_mask = d.mask.double(), p_graph = p.x.double(), p_adj = p.adj.double(), p_mask= p.mask.double())

In [None]:
for proteins, drugs in zip(p_adj, d_adj):
    print(proteins.x, drugs.x)

## Expected object of scalar type Double but got scalar type Float for argument #3 'mat2' in call to _th_addmm_out

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Encoder_Decoder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
    model.train()
    loss_all = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output, _, _ = model(data.x, data.adj, data.mask)
        loss = F.nll_loss(output, data.y.view(-1))
        loss.backward()
        loss_all += data.y.size(0) * loss.item()
        optimizer.step()
    return loss_all / len(train_dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0

    for data in loader:
        data = data.to(device)
        pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1]
        correct += pred.eq(data.y.view(-1)).sum().item()
    return correct / len(loader.dataset)


best_val_acc = test_acc = 0
for epoch in range(1, 151):
    train_loss = train(epoch)
    val_acc = test(val_loader)
    if val_acc > best_val_acc:
        test_acc = test(test_loader)
        best_val_acc = val_acc
    print('Epoch: {:03d}, Train Loss: {:.7f}, '
          'Val Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
                                                     val_acc, test_acc))

## Old Diff-Pool model data below

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 normalize=False, add_loop=False, lin=True):
        super(GNN, self).__init__()

        self.add_loop = add_loop

        self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize)
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize)
        self.bn3 = torch.nn.BatchNorm1d(out_channels)

        if lin is True:
            self.lin = torch.nn.Linear(2 * hidden_channels + out_channels,
                                       out_channels)
        else:
            self.lin = None

    def bn(self, i, x):
        batch_size, num_nodes, num_channels = x.size()

        x = x.view(-1, num_channels)
        x = getattr(self, 'bn{}'.format(i))(x)
        x = x.view(batch_size, num_nodes, num_channels)
        return x

    def forward(self, x, adj, mask=None):
        batch_size, num_nodes, in_channels = x.size()

        x0 = x
        x1 = self.bn(1, F.relu(self.conv1(x0, adj, mask, self.add_loop)))
        x2 = self.bn(2, F.relu(self.conv2(x1, adj, mask, self.add_loop)))
        x3 = self.bn(3, F.relu(self.conv3(x2, adj, mask, self.add_loop)))

        x = torch.cat([x1, x2, x3], dim=-1)

        if self.lin is not None:
            x = F.relu(self.lin(x))

        return x


In [None]:
class Diff_Pool_Encoder(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()


        num_nodes = ceil(0.25 * max_nodes)
        self.gnn1_pool = GNN(3, 64, num_nodes, add_loop=True)
        self.gnn1_embed = GNN(3, 64, 64, add_loop=True, lin=False)

        num_nodes = ceil(0.25 * num_nodes)
        self.gnn2_pool = GNN(3 * 64, 64, num_nodes)
        self.gnn2_embed = GNN(3 * 64, 64, 64, lin=False)  # self.lin1 = torch.nn.Linear(3 * 64, 64)
        self.gnn3_embed = GNN(3 * 64, 64, 64, lin=False)   # self.lin2 = torch.nn.Linear(64, 6)   
        


    def forward(self, x, adj, mask=None):

        s = self.gnn1_pool(x, adj, mask)
        x = self.gnn1_embed(x, adj, mask) #, print(x.shape)
        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask) #, print(x.shape)
 
        s = self.gnn2_pool(x, adj)
        x = self.gnn2_embed(x, adj) #, print(x.shape)
        x, adj, l2, e2 = dense_diff_pool(x, adj, s) #, print(x.shape)

        x = self.gnn3_embed(x, adj)#, print(x.shape)

        x = x.mean(dim=1)
        return(x)  #print(x.shape) #x= F.relu(self.lin1(x)) #x= self.lin2(x)                       #return F.log_softmax(x, dim=-1), l1+l2, e1+e2



In [None]:
class DistMult_Decoder(torch.nn.Module):
    def __init__(
        self, args=None, dropout=0.05,
    ):
        super(DistMult_Decoder, self).__init__()
        self.inp_drop = torch.nn.Dropout(dropout)
        # self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, protein_embedded, drug_embedded, rel_embedded):

        drug_embedded = self.inp_drop(drug_embedded)
        protein_embedded = self.inp_drop(protein_embedded)
        rel_embedded = self.inp_drop(rel_embedded)

        print(drug_embedded.shape)
        print(protein_embedded.shape)
        print(rel_embedded.shape)
        pred = torch.mm(drug_embedded * rel_embedded, protein_embedded.transpose(1, 0))

        return pred

In [None]:
class Encoder_Decoder(torch.nn.Module):
    def __init__(self, protein_encoder=Diff_Pool_Encoder(), drug_encoder=Diff_Pool_Encoder(), decoder=DistMult_Decoder(), num_relationships=2):
        super(Encoder_Decoder, self).__init__()

        self.protein_encoder = protein_encoder
        self.drug_encoder = drug_encoder
        self.decoder = decoder #this is the thing to build

        self.emb_rel = torch.nn.Embedding(num_relationships, embedding_dim=64*3, padding_idx=0)

    def init(self):
        xavier_normal_(self.emb_rel.weight.data)

    def forward(self, rel, d_graph, d_adj, d_mask, p_graph, p_adj, p_mask):

        rel_embedded = self.emb_rel(rel)
        rel_embedded = rel_embedded.squeeze()

        drug_embedded = self.drug_encoder(d_graph, d_adj, d_mask)
        protein_embedded = self.protein_encoder(d_graph, d_adj, d_mask)

        prediction = self.decoder(protein_embedded, drug_embedded, rel_embedded)

        print(prediction)
        return(prediction)



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Encoder_Decoder().to(device)

In [None]:
for data in test_loader:
    model.forward(rel=data.y, d_graph=data.x, d_adj=data.adj, d_mask = data.mask, p_graph = data.x, p_adj = data.adj, p_mask= data.mask)

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
    model.train()
    loss_all = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output, _, _ = model(data.x, data.adj, data.mask)
        loss = F.nll_loss(output, data.y.view(-1))
        loss.backward()
        loss_all += data.y.size(0) * loss.item()
        optimizer.step()
    return loss_all / len(train_dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0

    for data in loader:
        data = data.to(device)
        pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1]
        correct += pred.eq(data.y.view(-1)).sum().item()
    return correct / len(loader.dataset)


best_val_acc = test_acc = 0
for epoch in range(1, 151):
    train_loss = train(epoch)
    val_acc = test(val_loader)
    if val_acc > best_val_acc:
        test_acc = test(test_loader)
        best_val_acc = val_acc
    print('Epoch: {:03d}, Train Loss: {:.7f}, '
          'Val Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
                                                     val_acc, test_acc))