## Hyperparameter tuning using WandB

In [None]:
#Importing Libraries
import wandb
import evaluate
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from evaluate import load
from datasets import Dataset
import numpy as np
import pandas as pd
import os
from scipy.special import softmax
from peft import LoraConfig, get_peft_model
import torch
import torch.nn.functional as F
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score,matthews_corrcoef


In [None]:
def data_load():
    train_bbbp=pd.read_csv('bbbp_dataset/train.csv')
    val_bbbp=pd.read_csv('bbbp_dataset/valid.csv')

    return train_bbbp, val_bbbp

In [None]:
def data_prep(data_process,tokenizer):

    smiles_list = data_process['smiles'].tolist()
    tokenized=tokenizer(smiles_list)
    
    
    dataset = Dataset.from_dict(tokenized)
    

    labels = data_process['p_np'].tolist() 
    
    dataset = dataset.add_column("labels", labels)
    

    return dataset

In [None]:
from peft import LoraConfig, get_peft_model, PeftModel

def lora_config(r,lora_alpha,dropout):

    lora_config = LoraConfig(
        task_type="SEQ_CLS",  # Sequence classification task
        r=r,  
        lora_alpha=lora_alpha,  
        target_modules='all-linear',
        lora_dropout=dropout 
    )

    return lora_config

### Weighted Loss Function

In [None]:
def class_weights_calculation(train_dataset):

        # Calculate class weights based on the distribution of labels
        class_weights = [1 - (train_dataset['labels'].count(0) / len(train_dataset['labels'])),
                        1 - (train_dataset['labels'].count(1) / len(train_dataset['labels']))]
        return torch.from_numpy(np.array(class_weights)).float()


