In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from utils import set_seed, load_data
from utils import train_model, eval_model

In [None]:
class ProteinInteractionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()

        self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1)
        self.gru1 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True,
                           batch_first=True)
        self.conv2 = nn.Conv1d(3 * hidden_dim, hidden_dim, kernel_size=3,
                               padding=1)
        self.gru2 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True,
                           batch_first=True)
        self.conv3 = nn.Conv1d(3 * hidden_dim, hidden_dim, kernel_size=3,
                               padding=1)
        self.gru3 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True,
                           batch_first=True)
        self.conv4 = nn.Conv1d(3 * hidden_dim, hidden_dim, kernel_size=3,
                               padding=1)
        self.gru4 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True,
                           batch_first=True)
        self.conv5 = nn.Conv1d(3 * hidden_dim, hidden_dim, kernel_size=3,
                               padding=1)
        self.gru5 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True,
                           batch_first=True)
        self.conv6 = nn.Conv1d(3 * hidden_dim, hidden_dim, kernel_size=3,
                               padding=1)

        self.pool = nn.MaxPool1d(kernel_size=3, stride=3)
        self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
        self.leaky_relu = nn.LeakyReLU(0.3)

        self.fc1 = nn.Linear(hidden_dim, 100)
        self.fc2 = nn.Linear(100, (hidden_dim + 7) // 2)
        self.fc3 = nn.Linear((hidden_dim + 7) // 2, 1)

        self.dropout = nn.Dropout(0.5)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(100)
        self.bn3 = nn.BatchNorm1d((hidden_dim + 7) // 2)

    def process_sequence(self, x):
        # x (b, input_dim, 1500)
        x = self.conv1(x)  # b, hidden_dim, 1500,
        x = self.pool(x)  # b, hidden_dim, 500,
        x = x.permute(0, 2, 1)  # b, 500, hidden_dim
        gru_out, _ = self.gru1(x)  # b, 500, hidden_dim
        x = x.permute(0, 2, 1)  # b, hidden_dim, 500
        gru_out = gru_out.permute(0, 2, 1)  # b, hidden_dim, 500
        x = torch.cat([gru_out, x], dim=1)  # b, 3*hidden_dim, 500

        x = self.conv2(x)
        x = self.pool(x)
        x = x.permute(0, 2, 1)
        gru_out, _ = self.gru2(x)
        x = x.permute(0, 2, 1)
        gru_out = gru_out.permute(0, 2, 1)
        x = torch.cat([gru_out, x], dim=1)  # b, 3*hidden_dim, 166

        x = self.conv3(x)
        x = self.pool(x)
        x = x.permute(0, 2, 1)
        gru_out, _ = self.gru3(x)
        x = x.permute(0, 2, 1)
        gru_out = gru_out.permute(0, 2, 1)
        x = torch.cat([gru_out, x], dim=1)  # b, 3*hidden_dim, 55

        x = self.conv4(x)
        x = self.pool(x)
        x = x.permute(0, 2, 1)
        gru_out, _ = self.gru4(x)
        x = x.permute(0, 2, 1)
        gru_out = gru_out.permute(0, 2, 1)
        x = torch.cat([gru_out, x], dim=1)  # b, 3*hidden_dim, 18

        x = self.conv5(x)
        x = self.pool(x)
        x = x.permute(0, 2, 1)
        gru_out, _ = self.gru5(x)
        x = x.permute(0, 2, 1)
        gru_out = gru_out.permute(0, 2, 1)
        x = torch.cat([gru_out, x], dim=1)  # b, 3*hidden_dim, 6

        x = self.conv6(x)  # b, hidden_dim, 6

        x = self.adaptive_pool(x).squeeze(-1)  # b, hidden_dim
        return x

    def forward(self, x1, x2):
        x1 = x1.permute(0, 2, 1)
        x1 = self.process_sequence(x1)

        x2 = x2.permute(0, 2, 1)
        x2 = self.process_sequence(x2)

        merged = x1 * x2

        x = self.fc1(merged)
        x = self.bn2(x)
        x = self.leaky_relu(x)
        x = self.dropout(x)

        x = self.fc2(x)
        x = self.bn3(x)
        x = self.leaky_relu(x)
        x = self.dropout(x)

        x = self.fc3(x)
        x = torch.flatten(x)
        x = F.sigmoid(x)
        return x


In [None]:

spe = "yeast"

# data_dir = "ppi-data"
# train_file = os.path.join(data_dir, spe, "action/train_action_20.tsv")
# val_file = os.path.join(data_dir, spe, "action/val_action_10.tsv")
# test_file = os.path.join(data_dir, spe, "action/test_action_10.tsv")

# epochs = 10

from google.colab import drive

drive.mount('/content/drive')
data_dir = "drive/MyDrive/ppi-data"
train_file = os.path.join(data_dir, spe, "action/train_action.tsv")
val_file = os.path.join(data_dir, spe, "action/val_action.tsv")
test_file = os.path.join(data_dir, spe, "action/test_action.tsv")

epochs = 50

embedding_h5 = os.path.join(data_dir, spe, "seq/pipr.embedding.h5")

input_dim = 13
hidden_dim = 26
batch_size = 32
lr = 0.001

set_seed(1234)

device = "cuda" if torch.cuda.is_available() else "cpu"

train_loader = load_data(train_file, batch_size, embedding_h5, train=True)
val_loader = load_data(val_file, batch_size, embedding_h5, train=False)
test_loader = load_data(test_file, batch_size, embedding_h5, train=False)


In [None]:
model = ProteinInteractionModel(input_dim, hidden_dim)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
train_model(model, train_loader, val_loader, optimizer, epochs, device)

eval_model(model, test_loader, device)