# Setup

In [None]:
BENCHMARKS_DIR = '../'

In [None]:
def AUC_PRC(y_true,y_pred,savepath):
    with open(savepath+"y_true.txt","w") as f:
        f.write(str(y_true.tolist()))
    with open(savepath+"y_pred.txt","w") as f:
        f.write(str(y_pred.tolist()))

    fpr, tpr, _ = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    
    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    prc_auc = auc(recall, precision)

    y_pred=np.around(y_pred,0).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    
    acc = (tp+tn)/(tn + fp+ fn + tp)   
    sn = tp / (tp + fn)
    sp = tn / (tn + fp)
    mcc = matthews_corrcoef(y_true, y_pred)

    
    plt.figure(figsize=(12, 6))
    # ROC
    plt.subplot(1, 2, 1)
    plt.plot(fpr, tpr, color='orange', lw=2, label=f'ROC curve (AUROC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    # PRC 
    plt.subplot(1, 2, 2)
    plt.plot(recall, precision, color='orange', lw=2, label=f'PRC curve (AUPRC = {prc_auc:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc='lower left')

    plt.tight_layout()
    plt.savefig(savepath+"AUC_PRC.svg",dpi=600)
    
    with open(savepath+"results.txt","w") as f:
        f.write(f'AUROC: {roc_auc:.4f}\n')
        f.write(f'AUPRC: {prc_auc:.4f}\n')        
        f.write(f'Accuracy: {acc:.4f}\n')
        f.write(f'Sensitivity (Recall): {sn:.4f}\n')
        f.write(f'Specificity: {sp:.4f}\n')
        f.write(f'Matthews Correlation Coefficient (MCC): {mcc:.4f}\n')
    return roc_auc

In [None]:
def PBertKla(item):
    (batch_size,lr,seqlen) = item
    BENCHMARK_NAME = 'Data/PBertKla'
    # A local (non-global) binary output
    OUTPUT_TYPE = OutputType(False, 'binary')
    UNIQUE_LABELS = [0, 1]
    OUTPUT_SPEC = OutputSpec(OUTPUT_TYPE, UNIQUE_LABELS)

    # Loading the dataset

    train_set_file_path = os.path.join(BENCHMARKS_DIR, '%s_train.csv' % BENCHMARK_NAME)
    train_set = pd.read_csv(train_set_file_path).dropna().drop_duplicates()
    train_set, valid_set = train_test_split(train_set, stratify = train_set['label'], test_size = 0.1, random_state = 0)

    test_set_file_path = os.path.join(BENCHMARKS_DIR, '%s_test.csv' % BENCHMARK_NAME)
    test_set = pd.read_csv(test_set_file_path).dropna().drop_duplicates()

    print(f'{len(train_set)} training set records, {len(valid_set)} validation set records, {len(test_set)} test set records.')


    # Loading the pre-trained model and fine-tuning it on the loaded dataset

    pretrained_model_generator, input_encoder = load_pretrained_model()

    # get_model_with_hidden_layers_as_outputs gives the model output access to the hidden layers (on top of the output)
    model_generator = FinetuningModelGenerator(pretrained_model_generator, OUTPUT_SPEC, pretraining_model_manipulation_function = \
            get_model_with_hidden_layers_as_outputs, dropout_rate = 0.5)

    training_callbacks = [
        keras.callbacks.ReduceLROnPlateau(patience = 1, factor = 0.25, min_lr = 1e-05, verbose = 1),
        keras.callbacks.EarlyStopping(patience = 2, restore_best_weights = True),
    ]

    finetune(model_generator, input_encoder, OUTPUT_SPEC, train_set['seq'], train_set['label'], valid_set['seq'], valid_set['label'], \
            seq_len = seqlen, batch_size = batch_size, max_epochs_per_stage = 40, lr = lr, begin_with_frozen_pretrained_layers = True, \
            lr_with_frozen_pretrained_layers = 1e-02, n_final_epochs = 1, final_seq_len = 1024, final_lr = 1e-05, callbacks = training_callbacks)


    # Evaluating the performance on the test-set

    results, confusion_matrix, y_true, y_pred = evaluate_by_len(model_generator, input_encoder, OUTPUT_SPEC, test_set['seq'], test_set['label'], \
            start_seq_len = seqlen, start_batch_size = batch_size)
    savepath=BENCHMARKS_DIR+f'Results/'
    os.makedirs(savepath, exist_ok=True)
    roc=AUC_PRC(y_true,y_pred,savepath)

In [None]:
import os
import pandas as pd
import itertools
import multiprocessing as mp
from IPython.display import display
from tensorflow import keras
from sklearn.model_selection import train_test_split
from proteinbert import OutputType, OutputSpec, FinetuningModelGenerator, load_pretrained_model, finetune, evaluate_by_len
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.metrics import accuracy_score, confusion_matrix, matthews_corrcoef


if __name__ == '__main__':
    item=(4,5e-4,256)
    PBertKla(item)