# PEFT with DNA Language Models

This notebook demonstrates how to utilize parameter-efficient fine-tuning techniques (PEFT) from the PEFT library to fine-tune a DNA Language Model (DNA-LM). The fine-tuned DNA-LM will be applied to solve a task from the nucleotide benchmark dataset. Parameter-efficient fine-tuning (PEFT) techniques are crucial for adapting large pre-trained models to specific tasks with limited computational resources.

### 1. Import relevant libraries

We'll start by importing the required libraries, including the PEFT library and other dependencies.

In [53]:
import sklearn
import torch
import transformers 
import peft
import tqdm
import math

### 2. Load models


We'll load a pre-trained DNA Language Model, "SpeciesLM", that serves as the base for fine-tuning. This is done using the transformers library from HuggingFace.

In [54]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [55]:
tokenizer = AutoTokenizer.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")
lm = AutoModelForMaskedLM.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")

In [56]:
print(torch.cuda.is_available())
lm.eval()
lm.to("cuda")
print("Done!")

True
Done!


### 2. Prepare datasets

We'll load the `nucleotide_transformer_downstream_tasks` dataset, which contains 18 downstream tasks from the Nucleotide Transformer paper. This dataset provides a consistent genomics benchmark with both binary and multi-class classification tasks.

In [57]:
from datasets import load_dataset

raw_data = load_dataset("InstaDeepAI/nucleotide_transformer_downstream_tasks", "H3")

We'll use the "H3" subset of this dataset, which contains a total of 13,468 rows in the training data, and 1497 rows in the test data.

In [58]:
raw_data

DatasetDict({
    train: Dataset({
        features: ['sequence', 'name', 'label'],
        num_rows: 13468
    })
    test: Dataset({
        features: ['sequence', 'name', 'label'],
        num_rows: 1497
    })
})

The dataset consists of three columns, ```sequence```, ```name``` and ```label```. An example is given below.

In [59]:
raw_data['train'][0]

{'sequence': 'TCACTTCGATTATTGAGGCAGTCTTCATTAAAGTTTATTACAATGGATATGGTATCACCAGTCTTGAACCTACAATCATCTATTTTAGGTGAGCTCGTAGGCATTATTGGAAAAGTGTTCTTTCTCTTAATAGAAGAGATTAAATACCCGATAATCACACCCAAAATTATTGTGGATGCCCAGATATCTTCTTGGTCATTGTTTTTTTTCGCTTCAATCTGTAATCTCTCTGCAAAATTTCGGGAGCCAATAGTGACAACATCGTCAATAATAAGTTTGATGGAATCGGAAAAAGATCTTAAAAATGTAAATGAGTATTTCCAAATAATGGCCAAAATGCTCTTTATATTGGAAAATAAAATAGTTGTTTCGCTCTTCGTAGTATTTAACATTTCCGTTCTTATCATTGTAAAGTCTGAGCCATATTCATATGGAAAAGTGCTTTTTAAACCTAGTTCCTCCATATTTTAGTTTTTTATCGATATTGGAAAAAAAAGAGC',
 'name': 'YBR063C_YBR063C_367930|0',
 'label': 0}

We split out dataset into training, test, and validation.

In [61]:
from datasets import Dataset, DatasetDict

train_valid_split = raw_data['train'].train_test_split(test_size=0.15, seed=42)

train_valid_split = DatasetDict({
    'train': train_valid_split['train'],
    'validation': train_valid_split['test']
})

ds = DatasetDict({
    'train': train_valid_split['train'],
    'validation': train_valid_split['validation'],
    'test': raw_data['test']
})

Then, we generate our data and labels.

In [62]:
def get_kmers(seq, k=6, stride=1):
    return [seq[i:i + k] for i in range(0, len(seq), stride) if i + k <= len(seq)]

In [63]:
test_sequences = []
train_sequences = []
val_sequences = []

dataset_limit = 200

for i in range(0, len(ds['train'])):
    
    if dataset_limit and i == dataset_limit:
        break
        
    sequence = ds['train'][i]['sequence']
    sequence = "candida_glabrata " + " ".join(get_kmers(sequence))
    sequence = tokenizer(sequence)["input_ids"]
    train_sequences.append(sequence)
    

for i in range(0, len(ds['validation'])):
    if dataset_limit and i == dataset_limit:
        break
    sequence = ds['validation'][i]['sequence']
    sequence = "candida_glabrata " + " ".join(get_kmers(sequence))
    sequence = tokenizer(sequence)["input_ids"]
    val_sequences.append(sequence)
    

for i in range(0, len(ds['test'])):
    if dataset_limit and i == dataset_limit:
        break
    sequence = ds['test'][i]['sequence']
    sequence = "candida_glabrata " + " ".join(get_kmers(sequence))
    sequence = tokenizer(sequence)["input_ids"]
    test_sequences.append(sequence)
    

