In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, random_split, DataLoader
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MultilabelF1Score
from sklearn.model_selection import KFold, StratifiedKFold
import warnings
import copy

warnings.filterwarnings("ignore")
plt.style.use('ggplot')

MAIN_DIR = "/workspace/data"

In [2]:
class Config:
    """Configuration class for paths and hyperparameters."""
    train_sequences_path = MAIN_DIR + "/Train/train_sequences.fasta"
    train_labels_path = MAIN_DIR + "/Train/train_terms.tsv"
    test_sequences_path = MAIN_DIR + "/Test/testsuperset.fasta"
    
    num_labels = 500
    n_epochs = 20
    batch_size = 128
    lr = 0.01
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


embeds_map = {
    "T5": "t5embeds",
    "ProtBERT": "protbert-embeddings-for-cafa5",
    "EMS2": "cafa-5-ems-2-embeddings-numpy"
}

embeds_dim = {
    "T5": 1024,
    "ProtBERT": 1024,
    "EMS2": 1280
}

In [3]:
class ProteinSequenceDataset(Dataset):
    """Dataset class for protein sequences with embeddings."""
    
    def __init__(self, datatype, embeddings_source):
        super(ProteinSequenceDataset, self).__init__()
        self.datatype = datatype
        
        if embeddings_source in ["ProtBERT", "EMS2"]:
            embeds = np.load(f"/workspace/{embeds_map[embeddings_source]}/{datatype}_embeddings.npy")
            ids = np.load(f"/workspace/{embeds_map[embeddings_source]}/{datatype}_ids.npy")
        
        if embeddings_source == "T5":
            embeds = np.load(f"/workspace/{embeds_map[embeddings_source]}/{datatype}_embeds.npy")
            ids = np.load(f"/workspace/{embeds_map[embeddings_source]}/{datatype}_ids.npy")
            
        embeds_list = [embeds[l, :] for l in range(embeds.shape[0])]
        self.df = pd.DataFrame({"EntryID": ids, "embed": embeds_list})
        
        if datatype == "train":
            np_labels = np.load(f"/workspace/train_targets_top{Config.num_labels}.npy")
            df_labels = pd.DataFrame(self.df['EntryID'])
            df_labels['labels_vect'] = list(np_labels)
            self.df = self.df.merge(df_labels, on="EntryID")
            
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        embed = torch.tensor(self.df.iloc[index]["embed"], dtype=torch.float32)
        if self.datatype == "train":
            targets = torch.tensor(self.df.iloc[index]["labels_vect"], dtype=torch.float32)
            return embed, targets
        if self.datatype == "test":
            protein_id = self.df.iloc[index]["EntryID"]
            return embed, protein_id

In [4]:
class MultiLayerPerceptron(nn.Module):
    """Multi-layer Perceptron model."""
    
    def __init__(self, input_dim, num_classes):
        super(MultiLayerPerceptron, self).__init__()
        self.linear1 = nn.Linear(input_dim, 864)
        self.activation1 = nn.ReLU()
        self.linear2 = nn.Linear(864, 712)
        self.activation2 = nn.ReLU()
        self.linear3 = nn.Linear(712, num_classes)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        x = self.linear3(x)
        return x


