In [None]:
# Generate embeddings from ESM-2 models

import torch
import pandas as pd
from transformers import EsmTokenizer, EsmModel
from tqdm import tqdm
import os


file_name = 'esm-2_embeddings.pickle.zip'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if not os.path.exists(file_name):
    df = pd.read_pickle('../../data/all_with_candidates.pickle')
    for model_name in ["facebook/esm2_t30_150M_UR50D", "facebook/esm2_t12_35M_UR50D"]:
        tokenizer = EsmTokenizer.from_pretrained(model_name)
        model = EsmModel.from_pretrained(model_name)
        model.eval()
        model = model.to(device)
        
        def embed_sequence(seq):
            inputs = tokenizer(seq, return_tensors="pt", padding=True,
        					   truncation=True, add_special_tokens=True, max_length=1024)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = model(**inputs)
            token_embeddings = outputs.last_hidden_state.squeeze(0)
            embeddings = token_embeddings[1:-1].mean(dim=0)
            return embeddings.cpu().numpy()
        
        tqdm.pandas()
        num_params = model_name.split('esm2_')[1].split('_UR50D')[0]
        df[f'embedding_{num_params}'] = df['seq'].progress_apply(embed_sequence)
    
    df.to_pickle(file_name)

else:
    df = pd.read_pickle(file_name)

In [None]:
# Hyperparameter optimization of frozen backbone model -> Effectively only training linear classification head on embeddings

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import optuna
from sklearn.metrics import (
    roc_auc_score, precision_recall_curve, f1_score, accuracy_score,
    precision_score, average_precision_score, recall_score
)
from tqdm import tqdm
from pytorch_lightning import seed_everything
import warnings
import os

pd.set_option('display.max_columns', 200)


def get_best_f1_threshold(y_true, probs):
    precision, recall, thresholds = precision_recall_curve(y_true, probs)
    f1s = 2 * (precision * recall) / (precision + recall + 1e-8)
    best_idx = np.argmax(f1s)
    return thresholds[best_idx]


class LinearModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return self.linear(x).squeeze(-1)


def train_model(lr=1e-3, max_epochs=5000, patience=25, pos_weight=1.0, embedding_col='embedding_35M'):
    X_train = np.stack(df_train[embedding_col].values)
    y_train = df_train['label'].astype(int).values
    X_val = np.stack(df_val[embedding_col].values)
    y_val = df_val['label'].astype(int).values
    X_test = np.stack(df_test[embedding_col].values)
    y_test = df_test['label'].astype(int).values
    X_all = np.stack(df[embedding_col].values)
    
    input_dim = X_train.shape[1]

    model = LinearModel(input_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([pos_weight]).to(device))

    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32).to(device)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_tensor = torch.tensor(y_val, dtype=torch.float32).to(device)

    best_auc = -np.inf
    best_model_state = None
    patience_counter = 0

    for epoch in range(max_epochs):
        model.train()
        optimizer.zero_grad()
        logits = model(X_train_tensor)
        loss = loss_fn(logits, y_train_tensor)
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            val_logits = model(X_val_tensor)
            val_probs = torch.sigmoid(val_logits).cpu().numpy()
            val_auc = roc_auc_score(y_val, val_probs)

        if val_auc > best_auc:
            best_auc = val_auc
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                break

    model.load_state_dict(best_model_state)
    return model, best_auc


def objective(trial, study):
    lr = trial.suggest_float('lr', 1e-6, 1e-2, log=True)
    pos_weight = 10**(trial.suggest_float('pos_weight', 0.0, 1.0, step=0.5))
    embedding_col = trial.suggest_categorical('embedding_col', ['embedding_35M', 'embedding_150M'])
    model, val_auc = train_model(lr=lr, max_epochs=5000, patience=25, pos_weight=pos_weight, embedding_col=embedding_col)

    # Compare with GLOBAL best val_auc, save ckpt if better
    if val_auc > best_val_auc[0]:
        best_val_auc[0] = val_auc
        os.makedirs('saved_models_linear_optuna', exist_ok=True)
        save_path = f'saved_models_linear_optuna/best_model_split_{i}.pt'
        torch.save({
            'model_state_dict': model.state_dict(),
            'lr': lr,
            'val_auc': val_auc,
            'pos_weight': pos_weight,
            'embedding_col': embedding_col,
        }, save_path)
        
    return val_auc

In [None]:
seed_everything(42)
warnings.filterwarnings("ignore", category=UserWarning)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


df = pd.read_pickle('esm-2_embeddings.pickle.zip')
df_test = df[df['group_split_0'] == 'test'].copy()

results = []

