# Mindnlp PEFT with DNA Language Models
This notebook demonstrates how to utilize parameter-efficient fine-tuning techniques (PEFT) from the PEFT library to fine-tune a DNA Language Model (DNA-LM). The fine-tuned DNA-LM will be applied to solve a task from the nucleotide benchmark dataset. Parameter-efficient fine-tuning (PEFT) techniques are crucial for adapting large pre-trained models to specific tasks with limited computational resources.

## 1. Import relevant libraries
We'll start by importing the required libraries, including the mindnlp library and other dependencies.

In [1]:
import mindspore
import mindnlp
import numpy as np
import tqdm

  from .autonotebook import tqdm as notebook_tqdm
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.634 seconds.
Prefix dict has been built successfully.


## 2. Load models
We'll load a pre-trained DNA Language Model, "SpeciesLM", that serves as the base for fine-tuning. This is done using the transformers library from HuggingFace.

The tokenizer and the model comes from the paper, "Species-aware DNA language models capture regulatory elements and their evolution". Paper Link, Code Link. They introduce a species-aware DNA language model, which is trained on more than 800 species spanning over 500 million years of evolution.

In [2]:
from mindnlp.transformers import AutoTokenizer, AutoModelForMaskedLM

In [3]:
tokenizer = AutoTokenizer.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")
lm = AutoModelForMaskedLM.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")



