In [None]:
import os
import random
import h5py
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_curve, auc, precision_recall_curve, average_precision_score
)

plt.style.use('fivethirtyeight')

## Helper functions

In [None]:
def plot_losses(losses, val_losses):
    fig = plt.figure(figsize=(10, 4))
    plt.plot(losses, label='Training Loss', c='b')
    plt.plot(val_losses, label='Validation Loss', c='r')
    plt.yscale('log')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.tight_layout()
    return fig


def calculate_metrics(y_true, y_pred):
    y_pred = (y_pred >= 0.5).astype(np.int32)
    acc = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='binary')
    recall = recall_score(y_true, y_pred, average='binary')
    f1 = f1_score(y_true, y_pred, average='binary')
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    pr_auc = average_precision_score(y_true, y_pred)
    return acc, precision, recall, f1, roc_auc, pr_auc


def pot_metrics(y_true, y_pred):
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)

    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    pr_auc = average_precision_score(y_true, y_pred)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    ax1.plot(fpr, tpr, color='darkorange', lw=2,
             label=f'ROC curve (AUC = {roc_auc:.2f})')
    ax1.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--',
             label='Random classifier')
    ax1.set_xlim([0.0, 1.0])
    ax1.set_ylim([0.0, 1.05])
    ax1.set_xlabel('False Positive Rate')
    ax1.set_ylabel('True Positive Rate')
    ax1.set_title('Receiver Operating Characteristic (ROC) Curve')
    ax1.legend(loc="lower right")
    ax1.grid(True)

    ax2.plot(recall, precision, color='blue', lw=2,
             label=f'PR curve (AUC = {pr_auc:.2f})')
    positive_ratio = np.sum(y_true) / len(y_true)
    ax2.axhline(y=positive_ratio, color='r', linestyle='--',
                label='Random classifier')
    ax2.set_xlim([0.0, 1.0])
    ax2.set_ylim([0.0, 1.05])
    ax2.set_xlabel('Recall')
    ax2.set_ylabel('Precision')
    ax2.set_title('Precision-Recall Curve')
    ax2.legend(loc="upper right")
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

    print(f"ROC AUC: {roc_auc:.4f}")
    print(f"PR AUC: {pr_auc:.4f}")


def train_model(model, train_loader, val_loader, optimizer, epochs, device):
    losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        batch_losses = []
        for x1, x2, y in train_loader:
            x1 = x1.to(device)
            x2 = x2.to(device)
            y = y.to(device)
            yhat = model(x1, x2)
            batch_loss = F.binary_cross_entropy(yhat, y)
            batch_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            batch_losses.append(batch_loss.item())

        loss = np.mean(batch_losses)
        losses.append(loss)

        model.eval()
        with torch.no_grad():
            batch_losses = []
            n = 0
            accuracy_average = 0
            for x1, x2, y in val_loader:
                x1 = x1.to(device)
                x2 = x2.to(device)
                y = y.to(device)
                yhat = model(x1, x2)
                batch_loss = F.binary_cross_entropy(yhat, y)
                batch_losses.append(batch_loss.item())
                accuracy = accuracy_score(y.cpu(), (yhat.cpu() > 0.5).int())
                accuracy_average = (accuracy_average * n + accuracy * x1.shape[
                    0]) / (n + x1.shape[0])
                n = n + x1.shape[0]

            val_loss = np.mean(batch_losses)
            val_losses.append(val_loss)

        print(f"Epoch: {epoch + 1} -- loss: {loss:.4f}, val_loss: {val_loss:.4f}, accuracy: {accuracy_average:.4f}")

    # plot_losses(losses, val_losses)


def eval_model(model, val_loader, device):
    model.eval()
    with torch.no_grad():
        y_true = []
        y_pred = []

        for x1, x2, y in val_loader:
            x1 = x1.to(device)
            x2 = x2.to(device)
            y = y.to(device)
            yhat = model(x1, x2)
            y_true.append(y.cpu().numpy())
            y_pred.append(yhat.cpu().numpy())

        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)

        acc, precision, recall, f1, roc_auc, pr_auc = calculate_metrics(y_true, y_pred)
        print(f"Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, ROC AUC: {roc_auc:.4f}, PR AUC: {pr_auc:.4f}")
        # pot_metrics(y_true, y_pred)
    return acc, precision, recall, f1, roc_auc, pr_auc


