In [1]:
import os
import torch
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

2
NVIDIA RTX A5000


In [2]:
from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedModel, EsmConfig, EsmPreTrainedModel
from peft import PeftModel, LoraConfig, get_peft_model, TaskType
# from tomlkit import value
import torch
import esm
import pandas as pd
import numpy as np
import random
import os
import wandb
import pickle as pkl
from datetime import datetime
import accelerate
from accelerate import Accelerator
from huggingface_hub import notebook_login
from torch.utils.data import Dataset, random_split
from transformers import (
    EsmForTokenClassification,
    EsmForMaskedLM,
    EsmModel,
    EsmTokenizer,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

In [3]:
concat_all_exp_data = pd.read_pickle('/home/kaustubh/RuBisCO_ML/ESM_LoRA/data/processed_combined_all_exp_assays_data.pkl')

In [4]:
formIII_lsu_variant_data_df = concat_all_exp_data.query('LSU_id.str.startswith("Anc393") or LSU_id.str.startswith("Anc367") or LSU_id == "Anc365" or LSU_id == "Anc366"')
formIII_lsu_variant_data_df['fixed_threshold_activity'] = formIII_lsu_variant_data_df['mean_reading'].apply(lambda x: 1 if x >= 50 else 0)

formIII_lsu_variant_data_df[formIII_lsu_variant_data_df['activity_binary'] == 1].shape[0], formIII_lsu_variant_data_df[formIII_lsu_variant_data_df['activity_binary'] == 0].shape[0]
# formIII_lsu_variant_data_df[formIII_lsu_variant_data_df['fixed_threshold_activity'] == 1].shape[0], formIII_lsu_variant_data_df[formIII_lsu_variant_data_df['fixed_threshold_activity'] == 0].shape[0]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  formIII_lsu_variant_data_df['fixed_threshold_activity'] = formIII_lsu_variant_data_df['mean_reading'].apply(lambda x: 1 if x >= 50 else 0)


(44, 60)

In [5]:
sequences = formIII_lsu_variant_data_df['lsussu_seq'].to_list()
binary_activity = formIII_lsu_variant_data_df['fixed_threshold_activity'].to_list()

In [6]:
accelerator = Accelerator()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [86]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

In [87]:
class ProteinDataset(Dataset):
    def __init__(self, sequences, binary_activity, tokenizer, max_length=512):
        self.sequences = sequences
        self.binary_activity = binary_activity
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx][:self.max_length]
        binding_site = self.binary_activity[idx]
        encoding = self.tokenizer(sequence, truncation=True, padding='max_length', max_length=self.max_length)
        encoding['labels'] = binding_site # + [-100] * (self.max_length - len(binding_site))  # Ignore extra padding tokens
        return encoding

The `ProteinDataset` class is a custom dataset class that inherits from `torch.utils.data.Dataset`. It is designed to handle protein sequences and their corresponding binding sites for use in a machine learning model. Here's a breakdown of what each part of the class does:

- **`__init__` method**: This initializes the dataset with sequences, binding sites, a tokenizer, and an optional maximum length for the sequences.
    - `sequences`: A list of protein sequences.
    - `binary_activity`: A list of active-inactive labels corresponding to each sequence.
    - `tokenizer`: A tokenizer to convert sequences into token IDs.
    - `max_length`: The maximum length of the sequences to be tokenized (default is 512).

- **`__len__` method**: This returns the number of sequences in the dataset.

- **`__getitem__` method**: This retrieves a single item (sequence and its binding site) from the dataset at the specified index (`idx`).
    - It truncates the sequence to the maximum length.
    - It tokenizes the sequence using the provided tokenizer.
    - It adds the binding site label to the tokenized output.

The class is used to prepare the data for training and evaluation in a format that can be easily fed into a machine learning model.