In [4]:
lm.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(5504, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear (768 -> 768)
              (key): Linear (768 -> 768)
              (value): Linear (768 -> 768)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear (768 -> 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermedia

## 3. Prepare datasets
We'll load the `nucleotide_transformer_downstream_tasks` dataset, which contains 18 downstream tasks from the Nucleotide Transformer paper. This dataset provides a consistent genomics benchmark with binary classification tasks.

In [5]:
from mindnlp.dataset import load_dataset
raw_data = load_dataset("InstaDeepAI/nucleotide_transformer_downstream_tasks", "H3")
raw_data

{'train': <mindspore.dataset.engine.datasets_user_defined.GeneratorDataset at 0x7f6ee71d69a0>,
 'test': <mindspore.dataset.engine.datasets_user_defined.GeneratorDataset at 0x7f6ee7470a30>}

We'll use the "H3" subset of this dataset, which contains a total of 13,468 rows in the training data, and 1497 rows in the test data.

In [6]:
train_dataset = raw_data['train']
test_dataset = raw_data['test']

def show_dataset_info(dataset):
    print("dataset column: {}".format(dataset.get_col_names()))
    print("dataset size: {}".format(dataset.get_dataset_size()))
    print("dataset batch size: {}\n".format(dataset.get_batch_size()))

print("train dataset info:")
show_dataset_info(train_dataset)
print("test dataset info:")
show_dataset_info(test_dataset)

train dataset info:
dataset column: ['sequence', 'name', 'label']
dataset size: 13468
dataset batch size: 1

test dataset info:
dataset column: ['sequence', 'name', 'label']
dataset size: 1497
dataset batch size: 1



The dataset consists of three columns, `sequence`, `name` and `label`. An row in this dataset looks like:

In [7]:
for data in train_dataset:
    print(data)
    break

[Tensor(shape=[], dtype=String, value= 'TCACTTCGATTATTGAGGCAGTCTTCATTAAAGTTTATTACAATGGATATGGTATCACCAGTCTTGAACCTACAATCATCTATTTTAGGTGAGCTCGTAGGCATTATTGGAAAAGTGTTCTTTCTCTTAATAGAAGAGATTAAATACCCGATAATCACACCCAAAATTATTGTGGATGCCCAGATATCTTCTTGGTCATTGTTTTTTTTCGCTTCAATCTGTAATCTCTCTGCAAAATTTCGGGAGCCAATAGTGACAACATCGTCAATAATAAGTTTGATGGAATCGGAAAAAGATCTTAAAAATGTAAATGAGTATTTCCAAATAATGGCCAAAATGCTCTTTATATTGGAAAATAAAATAGTTGTTTCGCTCTTCGTAGTATTTAACATTTCCGTTCTTATCATTGTAAAGTCTGAGCCATATTCATATGGAAAAGTGCTTTTTAAACCTAGTTCCTCCATATTTTAGTTTTTTATCGATATTGGAAAAAAAAGAGC'), Tensor(shape=[], dtype=String, value= 'YBR063C_YBR063C_367930|0'), Tensor(shape=[], dtype=Int64, value= 0)]


We split out dataset into training, test, and validation sets.

In [8]:
train_dataset, valid_dataset = train_dataset.split([0.85, 0.15])

print("train dataset info:")
show_dataset_info(train_dataset)
print("valid dataset info:")
show_dataset_info(valid_dataset)

train dataset info:
dataset column: ['sequence', 'name', 'label']
dataset size: 11448
dataset batch size: 1

valid dataset info:
dataset column: ['sequence', 'name', 'label']
dataset size: 2020
dataset batch size: 1



Then, we use the tokenizer and a utility function we created, `get_kmers` to generate the final data and labels. The `get_kmers` function is essential for generating overlapping 6-mers needed by the language model (LM). By using k=6 and stride=1, we ensure that the model receives continuous and overlapping subsequences, capturing the local context within the biological sequence for more effective analysis and prediction.

In [9]:
def get_kmers(seq, k=6, stride=1):
    return [seq[i:i + k] for i in range(0, len(seq), stride) if i + k <= len(seq)]

In [10]:

BATCH_SIZE = 16
DATASET_LIMIT = 200 # NOTE: This dataset limit is set to 200, so that the training runs faster. It can be set to None to use the
                    # entire dataset

def process_sequence(data):
    sequence = data.tolist()
    sequence = "candida_glabrata " + " ".join(get_kmers(sequence))
    sequence = tokenizer(sequence)["input_ids"]
    return sequence

def nucleotide_dataset_process(dataset):
    # remove name column
    dataset = dataset.project(columns=['sequence', 'label'])

    # process sequence
    dataset = dataset.map(process_sequence, input_columns=['sequence'], output_columns=['sequence'])

    # change dataset size
    dataset = dataset.take(DATASET_LIMIT)

    # batch with padding
    dataset = dataset.padded_batch(batch_size=BATCH_SIZE,
                                   drop_remainder=True,
                                   pad_info={'sequence':(None, -100)})
    return dataset


Finally, we create a Dataset object for each our sets.

In [11]:
train_dataset = train_dataset.apply(nucleotide_dataset_process)
valid_dataset = valid_dataset.apply(nucleotide_dataset_process)
test_dataset = test_dataset.apply(nucleotide_dataset_process)

print("train dataset info:")
show_dataset_info(train_dataset)

print("valid dataset info:")
show_dataset_info(valid_dataset)

print("test dataset info:")
show_dataset_info(test_dataset)

for data in train_dataset:
    print(data)
    break

train dataset info:
dataset column: ['sequence', 'label']
dataset size: 12
dataset batch size: 16

valid dataset info:
dataset column: ['sequence', 'label']
dataset size: 12
dataset batch size: 16

test dataset info:
dataset column: ['sequence', 'label']
dataset size: 12
dataset batch size: 16

[Tensor(shape=[16, 498], dtype=Int64, value=
[[   2, 4724, 2162 ... 1240,  852,    3],
 [   2, 4724, 3690 ... 2394, 1369,    3],
 [   2, 4724,  665 ... 1726, 2794,    3],
 ...
 [   2, 4724, 3446 ...  860, 3427,    3],
 [   2, 4724, 2966 ... 1525, 1992,    3],
 [   2, 4724, 3297 ... 1970, 3770,    3]]), Tensor(shape=[16], dtype=Int64, value= [1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0])]


## 4. Train model
Now, we'll train our DNA Language Model with the training dataset. We'll add a linear layer in the final layer of our language model, and then, train all the parameteres of our model with the training dataset.

In [12]:
import mindspore
from mindnlp.core import nn

class DNA_LM(nn.Module):
    def __init__(self, model, num_labels):
        super(DNA_LM, self).__init__()
        self.model = model.bert
        self.in_features = model.config.hidden_size
        self.out_features = num_labels
        self.classifier = nn.Linear(self.in_features, self.out_features)

    def forward(self, sequence, label=None):
        outputs = self.model(input_ids=sequence, attention_mask=None, output_hidden_states=True)
        sequence_output = outputs.hidden_states[-1]
        # Use the [CLS] token for classification
        cls_output = sequence_output[:, 0, :]
        logits = self.classifier(cls_output)

        loss = None
        if label is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.out_features), label.view(-1))

        return (loss, logits) if loss is not None else logits

# Number of classes for your classification task
num_labels = 2
classification_model = DNA_LM(lm, num_labels)
classification_model