def save_model(model, model_path):
    model.to("cpu")
    torch.save(model.state_dict(), model_path)


class PairedDataset(Dataset):
    def __init__(self, ids1, ids2, labels, embedding_h5):
        self.ids1 = ids1
        self.ids2 = ids2
        self.labels = labels
        self.embed_data = {}

        ids = set(ids1).union(set(ids2))
        with h5py.File(embedding_h5, "r") as h5fin:
            for id in ids:
                self.embed_data[id] = h5fin[id][:, :]

    def __len__(self):
        return len(self.ids1)

    def __getitem__(self, i):
        x1 = self.loader(self.ids1[i])
        x2 = self.loader(self.ids2[i])
        return x1, x2, torch.as_tensor(self.labels[i]).float()

    def loader(self, id, max_len=1500):
        embedding = self.embed_data[id]
        seq_len = embedding.shape[0]
        seq_dim = embedding.shape[1]
        if seq_len > max_len:
            x = embedding[:max_len]
        elif seq_len < max_len:
            x = np.concatenate(
                (embedding, np.zeros((max_len - seq_len, seq_dim))))

        x = torch.from_numpy(x).float()

        return x


def load_data(data_file, batch_size, embedding_h5, train=True):
    df = pd.read_csv(data_file, sep="\t", header=None)
    dataset = PairedDataset(df[0].to_list(), df[1].to_list(),
                            df[2].to_list(), embedding_h5)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=train
    )

    return loader


