# Tutorial on PEFT with DNA Language Models

This notebook is a tutorial on how to use parameter-efficient fine-tuning techniques from the PEFT library to fine-tune a DNA Language Model. This fine-tuned DNA-LM is used to solve a task from the nucleotide benchmark dataset.

### 1. Import relevant libraries

In [56]:
import sklearn
import torch
import transformers 
import peft
import tqdm
import math

### 2. Load datasets

We load the ```nucleotide_transformer_downstream_tasks``` dataset, that contains 18 downstream tasks from the Nucleotide Transformer paper. It provides a consistent genomics benchmark with both binary and mulit-class classification tasks.

In [57]:
from datasets import load_dataset

ds = load_dataset("InstaDeepAI/nucleotide_transformer_downstream_tasks", "H3")

We'll use the "H3" subset of this dataset, which contains a total of 13,468 rows in the training data, and 1497 rows in the test data.

In [58]:
ds

DatasetDict({
    train: Dataset({
        features: ['sequence', 'name', 'label'],
        num_rows: 13468
    })
    test: Dataset({
        features: ['sequence', 'name', 'label'],
        num_rows: 1497
    })
})

The dataset consists of three columns, ```sequence```, ```name``` and ```label```. An example is given below.

In [59]:
ds['train'][0]

{'sequence': 'TCACTTCGATTATTGAGGCAGTCTTCATTAAAGTTTATTACAATGGATATGGTATCACCAGTCTTGAACCTACAATCATCTATTTTAGGTGAGCTCGTAGGCATTATTGGAAAAGTGTTCTTTCTCTTAATAGAAGAGATTAAATACCCGATAATCACACCCAAAATTATTGTGGATGCCCAGATATCTTCTTGGTCATTGTTTTTTTTCGCTTCAATCTGTAATCTCTCTGCAAAATTTCGGGAGCCAATAGTGACAACATCGTCAATAATAAGTTTGATGGAATCGGAAAAAGATCTTAAAAATGTAAATGAGTATTTCCAAATAATGGCCAAAATGCTCTTTATATTGGAAAATAAAATAGTTGTTTCGCTCTTCGTAGTATTTAACATTTCCGTTCTTATCATTGTAAAGTCTGAGCCATATTCATATGGAAAAGTGCTTTTTAAACCTAGTTCCTCCATATTTTAGTTTTTTATCGATATTGGAAAAAAAAGAGC',
 'name': 'YBR063C_YBR063C_367930|0',
 'label': 0}

### 3. Load models


We'll use a "species-aware" DNA Language Model, called Species-LM for our task. This can be loaded through HuggingFace.

In [60]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [61]:
tokenizer = AutoTokenizer.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")
lm = AutoModelForMaskedLM.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")

In [62]:
print(torch.cuda.is_available())
lm.eval()
lm.to("cuda")
print("Done!")

True
Done!


### 4. Embed the sequences

We embed the sequences in the training dataset using our DNA-Language Model. 

We start by creating a function that generates kmers for any given sequence (```get_kmers``` below).

In [63]:
def get_kmers(seq, k=6, stride=1):
    return [seq[i:i + k] for i in range(0, len(seq), stride) if i + k <= len(seq)]

Then, we tokenize our sequences using the tokenizer.

In [64]:
sequences = []

for i in range(0, len(ds['train'])):
    sequence = ds['train'][i]['sequence']
    sequence = "candida_glabrata " + " ".join(get_kmers(sequence))
    sequence = tokenizer(sequence)["input_ids"]

    sequences.append(sequence)

Next, we create a ```torch.Tensor``` matrix from our sequences.

In [67]:
tokenized_sequences = torch.tensor(sequences)

Checking the shape of our tokenized sequences, we get:

In [68]:
tokenized_sequences.shape

torch.Size([13468, 498])

We'll generate the embeddings for our tokenized sequences.

In [69]:
embeddings = []
batch_size = 64
device = "cuda"

for i in tqdm.tqdm(range(math.ceil(tokenized_sequences.shape[0]/batch_size))):
    with torch.inference_mode():
        with torch.autocast(device):
            embedding = lm(tokenized_sequences[i*batch_size:(i+1)*(batch_size)].to(device), output_hidden_states=True)["hidden_states"]
            embedding = torch.stack(embedding[8:], axis=0)
            embedding = torch.mean(embedding[:,2:-1,:], axis=1)
            embeddings.append(embedding.cpu())

embeddings = torch.concat(embeddings, axis=0)

100%|██████████| 211/211 [00:21<00:00,  9.63it/s]


In [70]:
embeddings.shape

torch.Size([1055, 498, 768])


### Train model

Now, we'll track our DNA Language Model with the training dataset. We'll add a linear layer in the final layer of our language model, and then, train all the parameteres of our model with the training dataset.

In [71]:
print(lm)

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(5504, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [72]:
import torch.nn as nn

class DNA_LM(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask=None):
        if attention_mask == None:
            logits = self.model(input_ids).logits
        else:
            logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits
        return logits

num_classes = 2
dna_lm = DNA_LM(num_classes=num_classes)

Since this is a classification task, the last linear layer only requires two nodes as output.

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
)

trainer = Trainer(
    model=dna_lm,
    args=training_args,
    train_dataset=tokenized_sequences,
    eval_dataset=tokenized_sequences,
    tokenizer=tokenizer,
)

trainer.train()

### TODO: Include implementation of PEFT library