DNA_LM(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(5504, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear (768 -> 768)
              (key): Linear (768 -> 768)
              (value): Linear (768 -> 768)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear (768 -> 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
    

In [13]:
from mindnlp.engine import Trainer, TrainingArguments

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_steps=1,
    logging_steps=1,
)

# Initialize Trainer
trainer = Trainer(
    model=classification_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

  2%|▏         | 1/60 [00:01<01:36,  1.64s/it]

{'loss': 0.8075, 'learning_rate': 1.9666666666666666e-05, 'epoch': 0.08}


  3%|▎         | 2/60 [00:02<01:10,  1.22s/it]

{'loss': 0.8228, 'learning_rate': 1.9333333333333333e-05, 'epoch': 0.17}


  5%|▌         | 3/60 [00:03<00:57,  1.01s/it]

{'loss': 0.684, 'learning_rate': 1.9e-05, 'epoch': 0.25}


  7%|▋         | 4/60 [00:04<00:50,  1.11it/s]

{'loss': 0.6842, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.33}


  8%|▊         | 5/60 [00:05<00:51,  1.08it/s]

{'loss': 0.8734, 'learning_rate': 1.8333333333333333e-05, 'epoch': 0.42}


 10%|█         | 6/60 [00:05<00:46,  1.17it/s]

{'loss': 0.6887, 'learning_rate': 1.8e-05, 'epoch': 0.5}


 12%|█▏        | 7/60 [00:06<00:42,  1.24it/s]

{'loss': 0.635, 'learning_rate': 1.7666666666666668e-05, 'epoch': 0.58}


 13%|█▎        | 8/60 [00:07<00:40,  1.27it/s]

{'loss': 0.7209, 'learning_rate': 1.7333333333333336e-05, 'epoch': 0.67}


 15%|█▌        | 9/60 [00:07<00:38,  1.31it/s]

{'loss': 0.7041, 'learning_rate': 1.7e-05, 'epoch': 0.75}


 17%|█▋        | 10/60 [00:08<00:37,  1.34it/s]

{'loss': 0.7175, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.83}


 18%|█▊        | 11/60 [00:09<00:34,  1.42it/s]

{'loss': 0.6137, 'learning_rate': 1.6333333333333335e-05, 'epoch': 0.92}


 20%|██        | 12/60 [00:09<00:32,  1.47it/s]

{'loss': 0.661, 'learning_rate': 1.6000000000000003e-05, 'epoch': 1.0}



 20%|██        | 12/60 [00:12<00:32,  1.47it/s]

{'eval_loss': 0.6326713562011719, 'eval_runtime': 2.2588, 'eval_samples_per_second': 5.312, 'eval_steps_per_second': 0.885, 'epoch': 1.0}


 22%|██▏       | 13/60 [00:12<01:05,  1.39s/it]

{'loss': 0.6403, 'learning_rate': 1.5666666666666667e-05, 'epoch': 1.08}


 23%|██▎       | 14/60 [00:13<00:55,  1.20s/it]

{'loss': 0.6801, 'learning_rate': 1.5333333333333334e-05, 'epoch': 1.17}


 25%|██▌       | 15/60 [00:14<00:47,  1.06s/it]

{'loss': 0.5909, 'learning_rate': 1.5000000000000002e-05, 'epoch': 1.25}


 27%|██▋       | 16/60 [00:15<00:42,  1.03it/s]

{'loss': 0.553, 'learning_rate': 1.4666666666666666e-05, 'epoch': 1.33}


 28%|██▊       | 17/60 [00:15<00:38,  1.11it/s]

{'loss': 0.6791, 'learning_rate': 1.4333333333333334e-05, 'epoch': 1.42}


 30%|███       | 18/60 [00:16<00:36,  1.16it/s]

{'loss': 0.5729, 'learning_rate': 1.4e-05, 'epoch': 1.5}


 32%|███▏      | 19/60 [00:17<00:34,  1.19it/s]

{'loss': 0.5118, 'learning_rate': 1.3666666666666667e-05, 'epoch': 1.58}


 33%|███▎      | 20/60 [00:18<00:32,  1.22it/s]

{'loss': 0.5608, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.67}


 35%|███▌      | 21/60 [00:19<00:31,  1.23it/s]

{'loss': 0.6366, 'learning_rate': 1.3000000000000001e-05, 'epoch': 1.75}


 37%|███▋      | 22/60 [00:19<00:31,  1.22it/s]

{'loss': 0.6166, 'learning_rate': 1.2666666666666667e-05, 'epoch': 1.83}


 38%|███▊      | 23/60 [00:20<00:28,  1.31it/s]

{'loss': 0.4981, 'learning_rate': 1.2333333333333334e-05, 'epoch': 1.92}


 40%|████      | 24/60 [00:21<00:25,  1.40it/s]

{'loss': 0.5189, 'learning_rate': 1.2e-05, 'epoch': 2.0}


                                               
 40%|████      | 24/60 [00:23<00:25,  1.40it/s]

{'eval_loss': 0.5681167244911194, 'eval_runtime': 2.2474, 'eval_samples_per_second': 5.339, 'eval_steps_per_second': 0.89, 'epoch': 2.0}


 42%|████▏     | 25/60 [00:24<00:49,  1.40s/it]

{'loss': 0.5464, 'learning_rate': 1.1666666666666668e-05, 'epoch': 2.08}


 43%|████▎     | 26/60 [00:24<00:41,  1.22s/it]

{'loss': 0.6042, 'learning_rate': 1.1333333333333334e-05, 'epoch': 2.17}


 45%|████▌     | 27/60 [00:25<00:36,  1.10s/it]

{'loss': 0.5416, 'learning_rate': 1.1000000000000001e-05, 'epoch': 2.25}


 47%|████▋     | 28/60 [00:26<00:32,  1.01s/it]

{'loss': 0.4642, 'learning_rate': 1.0666666666666667e-05, 'epoch': 2.33}


 48%|████▊     | 29/60 [00:27<00:29,  1.04it/s]

{'loss': 0.6242, 'learning_rate': 1.0333333333333335e-05, 'epoch': 2.42}


 50%|█████     | 30/60 [00:28<00:27,  1.10it/s]

{'loss': 0.4907, 'learning_rate': 1e-05, 'epoch': 2.5}


 52%|█████▏    | 31/60 [00:28<00:25,  1.15it/s]

{'loss': 0.4037, 'learning_rate': 9.666666666666667e-06, 'epoch': 2.58}


 53%|█████▎    | 32/60 [00:29<00:23,  1.17it/s]

{'loss': 0.4473, 'learning_rate': 9.333333333333334e-06, 'epoch': 2.67}


 55%|█████▌    | 33/60 [00:30<00:22,  1.20it/s]

{'loss': 0.5584, 'learning_rate': 9e-06, 'epoch': 2.75}


 57%|█████▋    | 34/60 [00:31<00:21,  1.21it/s]

{'loss': 0.5462, 'learning_rate': 8.666666666666668e-06, 'epoch': 2.83}


 58%|█████▊    | 35/60 [00:32<00:20,  1.20it/s]

{'loss': 0.3988, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.92}


 60%|██████    | 36/60 [00:32<00:18,  1.28it/s]

{'loss': 0.3943, 'learning_rate': 8.000000000000001e-06, 'epoch': 3.0}


                                               
 60%|██████    | 36/60 [00:35<00:18,  1.28it/s]

{'eval_loss': 0.509650468826294, 'eval_runtime': 2.2808, 'eval_samples_per_second': 5.261, 'eval_steps_per_second': 0.877, 'epoch': 3.0}


 62%|██████▏   | 37/60 [00:36<00:34,  1.49s/it]

{'loss': 0.4006, 'learning_rate': 7.666666666666667e-06, 'epoch': 3.08}


 63%|██████▎   | 38/60 [00:36<00:28,  1.31s/it]

{'loss': 0.5446, 'learning_rate': 7.333333333333333e-06, 'epoch': 3.17}


 65%|██████▌   | 39/60 [00:37<00:24,  1.16s/it]

{'loss': 0.4831, 'learning_rate': 7e-06, 'epoch': 3.25}


 67%|██████▋   | 40/60 [00:38<00:21,  1.07s/it]

{'loss': 0.4162, 'learning_rate': 6.666666666666667e-06, 'epoch': 3.33}


 68%|██████▊   | 41/60 [00:39<00:19,  1.02s/it]

{'loss': 0.5341, 'learning_rate': 6.333333333333333e-06, 'epoch': 3.42}


 70%|███████   | 42/60 [00:40<00:17,  1.02it/s]

{'loss': 0.4316, 'learning_rate': 6e-06, 'epoch': 3.5}


 72%|███████▏  | 43/60 [00:41<00:16,  1.03it/s]

{'loss': 0.2981, 'learning_rate': 5.666666666666667e-06, 'epoch': 3.58}


 73%|███████▎  | 44/60 [00:42<00:15,  1.06it/s]

{'loss': 0.4425, 'learning_rate': 5.333333333333334e-06, 'epoch': 3.67}


 75%|███████▌  | 45/60 [00:42<00:13,  1.13it/s]

{'loss': 0.548, 'learning_rate': 5e-06, 'epoch': 3.75}


 77%|███████▋  | 46/60 [00:43<00:11,  1.21it/s]

{'loss': 0.5135, 'learning_rate': 4.666666666666667e-06, 'epoch': 3.83}


 78%|███████▊  | 47/60 [00:44<00:10,  1.28it/s]

{'loss': 0.3461, 'learning_rate': 4.333333333333334e-06, 'epoch': 3.92}


 80%|████████  | 48/60 [00:45<00:09,  1.32it/s]

{'loss': 0.3772, 'learning_rate': 4.000000000000001e-06, 'epoch': 4.0}


                                               
 80%|████████  | 48/60 [00:47<00:09,  1.32it/s]

{'eval_loss': 0.4830377995967865, 'eval_runtime': 2.2901, 'eval_samples_per_second': 5.24, 'eval_steps_per_second': 0.873, 'epoch': 4.0}


 82%|████████▏ | 49/60 [00:48<00:16,  1.52s/it]

{'loss': 0.3524, 'learning_rate': 3.6666666666666666e-06, 'epoch': 4.08}


 83%|████████▎ | 50/60 [00:49<00:14,  1.41s/it]

{'loss': 0.4848, 'learning_rate': 3.3333333333333333e-06, 'epoch': 4.17}


 85%|████████▌ | 51/60 [00:50<00:11,  1.25s/it]

{'loss': 0.3999, 'learning_rate': 3e-06, 'epoch': 4.25}


 87%|████████▋ | 52/60 [00:51<00:09,  1.21s/it]

{'loss': 0.4042, 'learning_rate': 2.666666666666667e-06, 'epoch': 4.33}


 88%|████████▊ | 53/60 [00:52<00:08,  1.15s/it]

{'loss': 0.5453, 'learning_rate': 2.3333333333333336e-06, 'epoch': 4.42}


 90%|█████████ | 54/60 [00:53<00:06,  1.07s/it]

{'loss': 0.3801, 'learning_rate': 2.0000000000000003e-06, 'epoch': 4.5}


 92%|█████████▏| 55/60 [00:54<00:05,  1.11s/it]

{'loss': 0.2381, 'learning_rate': 1.6666666666666667e-06, 'epoch': 4.58}


 93%|█████████▎| 56/60 [00:55<00:03,  1.01it/s]

{'loss': 0.4054, 'learning_rate': 1.3333333333333334e-06, 'epoch': 4.67}


 95%|█████████▌| 57/60 [00:55<00:02,  1.10it/s]

{'loss': 0.5075, 'learning_rate': 1.0000000000000002e-06, 'epoch': 4.75}


 97%|█████████▋| 58/60 [00:56<00:01,  1.17it/s]

{'loss': 0.4821, 'learning_rate': 6.666666666666667e-07, 'epoch': 4.83}


 98%|█████████▊| 59/60 [00:57<00:00,  1.22it/s]

{'loss': 0.2894, 'learning_rate': 3.3333333333333335e-07, 'epoch': 4.92}


100%|██████████| 60/60 [00:58<00:00,  1.26it/s]

{'loss': 0.3635, 'learning_rate': 0.0, 'epoch': 5.0}


                                               
100%|██████████| 60/60 [01:00<00:00,  1.01s/it]

{'eval_loss': 0.4767491817474365, 'eval_runtime': 2.2767, 'eval_samples_per_second': 5.271, 'eval_steps_per_second': 0.878, 'epoch': 5.0}
{'train_runtime': 60.4651, 'train_samples_per_second': 15.877, 'train_steps_per_second': 0.992, 'train_loss': 0.5313406201700369, 'epoch': 5.0}





TrainOutput(global_step=60, training_loss=0.5313406201700369, metrics={'train_runtime': 60.4651, 'train_samples_per_second': 15.877, 'train_steps_per_second': 0.992, 'train_loss': 0.5313406201700369, 'epoch': 5.0})