class CNN1D(nn.Module):
    """1D Convolutional Neural Network model."""
    
    def __init__(self, input_dim, num_classes):
        super(CNN1D, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3, dilation=1, padding=1, stride=1)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=8, kernel_size=3, dilation=1, padding=1, stride=1)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(int(8 * input_dim / 4), 864)
        self.fc2 = nn.Linear(864, num_classes)
    
    def forward(self, x):
        x = x.reshape(x.shape[0], 1, x.shape[1])
        x = self.pool1(nn.functional.tanh(self.conv1(x)))
        x = self.pool2(nn.functional.tanh(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = nn.functional.tanh(self.fc1(x))
        x = self.fc2(x)
        return x

In [5]:
def train_model(embeddings_source, model_type="linear", n_folds=5):
    """Train models using Stratified K-Fold cross-validation with label count approximation."""
    train_dataset = ProteinSequenceDataset(datatype="train", embeddings_source=embeddings_source)
    
    # Extract labels for stratification
    labels = np.stack(train_dataset.df['labels_vect'].values)
    stratify_y = np.sum(labels, axis=1)  # Use sum of labels as stratification criterion for multi-label
    
    kf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
    
    models = []
    best_scores = []
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(np.arange(len(train_dataset)), stratify_y)):
        train_set = torch.utils.data.Subset(train_dataset, train_idx)
        val_set = torch.utils.data.Subset(train_dataset, val_idx)
        
        train_dataloader = DataLoader(train_set, batch_size=Config.batch_size, shuffle=True)
        val_dataloader = DataLoader(val_set, batch_size=Config.batch_size, shuffle=True)
        
        if model_type == "linear":
            model = MultiLayerPerceptron(input_dim=embeds_dim[embeddings_source], num_classes=Config.num_labels).to(Config.device)
        elif model_type == "convolutional":
            model = CNN1D(input_dim=embeds_dim[embeddings_source], num_classes=Config.num_labels).to(Config.device)
        
        optimizer = torch.optim.Adam(model.parameters(), lr=Config.lr)
        scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=1, threshold=0.0001)
        criterion = nn.CrossEntropyLoss()
        f1_metric = MultilabelF1Score(num_labels=Config.num_labels).to(Config.device)
        
        best_val_score = -float('inf')
        best_model = None
        best_epoch = 0
                
        for epoch in range(Config.n_epochs):
            # TRAINING
            model.train()
            losses, scores = [], []
            for embeds, targets in tqdm(train_dataloader):
                embeds, targets = embeds.to(Config.device), targets.to(Config.device)

                optimizer.zero_grad()
                preds = model(embeds)
                loss = criterion(preds, targets)
                score = f1_metric(preds, targets)
                losses.append(loss.item())
                scores.append(score.item())
                loss.backward()

                optimizer.step()

            train_loss = np.mean(losses)
            train_score = np.mean(scores)

            # VALIDATION
            model.eval()
            losses, scores = [], []
            with torch.no_grad():
                for embeds, targets in val_dataloader:
                    embeds, targets = embeds.to(Config.device), targets.to(Config.device)
                    preds = model(embeds)
                    loss = criterion(preds, targets)
                    score = f1_metric(preds, targets)
                    losses.append(loss.item())
                    scores.append(score.item())

            val_loss = np.mean(losses)
            val_score = np.mean(scores)
            
            if val_score >= best_val_score:
                best_val_score = val_score
                best_model = copy.deepcopy(model)
                best_epoch = epoch
            scheduler.step(val_loss)

            print(f"FOLD [{fold + 1}/{n_folds}] EPOCH [{epoch + 1:02d}] TRAIN LOSS [{train_loss:.6f}] TRAIN SCORE [{train_score:.6f}] VAL LOSS [{val_loss:.6f}] VAL SCORE [{val_score:.6f}]")
        
        models.append(best_model)
        best_scores.append(best_val_score)
        print(f"FOLD [{fold + 1}] BEST EPOCH [{best_epoch + 1:02d}] BEST VALIDATION SCORE [{best_val_score:.6f}]")

        print("=" * 100)
    
    print(f"BEST SCORES: {best_scores}")
    
    return models

protbert_models = train_model("ProtBERT")

100%|██████████| 890/890 [00:03<00:00, 232.05it/s]


FOLD [1] EPOCH [01] TRAIN LOSS [138.721262] TRAIN SCORE [0.092438] VAL LOSS [138.162761] VAL SCORE [0.107106]


100%|██████████| 890/890 [00:03<00:00, 241.61it/s]


FOLD [1] EPOCH [02] TRAIN LOSS [136.675289] TRAIN SCORE [0.120655] VAL LOSS [136.479465] VAL SCORE [0.127398]


100%|██████████| 890/890 [00:03<00:00, 241.47it/s]


FOLD [1] EPOCH [03] TRAIN LOSS [136.027707] TRAIN SCORE [0.130170] VAL LOSS [136.236096] VAL SCORE [0.131175]


100%|██████████| 890/890 [00:03<00:00, 237.67it/s]


FOLD [1] EPOCH [04] TRAIN LOSS [135.758086] TRAIN SCORE [0.135730] VAL LOSS [136.062424] VAL SCORE [0.132483]


100%|██████████| 890/890 [00:03<00:00, 230.43it/s]


FOLD [1] EPOCH [05] TRAIN LOSS [135.283213] TRAIN SCORE [0.139872] VAL LOSS [135.804636] VAL SCORE [0.141575]


100%|██████████| 890/890 [00:03<00:00, 235.86it/s]


FOLD [1] EPOCH [06] TRAIN LOSS [135.169333] TRAIN SCORE [0.141201] VAL LOSS [135.111769] VAL SCORE [0.138377]