class WeightedLossTrainer(Trainer):
   
    def compute_loss(self, model, inputs,train_dataset, return_outputs=False, num_items_in_batch=None):

        outputs = model(**inputs)
        logits = outputs.get("logits")

        # Extract labels
        labels = inputs.get("labels")
        
        class_weights= class_weights_calculation(train_dataset)
        # compute custom loss (suppose one has 2 labels with different weights)
        loss_func = torch.nn.CrossEntropyLoss(weight=class_weights)

        # compute loss
        loss = loss_func(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

### Focal Loss Function

In [None]:
#focal loss computation
import torch.nn.functional as F
import torch

def focal_loss(inputs, targets, alpha=1, gamma=2):
    log_prob = F.log_softmax(inputs, dim=-1)
    prob = torch.exp(log_prob)  # Convert log probabilities back to normal probabilities

    targets_one_hot = F.one_hot(targets, num_classes=inputs.shape[-1])
    pt = torch.sum(prob * targets_one_hot, dim=-1)  # Get probability of the true class

    focal_loss = -alpha * (1 - pt) ** gamma * torch.sum(log_prob * targets_one_hot, dim=-1)
    
    return focal_loss.mean() 


class FocalLossTrainer(Trainer):
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        loss = focal_loss(logits, labels)
        
        return (loss, outputs) if return_outputs else loss 

In [None]:
from evaluate import load
import numpy as np
from scipy.special import softmax
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score,matthews_corrcoef

accuracy_metric = load("accuracy")
mcc_metric= load("matthews_correlation")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    
    probabilities = softmax(logits, axis=1)[:, 1]  # Get probabilities for class 1
    predictions = np.argmax(logits, axis=1)  # Choose the most likely class
    

    mcc = matthews_corrcoef(labels, predictions)

    return {
        "eval_mcc_metric": mcc,
        "Accuracy": accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"],
        "AUC-ROC": roc_auc_score(labels, probabilities),  # AUC-ROC requires probabilities
        "Precision": precision_score(labels, predictions),
        "Recall": recall_score(labels, predictions),
        "F1-score": f1_score(labels, predictions)
    } 



In [None]:
import re

sweep_config = {
"name": "Flavor Hyperparameter Tuning",
"method": "bayes",
"metric": {
    "goal": "maximize", 
    "name": "eval/mcc_metric"},
"parameters": {"lr": {
        "distribution": "uniform",
        "min": 1e-5,  
        "max": 2e-3},
    "r": {"values": [4,8,16,32,64, 128]},
    "lora_alpha": {"values": [4,8,16,32,64,128]},
    "dropout": {"values": [0.0,0.1,0.2] },
    
    "optimizer": {"value": ["adamw"]}}
}

sweep_id = wandb.sweep(sweep_config, project="huggingface")

model_list= ["DeepChem/ChemBERTa-77M-MLM",
             "DeepChem/ChemBERTa-10M-MLM",
             "DeepChem/ChemBERTa-10M-MTR",
             "DeepChem/ChemBERTa-5M-MTR",
             "DeepChem/ChemBERTa-77M-MTR",
             "ibm/MoLFormer-XL-both-10pct"]
for model_name in model_list:
    print(f"Running sweep for model: {model_name}")
    
    def safe_model_name(name1):
        return re.sub(r"[^a-zA-Z0-9]", "__", name1)


    def run_training():
        print(f"Running training for model: {model_name}")
        # Initialize W&B with sweep
        run = wandb.init(project="BBBP focal loss Hyperparameter Tuning")
        config = run.config

        model_id_clean = safe_model_name(model_name)
        print(f"Model ID cleaned: {model_id_clean}")
        run_id = wandb.run.id

        # Define unique output folders
        save_dir = f".../{model_id_clean}/{run_id}"
        logging_dir = f".../{model_id_clean}/{run_id}"
        os.makedirs(save_dir, exist_ok=True)

        # Load tokenizer and model
        tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=2,
            trust_remote_code=True
        )

        # Load and preprocess data

        train_data, val_data = data_load()
        training_data = data_prep(train_data, tokenizer)
        validation_data = data_prep(val_data, tokenizer)

        # Apply LoRA
        peft_config = lora_config(config.r, config.lora_alpha, config.dropout)
        lora_model = get_peft_model(model, peft_config)
        lora_model.print_trainable_parameters()


        # Define training args
        training_args = TrainingArguments(
            output_dir=save_dir,
            eval_strategy="epoch",
            learning_rate=config.lr,
            per_device_train_batch_size=32,
            per_device_eval_batch_size=32,
            num_train_epochs=10,
            weight_decay=0.01,
            save_strategy="epoch",
            logging_dir=logging_dir,
            logging_strategy="steps",
            logging_steps=500,
            report_to="wandb",
            save_total_limit=3,
            load_best_model_at_end=True,
            metric_for_best_model="eval_mcc_metric"
        )

        accuracy_metric = load("accuracy")

        def compute_metrics(eval_pred):
            logits, labels = eval_pred
            probabilities = softmax(logits, axis=1)[:, 1]
            predictions = np.argmax(logits, axis=1)
            mcc = matthews_corrcoef(labels, predictions)

            return {
                "eval_mcc_metric": mcc,
                "Accuracy": accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"],
                "AUC-ROC": roc_auc_score(labels, probabilities),
                "Precision": precision_score(labels, predictions),
                "Recall": recall_score(labels, predictions),
                "F1-score": f1_score(labels, predictions)
            }

        # Train with weigted loss trainer
        trainer = WeightedLossTrainer(
            model=lora_model,
            args=training_args,
            train_dataset=training_data,
            eval_dataset=validation_data,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
        )
        
        '''
        # train with focal loss trainer
        trainer = FocalLossTrainer(
            model=lora_model,
            args=training_args,
            train_dataset=training_data,
            eval_dataset=validation_data,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
        )'''

        trainer.train()
        trainer.save_model(save_dir)
        print(f"Model saved to {save_dir}")
        print(f"Training completed for model: {model_name}")
        
        wandb.finish()



    wandb.agent(sweep_id, function=run_training, count=5)

    api = wandb.Api()
    sweep = api.sweep(f"huggingface/{sweep_id}")
    print(sweep.runs[0].summary_metrics)

    runs_with_rmse = [run for run in sweep.runs if 'eval/mcc_metric' in run.summary_metrics]
    if runs_with_rmse:
        # Sort by rmse in descending order (maximize)
        best_run = sorted(runs_with_rmse, key=lambda run: run.summary_metrics['eval/mcc_metric'])[0]
    else:
        raise ValueError("No runs found with 'eval/mcc_metric' metric.")

    best_hyperparameters = best_run.config
    print(f"Best hyperparameters: {best_hyperparameters}")
    print("completed sweep for model: ",model_name)




## Evaluation

In [None]:
from evaluate import load
import numpy as np
from scipy.special import softmax
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, matthews_corrcoef

accuracy_metric = load("accuracy")

