# Finetune ProteinBERT for ASM classification
### 0. (OPTIONAL) Prepare datasets from fasta for binary classification

In [None]:

import random

from Bio import SeqIO

FOLD_NO = "6"


def read_fasta(file_path, label):
    """Read a FASTA file and return a list of (sequence, label) tuples."""
    sequences = []
    with open(file_path, "r") as fasta_file:
        for record in SeqIO.parse(fasta_file, "fasta"):
            sequences.append((str(record.seq), label))
    return sequences


def combine_and_shuffle(sequences1, sequences2):
    """Combine two lists of sequences, shuffle them, and return a DataFrame."""
    combined = sequences1 + sequences2
    random.shuffle(combined)
    return pd.DataFrame(combined, columns=['seq', 'label'])


# Paths to the FASTA files
path_negative_train = "data(train+val)/negative/PB40/PB40_1z20_clu50_trn" + FOLD_NO + ".fa"
path_positive_train = "data(train+val)/positive/bass_motif/pad/bass_ctm_motif_trn" + FOLD_NO + ".fa"
path_negative_val = "data(train+val)/negative/PB40/PB40_1z20_clu50_val" + FOLD_NO + ".fa"
path_positive_val = "data(train+val)/positive/bass_motif/pad/bass_ctm_motif_val" + FOLD_NO + ".fa"
path_negative_test = "data(train+val)/negative/PB40/PB40_1z20_clu50_val" + FOLD_NO + ".fa"
path_positive_test = "data(train+val)/positive/bass_motif/bass_ntm_motif_test.fa"

# Read sequences and assign labels
negative_sequences_train = read_fasta(path_negative_train, 0)
positive_sequences_train = read_fasta(path_positive_train, 1)
negative_sequences_val = read_fasta(path_negative_val, 0)
positive_sequences_val = read_fasta(path_positive_val, 1)

# Combine and shuffle datasets
shuffled_data_train = combine_and_shuffle(negative_sequences_train, positive_sequences_train)
shuffled_data_val = combine_and_shuffle(negative_sequences_val, positive_sequences_val)

# Save to CSV
shuffled_data_train.to_csv("data(train+val)/prepared/" + FOLD_NO + "/bass_pb40.train.csv", index=False)
shuffled_data_val.to_csv("data(train+val)/prepared/" + FOLD_NO + "/bass_pb40.val.csv", index=False)

### 1. Verify configs and imports

In [None]:
BENCHMARK_NAME = 'bass_pb40'
import os

import tensorflow as tf

print(tf.__version__)
print(tf.config.list_physical_devices('GPU'))
import pandas as pd
from tensorflow import keras
from proteinbert import OutputType, OutputSpec, FinetuningModelGenerator, load_pretrained_model, finetune
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs

# A local (non-global) binary output
OUTPUT_TYPE = OutputType(False, 'binary')
UNIQUE_LABELS = [0, 1]
OUTPUT_SPEC = OutputSpec(OUTPUT_TYPE, UNIQUE_LABELS)

# 2. Finetune

In [None]:
# for 6-fold cross-validation
for i in range(1, 7):
    model_no = i
    BENCHMARKS_DIR = './data(train+val)/prepared/' + str(model_no)

    train_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.train.csv' % BENCHMARK_NAME)
    valid_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.val.csv' % BENCHMARK_NAME)
    train_set = pd.read_csv(train_set_file_path).dropna().drop_duplicates()
    valid_set = pd.read_csv(valid_set_file_path).dropna().drop_duplicates()

    print(f'{len(train_set)} training set records, {len(valid_set)} validation set records.')

    pretrained_model_generator, input_encoder = load_pretrained_model()

    model_generator = FinetuningModelGenerator(pretrained_model_generator, OUTPUT_SPEC,
                                               pretraining_model_manipulation_function=get_model_with_hidden_layers_as_outputs,
                                               dropout_rate=0.5)

    training_callbacks = [
        keras.callbacks.ReduceLROnPlateau(patience=1, factor=0.25, min_lr=1e-05, verbose=1),
        keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
        keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=1, update_freq=100)
    ]

    finetune(model_generator, input_encoder, OUTPUT_SPEC, train_set['seq'], train_set['label'], valid_set['seq'],
             valid_set['label'],
             seq_len=42, batch_size=64, max_epochs_per_stage=40, lr=1e-04, begin_with_frozen_pretrained_layers=True,
             lr_with_frozen_pretrained_layers=1e-02, n_final_epochs=0, final_seq_len=1024, final_lr=5e-06,
             callbacks=training_callbacks)

    model = model_generator.create_model(seq_len=42)

    model.save("./models/" + str(model_no))