100%|██████████| 890/890 [00:03<00:00, 225.09it/s]


FOLD [1] EPOCH [07] TRAIN LOSS [135.107695] TRAIN SCORE [0.143423] VAL LOSS [135.586807] VAL SCORE [0.137962]


100%|██████████| 890/890 [00:03<00:00, 230.30it/s]


FOLD [1] EPOCH [08] TRAIN LOSS [135.092189] TRAIN SCORE [0.144372] VAL LOSS [135.007604] VAL SCORE [0.145433]


100%|██████████| 890/890 [00:03<00:00, 226.89it/s]


FOLD [1] EPOCH [09] TRAIN LOSS [135.039800] TRAIN SCORE [0.145519] VAL LOSS [135.700342] VAL SCORE [0.137570]


100%|██████████| 890/890 [00:03<00:00, 232.27it/s]


FOLD [1] EPOCH [10] TRAIN LOSS [134.909015] TRAIN SCORE [0.145939] VAL LOSS [134.585486] VAL SCORE [0.148482]


100%|██████████| 890/890 [00:03<00:00, 233.92it/s]


FOLD [1] EPOCH [11] TRAIN LOSS [135.004413] TRAIN SCORE [0.147352] VAL LOSS [135.236526] VAL SCORE [0.140616]


100%|██████████| 890/890 [00:03<00:00, 228.26it/s]


FOLD [1] EPOCH [12] TRAIN LOSS [134.720687] TRAIN SCORE [0.147097] VAL LOSS [134.770017] VAL SCORE [0.148068]


100%|██████████| 890/890 [00:03<00:00, 239.01it/s]


FOLD [1] EPOCH [13] TRAIN LOSS [133.780786] TRAIN SCORE [0.155346] VAL LOSS [134.393165] VAL SCORE [0.153335]


100%|██████████| 890/890 [00:03<00:00, 229.32it/s]


FOLD [1] EPOCH [14] TRAIN LOSS [133.836869] TRAIN SCORE [0.156728] VAL LOSS [134.144459] VAL SCORE [0.154996]


100%|██████████| 890/890 [00:03<00:00, 232.15it/s]


FOLD [1] EPOCH [15] TRAIN LOSS [133.673773] TRAIN SCORE [0.157393] VAL LOSS [134.074704] VAL SCORE [0.155196]


100%|██████████| 890/890 [00:03<00:00, 229.18it/s]


FOLD [1] EPOCH [16] TRAIN LOSS [133.605824] TRAIN SCORE [0.158001] VAL LOSS [133.944339] VAL SCORE [0.155512]


100%|██████████| 890/890 [00:03<00:00, 230.98it/s]


FOLD [1] EPOCH [17] TRAIN LOSS [133.738787] TRAIN SCORE [0.158423] VAL LOSS [133.986152] VAL SCORE [0.156422]


100%|██████████| 890/890 [00:03<00:00, 240.67it/s]


FOLD [1] EPOCH [18] TRAIN LOSS [133.553370] TRAIN SCORE [0.158521] VAL LOSS [134.109763] VAL SCORE [0.156443]


100%|██████████| 890/890 [00:03<00:00, 235.26it/s]


FOLD [1] EPOCH [19] TRAIN LOSS [133.505510] TRAIN SCORE [0.159004] VAL LOSS [133.930391] VAL SCORE [0.156558]


100%|██████████| 890/890 [00:03<00:00, 232.55it/s]


FOLD [1] EPOCH [20] TRAIN LOSS [133.503844] TRAIN SCORE [0.159492] VAL LOSS [133.845264] VAL SCORE [0.156555]
FOLD [1] BEST VALIDATION SCORE [0.156558]


100%|██████████| 890/890 [00:03<00:00, 231.38it/s]


FOLD [2] EPOCH [01] TRAIN LOSS [138.506616] TRAIN SCORE [0.100064] VAL LOSS [137.927351] VAL SCORE [0.113112]


100%|██████████| 890/890 [00:03<00:00, 241.89it/s]


FOLD [2] EPOCH [02] TRAIN LOSS [136.409456] TRAIN SCORE [0.126316] VAL LOSS [136.216753] VAL SCORE [0.124445]


100%|██████████| 890/890 [00:03<00:00, 240.10it/s]


FOLD [2] EPOCH [03] TRAIN LOSS [135.871766] TRAIN SCORE [0.135369] VAL LOSS [135.734998] VAL SCORE [0.139365]