def set_seed(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## Data loader

In [None]:
spe = "yeast1"

batch_size = 128
lr = 0.001
set_seed(1234)

## local test

# data_dir = "ppi-data"
# kfold_dir = "kfold_20"
# epochs = 10

## google colab test

from google.colab import drive
drive.mount('/content/drive')
data_dir = "drive/MyDrive/ppi-data"
kfold_dir = "kfold"
epochs = 30

device = "cuda" if torch.cuda.is_available() else "cpu"

embedding_h5 = os.path.join(data_dir, spe, "embedding.h5")

## MyPPI

In [None]:
class Featuring(nn.Module):
    def __init__(self, seq_len, input_dim):
        super().__init__()

        self.lin1 = nn.Linear(seq_len * input_dim, 2048)
        self.bn1 = nn.BatchNorm1d(2048)
        self.lin2 = nn.Linear(2048, 1024)
        self.bn2 = nn.BatchNorm1d(1024)
        self.lin3 = nn.Linear(1024, 512)
        self.bn3 = nn.BatchNorm1d(512)
        self.lin4 = nn.Linear(512, 128)
        self.bn4 = nn.BatchNorm1d(128)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.lin1(x)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.lin2(x)
        x = self.relu(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.lin3(x)
        x = self.relu(x)
        x = self.bn3(x)
        x = self.dropout(x)
        x = self.lin4(x)
        x = self.relu(x)
        x = self.bn4(x)
        return x


class Classifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(128, 8)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(8)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(8, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.bn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = x.squeeze(-1)
        x = F.sigmoid(x)
        return x
    

class InteractionModel(nn.Module):
    def __init__(self, featuring, classifier):
        super().__init__()

        self.featuring = featuring
        self.classifier = classifier

    def forward(self, x1, x2):
        x1 = torch.flatten(x1, start_dim=1)
        x2 = torch.flatten(x2, start_dim=1)
        x1 = self.featuring(x1)
        x2 = self.featuring(x2)
        x = x1 * x2
        x = self.classifier(x)
        return x


acc_list = []
precision_list = []
recall_list = []
f1_list = []
roc_auc_list = []
pr_auc_list = []

for k in range(5):
    print(f"Kfold: ======================== {k+1} ========================")
    train_file = os.path.join(data_dir, spe, kfold_dir, f"train_fold_{k+1}.tsv")
    val_file = os.path.join(data_dir, spe, kfold_dir, f"val_fold_{k+1}.tsv")
    train_loader = load_data(train_file, batch_size, embedding_h5, train=True)
    val_loader = load_data(val_file, batch_size, embedding_h5, train=False)
    
    seq_len = 1500
    input_dim = 13
    featuring = Featuring(seq_len, input_dim)
    classifier = Classifier()
    model = InteractionModel(featuring, classifier).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_model(model, train_loader, val_loader, optimizer, epochs, device)
    acc, precision, recall, f1, roc_auc, pr_auc = eval_model(model, val_loader, device)
    acc_list.append(acc)
    precision_list.append(precision)
    recall_list.append(recall)
    f1_list.append(f1)
    roc_auc_list.append(roc_auc)
    pr_auc_list.append(pr_auc)

print(f"Accuracy: {np.mean(acc_list)*100:.2f}±{np.std(acc_list)*100:.2f}")
print(f"Precision: {np.mean(precision_list)*100:.2f}±{np.std(precision_list)*100:.2f}")
print(f"Recall: {np.mean(recall_list)*100:.2f}±{np.std(recall_list)*100:.2f}")
print(f"F1: {np.mean(f1_list)*100:.2f}±{np.std(f1_list)*100:.2f}")
print(f"ROC AUC: {np.mean(roc_auc_list)*100:.2f}±{np.std(roc_auc_list)*100:.2f}")
print(f"PR AUC: {np.mean(pr_auc_list)*100:.2f}±{np.std(pr_auc_list)*100:.2f}")

## PIPR

In [None]:
class ProteinInteractionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()

        self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1)
        self.gru1 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True,
                           batch_first=True)
        self.conv2 = nn.Conv1d(3 * hidden_dim, hidden_dim, kernel_size=3,
                               padding=1)
        self.gru2 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True,
                           batch_first=True)
        self.conv3 = nn.Conv1d(3 * hidden_dim, hidden_dim, kernel_size=3,
                               padding=1)
        self.gru3 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True,
                           batch_first=True)
        self.conv4 = nn.Conv1d(3 * hidden_dim, hidden_dim, kernel_size=3,
                               padding=1)
        self.gru4 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True,
                           batch_first=True)
        self.conv5 = nn.Conv1d(3 * hidden_dim, hidden_dim, kernel_size=3,
                               padding=1)
        self.gru5 = nn.GRU(hidden_dim, hidden_dim, bidirectional=True,
                           batch_first=True)
        self.conv6 = nn.Conv1d(3 * hidden_dim, hidden_dim, kernel_size=3,
                               padding=1)

        self.pool = nn.MaxPool1d(kernel_size=3, stride=3)
        self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
        self.leaky_relu = nn.LeakyReLU(0.3)

        self.fc1 = nn.Linear(hidden_dim, 100)
        self.fc2 = nn.Linear(100, (hidden_dim + 7) // 2)
        self.fc3 = nn.Linear((hidden_dim + 7) // 2, 1)

        self.dropout = nn.Dropout(0.5)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(100)
        self.bn3 = nn.BatchNorm1d((hidden_dim + 7) // 2)

    def process_sequence(self, x):
        # x (b, input_dim, 1500)
        x = self.conv1(x)  # b, hidden_dim, 1500,
        x = self.pool(x)  # b, hidden_dim, 500,
        x = x.permute(0, 2, 1)  # b, 500, hidden_dim
        gru_out, _ = self.gru1(x)  # b, 500, hidden_dim
        x = x.permute(0, 2, 1)  # b, hidden_dim, 500
        gru_out = gru_out.permute(0, 2, 1)  # b, hidden_dim, 500
        x = torch.cat([gru_out, x], dim=1)  # b, 3*hidden_dim, 500

        x = self.conv2(x)
        x = self.pool(x)
        x = x.permute(0, 2, 1)
        gru_out, _ = self.gru2(x)
        x = x.permute(0, 2, 1)
        gru_out = gru_out.permute(0, 2, 1)
        x = torch.cat([gru_out, x], dim=1)  # b, 3*hidden_dim, 166

        x = self.conv3(x)
        x = self.pool(x)
        x = x.permute(0, 2, 1)
        gru_out, _ = self.gru3(x)
        x = x.permute(0, 2, 1)
        gru_out = gru_out.permute(0, 2, 1)
        x = torch.cat([gru_out, x], dim=1)  # b, 3*hidden_dim, 55

        x = self.conv4(x)
        x = self.pool(x)
        x = x.permute(0, 2, 1)
        gru_out, _ = self.gru4(x)
        x = x.permute(0, 2, 1)
        gru_out = gru_out.permute(0, 2, 1)
        x = torch.cat([gru_out, x], dim=1)  # b, 3*hidden_dim, 18

        x = self.conv5(x)
        x = self.pool(x)
        x = x.permute(0, 2, 1)
        gru_out, _ = self.gru5(x)
        x = x.permute(0, 2, 1)
        gru_out = gru_out.permute(0, 2, 1)
        x = torch.cat([gru_out, x], dim=1)  # b, 3*hidden_dim, 6

        x = self.conv6(x)  # b, hidden_dim, 6

        x = self.adaptive_pool(x).squeeze(-1)  # b, hidden_dim
        return x

    def forward(self, x1, x2):
        x1 = x1.permute(0, 2, 1)
        x1 = self.process_sequence(x1)

        x2 = x2.permute(0, 2, 1)
        x2 = self.process_sequence(x2)

        merged = x1 * x2

        x = self.fc1(merged)
        x = self.bn2(x)
        x = self.leaky_relu(x)
        x = self.dropout(x)

        x = self.fc2(x)
        x = self.bn3(x)
        x = self.leaky_relu(x)
        x = self.dropout(x)

        x = self.fc3(x)
        x = torch.flatten(x)
        x = F.sigmoid(x)
        return x


acc_list = []
precision_list = []
recall_list = []
f1_list = []
roc_auc_list = []
pr_auc_list = []

for k in range(5):
    print(f"Kfold: ======================== {k+1} ========================")
    train_file = os.path.join(data_dir, spe, kfold_dir, f"train_fold_{k+1}.tsv")
    val_file = os.path.join(data_dir, spe, kfold_dir, f"val_fold_{k+1}.tsv")
    train_loader = load_data(train_file, batch_size, embedding_h5, train=True)
    val_loader = load_data(val_file, batch_size, embedding_h5, train=False)
    
    model = ProteinInteractionModel(input_dim=13, hidden_dim=26).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_model(model, train_loader, val_loader, optimizer, epochs, device)
    acc, precision, recall, f1, roc_auc, pr_auc = eval_model(model, val_loader, device)
    acc_list.append(acc)
    precision_list.append(precision)
    recall_list.append(recall)
    f1_list.append(f1)
    roc_auc_list.append(roc_auc)
    pr_auc_list.append(pr_auc)

print(f"Accuracy: {np.mean(acc_list)*100:.2f}±{np.std(acc_list)*100:.2f}")
print(f"Precision: {np.mean(precision_list)*100:.2f}±{np.std(precision_list)*100:.2f}")
print(f"Recall: {np.mean(recall_list)*100:.2f}±{np.std(recall_list)*100:.2f}")
print(f"F1: {np.mean(f1_list)*100:.2f}±{np.std(f1_list)*100:.2f}")
print(f"ROC AUC: {np.mean(roc_auc_list)*100:.2f}±{np.std(roc_auc_list)*100:.2f}")
print(f"PR AUC: {np.mean(pr_auc_list)*100:.2f}±{np.std(pr_auc_list)*100:.2f}")

## DeepFE-PPI

In [None]:
class MergedDBN(nn.Module):
    def __init__(self, seq_len=1500, input_dim=13):
        super().__init__()

        self.left_branch = nn.Sequential(
            nn.Linear(seq_len * input_dim, 2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Dropout(0.5),

            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.5),

            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.5),

            nn.Linear(512, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128)
        )

        self.right_branch = nn.Sequential(
            nn.Linear(seq_len * input_dim, 2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Dropout(0.5),

            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.5),

            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.5),

            nn.Linear(512, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128)
        )

        # Combined classifier
        self.classifier = nn.Sequential(
            nn.Linear(256, 8),
            nn.ReLU(),
            nn.BatchNorm1d(8),
            nn.Dropout(0.5),
            nn.Linear(8, 1)
        )

        # L2 regularization will be handled in the optimizer
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x1, x2):
        x1 = torch.flatten(x1, start_dim=1)
        x2 = torch.flatten(x2, start_dim=1)
        x1 = self.left_branch(x1)
        x2 = self.right_branch(x2)
        x = torch.cat((x1, x2), dim=1)
        x = self.classifier(x)
        x = F.sigmoid(x)
        return x.squeeze()