train_labels = ds['train']['label']
test_labels = ds['test']['label']
val_labels = ds['validation']['label']

if dataset_limit:
    train_labels = train_labels[0:dataset_limit]
    test_labels = test_labels[0:dataset_limit]
    val_labels = val_labels[0:dataset_limit]

Finally, we create a Dataset object for each our sets.

In [64]:
import pandas as pd
from datasets import Dataset

a = {"input_ids": train_sequences, "labels": train_labels}
df = pd.DataFrame.from_dict(a)
train_dataset = Dataset.from_pandas(df)

b = {"input_ids": val_sequences, "labels": val_labels}
df = pd.DataFrame.from_dict(b)
val_dataset = Dataset.from_pandas(df)

c = {"input_ids": test_sequences, "labels": test_labels}
df = pd.DataFrame.from_dict(c)
test_dataset = Dataset.from_pandas(df)

### 4. Train model

Now, we'll train our DNA Language Model with the training dataset. We'll add a linear layer in the final layer of our language model, and then, train all the parameteres of our model with the training dataset.

In [65]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [74]:
import torch
from torch import nn

class DNA_LM(nn.Module):
    def __init__(self, model, num_labels):
        super(DNA_LM, self).__init__()
        self.model = model
        self.classifier = nn.Linear(768, num_labels)
        self.out_features = 2
        self.in_features = 768

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        sequence_output = outputs.hidden_states[-1]
        # Use the [CLS] token for classification
        cls_output = sequence_output[:, 0, :]
        logits = self.classifier(cls_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.out_features), labels.view(-1))

        return (loss, logits) if loss is not None else logits

# Number of classes for your classification task
num_labels = 2
classification_model = DNA_LM(lm, num_labels)
classification_model.to('cuda')