In [88]:
dataset = ProteinDataset(sequences, binary_activity, tokenizer)
train_size = int(0.85 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
train_dataset, val_dataset = accelerator.prepare(train_dataset, val_dataset)

In [89]:
# Extract back the sequences and their binding sites info for the entries in the train_dataset & val_dataset
train_sequences = [sequences[i] for i in train_dataset.indices]
train_binary_activity = [binary_activity[i] for i in train_dataset.indices]

val_sequences = [sequences[i] for i in val_dataset.indices]
val_binary_activity = [binary_activity[i] for i in val_dataset.indices]

In [90]:
train_sequences, train_binary_activity

(['MSIRYTDFVDLNYTPGDNDLVCTFRIEPADGMSLEAAASRVASESSNGTWATLSVDEDIKKLKATVFEIDGNIIKIAYPLGLFEPGNIPQILSSIAGNIFGMKAVKNIRLLDCEWPEELLSSFKGPQFGSEGVQEILGVDDRPLTATVPKPKVGLSTEQHAEVGYEAWVGGLDLLKDDENLTDQPFNPFEERVKESLEARDKAEDETGEKKAYLVNITAETNEMLERAELVAEYGGEYVMVDVITAGWAAVQTLRERCEDLGLAIHAHRAMHAAFDRLPSHGVSMRVLAQLARLVGVDQLHTGTAGLGKLANDKETLAINDWLRSDWYGIKDVLPVASGGLHPGLVPELLDAFGTNIIIQAGGGVHGHPDGTRAGAKALRQAVEAVVEGVSLEEYAKDHPELAKALEKWGHVRPR',
  'MSRYTDYVDLNYTPKENDLICTFHIEPADGVDLEEAAGRVAAESSIGTWTDVSTMPEIWEKLKARVYEIDETGNIVKIAYPLDLFEPGNIPQILSSIAGNIFGMKAVKNLRLLDIRFPEELVKSFKGPKFGIEGVRELLGVYDRPLVGTIVKPKVGLSAEEHAEVAYEAWVGGLDLVKDDENLTSQPFNPFEERVKKVLEARDKAEEETGEKKVYLVNITAPTEEMIRRAELVADLGGKYVMIDIITAGFSAVQSLREEDLGLAIHAHRAMHAAFTRNPKHGISMLVLAKLARLVGVDQLHIGTGVGKMEGDKEEVLAIRDALRLDRVPADEANHFLEQDWYNIKPVFPVASGGLHPGLVPDLIDIFGKDIIIQAGGGVHGHPDGTRAGAKALRQAIEAVMEGISLEEYAKEHKELKKALEKWGHVR',
  'MSKRYTDYVDLNYTPKENDLICTFHIEPADGVDLEEAAGRVAAESSIGTWTDVSTMPEIWEKLKARVYEIDESGNIVKIAYPLDLFEPGNIPQILSSIAGNIFGMKAVKNLRLLDIRFPKELVKSFKGPKFGIEGVRELLGVY

In [91]:
val_sequences, val_binary_activity

(['MSKRYTDYVDLNYTPKENDLICTFHIEPADGVDLEEAAGRVAAESSIGTWTDVSTMPEIVEKLKARVYEIDESGNIVKIAYPLDLFEPGNIPQILSSIAGNIFGMKAVKNLRLLDIRFPKELVKSFKGPKFGIEGVRELLGVYDRPLVGTIVKPKVGLSAEEHAEVAYEAWVGGLDLVKDDENLTSQPFNPFEERVKKVLEARDKAEEETGEKKVYLVNITAPTEEMIRRAELVADLGGKYVMIDIITAGFSAVQSLREEDLGLAIHAHRAMHAAFTRNPKHGISMLVLAKLARLVGVDQLHIGTVVGKMEGDKEEVLAIRDALRLDRVPADEANHFLEQDWYNIKPVFPVASGGLHPGLVPDLIDIFGKDIIIQAGGGVHGHPDGTRAGAKALRQAIEAAMEGISLEEYAKEHKELKKALEKWGHVR',
  'MSIRYEDFLDLNYEPGDNDLICTFRIEPADGISMEAAASRVASESSNGTWTTLQVMPDRIKKLSATVFEIDGNIVKIAYPADLFEPGNMPQILSSIAGNIMGMKAVDTIRLLDCHWPESLVSSFPGPQFGSSVRRELFGVHDRPLTATVPKPKVGLSAEQHAQIAYEAWVGGLDLIKDDENLTDQPFNPFEERVKKVLAARDKAEEETGEKKAYLVNITAETNEMLERADLVADYGGEYVMIDVITAGWSAVQTLRERCEDLGLAIHAHRAMHAAFTRLPSHGVSMRVLAQIARLVGVDQLHTGTGLGKMEGDEDVLGIADWLRQDLYNINDVFPVASGGLHPGLVPELIEAFGTDICIQAGGGVHGHPDGTRAGAKALRQAVEAAMEGVSLEEYADDHPELATALDKWGTERPR',
  'MSERYEDFVDLSYTPGENDLVCTFRIEPAEGMSMEEAASRVASESSNGTWATLSTMEDIKDLKAKTFSIDGNIIKIAYPLGLFEAGNMPQILSCIAGNIMGMKAVDTLRLLDIHWPEELLSSFKGPQFGSEGRQEIFGVHDR

In [10]:
def model_init(trial):
    base_model = EsmForTokenClassification.from_pretrained(
        "facebook/esm2_t6_8M_UR50D",
        num_labels=2 # For binary classification
    )

    config = LoraConfig(
        task_type=TaskType.TOKEN_CLS,
        r=16,
        lora_alpha=16,
        target_modules=["query", "key", "value"],
        lora_dropout=0.1,
        bias="all",
    )
    
    lora_model = get_peft_model(base_model, config)
    return accelerator.prepare(lora_model)

# config = LoraConfig(
#     r=8,  # Rank of low-rank matrix (controls adaptation strength)
#     lora_alpha=32,  # Scaling factor
#     lora_dropout=0.1,  # Dropout rate for LoRa layers
#     bias="none",
#     target_modules=["self_attn.k_proj", "self_attn.v_proj"],
#     task_type="FEATURE_EXTRACTION"  # Since ESM-2 is a feature extractor
# )

In [11]:
def wandb_hp_space(trial):
    return {
        "method": "random",
        "metric": {"name": "accuracy", "goal": "maximize"},
        "parameters": {
            "learning_rate": {"distribution": "uniform", "min": 1e-5, "max": 1e-3},
            "per_device_train_batch_size": {"values": [2, 4, 8, 16]},
        },
    }

In [12]:
class CustomModel(EsmPreTrainedModel):
    def __init__(self):
        super().__init__(EsmConfig.from_pretrained("facebook/esm2_t6_8M_UR50D"))
        self.backbone = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

        self.outputs = torch.nn.Linear(320, 1)

    def forward(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        labels=None,
        inputs_embeds=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True
    ):
        outputs = self.backbone(
            input_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

        sequence_output = outputs.last_hidden_state # (B, L, 1280)
        bos_emb = sequence_output[:,0] # (B, L, 1280) -> (B, 1280)
        # outputs = [self.outputs[i](sequence_output) for i in range(5)]
        outputs = self.outputs(bos_emb).squeeze(1) # (B,)

        # if labels, then we are training
        loss = None
        if labels is not None:
            assert outputs.shape == labels.shape, f"{outputs}, {labels}"
            loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")
            loss = loss_fn(outputs, labels.float())
            # loss = sum(losses)/len(losses)

        return {
            "loss": loss,
            "logits": outputs
        }
    
    def get_embedding(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        labels=None,
        inputs_embeds=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True
    ):
        outputs = self.backbone(
            input_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

        return outputs

In [13]:
def model_init():
    base_model = CustomModel()
    config = LoraConfig(
        task_type=TaskType.TOKEN_CLS,
        r=16,
        lora_alpha=16,
        target_modules=["query", "key", "value"],  # Apply LoRa to self-attention layers
        lora_dropout=0.1,
        bias="all",
    )
    lora_model = get_peft_model(base_model, config)
    return accelerator.prepare(lora_model)

In [14]:
from sklearn.metrics import roc_auc_score
def compute_metrics(eval_pred):
    predictions, labels = eval_pred # (B,)
    # print(predictions.shape)
    # predictions = np.argmax(predictions, axis=2)  # Convert logits to class labels
    # labels = accelerator.gather(labels)
    # mask = labels != -100
    # accuracy = (predictions[mask] == labels[mask]).mean()
    accuracy = roc_auc_score(labels, predictions)
    return {'accuracy': accuracy}

In [None]:
output_dir = f"./training_runs/esm2_t6_8M-finetuned-lora"

args = TrainingArguments(
        output_dir,
        evaluation_strategy="epoch",
        learning_rate=5e-3,
        per_device_train_batch_size=2,
        num_train_epochs=5,
        logging_steps=10,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        save_strategy="epoch",
        label_names=["labels"],
    )



In [16]:
trainer = Trainer(
    model=None,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    model_init=model_init,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
best_trial = trainer.hyperparameter_search(
    direction="maximize",
    backend="wandb",
    hp_space=wandb_hp_space,
    n_trials=10,
)

print("Best Trial:", best_trial)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Create sweep with ID: pvuaf0c1
Sweep URL: https://wandb.ai/kauamritkar-university-of-wisconsin-madison/uncategorized/sweeps/pvuaf0c1


[34m[1mwandb[0m: Agent Starting Run: uwkevz5s with config:
[34m[1mwandb[0m: 	learning_rate: 0.0008249543621818322
[34m[1mwandb[0m: 	per_device_train_batch_size: 2
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


Trying to set _wandb in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set assignments in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set metric in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.6944,0.801625,0.916667
2,0.6668,0.790349,1.0
3,0.6919,0.801848,1.0
4,0.6723,0.79895,1.0
5,0.6818,0.796464,1.0




0,1
eval/accuracy,▁████
eval/loss,█▁█▆▅
eval/runtime,▃▂▁▂█
eval/samples_per_second,▆▇█▇▁
eval/steps_per_second,▆▇█▇▁
train/epoch,▁▁▃▃▄▅▆▆▇██
train/global_step,▁▁▃▃▄▅▆▆▇██
train/grad_norm,▇▃▁▁█
train/learning_rate,█▆▅▃▁
train/loss,█▁▇▂▅

0,1
eval/accuracy,1.0
eval/loss,0.79646
eval/runtime,0.0497
eval/samples_per_second,160.835
eval/steps_per_second,20.104
total_flos,5194426490880.0
train/epoch,5.0
train/global_step,55.0
train/grad_norm,0.5937
train/learning_rate,7e-05


[34m[1mwandb[0m: Agent Starting Run: 65ih2sg5 with config:
[34m[1mwandb[0m: 	learning_rate: 0.0008098346289701476
[34m[1mwandb[0m: 	per_device_train_batch_size: 2


Trying to set _wandb in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set assignments in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set metric in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




Epoch,Training Loss,Validation Loss,Accuracy
1,0.6941,0.801204,0.916667
2,0.6667,0.7903,1.0
3,0.6919,0.801688,1.0
4,0.6723,0.798859,1.0
5,0.6818,0.796411,1.0




0,1
eval/accuracy,▁████
eval/loss,█▁█▆▅
eval/runtime,▂▁▁▂█
eval/samples_per_second,▇██▇▁
eval/steps_per_second,▇██▇▁
train/epoch,▁▁▃▃▄▅▆▆▇██
train/global_step,▁▁▃▃▄▅▆▆▇██
train/grad_norm,▇▃▁▁█
train/learning_rate,█▆▅▃▁
train/loss,█▁▇▂▅

0,1
eval/accuracy,1.0
eval/loss,0.79641
eval/runtime,0.0488
eval/samples_per_second,163.768
eval/steps_per_second,20.471
total_flos,5194426490880.0
train/epoch,5.0
train/global_step,55.0
train/grad_norm,0.59499
train/learning_rate,7e-05


[34m[1mwandb[0m: Agent Starting Run: yykbf3d2 with config:
[34m[1mwandb[0m: 	learning_rate: 0.0009803028794160072
[34m[1mwandb[0m: 	per_device_train_batch_size: 2


Trying to set _wandb in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set assignments in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set metric in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




Epoch,Training Loss,Validation Loss,Accuracy
1,0.6966,0.806087,1.0
2,0.6674,0.790396,1.0
3,0.6922,0.80315,1.0
4,0.6724,0.799634,1.0
5,0.6819,0.79679,1.0




0,1
eval/accuracy,▁▁▁▁▁
eval/loss,█▁▇▅▄
eval/runtime,▃▁▁▁█
eval/samples_per_second,▆███▁
eval/steps_per_second,▆███▁
train/epoch,▁▁▃▃▄▅▆▆▇██
train/global_step,▁▁▃▃▄▅▆▆▇██
train/grad_norm,▇▃▁▁█
train/learning_rate,█▆▄▃▁
train/loss,█▁▇▂▄

0,1
eval/accuracy,1.0
eval/loss,0.79679
eval/runtime,0.05
eval/samples_per_second,160.129
eval/steps_per_second,20.016
total_flos,5194426490880.0
train/epoch,5.0
train/global_step,55.0
train/grad_norm,0.58195
train/learning_rate,9e-05


[34m[1mwandb[0m: Agent Starting Run: eh3xnoo6 with config:
[34m[1mwandb[0m: 	learning_rate: 0.0006965141764047415
[34m[1mwandb[0m: 	per_device_train_batch_size: 4


Trying to set _wandb in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set assignments in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set metric in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.790749,0.833333
2,0.676700,0.79354,1.0
3,0.676700,0.801404,1.0
4,0.682200,0.798745,1.0
5,0.676700,0.796916,0.916667




0,1
eval/accuracy,▁███▅
eval/loss,▁▃█▆▅
eval/runtime,▂▁▁▁█
eval/samples_per_second,▇███▁
eval/steps_per_second,▇███▁
train/epoch,▁▂▃▅▅▆███
train/global_step,▁▂▃▅▅▆███
train/grad_norm,█▃▁
train/learning_rate,█▅▁
train/loss,▁█▁

0,1
eval/accuracy,0.91667
eval/loss,0.79692
eval/runtime,0.0513
eval/samples_per_second,155.952
eval/steps_per_second,19.494
total_flos,5194426490880.0
train/epoch,5.0
train/global_step,30.0
train/grad_norm,0.14383
train/learning_rate,0.0


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: uz0euv5g with config:
[34m[1mwandb[0m: 	learning_rate: 3.141012562031269e-05
[34m[1mwandb[0m: 	per_device_train_batch_size: 16


Trying to set _wandb in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set assignments in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set metric in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.79431,0.916667
2,No log,0.793584,0.916667
3,No log,0.793274,0.916667
4,No log,0.793338,0.916667
5,0.677000,0.793359,0.916667




0,1
eval/accuracy,▁▁▁▁▁
eval/loss,█▃▁▁▂
eval/runtime,▃▁▁▁█
eval/samples_per_second,▅███▁
eval/steps_per_second,▅███▁
train/epoch,▁▃▅▆███
train/global_step,▁▃▅▆███
train/grad_norm,▁
train/learning_rate,▁
train/loss,▁

0,1
eval/accuracy,0.91667
eval/loss,0.79336
eval/runtime,0.0487
eval/samples_per_second,164.278
eval/steps_per_second,20.535
total_flos,5194426490880.0
train/epoch,5.0
train/global_step,10.0
train/grad_norm,0.02324
train/learning_rate,0.0


[34m[1mwandb[0m: Agent Starting Run: p69ffmvf with config:
[34m[1mwandb[0m: 	learning_rate: 0.00032194618130384085
[34m[1mwandb[0m: 	per_device_train_batch_size: 4


Trying to set _wandb in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set assignments in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set metric in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.788126,0.916667
2,0.674700,0.788751,0.916667
3,0.674700,0.793433,0.916667
4,0.681900,0.792592,0.916667
5,0.676800,0.791852,0.916667




0,1
eval/accuracy,▁▁▁▁▁
eval/loss,▁▂█▇▆
eval/runtime,▄▂▂▁█
eval/samples_per_second,▅▇▇█▁
eval/steps_per_second,▅▇▇█▁
train/epoch,▁▂▃▅▅▆███
train/global_step,▁▂▃▅▅▆███
train/grad_norm,█▃▁
train/learning_rate,█▅▁
train/loss,▁█▃

0,1
eval/accuracy,0.91667
eval/loss,0.79185
eval/runtime,0.0496
eval/samples_per_second,161.295
eval/steps_per_second,20.162
total_flos,5194426490880.0
train/epoch,5.0
train/global_step,30.0
train/grad_norm,0.14665
train/learning_rate,0.0


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: ovmph146 with config:
[34m[1mwandb[0m: 	learning_rate: 0.0006223177481951578
[34m[1mwandb[0m: 	per_device_train_batch_size: 2


Trying to set _wandb in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set assignments in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set metric in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




Epoch,Training Loss,Validation Loss,Accuracy
1,0.6913,0.796265,0.916667
2,0.6662,0.789086,1.0
3,0.6915,0.79916,1.0
4,0.6723,0.797247,1.0
5,0.6818,0.795314,1.0




0,1
eval/accuracy,▁████
eval/loss,▆▁█▇▅
eval/runtime,▃▂▁▁█
eval/samples_per_second,▆▇██▁
eval/steps_per_second,▆▇██▁
train/epoch,▁▁▃▃▄▅▆▆▇██
train/global_step,▁▁▃▃▄▅▆▆▇██
train/grad_norm,▇▃▁▁█
train/learning_rate,█▆▅▃▁
train/loss,█▁█▃▅

0,1
eval/accuracy,1.0
eval/loss,0.79531
eval/runtime,0.0504
eval/samples_per_second,158.671
eval/steps_per_second,19.834
total_flos,5194426490880.0
train/epoch,5.0
train/global_step,55.0
train/grad_norm,0.61435
train/learning_rate,6e-05


[34m[1mwandb[0m: Agent Starting Run: nq717uq7 with config:
[34m[1mwandb[0m: 	learning_rate: 6.771276697456436e-05
[34m[1mwandb[0m: 	per_device_train_batch_size: 2


Trying to set _wandb in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set assignments in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set metric in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




Epoch,Training Loss,Validation Loss,Accuracy
1,0.6826,0.791297,0.833333
2,0.6641,0.789486,0.916667
3,0.6909,0.791014,0.916667
4,0.6727,0.790872,0.916667
5,0.6819,0.790635,0.916667




0,1
eval/accuracy,▁████
eval/loss,█▁▇▆▅
eval/runtime,▃▁▁▁█
eval/samples_per_second,▆███▁
eval/steps_per_second,▆███▁
train/epoch,▁▁▃▃▄▅▆▆▇██
train/global_step,▁▁▃▃▄▅▆▆▇██
train/grad_norm,▆▃▁▁█
train/learning_rate,█▆▅▃▁
train/loss,▆▁█▃▆

0,1
eval/accuracy,0.91667
eval/loss,0.79064
eval/runtime,0.0518
eval/samples_per_second,154.411
eval/steps_per_second,19.301
total_flos,5194426490880.0
train/epoch,5.0
train/global_step,55.0
train/grad_norm,0.69199
train/learning_rate,1e-05


[34m[1mwandb[0m: Agent Starting Run: nhvc5aii with config:
[34m[1mwandb[0m: 	learning_rate: 0.00022075371526919648
[34m[1mwandb[0m: 	per_device_train_batch_size: 4


Trying to set _wandb in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set assignments in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set metric in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.788734,0.833333
2,0.674100,0.7887,0.833333
3,0.674100,0.791916,0.833333
4,0.681900,0.791336,0.833333
5,0.676900,0.79083,0.833333




0,1
eval/accuracy,▁▁▁▁▁
eval/loss,▁▁█▇▆
eval/runtime,▃▁▁▁█
eval/samples_per_second,▆███▁
eval/steps_per_second,▆███▁
train/epoch,▁▂▃▅▅▆███
train/global_step,▁▂▃▅▅▆███
train/grad_norm,█▄▁
train/learning_rate,█▅▁
train/loss,▁█▄

0,1
eval/accuracy,0.83333
eval/loss,0.79083
eval/runtime,0.049
eval/samples_per_second,163.166
eval/steps_per_second,20.396
total_flos,5194426490880.0
train/epoch,5.0
train/global_step,30.0
train/grad_norm,0.14711
train/learning_rate,0.0


[34m[1mwandb[0m: Agent Starting Run: w2y7s38g with config:
[34m[1mwandb[0m: 	learning_rate: 0.0005422392414127516
[34m[1mwandb[0m: 	per_device_train_batch_size: 4


Trying to set _wandb in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set assignments in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Trying to set metric in the hyperparameter search but there is no corresponding field in `TrainingArguments`.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.788992,0.833333
2,0.675900,0.791145,0.916667
3,0.675900,0.798135,1.0
4,0.682000,0.796367,0.916667
5,0.676700,0.795011,0.916667




0,1
eval/accuracy,▁▅█▅▅
eval/loss,▁▃█▇▆
eval/runtime,▂▁▃▂█
eval/samples_per_second,▇█▆▇▁
eval/steps_per_second,▇█▆▇▁
train/epoch,▁▂▃▅▅▆███
train/global_step,▁▂▃▅▅▆███
train/grad_norm,█▃▁
train/learning_rate,█▅▁
train/loss,▁█▂

0,1
eval/accuracy,0.91667
eval/loss,0.79501
eval/runtime,0.0493
eval/samples_per_second,162.289
eval/steps_per_second,20.286
total_flos,5194426490880.0
train/epoch,5.0
train/global_step,30.0
train/grad_norm,0.14553
train/learning_rate,0.0


Best Trial: BestRun(run_id='uwkevz5s', objective=1.0, hyperparameters={'learning_rate': 0.0008249543621818322, 'per_device_train_batch_size': 2, 'assignments': {}, 'metric': 'eval/loss'}, run_summary=None)


In [18]:
def train_final_model(best_trial):
    best_hyperparameters = best_trial.hyperparameters
    model = model_init()
    args.learning_rate = best_hyperparameters["learning_rate"]
    args.per_device_train_batch_size = best_hyperparameters["per_device_train_batch_size"]
    final_trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )
    final_trainer.train()
    # Explicitly save the model's configuration
    model.config.save_pretrained(output_dir)
    # Save the model
    final_trainer.save_model(output_dir)
    
    # Log in to Hugging Face account
    # notebook_login()
    # Push the model to the Hub
    # repo_name = "AmelieSchreiber/esm2_t6_8M_UR50D_lora_rna_binding_sites" # Change this to your desired repository name
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    # model.push_to_hub(repo_name)
train_final_model(best_trial)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.




Epoch,Training Loss,Validation Loss,Accuracy
1,0.6958,0.783984,0.916667
2,0.6688,0.781743,0.916667
3,0.6912,0.800304,0.916667
4,0.6727,0.799796,0.75
5,0.6813,0.797781,0.75




In [None]:
import itertools

def train_with_hyperparams(learning_rate, batch_size):
    args = TrainingArguments(
        output_dir="./esm2_t6_8M-finetuned-lora",
        evaluation_strategy="epoch",
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        num_train_epochs=5,
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
    )

    trainer = Trainer(
        model=model_init(None),
        args=args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )
    trainer.train()

    # Evaluate and return accuracy
    results = trainer.evaluate()
    return results["eval_accuracy"]

# Define hyperparameter grid
learning_rates = [1e-5, 5e-5, 1e-4, 5e-4]
batch_sizes = [2, 4, 8, 16]

best_accuracy = 0
best_hyperparams = {}

# Perform grid search
for lr, bs in itertools.product(learning_rates, batch_sizes):
    print(f"Training with learning_rate={lr}, batch_size={bs}")
    accuracy = train_with_hyperparams(lr, bs)

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_hyperparams = {"learning_rate": lr, "batch_size": bs}

print("Best hyperparameters:", best_hyperparams)