acc_list = []
precision_list = []
recall_list = []
f1_list = []
roc_auc_list = []
pr_auc_list = []

for k in range(5):
    print(f"Kfold: ======================== {k+1} ========================")
    train_file = os.path.join(data_dir, spe, kfold_dir, f"train_fold_{k+1}.tsv")
    val_file = os.path.join(data_dir, spe, kfold_dir, f"val_fold_{k+1}.tsv")
    train_loader = load_data(train_file, batch_size, embedding_h5, train=True)
    val_loader = load_data(val_file, batch_size, embedding_h5, train=False)
    
    seq_len = 1500
    input_dim = 13
    model = MergedDBN(seq_len, input_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_model(model, train_loader, val_loader, optimizer, epochs, device)
    acc, precision, recall, f1, roc_auc, pr_auc = eval_model(model, val_loader, device)
    acc_list.append(acc)
    precision_list.append(precision)
    recall_list.append(recall)
    f1_list.append(f1)
    roc_auc_list.append(roc_auc)
    pr_auc_list.append(pr_auc)

print(f"Accuracy: {np.mean(acc_list)*100:.2f}±{np.std(acc_list)*100:.2f}")
print(f"Precision: {np.mean(precision_list)*100:.2f}±{np.std(precision_list)*100:.2f}")
print(f"Recall: {np.mean(recall_list)*100:.2f}±{np.std(recall_list)*100:.2f}")
print(f"F1: {np.mean(f1_list)*100:.2f}±{np.std(f1_list)*100:.2f}")
print(f"ROC AUC: {np.mean(roc_auc_list)*100:.2f}±{np.std(roc_auc_list)*100:.2f}")
print(f"PR AUC: {np.mean(pr_auc_list)*100:.2f}±{np.std(pr_auc_list)*100:.2f}")

## DeepTrio

In [None]:
class DeepTrio(nn.Module):
    def __init__(self, em_dim=13, kernel_rate_1=0.16, strides_rate_1=0.15,
                 filter_num_1=150, kernel_rate_2=0.14, strides_rate_2=0.25,
                 filter_num_2=175, con_drop=0.05, fn_drop_1=0.2, fn_drop_2=0.1,
                 node_num=256):
        super(DeepTrio, self).__init__()

        # Create convolution layers for different kernel sizes
        self.conv_layers = nn.ModuleList()
        for n in range(2, 35):
            if n <= 15:
                kernel_size = int(np.ceil(kernel_rate_1 * n ** 2))
                stride = int(np.ceil(strides_rate_1 * (n - 1)))
                conv_layer = nn.Conv1d(
                    in_channels=em_dim,
                    out_channels=filter_num_1,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=0,
                    bias=False
                )
            else:
                kernel_size = int(np.ceil(kernel_rate_2 * n ** 2))
                stride = int(np.ceil(strides_rate_2 * (n - 1)))
                conv_layer = nn.Conv1d(
                    in_channels=em_dim,
                    out_channels=filter_num_2,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=0,
                    bias=False
                )
            self.conv_layers.append(conv_layer)

        self.conv_dropout = nn.Dropout2d(con_drop)
        self.fc_dropout1 = nn.Dropout(fn_drop_1)
        self.fc_dropout2 = nn.Dropout(fn_drop_2)

        # Calculate the total number of features after concatenation
        total_features = 0
        for n in range(2, 35):
            if n <= 15:
                total_features += filter_num_1
            else:
                total_features += filter_num_2

        self.fc1 = nn.Linear(total_features, node_num)
        self.fc2 = nn.Linear(node_num, 1)

    def forward(self, x1, x2):

        # Permute for Conv1d: (batch, channels, seq_len)
        x1 = x1.permute(0, 2, 1)
        x2 = x2.permute(0, 2, 1)

        tensor = []

        for i, conv_layer in enumerate(self.conv_layers):
            # Apply convolution
            conv_out_1 = F.relu(conv_layer(x1))
            conv_out_2 = F.relu(conv_layer(x2))

            # Apply dropout
            conv_out_1 = self.conv_dropout(conv_out_1.unsqueeze(-1)).squeeze(-1)
            conv_out_2 = self.conv_dropout(conv_out_2.unsqueeze(-1)).squeeze(-1)

            # Apply max pooling
            pool_out_1 = F.max_pool1d(conv_out_1, conv_out_1.size(-1))
            pool_out_2 = F.max_pool1d(conv_out_2, conv_out_2.size(-1))

            # Flatten
            flat_out_1 = pool_out_1.view(pool_out_1.size(0), -1)
            flat_out_2 = pool_out_2.view(pool_out_2.size(0), -1)

            pool_out = flat_out_1 + flat_out_2

            tensor.append(pool_out)

        # Concatenate all features
        concatenated = torch.cat(tensor, dim=1)

        # Fully connected layers
        x = self.fc_dropout1(concatenated)
        x = self.fc1(x)
        x = self.fc_dropout2(x)
        x = F.relu(x)
        x = self.fc2(x)

        return F.sigmoid(x).squeeze()


acc_list = []
precision_list = []
recall_list = []
f1_list = []
roc_auc_list = []
pr_auc_list = []

epochs = 15

for k in range(5):
    print(f"Kfold: ======================== {k+1} ========================")
    train_file = os.path.join(data_dir, spe, kfold_dir, f"train_fold_{k+1}.tsv")
    val_file = os.path.join(data_dir, spe, kfold_dir, f"val_fold_{k+1}.tsv")
    train_loader = load_data(train_file, batch_size, embedding_h5, train=True)
    val_loader = load_data(val_file, batch_size, embedding_h5, train=False)

    model = DeepTrio().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_model(model, train_loader, val_loader, optimizer, epochs, device)
    acc, precision, recall, f1, roc_auc, pr_auc = eval_model(model, val_loader, device)
    acc_list.append(acc)
    precision_list.append(precision)
    recall_list.append(recall)
    f1_list.append(f1)
    roc_auc_list.append(roc_auc)
    pr_auc_list.append(pr_auc)

print(f"Accuracy: {np.mean(acc_list)*100:.2f}±{np.std(acc_list)*100:.2f}")
print(f"Precision: {np.mean(precision_list)*100:.2f}±{np.std(precision_list)*100:.2f}")
print(f"Recall: {np.mean(recall_list)*100:.2f}±{np.std(recall_list)*100:.2f}")
print(f"F1: {np.mean(f1_list)*100:.2f}±{np.std(f1_list)*100:.2f}")
print(f"ROC AUC: {np.mean(roc_auc_list)*100:.2f}±{np.std(roc_auc_list)*100:.2f}")
print(f"PR AUC: {np.mean(pr_auc_list)*100:.2f}±{np.std(pr_auc_list)*100:.2f}")

In [None]:
# acc_list = [0.9334, 0.9097, 0.9321, 0.9437, 0.7868]
# precision_list = [0.9291, 0.9479, 0.9500,0.9482, 0.7244]
# recall_list = [0.9367, 0.8662, 0.9127, 0.9398, 0.9278]
# f1_list = [0.9329, 0.9053, 0.9310, 0.9440, 0.8136]
# roc_auc_list = [0.9335, 0.9095, 0.9322, 0.9437, 0.7863]
# pr_auc_list = [0.9016, 0.8877, 0.9108, 0.9215, 0.7083]

# print(f"Accuracy: {np.mean(acc_list)*100:.2f}±{np.std(acc_list)*100:.2f}")
# print(f"Precision: {np.mean(precision_list)*100:.2f}±{np.std(precision_list)*100:.2f}")
# print(f"Recall: {np.mean(recall_list)*100:.2f}±{np.std(recall_list)*100:.2f}")
# print(f"F1: {np.mean(f1_list)*100:.2f}±{np.std(f1_list)*100:.2f}")
# print(f"ROC AUC: {np.mean(roc_auc_list)*100:.2f}±{np.std(roc_auc_list)*100:.2f}")
# print(f"PR AUC: {np.mean(pr_auc_list)*100:.2f}±{np.std(pr_auc_list)*100:.2f}")
