In [None]:
import optuna
from optuna.trial import TrialState
from transformers import IntervalStrategy, AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import TrainerCallback, TrainerState, TrainerControl
from datasets import Dataset
import logging
import pandas as pd
import numpy as np
import os
from evaluate import load
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
path_work = "/home/conchae/PhageDepo_pdb"

# ******************************************************************
# LOAD THE DATA :
df_depo = pd.read_csv(f"{path_work}/Dpo_domains.phagedepo.0805.final.tsv" , sep = "\t" , header = 0)

df_beta_helix = df_depo[df_depo["Fold"] == "right-handed beta-helix"]
df_beta_prope = df_depo[df_depo["Fold"] == "6-bladed beta-propeller"]

def get_labels(df , label = 1) :
    labels_df = []
    for _,row in df.iterrows():
        info = row["Boundaries"]
        seq_length = len(row["Full_seq"])
        if info == "full_protein" :
            labels = [label] * seq_length
            labels_df.append(labels)
        else :
            start = int(info.split("_")[-2])
            end = int(info.split("_")[-1])
            labels = [0 if i < start or i >= end else label for i in range(seq_length)]
            labels_df.append(labels)
    return labels_df

# Beta-helix :
labels_beta_helix = get_labels(df_beta_helix , label = 1)
seq_beta_helix = df_beta_helix["Full_seq"].to_list()

# Beta propeller :
labels_beta_propeller = get_labels(df_beta_prope , label = 2)
seq_beta_propeller = df_beta_prope["Full_seq"].to_list()

# The input data :
sequences = seq_beta_helix + seq_beta_propeller
labels = labels_beta_helix + labels_beta_propeller


# ******************************************************************
# DEFINING THE MODEL SIZE : 
#model_checkpoint = "facebook/esm2_t6_8M_UR50D"
#model_checkpoint = "facebook/esm2_t12_35M_UR50D"
model_checkpoint = "facebook/esm2_t30_150M_UR50D"
model_name = model_checkpoint.split("/")[-1]


# ******************************************************************
# PREPROCESS THE DATA FOR THE MODEL :
train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

num_labels = 3
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer)

#metric = load("accuracy")

# ******************************************************************
# THE SEARCH : 

logging.basicConfig(filename=f'{path_work}/hyperparameters_tuning.log', level=logging.INFO)

class ObjectiveWrapper:
    def __init__(self, objective):
        self.objective = objective

    def __call__(self, trial):
        result = self.objective(trial)
        # Log the result of each trial
        logging.info(f'Trial {trial.number} finished with value: {result} and parameters: {trial.params}.')
        return result


class MyCallback(TrainerCallback):
    def __init__(self, trial: optuna.trial.Trial):
        self._trial = trial

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics=None, **kwargs):
        if state.is_local_process_zero:
            self._trial.report(metrics["eval_loss"], step=state.global_step)
            if self._trial.should_prune():
                raise optuna.exceptions.TrialPruned()


def objective(trial):
    model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
    batch_size = trial.suggest_int("batch_size", 8, 16)
    learning_rate = trial.suggest_float("lr", 1e-6, 1e-4, log=True)
    weight_decay = trial.suggest_float("weight_decay", 0.0, 0.1)
    args = TrainingArguments(
        f"{model_name}-finetuned-depolymerase",
        evaluation_strategy = "epoch",
        save_strategy = "epoch",
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=3,
        weight_decay=weight_decay,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        logging_dir='./logs',
        push_to_hub=False,
    )
    trainer = Trainer(
        model,
        args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        data_collator=data_collator,
        callbacks=[MyCallback(trial)]
    )
    trainer.train()
    return trainer.evaluate()["accuracy"]

from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


wrapper = ObjectiveWrapper(objective)

pruner = optuna.pruners.MedianPruner(
    n_startup_trials=5,
    n_warmup_steps=30
)

study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(wrapper, n_trials=20, n_jobs=5)

# Log the result of the optimization
best_trial = study.best_trial
logging.info(f'Best trial finished with value: {best_trial.value} and parameters: {best_trial.params}.')