100%|██████████| 890/890 [00:03<00:00, 241.33it/s]


FOLD [2] EPOCH [04] TRAIN LOSS [135.364095] TRAIN SCORE [0.140977] VAL LOSS [135.779519] VAL SCORE [0.138728]


100%|██████████| 890/890 [00:03<00:00, 240.73it/s]


FOLD [2] EPOCH [05] TRAIN LOSS [135.253530] TRAIN SCORE [0.143684] VAL LOSS [135.526960] VAL SCORE [0.141046]


100%|██████████| 890/890 [00:03<00:00, 241.72it/s]


FOLD [2] EPOCH [06] TRAIN LOSS [135.337324] TRAIN SCORE [0.146232] VAL LOSS [136.590100] VAL SCORE [0.135097]


100%|██████████| 890/890 [00:03<00:00, 236.99it/s]


FOLD [2] EPOCH [07] TRAIN LOSS [135.113697] TRAIN SCORE [0.144739] VAL LOSS [135.714317] VAL SCORE [0.142346]


100%|██████████| 890/890 [00:03<00:00, 240.16it/s]


FOLD [2] EPOCH [08] TRAIN LOSS [133.808050] TRAIN SCORE [0.155989] VAL LOSS [134.142306] VAL SCORE [0.154719]


100%|██████████| 890/890 [00:03<00:00, 235.04it/s]


FOLD [2] EPOCH [09] TRAIN LOSS [133.662140] TRAIN SCORE [0.158609] VAL LOSS [134.125589] VAL SCORE [0.155342]


100%|██████████| 890/890 [00:03<00:00, 238.22it/s]


FOLD [2] EPOCH [10] TRAIN LOSS [133.672621] TRAIN SCORE [0.159609] VAL LOSS [134.356250] VAL SCORE [0.155062]


100%|██████████| 890/890 [00:03<00:00, 225.16it/s]


FOLD [2] EPOCH [11] TRAIN LOSS [133.451471] TRAIN SCORE [0.160144] VAL LOSS [134.291528] VAL SCORE [0.155904]


100%|██████████| 890/890 [00:03<00:00, 235.22it/s]


FOLD [2] EPOCH [12] TRAIN LOSS [133.338269] TRAIN SCORE [0.161635] VAL LOSS [133.953687] VAL SCORE [0.157203]


100%|██████████| 890/890 [00:03<00:00, 228.58it/s]


FOLD [2] EPOCH [13] TRAIN LOSS [133.306518] TRAIN SCORE [0.161709] VAL LOSS [133.986927] VAL SCORE [0.158031]


100%|██████████| 890/890 [00:03<00:00, 226.66it/s]


FOLD [2] EPOCH [14] TRAIN LOSS [133.469719] TRAIN SCORE [0.162374] VAL LOSS [133.919438] VAL SCORE [0.157174]


100%|██████████| 890/890 [00:04<00:00, 221.61it/s]


FOLD [2] EPOCH [15] TRAIN LOSS [133.243797] TRAIN SCORE [0.162260] VAL LOSS [133.891120] VAL SCORE [0.157287]


100%|██████████| 890/890 [00:03<00:00, 229.10it/s]


FOLD [2] EPOCH [16] TRAIN LOSS [133.371352] TRAIN SCORE [0.162385] VAL LOSS [133.803424] VAL SCORE [0.158298]


100%|██████████| 890/890 [00:03<00:00, 226.88it/s]


FOLD [2] EPOCH [17] TRAIN LOSS [133.266993] TRAIN SCORE [0.162390] VAL LOSS [133.951287] VAL SCORE [0.158170]


100%|██████████| 890/890 [00:03<00:00, 238.37it/s]


FOLD [2] EPOCH [18] TRAIN LOSS [133.261395] TRAIN SCORE [0.162524] VAL LOSS [133.888929] VAL SCORE [0.158056]


100%|██████████| 890/890 [00:03<00:00, 226.56it/s]


FOLD [2] EPOCH [19] TRAIN LOSS [133.312781] TRAIN SCORE [0.162811] VAL LOSS [133.968836] VAL SCORE [0.158754]


100%|██████████| 890/890 [00:03<00:00, 228.75it/s]


FOLD [2] EPOCH [20] TRAIN LOSS [133.312800] TRAIN SCORE [0.162721] VAL LOSS [133.919933] VAL SCORE [0.158286]
FOLD [2] BEST VALIDATION SCORE [0.158754]


