# Set up

In [1]:
# 1. Imports
import os
import numpy as np
import xml.etree.ElementTree as ET
import torch
import torch.nn as nn
import requests
import pickle
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, matthews_corrcoef, accuracy_score
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)
from datasets import Dataset
from accelerate import Accelerator
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType


# Setup Environment Variables and Accelerator
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
accelerator = Accelerator()

# Setup Wandb
os.environ["WANDB_ENTITY"] = "hc-ai-handson"
os.environ["WANDB_LOG_MODEL"] = "end"
project="esm2-binding-sites"
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mkeisuke-kamata[0m ([33mhc-ai-handson[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

# Data prepartion

In [None]:
def get_data_from_url(url,name):
    response = requests.get(url)
    with open(name, 'wb') as file:
        file.write(response.content)

with wandb.init(project=project,job_type="upload_data") as run:
    artifact = wandb.Artifact(name="binding_sites_random_split_by_family_train",
                              metadata={
                                        "url": "https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family_550K",
                                        },
                              type="dataset")
    url="https://cdn-lfs.huggingface.co/repos/f1/20/f1203a07ea684a9586a90e512fe0ab40290bbaa0f57833aed3e29decf5520637/f17f5ca4beb72ba0a867c94da2f145a1cb7924f8013461dcf39384555ccc3d79?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27train_labels_chunked_by_family.pkl%3B+filename%3D%22train_labels_chunked_by_family.pkl%22%3B&Expires=1695257172&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY5NTI1NzE3Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9mMS8yMC9mMTIwM2EwN2VhNjg0YTk1ODZhOTBlNTEyZmUwYWI0MDI5MGJiYWEwZjU3ODMzYWVkM2UyOWRlY2Y1NTIwNjM3L2YxN2Y1Y2E0YmViNzJiYTBhODY3Yzk0ZGEyZjE0NWExY2I3OTI0ZjgwMTM0NjFkY2YzOTM4NDU1NWNjYzNkNzk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=jtL4YhzDy5LYzD3IvmWmPG0u4XlGQB6a48KPkKGFtyirCUvCeywmlffV6RRqwwX1ROChlcDM-ytSPUJ5%7Ef3EkhRPogFCyTnKGJJlZ9DNYyF8qCb-s4FdjPZeh2SPXzGDvoDgtBOfgOv-yx%7Evyk8GILBYfgjZqCKj1I7LByOUmslSYIYNsZOp90bLtVlbQOO3V7EpGnc4iRawkcA90D1I0zwUFfvOxZ6K7mZW4en3nXlNOTZ5e71uI5Kmar%7EY4f6wMYn5B5HOiSL%7EwEww-AUlicMcvkawmVBxTxKoNjZ6M%7EQJ0fnyyzKyXW6rrpjojfEprUs1D9JtCaZlKxrWTnIeSQ__&Key-Pair-Id=KVTP0A1DKRTAX"
    get_data_from_url(url,"train_labels_chunked_by_family.pkl")
    url="https://cdn-lfs.huggingface.co/repos/f1/20/f1203a07ea684a9586a90e512fe0ab40290bbaa0f57833aed3e29decf5520637/c5fb314b71338ce943c62301e00f1fd865afa50bf2d7a6440a2935db30989e45?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27train_sequences_chunked_by_family.pkl%3B+filename%3D%22train_sequences_chunked_by_family.pkl%22%3B&Expires=1695257198&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY5NTI1NzE5OH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9mMS8yMC9mMTIwM2EwN2VhNjg0YTk1ODZhOTBlNTEyZmUwYWI0MDI5MGJiYWEwZjU3ODMzYWVkM2UyOWRlY2Y1NTIwNjM3L2M1ZmIzMTRiNzEzMzhjZTk0M2M2MjMwMWUwMGYxZmQ4NjVhZmE1MGJmMmQ3YTY0NDBhMjkzNWRiMzA5ODllNDU%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=Dh2R5oGLfwPWWce97ZJsihX-mSTFfwbtEfq1SI00tVyacaT4dlxnCddshe%7E15tSamiMukONmeTP38hQUmRitP1-v3wjZ3%7E5G7OeBmsWeK5%7Eivva7xcJOOe45KUXgy59fIkBneULZOLw8ysrkD83IGkybkRB3MlNWLeiwHIai9cPgqdF-yCk1nxU2VHL-A2GUh2W3sTo7Q1edE3Qzy2NV2QSbfmadsz1xIYBh135d%7EzJp0Y7MnNGZV-UyGBJJXjgOjwU7B05esYekTjuQTaX9ujh3iEF883SBepgtWypJ81%7EZ7LYYpuvGwCv5eO2-6qf9kMtLZQLww87US4C1%7EWOxMQ__&Key-Pair-Id=KVTP0A1DKRTAX"
    get_data_from_url(url,"train_sequences_chunked_by_family.pkl")
    artifact.add_file(local_path="train_sequences_chunked_by_family.pkl")
    artifact.add_file(local_path="train_labels_chunked_by_family.pkl")
    run.log_artifact(artifact)
    
with wandb.init(project=project,job_type="upload_data") as run:
    artifact = wandb.Artifact(name="binding_sites_random_split_by_family_test",
                              metadata={
                                        "url": "https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family_550K",
                                        },
                              type="dataset")
    url="https://cdn-lfs.huggingface.co/repos/f1/20/f1203a07ea684a9586a90e512fe0ab40290bbaa0f57833aed3e29decf5520637/af3ddcda3b27eeac15739f06046b122ae5bfac944929434c53fbd2a91e44684c?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27test_labels_chunked_by_family.pkl%3B+filename%3D%22test_labels_chunked_by_family.pkl%22%3B&Expires=1695256911&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY5NTI1NjkxMX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9mMS8yMC9mMTIwM2EwN2VhNjg0YTk1ODZhOTBlNTEyZmUwYWI0MDI5MGJiYWEwZjU3ODMzYWVkM2UyOWRlY2Y1NTIwNjM3L2FmM2RkY2RhM2IyN2VlYWMxNTczOWYwNjA0NmIxMjJhZTViZmFjOTQ0OTI5NDM0YzUzZmJkMmE5MWU0NDY4NGM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=0Z%7ENnciKC5NA3W47qQxn6t4KGdDY65%7Efma0Kq7Oyfd79xnxZVerXHSD8jeWK52iPP-My7DbqpzWdvQNtLupHzzPhzqqQ%7EUwZfr%7EwbnHOvUCFLGQJczkCWXyQO90cDyzXm5vNR1lG%7EcM5V1COhWeOKRkToeqzeTvaL4e1bTN71ZLCepfYRzGr8le5xHNPTNmhnnpIldkuB4guUGYsJFirdJ5bJT3ZwLRq0hN3EMvPimoORfyVcO%7E-pYN8%7ENqnaQroPjrc8SnVA2BXmZWyyozYZ2vfWw83rnVB%7EYWJcRO46uBqK0tW5bgeBWcI98qyBO458vjyzpdG5skoz7Cq3OwiZA__&Key-Pair-Id=KVTP0A1DKRTAX"
    get_data_from_url(url,"test_labels_chunked_by_family.pkl")
    url="https://cdn-lfs.huggingface.co/repos/f1/20/f1203a07ea684a9586a90e512fe0ab40290bbaa0f57833aed3e29decf5520637/e304531c74a93d537f17b00585424c87afd6b130915dde22f356892b4cb3b240?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27test_sequences_chunked_by_family.pkl%3B+filename%3D%22test_sequences_chunked_by_family.pkl%22%3B&Expires=1695257137&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY5NTI1NzEzN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9mMS8yMC9mMTIwM2EwN2VhNjg0YTk1ODZhOTBlNTEyZmUwYWI0MDI5MGJiYWEwZjU3ODMzYWVkM2UyOWRlY2Y1NTIwNjM3L2UzMDQ1MzFjNzRhOTNkNTM3ZjE3YjAwNTg1NDI0Yzg3YWZkNmIxMzA5MTVkZGUyMmYzNTY4OTJiNGNiM2IyNDA%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=ATB3zaFSwDQ9TClQPD0ERVzTjjeEJZamzZjjnXeda3ecjVCJoKCycIQ-o0zoJL1b0F9NShZyUVmCPQqbMA4fn7RhA63lbOv8%7EBrjD3FnqUVnXaOADvmGhCS0Uq2aFpEM9Z0aLn8rpPt04gIU3MpDGjqRVUX8TlANZcp8y4fXPj9IcEXKT7xIlkhnDoG5Pj4gJzNMziHYRfW4q792pKfjl0XEBeBHEuHxXSBG5CsalXGN1OAb3w2JOruLI0v4L461TPHSPAAuZTYI-KeMhZV1zAMEWQ4ODEKWA%7EcPa94HlKnsC7DEVY-kcPUET98kRbtHpTDH7s9OA51kPJoYk-sZVg__&Key-Pair-Id=KVTP0A1DKRTAX"
    get_data_from_url(url,"test_sequences_chunked_by_family.pkl")
    artifact.add_file(local_path="test_sequences_chunked_by_family.pkl")
    artifact.add_file(local_path="test_labels_chunked_by_family.pkl")
    run.log_artifact(artifact)

# Define functions

In [2]:
def convert_binding_string_to_labels(binding_string):
    """Convert 'proBnd' strings into label arrays."""
    return [1 if char == '+' else 0 for char in binding_string]

def truncate_labels(labels, max_length):
    """Truncate labels to the specified max_length."""
    return [label[:max_length] for label in labels]

def compute_metrics(p):
    """Compute metrics for evaluation."""
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions[labels != -100].flatten()
    labels = labels[labels != -100].flatten()
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    auc = roc_auc_score(labels, predictions)
    return {'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}

def compute_loss(model, inputs):
    """Custom compute_loss function."""
    logits = model(**inputs).logits
    labels = inputs["labels"]
    loss_fct = nn.CrossEntropyLoss(weight=class_weights)
    active_loss = inputs["attention_mask"].view(-1) == 1
    active_logits = logits.view(-1, model.config.num_labels)
    
    # The torch.where function is used to obtain the labels at the positions of the active tokens,
    # and set the ignore index at the positions of inactive tokens (padding tokens).
    active_labels = torch.where(
        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
    )
    loss = loss_fct(active_logits, active_labels)
    return loss

In [3]:
def data_preparation(config,download=False):
    # Tokenization
    tokenizer = AutoTokenizer.from_pretrained(config["base_model_path"])
    max_sequence_length = 1000
    
    # Use artifacts on wandb
    if download:
        with wandb.init(project=config["wandb_project"], job_type="data_prep") as run:
            artifact_train = run.use_artifact('hc-ai-handson/esm2-binding-sites/binding_sites_random_split_by_family_train:v0', type='dataset')
            artifact_test = run.use_artifact('hc-ai-handson/esm2-binding-sites/binding_sites_random_split_by_family_test:v0', type='dataset')
            artifact_dir_train = artifact_train.download()
            artifact_dir_test = artifact_test.download()
            # Load the data from pickle files
            with open(artifact_dir_train+"/train_sequences_chunked_by_family.pkl", "rb") as f:
                train_sequences = pickle.load(f)
            with open(artifact_dir_train+"/train_labels_chunked_by_family.pkl", "rb") as f:
                train_labels = pickle.load(f)
            with open(artifact_dir_test+"/test_sequences_chunked_by_family.pkl", "rb") as f:
                test_sequences = pickle.load(f)
            with open(artifact_dir_test+"/test_labels_chunked_by_family.pkl", "rb") as f:
                test_labels = pickle.load(f)
    else:
        with open("train_sequences_chunked_by_family.pkl", "rb") as f:
            train_sequences = pickle.load(f)
        with open("train_labels_chunked_by_family.pkl", "rb") as f:
            train_labels = pickle.load(f)
        with open("test_sequences_chunked_by_family.pkl", "rb") as f:
            test_sequences = pickle.load(f)
        with open("test_labels_chunked_by_family.pkl", "rb") as f:
            test_labels = pickle.load(f)
        
    train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
    test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
    
    # Directly truncate the entire list of labels
    train_labels = truncate_labels(train_labels, max_sequence_length)
    test_labels = truncate_labels(test_labels, max_sequence_length)
    
    train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
    test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
    
    # Compute Class Weights
    classes = [0, 1]  
    flat_train_labels = [label for sublist in train_labels for label in sublist]
    class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
    accelerator = Accelerator()
    class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
            
    return train_dataset, test_dataset, class_weights, tokenizer
    
# Define Custom Trainer Class
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = compute_loss(model, inputs)
        return (loss, outputs) if return_outputs else loss

In [4]:
def finetuning(train_dataset,test_dataset,config):
    # Define labels and model
    id2label = {0: "No binding site", 1: "Binding site"}
    label2id = {v: k for k, v in id2label.items()}

    # Train and Save Model
    with wandb.init(project=config["wandb_project"],config=config,tags=["finetuning"]) as run:
        config=wandb.config
        base_model = AutoModelForTokenClassification.from_pretrained(config.base_model_path, 
                                                                     num_labels=len(id2label),
                                                                     id2label=id2label,
                                                                     label2id=label2id)
        # Convert the model into a PeftModel
        peft_config = LoraConfig(
            task_type=TaskType.TOKEN_CLS, 
            inference_mode=False, 
            r=config["r"], 
            lora_alpha=config.lora_alpha, 
            target_modules=["query", "key", "value"], # also try "dense_h_to_4h" and "dense_4h_to_h"
            lora_dropout=config.lora_dropout, 
            bias="none" # or "all" or "lora_only" 
        )
        model = get_peft_model(base_model, peft_config)
      
        # Use the accelerator
        accelerator = Accelerator()
        model = accelerator.prepare(model)
        train_dataset = accelerator.prepare(train_dataset)
        test_dataset = accelerator.prepare(test_dataset)
        timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    
        # Training setup
        training_args = TrainingArguments(
            output_dir=f"checkpoint/finetuned_model_{timestamp}",
            learning_rate=config.lr,
            lr_scheduler_type=config.lr_scheduler_type,
            gradient_accumulation_steps=1,
            max_grad_norm=config.max_grad_norm,
            per_device_train_batch_size=config.per_device_train_batch_size,
            per_device_eval_batch_size=config.per_device_train_batch_size,
            num_train_epochs=config.num_train_epochs,
            weight_decay=config.weight_decay,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            metric_for_best_model="f1",
            greater_is_better=True,
            push_to_hub=False,
            logging_dir=None,
            logging_first_step=False,
            logging_steps=200,
            save_total_limit=7,
            no_cuda=False,
            seed=8893,
            fp16=True,
            report_to='wandb',
        )
    
        # Initialize Trainer
        trainer = WeightedTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            tokenizer=tokenizer,
            data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
            compute_metrics=compute_metrics
        )
        
        trainer.train()
        lora_model_path = os.path.join("finetuned_model", f"best_model_esm2_lora_{timestamp}")
        trainer.save_model(lora_model_path)
        tokenizer.save_pretrained(lora_model_path)
        
        # Load the LoRA model
        #finetuned_model = PeftModel.from_pretrained(base_model, lora_model_path, torch_dtype=torch.float16)
        #finetuned_model = accelerator.prepare(finetuned_model)

        # Create a data collator
        #data_collator = DataCollatorForTokenClassification(tokenizer)
        
        # Get the metrics for the training and test datasets
        #train_metrics = compute_metrics_evalaution(train_dataset, finetuned_model, data_collator,"train_")
        #test_metrics = compute_metrics_evalaution(test_dataset, finetuned_model, data_collator,"test_")
        #run.log(train_metrics)
        #run.log(test_metrics)

In [9]:
config = {
    "base_model_path": "facebook/esm2_t6_8M_UR50D",
    "lora_alpha": 1,
    "lora_dropout": 0.2,
    "lr": 5e-03,
    "lr_scheduler_type": "cosine",
    "max_grad_norm": 0.5,
    "num_train_epochs": 4,
    "per_device_train_batch_size": 12,
    "r": 2,
    "weight_decay": 0.2,
    "wandb_project": project,
}

In [6]:
train_dataset, test_dataset, class_weights, tokenizer = data_preparation(config,download=False)

In [7]:
#train_dataset = train_dataset.select(range(100))
#test_dataset = test_dataset.select(range(100))

In [None]:
finetuning(train_dataset,test_dataset,config)

Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Auc
1,0.0729,0.408267,0.246493,0.76477,0.372822,0.836638
2,0.0577,0.5363,0.267435,0.670167,0.382308,0.799162
3,0.055,0.503673,0.268758,0.714185,0.390547,0.819068


In [None]:
"""
### not in use
# Define a function to compute the metrics
def compute_metrics_evalaution(dataset, model, data_collator,prefix):
    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)
    
    # Remove padding and special tokens
    mask = labels != -100
    true_labels = labels[mask].flatten()
    flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute the MCC
    
    return {prefix+"accuracy": accuracy, prefix+"precision": precision, prefix+"recall": recall, prefix+"f1": f1, prefix+"auc": auc, prefix+"mcc": mcc}  
"""