In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from utils import set_seed, load_data
from utils import train_model, eval_model

In [None]:
class MergedDBN(nn.Module):
    def __init__(self, seq_len=1500):
        super().__init__()

        self.left_branch = nn.Sequential(
            nn.Linear(seq_len, 2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Dropout(0.5),

            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.5),

            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.5),

            nn.Linear(512, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128)
        )

        self.right_branch = nn.Sequential(
            nn.Linear(seq_len, 2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Dropout(0.5),

            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.5),

            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.5),

            nn.Linear(512, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128)
        )

        # Combined classifier
        self.classifier = nn.Sequential(
            nn.Linear(256, 8),  # 128 * 2 from both branches
            nn.ReLU(),
            nn.BatchNorm1d(8),
            nn.Dropout(0.5),
            nn.Linear(8, 1)
        )

        # L2 regularization will be handled in the optimizer
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x1, x2):
        x1 = torch.flatten(x1, start_dim=1)
        x2 = torch.flatten(x2, start_dim=1)
        x1 = self.left_branch(x1)
        x2 = self.right_branch(x2)
        x = torch.cat((x1, x2), dim=1)
        x = self.classifier(x)
        x = F.sigmoid(x)
        return x.squeeze()

In [None]:
spe = "yeast"

# data_dir = "ppi-data"
# train_file = os.path.join(data_dir, spe, "action/train_action_20.tsv")
# val_file = os.path.join(data_dir, spe, "action/val_action_10.tsv")
# test_file = os.path.join(data_dir, spe, "action/test_action_10.tsv")
# epochs = 10

from google.colab import drive

drive.mount('/content/drive')
data_dir = "drive/MyDrive/ppi-data"
train_file = os.path.join(data_dir, spe, "action/train_action.tsv")
val_file = os.path.join(data_dir, spe, "action/val_action.tsv")
test_file = os.path.join(data_dir, spe, "action/test_action.tsv")
epochs = 50

embedding_h5 = os.path.join(data_dir, spe, "seq/pipr.embedding.h5")

input_dim = 13
seq_len = 1500
batch_size = 32
lr = 0.001

set_seed(1234)

device = "cuda" if torch.cuda.is_available() else "cpu"

train_loader = load_data(train_file, batch_size, embedding_h5, train=True)
val_loader = load_data(val_file, batch_size, embedding_h5, train=False)
test_loader = load_data(test_file, batch_size, embedding_h5, train=False)

In [None]:
model = MergedDBN(input_dim * seq_len)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
train_model(model, train_loader, val_loader, optimizer, epochs, device)

In [None]:
eval_model(model, test_loader, device)