DNA_LM(
  (model): BertForMaskedLM(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(5504, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
       

In [75]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [76]:
from transformers import Trainer, TrainingArguments

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
)

# Initialize Trainer
trainer = Trainer(
    model=classification_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Train the model
trainer.train()

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,0.482426
2,No log,0.413035
3,No log,0.563575
4,No log,0.589309
5,No log,0.616723


TrainOutput(global_step=65, training_loss=0.19404044518103966, metrics={'train_runtime': 26.5384, 'train_samples_per_second': 37.681, 'train_steps_per_second': 2.449, 'total_flos': 0.0, 'train_loss': 0.19404044518103966, 'epoch': 5.0})

### 5. Evaluation

In [77]:
# Generate predictions

predictions = trainer.predict(test_dataset)
logits = predictions.predictions
predicted_labels = logits.argmax(axis=-1)
print(predicted_labels)

[0 1 1 0 0 1 1 0 0 0 0 1 0 0 1 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 1 1 1 0 0 0 0
 1 1 0 0 0 1 1 0 1 0 0 0 0 0 1 0 0 0 0 1 1 1 0 1 1 1 1 0 0 1 0 1 1 0 0 0 1
 0 0 0 0 1 0 0 0 0 1 0 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 1 1 0 0 0 1 0 0 0 0 0
 0 0 1 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 1 0 0 1 0 1 0 1 0 1 1 0 1 0 1 1 0 0 1
 0 0 1 1 0 0 1 0 1 0 1 0 1 1 1 0 1 0 0 1 0 1 0 1 0 0 0 1 1 0 1 1 0 0 1 0 0
 1 0 0 1 1 0 1 1 1 1 0 1 1 1 0]


In [79]:
# Step 3: Evaluate the Predictions (Optional)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

accuracy = accuracy_score(test_labels, predicted_labels)
print(f"Accuracy: {accuracy}")

precision, recall, f1, _ = precision_recall_fscore_support(test_labels, predicted_labels, average='weighted')
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1-Score: {f1}")

Accuracy: 0.815
Precision: 0.8283914728682171
Recall: 0.815
F1-Score: 0.8142075202284625


### 7. Parameter Efficient Fine-Tuning Techniques

In [80]:
classification_model

DNA_LM(
  (model): BertForMaskedLM(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(5504, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
       

In [81]:
from peft import LoraConfig, TaskType

peft_config = LoraConfig(
    task_type="SEQ_CLS",
    r=8,
    lora_alpha=32,
    target_modules=["query", "key", "value"],
    lora_dropout=0.01,
)

In [82]:
from peft import get_peft_model

peft_model = get_peft_model(classification_model, peft_config)
peft_model.print_trainable_parameters()

trainable params: 443,906 || all params: 90,720,900 || trainable%: 0.4893


In [83]:
peft_model

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): DNA_LM(
      (model): BertForMaskedLM(
        (bert): BertModel(
          (embeddings): BertEmbeddings(
            (word_embeddings): Embedding(5504, 768, padding_idx=0)
            (position_embeddings): Embedding(512, 768)
            (token_type_embeddings): Embedding(2, 768)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder): BertEncoder(
            (layer): ModuleList(
              (0-11): 12 x BertLayer(
                (attention): BertAttention(
                  (self): BertSdpaSelfAttention(
                    (query): lora.Linear(
                      (base_layer): Linear(in_features=768, out_features=768, bias=True)
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.01, inplace=False)
                      )
                      (lora

In [84]:
from transformers import Trainer, TrainingArguments

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
)

# Initialize Trainer
trainer = Trainer(
    model=peft_model.model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Train the model
trainer.train()

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,No log,0.643141
2,No log,0.631494
3,No log,0.614654
4,No log,0.609057
5,No log,0.610397
6,No log,0.615955
7,No log,0.621695
8,No log,0.61908
9,No log,0.620627
10,No log,0.620246


TrainOutput(global_step=130, training_loss=0.10182150327242338, metrics={'train_runtime': 48.6567, 'train_samples_per_second': 41.104, 'train_steps_per_second': 2.672, 'total_flos': 0.0, 'train_loss': 0.10182150327242338, 'epoch': 10.0})

### 8. Evaluate PEFT Model

In [85]:
# Generate predictions

predictions = trainer.predict(test_dataset)
logits = predictions.predictions
predicted_labels = logits.argmax(axis=-1)
print(predicted_labels)

[0 1 1 0 0 1 1 0 0 0 0 1 0 0 1 0 1 1 0 0 1 1 0 0 0 0 1 0 0 0 1 1 1 0 0 0 0
 1 1 0 0 0 1 1 0 1 0 0 0 0 0 1 0 0 0 0 1 1 1 0 1 1 1 1 0 0 1 0 1 1 0 0 0 1
 0 0 0 0 1 0 0 0 0 1 0 0 1 0 1 0 0 0 0 1 0 1 0 1 0 0 1 1 0 0 0 1 0 0 0 0 0
 0 0 1 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 1 0 0 1 0 1 0 1 0 1 1 0 1 0 1 1 0 0 1
 0 0 1 1 0 0 1 0 1 0 1 0 1 1 1 0 1 0 0 1 0 1 0 1 0 0 0 1 1 0 1 1 0 0 1 0 0
 1 0 0 1 1 0 1 1 1 1 0 1 1 1 1]


In [86]:
# Step 3: Evaluate the Predictions (Optional)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

accuracy = accuracy_score(test_labels, predicted_labels)
print(f"Accuracy: {accuracy}")

precision, recall, f1, _ = precision_recall_fscore_support(test_labels, predicted_labels, average='weighted')
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1-Score: {f1}")

Accuracy: 0.81
Precision: 0.8219560573695452
Recall: 0.81
F1-Score: 0.8093149038461539


### <WIP> Embed the sequences

We embed the sequences in the training dataset using our DNA-Language Model. 

We start by creating a function that generates kmers for any given sequence (```get_kmers``` below).

In [None]:
def get_kmers(seq, k=6, stride=1):
    return [seq[i:i + k] for i in range(0, len(seq), stride) if i + k <= len(seq)]

Then, we tokenize our sequences using the tokenizer.

Next, we create a ```torch.Tensor``` matrix from our sequences.

In [None]:
tokenized_sequences = torch.tensor(sequences)

Checking the shape of our tokenized sequences, we get:

In [None]:
tokenized_sequences.shape

We'll generate the embeddings for our tokenized sequences.

In [None]:
embedding = lm(tokenized_sequences.to('cuda'), output_hidden_states=True)["hidden_states"]

In [None]:
type(embedding[12])

In [None]:
embeddings = []
embeddings.append(embedding[12])

In [None]:
embeddings = []
batch_size = 64
device = "cuda"
count = 0

for i in tqdm.tqdm(range(math.ceil(tokenized_sequences.shape[0]/batch_size))):
    with torch.inference_mode():
        with torch.autocast(device):
            embedding = lm(tokenized_sequences[i * batch_size:(i+1)*(batch_size)].to(device), output_hidden_states=True)["hidden_states"]
            
            #embedding = torch.stack(embedding[12:], axis=0)
            embeddings.append(embedding.cpu())
            
            # embedding = lm(tokenized_sequences[i*batch_size:(i+1)*(batch_size)].to(device), output_hidden_states=True)["hidden_states"]
            # embedding = torch.stack(embedding[8:], axis=0)
            # embedding = torch.mean(embedding[:,2:-1,:], axis=1)
            # embeddings.append(embedding.cpu())

#embeddings = torch.concat(embeddings, axis=0)

In [None]:
tokenized

In [None]:
outputs.hidden_states[-1].shape

In [None]:
last_hidden_states = outputs.hidden_states[-1]
last_hidden_states.shape

In [None]:
embeddings.shape