# Molformer Full Finetuning

### Installing deepseed for FusedLamb optimizer

### Load Molformer model and tokeniser

In [1]:
import os
import torch
  
device = torch.device("cuda:0")

In [2]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer


model_name = "ibm/MoLFormer-XL-both-10pct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2,problem_type="single_label_classification",trust_remote_code=True)



  from .autonotebook import tqdm as notebook_tqdm
Some weights of MolformerForSequenceClassification were not initialized from the model checkpoint at ibm/MoLFormer-XL-both-10pct and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.dense2.bias', 'classifier.dense2.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model.to(device)

MolformerForSequenceClassification(
  (molformer): MolformerModel(
    (embeddings): MolformerEmbeddings(
      (word_embeddings): Embedding(2362, 768, padding_idx=2)
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (encoder): MolformerEncoder(
      (layer): ModuleList(
        (0-11): 12 x MolformerLayer(
          (attention): MolformerAttention(
            (self): MolformerSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (rotary_embeddings): MolformerRotaryEmbedding()
              (feature_map): MolformerFeatureMap(
                (kernel): ReLU()
              )
            )
            (output): MolformerSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

### Load Dataset

In [4]:
import pandas as pd

train_clin=pd.read_csv('/home/raghvendra2/Molformer_Finetuning/clintox_train.csv')
val_clin=pd.read_csv('/home/raghvendra2/Molformer_Finetuning/clintox_valid.csv')

In [None]:
print(train_clin.isnull().sum()) 

### Preparing Training and Validation Dataset for Training

In [5]:
from datasets import Dataset

smiles_list_clin = train_clin['smiles'].tolist()
smiles_val_clin=val_clin['smiles'].tolist()

  # Adjust based on your model's limits

train_tokenized_clin = tokenizer(
    smiles_list_clin, padding=True
)
val_tokenized_clin = tokenizer(
    smiles_val_clin, padding=True
)




train_dataset_clin = Dataset.from_dict(train_tokenized_clin)
val_dataset_clin = Dataset.from_dict(val_tokenized_clin)

train_labels_clin = train_clin['CT_TOX'].tolist() # Assuming tasks start from column 1
val_labels_clin = val_clin['CT_TOX'].tolist()



train_dataset_clin = train_dataset_clin.add_column("labels", train_labels_clin)
val_dataset_clin = val_dataset_clin.add_column("labels", val_labels_clin)



### Finetune with Loss Trainer

In [6]:
from evaluate import load
import numpy as np
from scipy.special import softmax
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score,matthews_corrcoef

accuracy_metric = load("accuracy")

def compute_metrics(eval_pred):

    
    logits, labels = eval_pred
    #logits = torch.tensor(logits)
    probabilities = softmax(logits, axis=1)[:, 1]  # Get probabilities for class 1
    predictions = np.argmax(logits, axis=1)  
   
    #probabilities = expit(logits)
    print(logits)
    

    predictions = np.argmax(logits, axis=1)  
    print(predictions)
    mcc = matthews_corrcoef(labels, predictions)
    
    

    return {
        "eval_mcc_metric": mcc,
        "Accuracy": accuracy_metric.compute(predictions=predictions, references=labels)["accuracy"],
        "AUC-ROC": roc_auc_score(labels, probabilities),
        "Precision": precision_score(labels, predictions),
        "Recall": recall_score(labels, predictions),
        "F1-score": f1_score(labels, predictions)
    }

In [7]:
from transformers import Trainer
import torch

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False,num_items_in_batch=None):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.logits


        # Apply loss function        

        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        # Debugging NaN loss
        if torch.isnan(loss):
            print("⚠️ Warning: Loss is NaN!")
        return (loss, outputs) if return_outputs else loss

    


In [9]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

training_args = TrainingArguments(
    output_dir="./molformer_clintox_fullfinetune",
    learning_rate=3e-5,  # Match the shell script
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    save_strategy="steps",         # Save model every N steps
    save_steps=1000,
    num_train_epochs=500,  # Match the shell script
    logging_dir="./logs",
    logging_steps=200,
    eval_steps=1000,
    evaluation_strategy="steps",
    save_total_limit=None,  # Save all checkpoints
    load_best_model_at_end=True,   
    metric_for_best_model="AUC-ROC"
    
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_clin,
    eval_dataset=val_dataset_clin,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
)
trainer.train()








Step,Training Loss,Validation Loss,Mcc Metric,Accuracy,Auc-roc,Precision,Recall,F1-score
1000,0.0,0.398748,0.697217,0.972973,0.904464,1.0,0.5,0.666667
2000,0.0,0.425769,0.697217,0.972973,0.874107,1.0,0.5,0.666667
3000,0.0,0.451263,0.697217,0.972973,0.884821,1.0,0.5,0.666667
4000,0.0,0.448866,0.697217,0.972973,0.9125,1.0,0.5,0.666667
5000,0.0,0.45913,0.697217,0.972973,0.91875,1.0,0.5,0.666667


[[ 7.976142  -8.2034645]
 [ 7.44862   -7.666009 ]
 [ 7.455622  -7.333296 ]
 [ 8.047769  -7.8109713]
 [ 8.233107  -8.240908 ]
 [ 7.915318  -7.825328 ]
 [ 8.315034  -8.498877 ]
 [ 7.9342985 -7.809506 ]
 [ 8.437235  -8.633329 ]
 [ 8.59896   -8.195477 ]
 [ 7.5281596 -7.861973 ]
 [ 7.6701136 -7.3097057]
 [ 7.7835135 -7.5402365]
 [ 8.093122  -8.210205 ]
 [ 8.39883   -8.630177 ]
 [ 7.9815817 -7.9596405]
 [ 8.118087  -7.8033733]
 [ 8.36378   -8.136759 ]
 [ 8.202963  -8.109262 ]
 [ 8.047143  -8.073186 ]
 [ 7.932391  -8.257773 ]
 [ 8.269874  -8.21621  ]
 [ 8.427959  -8.458924 ]
 [ 7.9172153 -7.789404 ]
 [ 8.179612  -8.106303 ]
 [ 8.134883  -7.8583164]
 [ 7.967719  -8.095737 ]
 [ 8.129452  -8.37828  ]
 [ 8.258147  -8.158188 ]
 [ 7.172181  -6.802682 ]
 [ 8.232873  -8.341744 ]
 [ 7.993139  -7.6231103]
 [ 7.906086  -7.844763 ]
 [ 7.9474926 -8.108671 ]
 [ 8.155557  -8.156211 ]
 [ 8.273928  -8.530152 ]
 [ 7.318047  -7.007184 ]
 [ 6.6843014 -6.4497404]
 [ 7.8696604 -7.595783 ]
 [-7.420062   7.5792503]




[[ 8.4600315 -8.822439 ]
 [ 7.7215567 -8.20003  ]
 [ 7.6789007 -7.473418 ]
 [ 8.527677  -8.432951 ]
 [ 8.468397  -8.43548  ]
 [ 8.073323  -7.9199996]
 [ 8.5418415 -8.756078 ]
 [ 8.417519  -8.371271 ]
 [ 8.692868  -8.8452015]
 [ 8.866499  -8.63217  ]
 [ 8.20931   -8.515105 ]
 [ 8.214742  -7.8535013]
 [ 8.417879  -8.726584 ]
 [ 8.55682   -8.661205 ]
 [ 8.541762  -8.927278 ]
 [ 8.480862  -8.50317  ]
 [ 8.424135  -8.828307 ]
 [ 8.552292  -8.374814 ]
 [ 8.474842  -8.346403 ]
 [ 8.580782  -8.699047 ]
 [ 8.26714   -8.739472 ]
 [ 8.277401  -8.289169 ]
 [ 8.636427  -8.725577 ]
 [ 7.95376   -7.5760007]
 [ 8.487286  -8.504638 ]
 [ 8.365292  -7.929423 ]
 [ 8.402824  -8.409408 ]
 [ 8.602202  -8.844472 ]
 [ 8.547641  -8.38483  ]
 [ 8.026261  -7.6109295]
 [ 8.338638  -8.462088 ]
 [ 8.48992   -8.047906 ]
 [ 8.28      -8.3168545]
 [ 8.514458  -8.6716   ]
 [ 8.567383  -8.533015 ]
 [ 8.667568  -8.821632 ]
 [ 7.9340057 -7.908284 ]
 [ 7.070325  -6.7257953]
 [ 8.168016  -7.818686 ]
 [-7.8820143  7.975925 ]




[[ 8.816636  -9.218263 ]
 [ 8.206443  -8.572409 ]
 [ 8.035468  -7.8045273]
 [ 8.590119  -8.328663 ]
 [ 8.875689  -8.670244 ]
 [ 8.582598  -8.455249 ]
 [ 9.030329  -9.0968275]
 [ 8.8937645 -8.592189 ]
 [ 9.140474  -9.279081 ]
 [ 8.985215  -8.588616 ]
 [ 8.61796   -9.015681 ]
 [ 8.740722  -8.606803 ]
 [ 8.721416  -9.148652 ]
 [ 8.897603  -9.028558 ]
 [ 8.936075  -9.149446 ]
 [ 8.8066845 -8.657784 ]
 [ 8.852861  -9.253287 ]
 [ 8.965634  -8.795188 ]
 [ 8.944965  -8.754214 ]
 [ 8.966086  -9.114037 ]
 [ 8.81981   -9.000112 ]
 [ 8.735507  -8.687104 ]
 [ 9.08207   -9.078596 ]
 [ 8.571126  -8.376744 ]
 [ 8.867149  -8.830878 ]
 [ 8.759537  -8.306289 ]
 [ 8.729133  -8.804563 ]
 [ 8.879483  -9.121028 ]
 [ 8.936614  -8.75976  ]
 [ 8.278281  -7.787169 ]
 [ 8.735613  -8.774414 ]
 [ 8.80815   -8.30507  ]
 [ 8.427335  -8.38117  ]
 [ 8.852214  -8.91591  ]
 [ 8.989321  -9.017148 ]
 [ 9.008277  -9.141353 ]
 [ 8.540649  -8.666766 ]
 [ 7.488184  -6.9350386]
 [ 8.605814  -8.226139 ]
 [-8.133172   8.308456 ]




[[ 8.932809  -9.438262 ]
 [ 8.319634  -8.7654   ]
 [ 7.993917  -7.684452 ]
 [ 9.199221  -9.195035 ]
 [ 9.0554285 -8.9361   ]
 [ 8.622343  -8.363922 ]
 [ 9.155583  -9.152173 ]
 [ 8.840806  -9.180493 ]
 [ 9.357496  -9.470092 ]
 [ 9.401744  -9.032939 ]
 [ 8.96011   -8.891446 ]
 [ 9.047278  -8.721943 ]
 [ 8.951804  -9.079738 ]
 [ 9.184522  -9.219443 ]
 [ 9.142862  -9.299851 ]
 [ 9.1579485 -8.919745 ]
 [ 9.104917  -9.148069 ]
 [ 9.157532  -9.0305195]
 [ 9.007813  -8.861681 ]
 [ 9.142425  -9.180175 ]
 [ 9.030221  -9.239734 ]
 [ 9.045479  -8.962274 ]
 [ 9.259438  -9.2431345]
 [ 8.576483  -8.340257 ]
 [ 9.161357  -9.093768 ]
 [ 8.936835  -8.613013 ]
 [ 9.10777   -9.1093645]
 [ 9.053472  -9.264688 ]
 [ 9.061289  -8.845121 ]
 [ 8.542988  -8.076984 ]
 [ 8.942612  -9.085942 ]
 [ 8.879145  -8.339897 ]
 [ 8.851014  -8.876269 ]
 [ 9.044156  -9.107879 ]
 [ 9.159938  -9.122339 ]
 [ 9.172993  -9.331667 ]
 [ 8.775998  -8.537176 ]
 [ 8.158906  -7.972153 ]
 [ 8.543487  -8.4133005]
 [-8.445318   8.56668  ]




[[ 9.060866  -9.441971 ]
 [ 8.435446  -8.894133 ]
 [ 8.443639  -8.263637 ]
 [ 9.161125  -8.902023 ]
 [ 9.096474  -9.0211315]
 [ 8.239926  -7.8163095]
 [ 9.11348   -9.176429 ]
 [ 9.066248  -9.142174 ]
 [ 9.327877  -9.513355 ]
 [ 9.466238  -9.095036 ]
 [ 9.010973  -8.783766 ]
 [ 9.316803  -9.164569 ]
 [ 9.105522  -8.893385 ]
 [ 9.15725   -9.107633 ]
 [ 9.180476  -9.384787 ]
 [ 9.103544  -8.985961 ]
 [ 9.20333   -9.285212 ]
 [ 9.165677  -8.974496 ]
 [ 9.15814   -8.991894 ]
 [ 9.14394   -9.330064 ]
 [ 9.028157  -9.289801 ]
 [ 9.090616  -8.943526 ]
 [ 9.317346  -9.148583 ]
 [ 8.759788  -8.502395 ]
 [ 9.125277  -9.061471 ]
 [ 9.025848  -8.546722 ]
 [ 9.038679  -9.104787 ]
 [ 9.144011  -9.236082 ]
 [ 9.176923  -8.970052 ]
 [ 8.5268    -8.244446 ]
 [ 9.001042  -9.121523 ]
 [ 9.106316  -8.638199 ]
 [ 8.876079  -8.807353 ]
 [ 9.148453  -9.187101 ]
 [ 9.248014  -9.184084 ]
 [ 9.261742  -9.351289 ]
 [ 8.496208  -7.808834 ]
 [ 7.3163366 -6.3999367]
 [ 8.934605  -8.439486 ]
 [-8.526587   8.742594 ]


TrainOutput(global_step=5000, training_loss=0.001226675476988021, metrics={'train_runtime': 1792.6748, 'train_samples_per_second': 330.512, 'train_steps_per_second': 2.789, 'total_flos': 3.794420015532e+16, 'train_loss': 0.001226675476988021, 'epoch': 500.0})