100%|██████████| 890/890 [00:03<00:00, 236.07it/s]


FOLD [3] EPOCH [01] TRAIN LOSS [138.871625] TRAIN SCORE [0.093837] VAL LOSS [137.354096] VAL SCORE [0.112008]


100%|██████████| 890/890 [00:03<00:00, 229.28it/s]


FOLD [3] EPOCH [02] TRAIN LOSS [136.817691] TRAIN SCORE [0.122606] VAL LOSS [136.772512] VAL SCORE [0.123262]


100%|██████████| 890/890 [00:03<00:00, 237.17it/s]


FOLD [3] EPOCH [03] TRAIN LOSS [136.115410] TRAIN SCORE [0.129262] VAL LOSS [136.804587] VAL SCORE [0.118733]


100%|██████████| 890/890 [00:03<00:00, 231.36it/s]


FOLD [3] EPOCH [04] TRAIN LOSS [135.745803] TRAIN SCORE [0.134720] VAL LOSS [137.131027] VAL SCORE [0.127815]


100%|██████████| 890/890 [00:03<00:00, 232.73it/s]


FOLD [3] EPOCH [05] TRAIN LOSS [134.552214] TRAIN SCORE [0.145770] VAL LOSS [134.800392] VAL SCORE [0.144592]


100%|██████████| 890/890 [00:03<00:00, 235.90it/s]


FOLD [3] EPOCH [06] TRAIN LOSS [134.544118] TRAIN SCORE [0.148916] VAL LOSS [134.758030] VAL SCORE [0.145852]


100%|██████████| 890/890 [00:03<00:00, 226.32it/s]


FOLD [3] EPOCH [07] TRAIN LOSS [134.200495] TRAIN SCORE [0.150691] VAL LOSS [134.619622] VAL SCORE [0.144673]


100%|██████████| 890/890 [00:03<00:00, 232.28it/s]


FOLD [3] EPOCH [08] TRAIN LOSS [134.142843] TRAIN SCORE [0.151515] VAL LOSS [134.379384] VAL SCORE [0.148031]


100%|██████████| 890/890 [00:03<00:00, 228.16it/s]


FOLD [3] EPOCH [09] TRAIN LOSS [133.997741] TRAIN SCORE [0.152957] VAL LOSS [134.400100] VAL SCORE [0.148525]


100%|██████████| 890/890 [00:03<00:00, 228.90it/s]


FOLD [3] EPOCH [10] TRAIN LOSS [133.916702] TRAIN SCORE [0.153921] VAL LOSS [134.453349] VAL SCORE [0.149302]


100%|██████████| 890/890 [00:03<00:00, 226.49it/s]


FOLD [3] EPOCH [11] TRAIN LOSS [133.764775] TRAIN SCORE [0.155416] VAL LOSS [134.307338] VAL SCORE [0.150163]


100%|██████████| 890/890 [00:03<00:00, 226.11it/s]


FOLD [3] EPOCH [12] TRAIN LOSS [133.745181] TRAIN SCORE [0.155615] VAL LOSS [134.280209] VAL SCORE [0.150628]


100%|██████████| 890/890 [00:03<00:00, 226.84it/s]


FOLD [3] EPOCH [13] TRAIN LOSS [133.645165] TRAIN SCORE [0.156142] VAL LOSS [134.327252] VAL SCORE [0.150572]


100%|██████████| 890/890 [00:03<00:00, 228.68it/s]


FOLD [3] EPOCH [14] TRAIN LOSS [133.758355] TRAIN SCORE [0.156585] VAL LOSS [134.334596] VAL SCORE [0.150213]


100%|██████████| 890/890 [00:03<00:00, 228.17it/s]


FOLD [3] EPOCH [15] TRAIN LOSS [133.687885] TRAIN SCORE [0.156387] VAL LOSS [134.207352] VAL SCORE [0.151254]


100%|██████████| 890/890 [00:03<00:00, 235.93it/s]


FOLD [3] EPOCH [16] TRAIN LOSS [133.763168] TRAIN SCORE [0.156522] VAL LOSS [134.065083] VAL SCORE [0.150592]


100%|██████████| 890/890 [00:03<00:00, 231.63it/s]


FOLD [3] EPOCH [17] TRAIN LOSS [133.683580] TRAIN SCORE [0.156463] VAL LOSS [134.379958] VAL SCORE [0.151144]


100%|██████████| 890/890 [00:03<00:00, 226.80it/s]