for i in tqdm(range(5)):
    df_train = df[df[f'group_split_{i}'] == 'train']
    df_val = df[df[f'group_split_{i}'] == 'val']

    if not os.path.exists(f'saved_models_frozen_optuna/best_model_split_{i}.pt'):
        # Perform HPO if not ran before
        study = optuna.create_study(storage=f'sqlite:///optuna_study_split_{i}.db', load_if_exists=False, study_name=f'linear_split_{i}',
                                    sampler=optuna.samplers.TPESampler(seed=42), direction='maximize')
        
        best_val_auc = [-np.inf]  # use list or other mutuable for global access inside the objective method
        study.optimize(lambda trial: objective(trial, study), n_trials=50, n_jobs=1)

    ####################### Best model after HPO ########################
    ckpt = torch.load(f'saved_models_frozen_optuna/best_model_split_{i}.pt', map_location=device, weights_only=False)
    embedding_col = ckpt['embedding_col']
    X_train = np.stack(df_train[embedding_col].values)
    y_train = df_train['label'].astype(int).values
    X_val = np.stack(df_val[embedding_col].values)
    y_val = df_val['label'].astype(int).values
    X_test = np.stack(df_test[embedding_col].values)
    y_test = df_test['label'].astype(int).values
    X_all = np.stack(df[embedding_col].values)
    
    input_dim = X_train.shape[1]

    model = LinearModel(input_dim).to(device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()

    with torch.no_grad():
        val_probs = torch.sigmoid(model(torch.tensor(X_val, dtype=torch.float32).to(device))).cpu().numpy()
        test_probs = torch.sigmoid(model(torch.tensor(X_test, dtype=torch.float32).to(device))).cpu().numpy()
        all_probs = torch.sigmoid(model(torch.tensor(X_all, dtype=torch.float32).to(device))).cpu().numpy()
    best_thresh = get_best_f1_threshold(y_val, val_probs)

    df[f'linreg_{i}'] = all_probs
    df[f'linreg_{i}_pred_label'] = all_probs > best_thresh

    result = {
        'split': i,
        'best_thresh': best_thresh,
        'AUROC': roc_auc_score(y_test, test_probs),
        'APS': average_precision_score(y_test, test_probs),
        'F1-score': f1_score(y_test, test_probs > best_thresh),
        'Precision': precision_score(y_test, test_probs > best_thresh),
        'Recall': recall_score(y_test, test_probs > best_thresh),
    }

    df_test['tmp'] = test_probs
    for K in [50, 100, 200]:
        top_k = df_test.nlargest(K, 'tmp')
        result[f'Precision@{K}'] = precision_score(top_k['label'] == 1.0, np.ones_like(top_k['tmp']))

    results.append(result)

results = pd.DataFrame(results)

In [None]:
# Aggregated individual models performance
mean_std = results.agg(['mean', 'std']).T
mean_std_formatted = mean_std.apply(lambda row: f"{row['mean']:.2f}±{row['std']:.2f}", axis=1)
result_df = pd.DataFrame(mean_std_formatted, columns=['Mean ± Std'])

result_df.T

In [None]:
# Hyperparameters
hp = []
for i in range(5):
    ckpt = torch.load(f'saved_models_frozen_optuna/best_model_split_{i}.pt', map_location=device, weights_only=False)
    hp.append({'split': i, 'lr': ckpt['lr'], 'pos_weight': ckpt['pos_weight'], 'emb': ckpt['embedding_col']})
pd.DataFrame(hp)

In [None]:
# Ensemble performance
df[f'linreg_ensemble'] = df[[f'linreg_{i}' for i in range(5)]].mean(1)

df_test = df[df['group_split_0'] == 'test'].copy()
df_trainval = df[df[f'group_split_{i}'].isin(['train', 'val'])]
y_trainval = df_trainval['label'].astype(int).values
best_thresh = get_best_f1_threshold(y_trainval, df_trainval[f'linreg_ensemble'])
test_probs = df_test['linreg_ensemble']
y_test = df_test['label'].astype(int).values

result = {
    'split': 'ensemble',
    'AUROC': roc_auc_score(y_test, test_probs),
    'APS': average_precision_score(y_test, test_probs),
    'F1-score': f1_score(y_test, test_probs > best_thresh),
    'Precision': precision_score(y_test, test_probs > best_thresh),
    'Recall': recall_score(y_test, test_probs > best_thresh),
}

for K in [50, 100, 200]:
    top_k = df_test.nlargest(K, 'linreg_ensemble')
    result[f'Precision@{K}'] = precision_score(top_k['label'] == 1.0, np.ones_like(top_k['linreg_ensemble']))

result = pd.DataFrame([result])
result.round(2)