In [1]:
# Prediction with truncation at LENGTH_THRESHOLD

LENGTH_THRESHOLD = 2048

import os
import argparse

import torch
import pandas as pd
from tqdm import tqdm

import sys
sys.path.append('../../')

from dataloader import ProteinDataset
from transfactor_model import TransFactor
from baseline_model import CNNLSTMPredictor


pd.set_option('display.max_columns', 200)


out_dir = 'results_longer_context'
if not os.path.exists(f'transfactor_len_{LENGTH_THRESHOLD}.pickle.zip'):
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', type=str, default='../saved_models/transfactor/transfactor_0.ckpt')  # for argparse to work in jupyter notebook
    parser.add_argument('--ckpt_path', type=str, default='../saved_models/transfactor/transfactor_0.ckpt')
    parser.add_argument('--model_type', type=str, default='esm')
    parser.add_argument('--device', type=str, default=None)
    parser.add_argument('--data_path', type=str, default='../../data/all_with_candidates.pickle')
    parser.add_argument('--out_dir', type=str, default='results_longer_context')
    parser.add_argument('--subsample', type=int, default=None)
    args = parser.parse_args()
    
    df = pd.read_pickle(args.data_path)
    # Replace candidate with 0.5 to make column numeric
    df['label'] = df['label'].replace('candidate', 0.5).astype(float)
        
    # Non candidate proteins
    df = df[df['group_split_0'].notna()]
    
    if args.device is None:
    	device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
    	device = args.device
    
    for i in range(5):
        args.ckpt_path = f'../saved_models/transfactor/transfactor_{i}.ckpt'
        model_name = os.path.basename(args.ckpt_path).replace('.ckpt', '')
        if args.model_type == 'esm':
            model = TransFactor.load_from_checkpoint(args.ckpt_path)
        elif args.model_type == 'cnn_lstm':
            model = CNNLSTMPredictor.load_from_checkpoint(args.ckpt_path)
        
        model.eval()
        model = model.to(device)
        
        data_config = model.data_config.copy()
        data_config['max_seq_len'] = LENGTH_THRESHOLD
        dataset = ProteinDataset(df=df, config=data_config, split=None, tokenizer=model.tokenizer)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False)
        
        y_hats = []
        with torch.no_grad():
            for batch in tqdm(dataloader):
                batch = [item.to(device) for item in batch]
                seq, attention_mask, y = batch
                y_hat_logits = model(seq, attention_mask)
                y_hat = model.sigmoid(y_hat_logits)
                y_hats.append(y_hat.detach().cpu())
            y_hat = torch.cat(y_hats)
        
        df[model_name+f"_{LENGTH_THRESHOLD}"] = y_hat
        
        os.makedirs(args.out_dir, exist_ok=True)
    
    # Add previous predictions on 1024 length sequences.
    df_raw = pd.read_csv('../benchmark/prediction_values_all_models.csv.zip')
    df = df.merge(df_raw[['protein_ac', 'transfactor_0', 'transfactor_1', 'transfactor_2', 'transfactor_3', 'transfactor_4']], 'left', on='protein_ac')
    df = df.rename(columns={'transfactor_0': 'transfactor_1024_0', 'transfactor_1': 'transfactor_1024_1', 'transfactor_2': 'transfactor_1024_2', 'transfactor_3': 'transfactor_1024_3', 'transfactor_4': 'transfactor_1024_4'})
    df['len'] = df['seq'].str.len()
    df['transfactor_ensemble_1024'] = df[[f'transfactor_1024_{i}' for i in range(5)]].mean(1)
    df[f'transfactor_ensemble_{LENGTH_THRESHOLD}'] = df[[f'transfactor_{LENGTH_THRESHOLD}_{i}' for i in range(5)]].mean(1)

    df.to_pickle(f'transfactor_len_{LENGTH_THRESHOLD}.pickle.zip')
else:
    df = pd.read_pickle(f'transfactor_len_{LENGTH_THRESHOLD}.pickle.zip')