def compute_metrics(eval_pred):
        logits, labels = eval_pred
        probabilities = softmax(logits, axis=1)[:, 1]  # Get probabilities for class 1
        predictions = np.argmax(logits, axis=1)  # Choose the most likely class
        mcc = matthews_corrcoef(labels, predictions)
        
        
        return {
            "eval_mcc_metric": mcc,
            "Accuracy": accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"],
            "AUC-ROC": roc_auc_score(labels, probabilities),  # AUC-ROC requires probabilities
            "Precision": precision_score(labels, predictions),
            "Recall": recall_score(labels, predictions),
            "F1-score": f1_score(labels, predictions)
        }

In [None]:

# Map your folder names to the base HuggingFace model names
MODEL_NAME_MAP = {
    "DeepChem__ChemBERTa__5M__MTR": "DeepChem/ChemBERTa-5M-MTR",
    "DeepChem__ChemBERTa__10M__MTR": "DeepChem/ChemBERTa-10M-MTR",
    "DeepChem__ChemBERTa__77M__MLM": "DeepChem/ChemBERTa-77M-MLM",
    "DeepChem__ChemBERTa__10M__MLM": "DeepChem/ChemBERTa-10M-MLM",
    "DeepChem__ChemBERTa__77M__MTR": "DeepChem/ChemBERTa-77M-MTR",
    "ibm__MoLFormer__XL__both__10pct": "ibm/MoLFormer-XL-both-10pct",
       
}

test_data=pd.read_csv('bbbp_dataset/test.csv')

models_root_dir = ".../models_bbbp_chem_WL" #add model saved path

eval_args = TrainingArguments(
    output_dir="./test_results_bbbp",
    per_device_eval_batch_size=32,
    report_to="none",  
    disable_tqdm=True, 
  

)

def find_all_peft_checkpoints(root_dir):
    checkpoints = []
    for model_folder in os.listdir(root_dir):
        model_folder_path = os.path.join(root_dir, model_folder)
        if not os.path.isdir(model_folder_path):
            continue
        for run_id in os.listdir(model_folder_path):
            run_path = os.path.join(model_folder_path, run_id)
            if not os.path.isdir(run_path):
                continue
            for subdir in os.listdir(run_path):
                checkpoint_path = os.path.join(run_path, subdir)
                if subdir.startswith("checkpoint-") and os.path.exists(os.path.join(checkpoint_path, "adapter_config.json")):
                    checkpoints.append((model_folder, run_id, checkpoint_path))
    return checkpoints

valid_checkpoints = find_all_peft_checkpoints(models_root_dir)
print(f"Found {len(valid_checkpoints)} valid checkpoints.")

for model_folder, run_id, checkpoint_path in valid_checkpoints:
    print("Model folder: ",model_folder)

    hf_model_name = MODEL_NAME_MAP[model_folder]
    print(f"Using base model: {hf_model_name}")
    # Load tokenizer and base model for the model type
    tokenizer = AutoTokenizer.from_pretrained(hf_model_name, trust_remote_code=True)
    base_model = AutoModelForSequenceClassification.from_pretrained(
        hf_model_name,
        num_labels=2,
        problem_type="single_label_classification",
        trust_remote_code=True
    )

    from datasets import Dataset

    smiles_test = test_data['smiles'].tolist()

    test_tokenized =tokenizer(smiles_test)

    test_dataset = Dataset.from_dict(test_tokenized)

    test_labels = test_data['p_np'].tolist() 


    test_dataset = test_dataset.add_column("labels", test_labels)

    # Load the adapter checkpoint
    adapter_model = PeftModel.from_pretrained(base_model, checkpoint_path)
    adapter_model.eval()

    # Eval
    from transformers import Trainer

    trainer = Trainer(
        model=adapter_model,
        args=eval_args,
        eval_dataset=test_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    print(f"\n Evaluating {model_folder}/{run_id}/{os.path.basename(checkpoint_path)}")
    
    test_results = trainer.evaluate()

## Load and Merge Base Model with LoRA weights

In [None]:
base_model = AutoModelForSequenceClassification.from_pretrained(
    "DeepChem/ChemBERTa-77M-MLM", #add the base model name
    num_labels=2,
    problem_type="single_label_classification",    
    trust_remote_code=True,
    
)
from peft import PeftModel  

adapter_model = PeftModel.from_pretrained(base_model, ".../models_bbbp_chem_WL/DeepChem__ChemBERTa__77M__MLM/x38bwbvz/checkpoint-416")

final_model_clintox_molformer= adapter_model.merge_and_unload()

In [None]:
save_path = ".../Final_merged_model/weighted_loss_bbbp/final_model_chem_77M-MLM-WL"

final_model_clintox_molformer.save_pretrained(save_path)