In [61]:
import h5py
import torch
from Bio.SeqIO.FastaIO import SimpleFastaParser
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import pandas as pd
import numpy as np

root='/media/johannes/Crucial SSD/PP2CS/'
seq_files = {
    'test': 'ec_vs_NOec_pide20_c50_test.fasta',
    'train': 'ec_vs_NOec_pide20_c50_train.fasta',
    'val': 'ec_vs_NOec_pide20_c50_val.fasta'
}
data_sets = ['train','test','val']
ec_annotation_file = 'merged_anno.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [78]:
class ProteinDataset(Dataset):
    def __init__(self, ds="train"):
        self.path = root + 'ec_vs_NOec_pide20_c50_' + ds + '.h5'
        with h5py.File(self.path, 'r') as f:
            self.key_map = dict()
            self.keys = list(f.keys())
            for i, key in enumerate(self.keys):
                self.key_map[i] = key
        with open(root + seq_files[ds]) as fasta:
            ids = []
            seqs = []
            lengths = []
            for title, seq in SimpleFastaParser(fasta):
                ids.append(title)
                seqs.append(seq)
                lengths.append(len(seq))
            data = {"sequence": seqs, "length": lengths}
            self.data_frame = pd.DataFrame(data=data,index=ids)
        annotations =  pd.read_csv(root + ec_annotation_file,sep="\\t",names=['index','EC'],index_col=0, engine="python")

        self.data_frame = self.data_frame.merge(annotations,how="left",left_index=True, right_index=True)
        self.data_frame['ec_or_nc'] = np.where(self.data_frame['EC'].isna(), 'NC', 'EC')
        pass

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, index):
        key = self.key_map[index]
        with h5py.File(self.path, 'r') as f:
            embeddings = torch.tensor(f[key][:], dtype=torch.float32).T
        if self.data_frame['ec_or_nc'][key] in 'EC':
            bin_label = torch.Tensor([1])
        else:
            bin_label = torch.Tensor([0])
        return [embeddings, bin_label]

trainset = ProteinDataset("train")
valset = ProteinDataset("val")
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=12)
valloader = torch.utils.data.DataLoader(valset, batch_size=1,
                                          shuffle=True, num_workers=12)

In [79]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv1d(1024, 32, 7)
        self.conv2 = nn.Conv1d(32, 1, 7)
        self.pool = nn.AdaptiveMaxPool1d(256)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3 = nn.Linear(32, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net().to(device)

In [90]:
import torch.optim as optim

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs,labels = inputs.to(device), labels.to(device)


        # forward + backward + optimize
        outputs = net(inputs)
        outputs.squeeze_(1)
        loss = criterion(outputs, labels)
        loss.backward()
        if i % 64 == 0:
            optimizer.step()
            optimizer.zero_grad()
        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

[1,  2000] loss: 0.271
[1,  4000] loss: 0.258
[1,  6000] loss: 0.246
[1,  8000] loss: 0.231
[1, 10000] loss: 0.243
[1, 12000] loss: 0.243
[1, 14000] loss: 0.244
[1, 16000] loss: 0.227
[1, 18000] loss: 0.233
[1, 20000] loss: 0.235
[1, 22000] loss: 0.224
[1, 24000] loss: 0.238
[1, 26000] loss: 0.208
[1, 28000] loss: 0.218
[1, 30000] loss: 0.229
[1, 32000] loss: 0.237
[2,  2000] loss: 0.201
[2,  4000] loss: 0.195
[2,  6000] loss: 0.233
[2,  8000] loss: 0.207
[2, 10000] loss: 0.217
[2, 12000] loss: 0.182
[2, 14000] loss: 0.214
[2, 16000] loss: 0.200
[2, 18000] loss: 0.214
[2, 20000] loss: 0.202
[2, 22000] loss: 0.227
[2, 24000] loss: 0.230
[2, 26000] loss: 0.219
[2, 28000] loss: 0.213
[2, 30000] loss: 0.219
[2, 32000] loss: 0.194
Finished Training


In [91]:
correct = 0
total = 0
y_pred = []
y_true = []
with torch.no_grad():
    for data in valloader:
        embeddings, labels = data
        embeddings,labels = embeddings.to(device), labels.to(device)
        outputs = net(embeddings)
        outputs = F.sigmoid(outputs)
        predicted = torch.round(outputs.data)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        y_pred.append(predicted.item())
        y_true.append(labels.item())

print('Accuracy of the network on the dev set: %d %%' % (
    100 * correct / total))

import sklearn.metrics

print(sklearn.metrics.classification_report(y_true, y_pred))
print(sklearn.metrics.confusion_matrix(y_true,y_pred))

Accuracy of the network on the dev set: 92 %
              precision    recall  f1-score   support

         0.0       0.95      0.96      0.95       419
         1.0       0.77      0.74      0.75        81

    accuracy                           0.92       500
   macro avg       0.86      0.85      0.85       500
weighted avg       0.92      0.92      0.92       500

[[401  18]
 [ 21  60]]
