In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch_geometric import data as DATA
from torch_geometric.nn import GCNConv, global_max_pool

In [2]:
# now we have a new dataset, which is a list of GCN_DATA convert into dataloader
dataset = torch.load('preprocessed_dataset.pt')
train_dataset, test_dataset = train_test_split(
    dataset, test_size=0.2, random_state=42)
train_dataset, val_dataset = train_test_split(
    train_dataset, test_size=0.2, random_state=42)
train_loader = DATA.DataLoader(
    train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DATA.DataLoader(val_dataset,  shuffle=True, drop_last=True)
test_loader = DATA.DataLoader(test_dataset,  shuffle=True, drop_last=True)




In [3]:
def Model():
    class GCNNet(torch.nn.Module):
        def __init__(self, n_output=2, n_filters=32, embed_dim=128, num_features_xd=78, num_features_xt=25, output_dim=128, dropout=0.2):

            super(GCNNet, self).__init__()

            self.n_output = n_output
            self.conv1 = GCNConv(num_features_xd, num_features_xd)
            self.conv2 = GCNConv(num_features_xd, num_features_xd*2)
            self.conv3 = GCNConv(num_features_xd*2, num_features_xd * 4)
            self.fc_g1 = torch.nn.Linear(num_features_xd*4, 1024)
            self.fc_g2 = torch.nn.Linear(1024, output_dim)
            self.relu = nn.ReLU()
            self.dropout = nn.Dropout(dropout)

            self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim)
            self.conv_xt_1 = nn.Conv1d(
                in_channels=1000, out_channels=n_filters, kernel_size=8)
            self.fc1_xt = nn.Linear(32*121, output_dim)

            self.fc1 = nn.Linear(2*output_dim, 1024)
            self.fc2 = nn.Linear(1024, 512)
            self.out = nn.Linear(512, self.n_output)

        def forward(self, data):
            # get graph input
            x, edge_index, batch = data.x, data.edge_index, data.batch
            # get protein input
            target = data.target

            x = self.conv1(x, edge_index)
            x = self.relu(x)

            x = self.conv2(x, edge_index)
            x = self.relu(x)

            x = self.conv3(x, edge_index)
            x = self.relu(x)
            x = global_max_pool(x, batch)       # global max pooling

            # flatten
            x = self.relu(self.fc_g1(x))
            x = self.dropout(x)
            x = self.fc_g2(x)
            x = self.dropout(x)

            # 1d conv layers
            embedded_xt = self.embedding_xt(target)
            conv_xt = self.conv_xt_1(embedded_xt)
            # flatten
            xt = conv_xt.view(-1, 32 * 121)
            xt = self.fc1_xt(xt)

            # concat
            xc = torch.cat((x, xt), 1)
            # add some dense layers
            xc = self.fc1(xc)
            xc = self.relu(xc)
            xc = self.dropout(xc)
            xc = self.fc2(xc)
            xc = self.relu(xc)
            xc = self.dropout(xc)
            out = self.out(xc)
            return out
    model = GCNNet()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_function = nn.CrossEntropyLoss()
    return model, optimizer, loss_function


In [4]:
def TrainTheModel(num_epochs=20):
    model, optimizer, loss_function = Model()
    train_loss = torch.zeros(num_epochs)
    valid_loss = torch.zeros(num_epochs)
    train_acc = torch.zeros(num_epochs)
    valid_acc = torch.zeros(num_epochs)
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, data in enumerate(train_loader):
            optimizer.zero_grad()
            out = model(data)
            loss = loss_function(out, data.y)
            loss.backward()
            optimizer.step()
            train_loss[epoch] += loss.item()
            pred = out.max(1)[1]
            train_acc[epoch] += pred.eq(data.y).sum().item()
        train_loss[epoch] /= len(train_loader.dataset)
        train_acc[epoch] /= len(train_loader.dataset)

        model.eval()
        for data in val_loader:
            out = model(data)
            loss = loss_function(out, data.y)
            valid_loss[epoch] += loss.item()
            pred = out.max(1)[1]
            valid_acc[epoch] += pred.eq(data.y).sum().item()
        valid_loss[epoch] /= len(val_loader.dataset)
        valid_acc[epoch] /= len(val_loader.dataset)
        print('Epoch: {:03d}, Train Loss: {:.5f}, Train Acc: {:.5f}, Val Loss: {:.5f}, Val Acc: {:.5f}'.format(
            epoch, train_loss[epoch], train_acc[epoch], valid_loss[epoch], valid_acc[epoch]))
    return model, train_loss, valid_loss, train_acc, valid_acc


In [5]:
model, train_loss, valid_loss, train_acc, valid_acc = TrainTheModel()

Epoch: 000, Train Loss: 0.01320, Train Acc: 0.82120, Val Loss: 0.39597, Val Acc: 0.82164
Epoch: 001, Train Loss: 0.01107, Train Acc: 0.82425, Val Loss: 0.40523, Val Acc: 0.82164
Epoch: 002, Train Loss: 0.01173, Train Acc: 0.81733, Val Loss: 0.44951, Val Acc: 0.76662
Epoch: 003, Train Loss: 0.01133, Train Acc: 0.82564, Val Loss: 0.49699, Val Acc: 0.77880
Epoch: 004, Train Loss: 0.01063, Train Acc: 0.83386, Val Loss: 0.45005, Val Acc: 0.82312
Epoch: 005, Train Loss: 0.01013, Train Acc: 0.84226, Val Loss: 0.42727, Val Acc: 0.80355
Epoch: 006, Train Loss: 0.00979, Train Acc: 0.85593, Val Loss: 0.50662, Val Acc: 0.84232
Epoch: 007, Train Loss: 0.00928, Train Acc: 0.86950, Val Loss: 0.53000, Val Acc: 0.84564
Epoch: 008, Train Loss: 0.00917, Train Acc: 0.87107, Val Loss: 0.57366, Val Acc: 0.82755
Epoch: 009, Train Loss: 0.00893, Train Acc: 0.87551, Val Loss: 0.46740, Val Acc: 0.83936
Epoch: 010, Train Loss: 0.00859, Train Acc: 0.87726, Val Loss: 0.50132, Val Acc: 0.84343
Epoch: 011, Train Los

In [6]:
from sklearn.metrics import precision_score, recall_score

model.eval()
test_acc = 0
precision = 0
specificty = 0
for data in test_loader:
    out = model(data)
    pred = out.max(1)[1]
    test_acc += pred.eq(data.y).sum().item()
    precision += precision_score(data.y, pred)
    specificty += recall_score(data.y, pred)
precision /= len(test_loader.dataset)
specificty /= len(test_loader.dataset)
test_acc /= len(test_loader.dataset)
print('Precision: {:.5f}'.format(precision))
print('Specificty: {:.5f}'.format(specificty))
print('Test Accuracy: {:.5f}'.format(test_acc))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Precision: 0.80851
Specificty: 0.80851
Test Accuracy: 0.85875


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