FOLD [3] EPOCH [18] TRAIN LOSS [133.658650] TRAIN SCORE [0.156583] VAL LOSS [134.248132] VAL SCORE [0.150752]


100%|██████████| 890/890 [00:03<00:00, 229.64it/s]


FOLD [3] EPOCH [19] TRAIN LOSS [133.656567] TRAIN SCORE [0.156548] VAL LOSS [134.126497] VAL SCORE [0.150673]


100%|██████████| 890/890 [00:03<00:00, 228.25it/s]


FOLD [3] EPOCH [20] TRAIN LOSS [133.750725] TRAIN SCORE [0.156608] VAL LOSS [134.290027] VAL SCORE [0.150788]
FOLD [3] BEST VALIDATION SCORE [0.151254]


100%|██████████| 890/890 [00:03<00:00, 228.51it/s]


FOLD [4] EPOCH [01] TRAIN LOSS [138.476877] TRAIN SCORE [0.097698] VAL LOSS [136.909803] VAL SCORE [0.121729]


100%|██████████| 890/890 [00:03<00:00, 230.24it/s]


FOLD [4] EPOCH [02] TRAIN LOSS [136.346714] TRAIN SCORE [0.127615] VAL LOSS [138.011906] VAL SCORE [0.122532]


100%|██████████| 890/890 [00:03<00:00, 227.73it/s]


FOLD [4] EPOCH [03] TRAIN LOSS [135.857930] TRAIN SCORE [0.132559] VAL LOSS [135.297895] VAL SCORE [0.139898]


100%|██████████| 890/890 [00:03<00:00, 234.48it/s]


FOLD [4] EPOCH [04] TRAIN LOSS [135.230698] TRAIN SCORE [0.141607] VAL LOSS [136.772993] VAL SCORE [0.138586]


100%|██████████| 890/890 [00:03<00:00, 238.12it/s]


FOLD [4] EPOCH [05] TRAIN LOSS [135.069582] TRAIN SCORE [0.143752] VAL LOSS [135.382779] VAL SCORE [0.138560]


100%|██████████| 890/890 [00:03<00:00, 236.45it/s]


FOLD [4] EPOCH [06] TRAIN LOSS [133.907460] TRAIN SCORE [0.154191] VAL LOSS [133.887126] VAL SCORE [0.155304]


100%|██████████| 890/890 [00:03<00:00, 228.59it/s]


FOLD [4] EPOCH [07] TRAIN LOSS [133.731089] TRAIN SCORE [0.157014] VAL LOSS [133.856878] VAL SCORE [0.157118]


100%|██████████| 890/890 [00:03<00:00, 226.78it/s]


FOLD [4] EPOCH [08] TRAIN LOSS [133.570972] TRAIN SCORE [0.157959] VAL LOSS [133.641207] VAL SCORE [0.157565]


100%|██████████| 890/890 [00:03<00:00, 230.05it/s]


FOLD [4] EPOCH [09] TRAIN LOSS [133.569094] TRAIN SCORE [0.158796] VAL LOSS [133.790906] VAL SCORE [0.157563]


100%|██████████| 890/890 [00:03<00:00, 230.73it/s]


FOLD [4] EPOCH [10] TRAIN LOSS [133.420196] TRAIN SCORE [0.159362] VAL LOSS [133.620010] VAL SCORE [0.158831]


100%|██████████| 890/890 [00:03<00:00, 228.02it/s]


FOLD [4] EPOCH [11] TRAIN LOSS [133.412912] TRAIN SCORE [0.160445] VAL LOSS [133.643299] VAL SCORE [0.159276]


100%|██████████| 890/890 [00:04<00:00, 222.07it/s]


FOLD [4] EPOCH [12] TRAIN LOSS [133.497724] TRAIN SCORE [0.160983] VAL LOSS [133.557119] VAL SCORE [0.160373]


100%|██████████| 890/890 [00:03<00:00, 226.30it/s]


FOLD [4] EPOCH [13] TRAIN LOSS [133.252461] TRAIN SCORE [0.161377] VAL LOSS [133.456359] VAL SCORE [0.160637]


100%|██████████| 890/890 [00:03<00:00, 228.14it/s]


FOLD [4] EPOCH [14] TRAIN LOSS [133.321391] TRAIN SCORE [0.162034] VAL LOSS [133.674194] VAL SCORE [0.160965]


100%|██████████| 890/890 [00:03<00:00, 224.63it/s]


