In [None]:
import pandas as pd
import os 
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
from datasets import Dataset
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
from evaluate import load
import numpy as np

path_work = "/home/conchae/PhageDepo_pdb"

df_depo = pd.read_csv(f"{path_work}/Phagedepo.Dataset.2007.tsv" , sep = "\t" , header = 0)

df_beta_helix = df_depo[df_depo["Fold"] == "right-handed beta-helix"]
df_beta_prope = df_depo[df_depo["Fold"] == "6-bladed beta-propeller"]
df_beta_triple =  df_depo[df_depo["Fold"] == "triple-helix"]
df_negative = df_depo[df_depo["Fold"] == "Negative"]

In [None]:
def get_labels(df , label = 1) :
    labels_df = []
    for _,row in df.iterrows():
        info = row["Boundaries"]
        seq_length = len(row["Full_seq"])
        if info == "Negative" :
            labels = [label] * seq_length
            labels_df.append(labels)
        elif info == "full_protein" or info == "full" :
            labels = [label] * seq_length
            labels_df.append(labels)
        elif info.count(":") > 0 : 
            start = int(info.split(":")[0])
            end = int(info.split(":")[1])
            labels = [0 if i < start or i >= end else label for i in range(seq_length)]
            labels_df.append(labels)
        else :
            start = int(info.split("_")[-2])
            end = int(info.split("_")[-1])
            labels = [0 if i < start or i >= end else label for i in range(seq_length)]
            labels_df.append(labels)
    return labels_df



In [None]:
# ********************************************
# Transforming the data. Choose the label to associate with the triple helix :
# Beta-helix :
labels_beta_helix = get_labels(df_beta_helix , label = 1)
seq_beta_helix = df_beta_helix["Full_seq"].to_list()

# Beta propeller : 
labels_beta_propeller = get_labels(df_beta_prope , label = 2)
seq_beta_propeller = df_beta_prope["Full_seq"].to_list()

# Triple helix : 
labels_triple_helix = get_labels(df_beta_triple , label = 3)
seq_triple_helix = df_beta_triple["Full_seq"].to_list()

# Negative :
labels_negative = get_labels(df_negative , label = 0)
seq_negative = df_negative["Full_seq"].to_list()

# The input data :
sequences = seq_beta_helix + seq_beta_propeller + seq_triple_helix + seq_negative
labels = labels_beta_helix + labels_beta_propeller + labels_triple_helix + labels_negative

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.2, random_state = 243)
train_esm2 , train_CNV , esm2_labels , CNV_labels = train_test_split(train_sequences, train_labels, test_size=0.125, random_state = 243)


In [None]:
# ********************************************
# Define the model :
#model_checkpoint = "facebook/esm2_t6_8M_UR50D"
model_checkpoint = "facebook/esm2_t12_35M_UR50D"
#model_checkpoint = "facebook/esm2_t30_150M_UR50D"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

train_tokenized = tokenizer(train_esm2)
test_tokenized = tokenizer(test_sequences)

train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", esm2_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

num_labels = 4
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)


In [None]:
# ********************************************
# Training :
data_collator = DataCollatorForTokenClassification(tokenizer)

model_name = model_checkpoint.split("/")[-1]
batch_size = 4

args = TrainingArguments(
    f"{model_name}-finetuned-depolymerase.2007.4_labels",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-5, #4
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=15,
    weight_decay=0.001,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir='./logs',
    push_to_hub=False,
)

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    labels = labels.reshape((-1,))
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions.reshape((-1,))
    predictions = predictions[labels!=-100]
    labels = labels[labels!=-100]
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

trainer.train()