In [2]:
# Performance metrics

from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score, accuracy_score, precision_score, average_precision_score, recall_score, precision_recall_curve
import numpy as np

df['label'] = df['label'].astype(str)
predictions = df
Ks = [50, 100, 200]
results = []

for modality in ['transfactor_1024', 'transfactor_2048']:
    for i in list(range(5)) + ['ensemble']:
        config = f'{modality}_{i}'
        result = {'config': config,
              'split': i,
              'ensemble': i == 'ensemble',
              'modality': modality,
             }
        # Without candidates
        prediction_subset = predictions[predictions['label'].isin(['1.0', '0.0'])]
        test_subset = prediction_subset[prediction_subset['group_split_0'] == 'test'].copy()
        result.update({
                  f'auc': roc_auc_score(test_subset['label'] == '1.0', test_subset[config]),
                  f'aps': average_precision_score(test_subset['label'] == '1.0', test_subset[config]),
                 })
        
        for K in Ks:
            test_subset_top_k = test_subset.loc[test_subset[config].nlargest(K).index]
            result.update({f'P@{K}': precision_score(test_subset_top_k['label'] == '1.0', np.ones_like(test_subset_top_k[config]))})

        # ideal threshold
        if i == 'ensemble':
            val_subset = prediction_subset[prediction_subset[f'group_split_0'].isin(['train', 'val'])].copy()
        else:
            val_subset = prediction_subset[prediction_subset[f'group_split_{i}'] == 'val'].copy()
        precision, recall, thresholds = precision_recall_curve(val_subset['label'] == '1.0', val_subset[f'{modality}_ensemble'])
        f1_scores = 2*precision*recall / (precision+recall)
        best_threshold = thresholds[f1_scores.argmax()]

        result.update({
                  f'precision': precision_score(test_subset['label'] == '1.0', test_subset[config] > best_threshold, ),
                  f'recall': recall_score(test_subset['label'] == '1.0', test_subset[config] > best_threshold, ),
                  f'f1': f1_score(test_subset['label'] == '1.0', test_subset[config] > best_threshold, ),
                 })
        results.append(result)

results = pd.DataFrame(results)

results_ensemble = results[results['ensemble']].copy()
results_ensemble = results_ensemble[['config'] + [f'{metric}' for metric in ['auc', 'aps', 'f1', 'precision', 'recall']] + [f'P@{K}' for K in Ks]]
display(results_ensemble.round(2))

cols = [f'{metric}' for metric in ['auc', 'aps', 'precision', 'recall', 'f1']]
cols += [f'P@{K}' for K in [50, 100, 200]]
results_without_ensemble = results[~results['ensemble']].copy()
results_without_ensemble = results_without_ensemble.melt(id_vars=['modality'], value_vars=cols)

display(results_without_ensemble.groupby(['modality', 'variable']).describe()['value'][['mean', 'std']].apply(lambda x: f"{x['mean']:.2f}±{x['std']:.2f}", axis=1).to_frame().reset_index().rename(columns={0: 'value'})\
.pivot(index='modality', columns='variable', values='value')[['auc', 'aps', 'f1', 'precision', 'recall'] + [f'P@{K}' for K in Ks]])

Unnamed: 0,config,auc,aps,f1,precision,recall,P@50,P@100,P@200
5,transfactor_1024_ensemble,0.89,0.3,0.38,0.34,0.44,0.44,0.37,0.28
11,transfactor_2048_ensemble,0.89,0.31,0.39,0.34,0.47,0.4,0.41,0.3


variable,auc,aps,f1,precision,recall,P@50,P@100,P@200
modality,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
transfactor_1024,0.87±0.01,0.27±0.01,0.25±0.12,0.40±0.34,0.39±0.27,0.37±0.03,0.32±0.02,0.26±0.01
transfactor_2048,0.88±0.01,0.28±0.03,0.28±0.14,0.31±0.12,0.44±0.27,0.38±0.03,0.32±0.02,0.27±0.03