FOLD [4] EPOCH [15] TRAIN LOSS [133.228458] TRAIN SCORE [0.163198] VAL LOSS [133.426116] VAL SCORE [0.161275]


100%|██████████| 890/890 [00:04<00:00, 220.90it/s]


FOLD [4] EPOCH [16] TRAIN LOSS [133.128306] TRAIN SCORE [0.163282] VAL LOSS [133.522641] VAL SCORE [0.161724]


100%|██████████| 890/890 [00:03<00:00, 227.89it/s]


FOLD [4] EPOCH [17] TRAIN LOSS [133.169862] TRAIN SCORE [0.164241] VAL LOSS [133.620050] VAL SCORE [0.163005]


100%|██████████| 890/890 [00:03<00:00, 231.62it/s]


FOLD [4] EPOCH [18] TRAIN LOSS [132.828364] TRAIN SCORE [0.166191] VAL LOSS [133.441894] VAL SCORE [0.163486]


100%|██████████| 890/890 [00:03<00:00, 231.94it/s]


FOLD [4] EPOCH [19] TRAIN LOSS [132.819671] TRAIN SCORE [0.165757] VAL LOSS [133.307160] VAL SCORE [0.163795]


100%|██████████| 890/890 [00:03<00:00, 228.92it/s]


FOLD [4] EPOCH [20] TRAIN LOSS [133.082047] TRAIN SCORE [0.166395] VAL LOSS [133.250644] VAL SCORE [0.163600]
FOLD [4] BEST VALIDATION SCORE [0.163795]


100%|██████████| 890/890 [00:03<00:00, 231.16it/s]


FOLD [5] EPOCH [01] TRAIN LOSS [138.545555] TRAIN SCORE [0.096711] VAL LOSS [138.285748] VAL SCORE [0.117896]


100%|██████████| 890/890 [00:03<00:00, 233.67it/s]


FOLD [5] EPOCH [02] TRAIN LOSS [136.508250] TRAIN SCORE [0.125468] VAL LOSS [136.351164] VAL SCORE [0.129699]


100%|██████████| 890/890 [00:03<00:00, 235.86it/s]


FOLD [5] EPOCH [03] TRAIN LOSS [135.846989] TRAIN SCORE [0.135129] VAL LOSS [136.232476] VAL SCORE [0.136391]


100%|██████████| 890/890 [00:03<00:00, 227.06it/s]


FOLD [5] EPOCH [04] TRAIN LOSS [135.381464] TRAIN SCORE [0.138900] VAL LOSS [136.469865] VAL SCORE [0.132683]


100%|██████████| 890/890 [00:03<00:00, 226.93it/s]


FOLD [5] EPOCH [05] TRAIN LOSS [135.261848] TRAIN SCORE [0.142321] VAL LOSS [135.694708] VAL SCORE [0.142342]


100%|██████████| 890/890 [00:03<00:00, 224.50it/s]


FOLD [5] EPOCH [06] TRAIN LOSS [134.954262] TRAIN SCORE [0.144464] VAL LOSS [135.094635] VAL SCORE [0.147164]


100%|██████████| 890/890 [00:03<00:00, 228.12it/s]


FOLD [5] EPOCH [07] TRAIN LOSS [134.874433] TRAIN SCORE [0.145740] VAL LOSS [136.022490] VAL SCORE [0.135021]


100%|██████████| 890/890 [00:03<00:00, 240.08it/s]


FOLD [5] EPOCH [08] TRAIN LOSS [134.745013] TRAIN SCORE [0.146397] VAL LOSS [134.996411] VAL SCORE [0.147466]


100%|██████████| 890/890 [00:03<00:00, 229.72it/s]


FOLD [5] EPOCH [09] TRAIN LOSS [134.621923] TRAIN SCORE [0.147733] VAL LOSS [134.986535] VAL SCORE [0.145726]


100%|██████████| 890/890 [00:03<00:00, 237.14it/s]


FOLD [5] EPOCH [10] TRAIN LOSS [134.597000] TRAIN SCORE [0.148663] VAL LOSS [135.184306] VAL SCORE [0.149734]


100%|██████████| 890/890 [00:03<00:00, 226.85it/s]


FOLD [5] EPOCH [11] TRAIN LOSS [133.605186] TRAIN SCORE [0.156260] VAL LOSS [134.000122] VAL SCORE [0.155615]


100%|██████████| 890/890 [00:03<00:00, 227.44it/s]


FOLD [5] EPOCH [12] TRAIN LOSS [133.425593] TRAIN SCORE [0.157971] VAL LOSS [134.011844] VAL SCORE [0.156089]


100%|██████████| 890/890 [00:03<00:00, 225.04it/s]


FOLD [5] EPOCH [13] TRAIN LOSS [133.484816] TRAIN SCORE [0.158342] VAL LOSS [133.992345] VAL SCORE [0.155023]


100%|██████████| 890/890 [00:03<00:00, 227.57it/s]


FOLD [5] EPOCH [14] TRAIN LOSS [133.272434] TRAIN SCORE [0.159489] VAL LOSS [133.934649] VAL SCORE [0.157044]


100%|██████████| 890/890 [00:03<00:00, 224.94it/s]


FOLD [5] EPOCH [15] TRAIN LOSS [133.208258] TRAIN SCORE [0.159466] VAL LOSS [133.741711] VAL SCORE [0.156797]


100%|██████████| 890/890 [00:03<00:00, 226.36it/s]


FOLD [5] EPOCH [16] TRAIN LOSS [133.278642] TRAIN SCORE [0.159661] VAL LOSS [133.822960] VAL SCORE [0.156785]


100%|██████████| 890/890 [00:03<00:00, 236.16it/s]


FOLD [5] EPOCH [17] TRAIN LOSS [133.256029] TRAIN SCORE [0.159657] VAL LOSS [133.679817] VAL SCORE [0.157046]


100%|██████████| 890/890 [00:03<00:00, 238.36it/s]


FOLD [5] EPOCH [18] TRAIN LOSS [133.301143] TRAIN SCORE [0.159776] VAL LOSS [133.740755] VAL SCORE [0.156871]


100%|██████████| 890/890 [00:03<00:00, 235.39it/s]


FOLD [5] EPOCH [19] TRAIN LOSS [133.307326] TRAIN SCORE [0.159978] VAL LOSS [133.745292] VAL SCORE [0.157402]


100%|██████████| 890/890 [00:03<00:00, 232.33it/s]


FOLD [5] EPOCH [20] TRAIN LOSS [133.258022] TRAIN SCORE [0.159848] VAL LOSS [133.780795] VAL SCORE [0.157135]
FOLD [5] BEST VALIDATION SCORE [0.157402]


In [6]:
def predict(embeddings_source, models):
    """Generate predictions using ensemble of models."""
    test_dataset = ProteinSequenceDataset(datatype="test", embeddings_source=embeddings_source)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    for model in models:
        model.eval()
    
    labels = pd.read_csv(Config.train_labels_path, sep="\t")
    top_terms = labels.groupby("term")["EntryID"].count().sort_values(ascending=False)
    labels_names = top_terms[:Config.num_labels].index.values
    print("GENERATE PREDICTION FOR TEST SET...")
    
    protein_ids = np.empty(shape=(len(test_dataloader) * Config.num_labels,), dtype=object)
    go_terms = np.empty(shape=(len(test_dataloader) * Config.num_labels,), dtype=object)
    confidences = np.empty(shape=(len(test_dataloader) * Config.num_labels,), dtype=np.float32)
    
    for i, (embed, protein_id) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        embed = embed.to(Config.device)
        preds_list = []
        for model in models:
            with torch.no_grad():
                preds = nn.functional.sigmoid(model(embed)).squeeze().detach().cpu().numpy()
            preds_list.append(preds)
        avg_preds = np.mean(preds_list, axis=0)
        confidences[i * Config.num_labels:(i + 1) * Config.num_labels] = avg_preds
        protein_ids[i * Config.num_labels:(i + 1) * Config.num_labels] = protein_id[0]
        go_terms[i * Config.num_labels:(i + 1) * Config.num_labels] = labels_names
    
    submission_df = pd.DataFrame({"Id": protein_ids, "GO term": go_terms, "Confidence": confidences})
    print("PREDICTIONS DONE")
    return submission_df

submission_df = predict("ProtBERT", protbert_models)
submission_df.to_csv('submission.tsv', sep='\t', header=False, index=False)

GENERATE PREDICTION FOR TEST SET...


141865it [01:06, 2124.00it/s]


PREDICTIONS DONE


In [None]:
!kaggle competitions submit -c cafa-6-protein-function-prediction -f submission.tsv -m "CAFA6 ProtBERT Embeddings Stratified 5-Fold"

 20%|███████▉                               | 397M/1.91G [00:11<00:46, 35.5MB/s]