# Mindnlp PEFT with DNA Language Models
This notebook demonstrates how to utilize parameter-efficient fine-tuning techniques (PEFT) from the PEFT library to fine-tune a DNA Language Model (DNA-LM). The fine-tuned DNA-LM will be applied to solve a task from the nucleotide benchmark dataset. Parameter-efficient fine-tuning (PEFT) techniques are crucial for adapting large pre-trained models to specific tasks with limited computational resources.

When using modelarts env, need to remove `RANK_TABLE_FILE` to use `MindSpore` dynamic graph mode.
```python
import os
if "RANK_TABLE_FILE" in os.environ:
    del os.environ["RANK_TABLE_FILE"]
```

## 1. Import relevant libraries
We'll start by importing the required libraries, including the mindnlp library and other dependencies.

In [4]:
import mindspore
import mindnlp
import numpy as np
import tqdm

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  from .autonotebook import tqdm as notebook_tqdm
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 1.419 seconds.
Prefix dict has been built successfully.


## 2. Load models
We'll load a pre-trained DNA Language Model, "SpeciesLM", that serves as the base for fine-tuning. This is done using the transformers library from HuggingFace.

The tokenizer and the model comes from the paper, "Species-aware DNA language models capture regulatory elements and their evolution". Paper Link, Code Link. They introduce a species-aware DNA language model, which is trained on more than 800 species spanning over 500 million years of evolution.

In [5]:
from mindnlp.transformers import AutoTokenizer, AutoModelForMaskedLM

In [6]:
tokenizer = AutoTokenizer.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")
lm = AutoModelForMaskedLM.from_pretrained("gagneurlab/SpeciesLM", revision = "downstream_species_lm")



[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB


In [7]:
lm.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(5504, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear (768 -> 768)
              (key): Linear (768 -> 768)
              (value): Linear (768 -> 768)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear (768 -> 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermedia

## 3. Prepare datasets
We'll load the `nucleotide_transformer_downstream_tasks` dataset, which contains 18 downstream tasks from the Nucleotide Transformer paper. This dataset provides a consistent genomics benchmark with binary classification tasks.

In [8]:
from mindnlp.dataset import load_dataset
raw_data = load_dataset("InstaDeepAI/nucleotide_transformer_downstream_tasks", "H3")
raw_data

{'train': <mindspore.dataset.engine.datasets_user_defined.GeneratorDataset at 0xfffe3ef8ac40>,
 'test': <mindspore.dataset.engine.datasets_user_defined.GeneratorDataset at 0xfffe3ef741f0>}

We'll use the "H3" subset of this dataset, which contains a total of 13,468 rows in the training data, and 1497 rows in the test data.

In [9]:
train_dataset = raw_data['train']
test_dataset = raw_data['test']

def show_dataset_info(dataset):
    print("dataset column: {}".format(dataset.get_col_names()))
    print("dataset size: {}".format(dataset.get_dataset_size()))
    print("dataset batch size: {}\n".format(dataset.get_batch_size()))

print("train dataset info:")
show_dataset_info(train_dataset)
print("test dataset info:")
show_dataset_info(test_dataset)

train dataset info:
dataset column: ['sequence', 'name', 'label']
dataset size: 13468
dataset batch size: 1

test dataset info:
dataset column: ['sequence', 'name', 'label']
dataset size: 1497
dataset batch size: 1



The dataset consists of three columns, `sequence`, `name` and `label`. An row in this dataset looks like:

In [10]:
for data in train_dataset:
    print(data)
    break

[Tensor(shape=[], dtype=String, value= 'TCACTTCGATTATTGAGGCAGTCTTCATTAAAGTTTATTACAATGGATATGGTATCACCAGTCTTGAACCTACAATCATCTATTTTAGGTGAGCTCGTAGGCATTATTGGAAAAGTGTTCTTTCTCTTAATAGAAGAGATTAAATACCCGATAATCACACCCAAAATTATTGTGGATGCCCAGATATCTTCTTGGTCATTGTTTTTTTTCGCTTCAATCTGTAATCTCTCTGCAAAATTTCGGGAGCCAATAGTGACAACATCGTCAATAATAAGTTTGATGGAATCGGAAAAAGATCTTAAAAATGTAAATGAGTATTTCCAAATAATGGCCAAAATGCTCTTTATATTGGAAAATAAAATAGTTGTTTCGCTCTTCGTAGTATTTAACATTTCCGTTCTTATCATTGTAAAGTCTGAGCCATATTCATATGGAAAAGTGCTTTTTAAACCTAGTTCCTCCATATTTTAGTTTTTTATCGATATTGGAAAAAAAAGAGC'), Tensor(shape=[], dtype=String, value= 'YBR063C_YBR063C_367930|0'), Tensor(shape=[], dtype=Int64, value= 0)]


We split out dataset into training, test, and validation sets.

In [11]:
train_dataset, valid_dataset = train_dataset.split([0.85, 0.15])

print("train dataset info:")
show_dataset_info(train_dataset)
print("valid dataset info:")
show_dataset_info(valid_dataset)

train dataset info:
dataset column: ['sequence', 'name', 'label']
dataset size: 11448
dataset batch size: 1

valid dataset info:
dataset column: ['sequence', 'name', 'label']
dataset size: 2020
dataset batch size: 1



Then, we use the tokenizer and a utility function we created, `get_kmers` to generate the final data and labels. The `get_kmers` function is essential for generating overlapping 6-mers needed by the language model (LM). By using k=6 and stride=1, we ensure that the model receives continuous and overlapping subsequences, capturing the local context within the biological sequence for more effective analysis and prediction.

In [12]:
def get_kmers(seq, k=6, stride=1):
    return [seq[i:i + k] for i in range(0, len(seq), stride) if i + k <= len(seq)]

In [13]:

BATCH_SIZE = 16
DATASET_LIMIT = 200 # NOTE: This dataset limit is set to 200, so that the training runs faster. It can be set to None to use the
                    # entire dataset

def process_sequence(data):
    sequence = data.tolist()
    sequence = "candida_glabrata " + " ".join(get_kmers(sequence))
    sequence = tokenizer(sequence)["input_ids"]
    return sequence

def nucleotide_dataset_process(dataset):
    # remove name column
    dataset = dataset.project(columns=['sequence', 'label'])

    # process sequence"""  """
    dataset = dataset.map(process_sequence, input_columns=['sequence'], output_columns=['sequence'])

    # change dataset size
    dataset = dataset.take(DATASET_LIMIT)

    # batch with padding
    dataset = dataset.padded_batch(batch_size=BATCH_SIZE,
                                   drop_remainder=True,
                                   pad_info={'sequence':(None, -100)})
    return dataset


Finally, we create a Dataset object for each our sets.

In [14]:
train_dataset = train_dataset.apply(nucleotide_dataset_process)
valid_dataset = valid_dataset.apply(nucleotide_dataset_process)
test_dataset = test_dataset.apply(nucleotide_dataset_process)

print("train dataset info:")
show_dataset_info(train_dataset)

print("valid dataset info:")
show_dataset_info(valid_dataset)

print("test dataset info:")
show_dataset_info(test_dataset)

for data in train_dataset:
    print(data)
    break

train dataset info:
dataset column: ['sequence', 'label']
dataset size: 12
dataset batch size: 16

valid dataset info:
dataset column: ['sequence', 'label']
dataset size: 12
dataset batch size: 16

test dataset info:
dataset column: ['sequence', 'label']
dataset size: 12
dataset batch size: 16

[Tensor(shape=[16, 498], dtype=Int64, value=
[[   2, 4724,  200 ... 3930, 3420,    3],
 [   2, 4724,  414 ...  518, 2057,    3],
 [   2, 4724,    8 ... 3674, 2394,    3],
 ...
 [   2, 4724, 1369 ... 3089,   54,    3],
 [   2, 4724, 3365 ... 2606, 2219,    3],
 [   2, 4724, 1630 ... 2727, 2702,    3]]), Tensor(shape=[16], dtype=Int64, value= [1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1])]


## 4. Train model
Now, we'll train our DNA Language Model with the training dataset. We'll add a linear layer in the final layer of our language model, and then, train all the parameteres of our model with the training dataset.

In [15]:
import mindspore
from mindnlp.core import nn

class DNA_LM(nn.Module):
    def __init__(self, model, num_labels):
        super(DNA_LM, self).__init__()
        self.model = model.bert
        self.in_features = model.config.hidden_size
        self.out_features = num_labels
        self.classifier = nn.Linear(self.in_features, self.out_features)

    def forward(self, sequence, label=None):
        outputs = self.model(input_ids=sequence, attention_mask=None, output_hidden_states=True)
        sequence_output = outputs.hidden_states[-1]
        # Use the [CLS] token for classification
        cls_output = sequence_output[:, 0, :]
        logits = self.classifier(cls_output)

        loss = None
        if label is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.out_features), label.view(-1))

        return (loss, logits) if loss is not None else logits

# Number of classes for your classification task
num_labels = 2
classification_model = DNA_LM(lm, num_labels)
classification_model

DNA_LM(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(5504, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear (768 -> 768)
              (key): Linear (768 -> 768)
              (value): Linear (768 -> 768)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear (768 -> 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
    

In [16]:
from mindnlp.engine import Trainer, TrainingArguments

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_steps=1,
    logging_steps=1,
)

# Initialize Trainer
trainer = Trainer(
    model=classification_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

  0%|          | 0/60 [00:00<?, ?it/s]

\

  2%|▏         | 1/60 [00:22<21:42, 22.08s/it]

\

  2%|▏         | 1/60 [01:06<21:42, 22.08s/it]

{'loss': 0.8252, 'learning_rate': 1.9666666666666666e-05, 'epoch': 0.08}


  3%|▎         | 2/60 [01:07<34:40, 35.87s/it]

{'loss': 0.7127, 'learning_rate': 1.9333333333333333e-05, 'epoch': 0.17}


  5%|▌         | 3/60 [01:08<18:47, 19.78s/it]

{'loss': 0.7957, 'learning_rate': 1.9e-05, 'epoch': 0.25}


  7%|▋         | 4/60 [01:09<11:25, 12.24s/it]

{'loss': 0.7914, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.33}


  8%|▊         | 5/60 [01:09<07:22,  8.04s/it]

{'loss': 0.9366, 'learning_rate': 1.8333333333333333e-05, 'epoch': 0.42}


 10%|█         | 6/60 [01:10<04:58,  5.53s/it]

{'loss': 0.5876, 'learning_rate': 1.8e-05, 'epoch': 0.5}


 12%|█▏        | 7/60 [01:10<03:28,  3.93s/it]

{'loss': 0.5856, 'learning_rate': 1.7666666666666668e-05, 'epoch': 0.58}


 13%|█▎        | 8/60 [01:11<02:29,  2.88s/it]

{'loss': 0.6507, 'learning_rate': 1.7333333333333336e-05, 'epoch': 0.67}


 15%|█▌        | 9/60 [01:12<01:51,  2.19s/it]

{'loss': 0.8005, 'learning_rate': 1.7e-05, 'epoch': 0.75}


 17%|█▋        | 10/60 [01:12<01:25,  1.71s/it]

{'loss': 0.7325, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.83}


 18%|█▊        | 11/60 [01:13<01:07,  1.38s/it]

{'loss': 0.6998, 'learning_rate': 1.6333333333333335e-05, 'epoch': 0.92}


 20%|██        | 12/60 [01:14<00:55,  1.16s/it]

{'loss': 0.8102, 'learning_rate': 1.6000000000000003e-05, 'epoch': 1.0}




|


 20%|██        | 12/60 [01:18<00:55,  1.16s/it]

{'eval_loss': 0.6935629844665527, 'eval_runtime': 4.0479, 'eval_samples_per_second': 2.964, 'eval_steps_per_second': 0.494, 'epoch': 1.0}


 22%|██▏       | 13/60 [01:19<01:47,  2.29s/it]

{'loss': 0.6874, 'learning_rate': 1.5666666666666667e-05, 'epoch': 1.08}


 23%|██▎       | 14/60 [01:19<01:23,  1.81s/it]

{'loss': 0.5356, 'learning_rate': 1.5333333333333334e-05, 'epoch': 1.17}


 25%|██▌       | 15/60 [01:20<01:07,  1.50s/it]

{'loss': 0.717, 'learning_rate': 1.5000000000000002e-05, 'epoch': 1.25}


 27%|██▋       | 16/60 [01:21<00:57,  1.30s/it]

{'loss': 0.711, 'learning_rate': 1.4666666666666666e-05, 'epoch': 1.33}


 28%|██▊       | 17/60 [01:22<00:50,  1.17s/it]

{'loss': 0.7494, 'learning_rate': 1.4333333333333334e-05, 'epoch': 1.42}


 30%|███       | 18/60 [01:23<00:44,  1.07s/it]

{'loss': 0.5204, 'learning_rate': 1.4e-05, 'epoch': 1.5}


 32%|███▏      | 19/60 [01:23<00:39,  1.03it/s]

{'loss': 0.5145, 'learning_rate': 1.3666666666666667e-05, 'epoch': 1.58}


 33%|███▎      | 20/60 [01:24<00:37,  1.07it/s]

{'loss': 0.5374, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.67}


 35%|███▌      | 21/60 [01:25<00:34,  1.13it/s]

{'loss': 0.7067, 'learning_rate': 1.3000000000000001e-05, 'epoch': 1.75}


 37%|███▋      | 22/60 [01:26<00:31,  1.20it/s]

{'loss': 0.6515, 'learning_rate': 1.2666666666666667e-05, 'epoch': 1.83}


 38%|███▊      | 23/60 [01:26<00:28,  1.30it/s]

{'loss': 0.6086, 'learning_rate': 1.2333333333333334e-05, 'epoch': 1.92}


 40%|████      | 24/60 [01:27<00:28,  1.27it/s]

{'loss': 0.6736, 'learning_rate': 1.2e-05, 'epoch': 2.0}


                                               
 40%|████      | 24/60 [01:45<00:28,  1.27it/s]

{'eval_loss': 0.6261480450630188, 'eval_runtime': 4.1712, 'eval_samples_per_second': 2.877, 'eval_steps_per_second': 0.479, 'epoch': 2.0}


 42%|████▏     | 25/60 [01:46<03:34,  6.12s/it]

{'loss': 0.5717, 'learning_rate': 1.1666666666666668e-05, 'epoch': 2.08}


 43%|████▎     | 26/60 [01:46<02:32,  4.49s/it]

{'loss': 0.5132, 'learning_rate': 1.1333333333333334e-05, 'epoch': 2.17}


 45%|████▌     | 27/60 [01:47<01:50,  3.35s/it]

{'loss': 0.6381, 'learning_rate': 1.1000000000000001e-05, 'epoch': 2.25}


 47%|████▋     | 28/60 [01:48<01:21,  2.53s/it]

{'loss': 0.576, 'learning_rate': 1.0666666666666667e-05, 'epoch': 2.33}


 48%|████▊     | 29/60 [01:48<01:01,  1.98s/it]

{'loss': 0.6151, 'learning_rate': 1.0333333333333335e-05, 'epoch': 2.42}


 50%|█████     | 30/60 [01:49<00:47,  1.59s/it]

{'loss': 0.4419, 'learning_rate': 1e-05, 'epoch': 2.5}


 52%|█████▏    | 31/60 [01:50<00:39,  1.38s/it]

{'loss': 0.4393, 'learning_rate': 9.666666666666667e-06, 'epoch': 2.58}


 53%|█████▎    | 32/60 [01:51<00:33,  1.20s/it]

{'loss': 0.4634, 'learning_rate': 9.333333333333334e-06, 'epoch': 2.67}


 55%|█████▌    | 33/60 [01:51<00:28,  1.07s/it]

{'loss': 0.612, 'learning_rate': 9e-06, 'epoch': 2.75}


 57%|█████▋    | 34/60 [01:52<00:25,  1.02it/s]

{'loss': 0.5955, 'learning_rate': 8.666666666666668e-06, 'epoch': 2.83}


 58%|█████▊    | 35/60 [01:53<00:23,  1.09it/s]

{'loss': 0.4584, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.92}


 60%|██████    | 36/60 [01:54<00:20,  1.16it/s]

{'loss': 0.6288, 'learning_rate': 8.000000000000001e-06, 'epoch': 3.0}


                                               
 60%|██████    | 36/60 [02:12<00:20,  1.16it/s]

{'eval_loss': 0.5678010582923889, 'eval_runtime': 4.0958, 'eval_samples_per_second': 2.93, 'eval_steps_per_second': 0.488, 'epoch': 3.0}


 62%|██████▏   | 37/60 [02:13<02:27,  6.42s/it]

{'loss': 0.4614, 'learning_rate': 7.666666666666667e-06, 'epoch': 3.08}


 63%|██████▎   | 38/60 [02:14<01:43,  4.71s/it]

{'loss': 0.4509, 'learning_rate': 7.333333333333333e-06, 'epoch': 3.17}


 65%|██████▌   | 39/60 [02:15<01:13,  3.50s/it]

{'loss': 0.5712, 'learning_rate': 7e-06, 'epoch': 3.25}


 67%|██████▋   | 40/60 [02:15<00:53,  2.69s/it]

{'loss': 0.5515, 'learning_rate': 6.666666666666667e-06, 'epoch': 3.33}


 68%|██████▊   | 41/60 [02:16<00:40,  2.11s/it]

{'loss': 0.6143, 'learning_rate': 6.333333333333333e-06, 'epoch': 3.42}


 70%|███████   | 42/60 [02:17<00:31,  1.73s/it]

{'loss': 0.3716, 'learning_rate': 6e-06, 'epoch': 3.5}


 72%|███████▏  | 43/60 [02:18<00:24,  1.42s/it]

{'loss': 0.3609, 'learning_rate': 5.666666666666667e-06, 'epoch': 3.58}


 73%|███████▎  | 44/60 [02:18<00:19,  1.21s/it]

{'loss': 0.4412, 'learning_rate': 5.333333333333334e-06, 'epoch': 3.67}


 75%|███████▌  | 45/60 [02:19<00:15,  1.05s/it]

{'loss': 0.6236, 'learning_rate': 5e-06, 'epoch': 3.75}


 77%|███████▋  | 46/60 [02:20<00:13,  1.02it/s]

{'loss': 0.5603, 'learning_rate': 4.666666666666667e-06, 'epoch': 3.83}


 78%|███████▊  | 47/60 [02:21<00:11,  1.11it/s]

{'loss': 0.4473, 'learning_rate': 4.333333333333334e-06, 'epoch': 3.92}


 80%|████████  | 48/60 [02:21<00:10,  1.15it/s]

{'loss': 0.5042, 'learning_rate': 4.000000000000001e-06, 'epoch': 4.0}


                                               
 80%|████████  | 48/60 [02:39<00:10,  1.15it/s]

{'eval_loss': 0.5305204391479492, 'eval_runtime': 4.2506, 'eval_samples_per_second': 2.823, 'eval_steps_per_second': 0.471, 'epoch': 4.0}


 82%|████████▏ | 49/60 [02:40<01:07,  6.11s/it]

{'loss': 0.385, 'learning_rate': 3.6666666666666666e-06, 'epoch': 4.08}


 83%|████████▎ | 50/60 [02:40<00:44,  4.46s/it]

{'loss': 0.4341, 'learning_rate': 3.3333333333333333e-06, 'epoch': 4.17}


 85%|████████▌ | 51/60 [02:41<00:29,  3.32s/it]

{'loss': 0.4644, 'learning_rate': 3e-06, 'epoch': 4.25}


 87%|████████▋ | 52/60 [02:42<00:20,  2.55s/it]

{'loss': 0.5216, 'learning_rate': 2.666666666666667e-06, 'epoch': 4.33}


 88%|████████▊ | 53/60 [02:42<00:13,  1.98s/it]

{'loss': 0.5756, 'learning_rate': 2.3333333333333336e-06, 'epoch': 4.42}


 90%|█████████ | 54/60 [02:43<00:09,  1.58s/it]

{'loss': 0.3522, 'learning_rate': 2.0000000000000003e-06, 'epoch': 4.5}


 92%|█████████▏| 55/60 [02:44<00:06,  1.29s/it]

{'loss': 0.2875, 'learning_rate': 1.6666666666666667e-06, 'epoch': 4.58}


 93%|█████████▎| 56/60 [02:44<00:04,  1.10s/it]

{'loss': 0.4042, 'learning_rate': 1.3333333333333334e-06, 'epoch': 4.67}


 95%|█████████▌| 57/60 [02:45<00:02,  1.01it/s]

{'loss': 0.5759, 'learning_rate': 1.0000000000000002e-06, 'epoch': 4.75}


 97%|█████████▋| 58/60 [02:46<00:01,  1.14it/s]

{'loss': 0.5405, 'learning_rate': 6.666666666666667e-07, 'epoch': 4.83}


 98%|█████████▊| 59/60 [02:46<00:00,  1.20it/s]

{'loss': 0.4023, 'learning_rate': 3.3333333333333335e-07, 'epoch': 4.92}


100%|██████████| 60/60 [02:47<00:00,  1.30it/s]

{'loss': 0.5196, 'learning_rate': 0.0, 'epoch': 5.0}


                                               
100%|██████████| 60/60 [02:51<00:00,  2.86s/it]

{'eval_loss': 0.5207817554473877, 'eval_runtime': 4.0536, 'eval_samples_per_second': 2.96, 'eval_steps_per_second': 0.493, 'epoch': 5.0}
{'train_runtime': 171.6193, 'train_samples_per_second': 5.594, 'train_steps_per_second': 0.35, 'train_loss': 0.5752782459060352, 'epoch': 5.0}





TrainOutput(global_step=60, training_loss=0.5752782459060352, metrics={'train_runtime': 171.6193, 'train_samples_per_second': 5.594, 'train_steps_per_second': 0.35, 'train_loss': 0.5752782459060352, 'epoch': 5.0})

## 5. Evaluation

In [17]:
# Generate predictions

predictions = trainer.predict(test_dataset)
logits = predictions.predictions
predicted_labels = logits.argmax(axis=-1)
print(predicted_labels)

100%|██████████| 12/12 [00:02<00:00,  4.35it/s]

[1 1 1 1 0 1 1 0 0 0 0 1 1 0 1 0 1 1 1 0 1 1 0 0 1 0 1 0 0 0 1 1 1 0 0 0 1
 1 1 0 0 0 1 1 0 1 0 0 0 0 1 1 1 1 0 0 1 1 0 0 1 1 1 1 0 0 1 0 1 1 0 0 0 1
 0 0 0 1 1 0 1 0 0 1 0 0 1 1 1 1 0 0 1 1 1 1 0 1 0 1 1 1 1 0 0 1 1 0 0 0 1
 1 0 1 1 0 0 1 1 1 0 1 0 1 1 1 0 0 0 1 0 1 1 1 1 1 1 0 1 1 0 1 0 1 1 0 1 1
 0 1 1 1 0 0 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1 1 0 1 0 0 1 1 1 1 1 1 0 0 1 0 0
 1 1 0 1 1 0 1]





Then, we create a function to calculate the accuracy from the test and predicted labels.

In [24]:
def calculate_accuracy(true_labels, predicted_labels):

    assert len(true_labels) == len(predicted_labels), "Arrays must have the same length"
    correct_predictions = np.sum(true_labels == predicted_labels)
    accuracy = correct_predictions / len(true_labels)

    return accuracy

test_labels = []
for data in test_dataset:
    sequence, label = data
    for single_label in label:
        test_labels.append(single_label)

accuracy = calculate_accuracy(test_labels, predicted_labels)
print(f"Accuracy: {accuracy:.2f}")

Accuracy: 0.78


The results aren't that good, which we can attribute to the small dataset size.

## 6. Parameter Efficient Fine-Tuning Techniques
In this section, we demonstrate how to employ parameter-efficient fine-tuning (PEFT) techniques to adapt a pre-trained model for specific genomics tasks using the PEFT library.

The LoraConfig object is instantiated to configure the PEFT parameters:

+ task_type: Specifies the type of task, in this case, sequence classification (SEQ_CLS).
+ r: The rank of the LoRA matrices.
+ lora_alpha: Scaling factor for adaptive re-parameterization.
+ target_modules: Modules within the model to apply PEFT re-parameterization (query, key, value in this example).
+ lora_dropout: Dropout rate used during PEFT fine-tuning.

In [26]:
# Number of classes for your classification task
num_labels = 2
classification_model = DNA_LM(lm, num_labels)
classification_model

DNA_LM(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(5504, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear (768 -> 768)
              (key): Linear (768 -> 768)
              (value): Linear (768 -> 768)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear (768 -> 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
    

In [27]:
from mindnlp.peft import LoraConfig, TaskType

peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["query", "key", "value"],
    lora_dropout=0.01,
)

In [28]:
from mindnlp.peft import get_peft_model

peft_model = get_peft_model(classification_model, peft_config)
peft_model.print_trainable_parameters()

trainable params: 442,368 || all params: 90,121,730 || trainable%: 0.49085608986866985


In [29]:
peft_model

PeftModel(
  (base_model): LoraModel(
    (model): DNA_LM(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(5504, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear (768 -> 768)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.01, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear (768 -> 8)
                    )
                    (lora_B): ModuleDict(

In [34]:
# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_steps=1,
    logging_steps=1,
)

# Initialize Trainer
trainer = Trainer(
    model=peft_model.model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

  0%|          | 0/60 [00:00<?, ?it/s]

  2%|▏         | 1/60 [00:00<00:38,  1.54it/s]

{'loss': 0.865, 'learning_rate': 1.9666666666666666e-05, 'epoch': 0.08}


  3%|▎         | 2/60 [00:01<00:34,  1.66it/s]

{'loss': 0.8191, 'learning_rate': 1.9333333333333333e-05, 'epoch': 0.17}


  5%|▌         | 3/60 [00:01<00:33,  1.72it/s]

{'loss': 0.7792, 'learning_rate': 1.9e-05, 'epoch': 0.25}


  7%|▋         | 4/60 [00:02<00:32,  1.70it/s]

{'loss': 0.6196, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.33}


  8%|▊         | 5/60 [00:02<00:31,  1.75it/s]

{'loss': 0.7699, 'learning_rate': 1.8333333333333333e-05, 'epoch': 0.42}


 10%|█         | 6/60 [00:03<00:30,  1.79it/s]

{'loss': 0.6927, 'learning_rate': 1.8e-05, 'epoch': 0.5}


 12%|█▏        | 7/60 [00:04<00:29,  1.80it/s]

{'loss': 0.7386, 'learning_rate': 1.7666666666666668e-05, 'epoch': 0.58}


 13%|█▎        | 8/60 [00:04<00:29,  1.78it/s]

{'loss': 0.8408, 'learning_rate': 1.7333333333333336e-05, 'epoch': 0.67}


 15%|█▌        | 9/60 [00:05<00:29,  1.75it/s]

{'loss': 0.6235, 'learning_rate': 1.7e-05, 'epoch': 0.75}


 17%|█▋        | 10/60 [00:05<00:29,  1.72it/s]

{'loss': 0.7194, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.83}


 18%|█▊        | 11/60 [00:06<00:29,  1.67it/s]

{'loss': 0.6711, 'learning_rate': 1.6333333333333335e-05, 'epoch': 0.92}


 20%|██        | 12/60 [00:07<00:28,  1.66it/s]

{'loss': 0.8141, 'learning_rate': 1.6000000000000003e-05, 'epoch': 1.0}


                                               
 20%|██        | 12/60 [00:26<00:28,  1.66it/s]

{'eval_loss': 0.74076908826828, 'eval_runtime': 3.9145, 'eval_samples_per_second': 3.066, 'eval_steps_per_second': 0.511, 'epoch': 1.0}


 22%|██▏       | 13/60 [00:27<05:11,  6.62s/it]

{'loss': 0.8463, 'learning_rate': 1.5666666666666667e-05, 'epoch': 1.08}


 23%|██▎       | 14/60 [00:28<03:39,  4.78s/it]

{'loss': 0.8359, 'learning_rate': 1.5333333333333334e-05, 'epoch': 1.17}


 25%|██▌       | 15/60 [00:28<02:37,  3.49s/it]

{'loss': 0.7896, 'learning_rate': 1.5000000000000002e-05, 'epoch': 1.25}


 27%|██▋       | 16/60 [00:29<01:55,  2.62s/it]

{'loss': 0.5835, 'learning_rate': 1.4666666666666666e-05, 'epoch': 1.33}


 28%|██▊       | 17/60 [00:29<01:25,  2.00s/it]

{'loss': 0.7319, 'learning_rate': 1.4333333333333334e-05, 'epoch': 1.42}


 30%|███       | 18/60 [00:30<01:05,  1.55s/it]

{'loss': 0.734, 'learning_rate': 1.4e-05, 'epoch': 1.5}


 32%|███▏      | 19/60 [00:30<00:51,  1.24s/it]

{'loss': 0.7376, 'learning_rate': 1.3666666666666667e-05, 'epoch': 1.58}


 33%|███▎      | 20/60 [00:31<00:40,  1.02s/it]

{'loss': 0.7403, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.67}


 35%|███▌      | 21/60 [00:31<00:34,  1.14it/s]

{'loss': 0.6485, 'learning_rate': 1.3000000000000001e-05, 'epoch': 1.75}


 37%|███▋      | 22/60 [00:32<00:29,  1.28it/s]

{'loss': 0.6661, 'learning_rate': 1.2666666666666667e-05, 'epoch': 1.83}


 38%|███▊      | 23/60 [00:32<00:25,  1.43it/s]

{'loss': 0.6876, 'learning_rate': 1.2333333333333334e-05, 'epoch': 1.92}


 40%|████      | 24/60 [00:33<00:23,  1.53it/s]

{'loss': 0.7695, 'learning_rate': 1.2e-05, 'epoch': 2.0}


                                               
 40%|████      | 24/60 [00:51<00:23,  1.53it/s]

{'eval_loss': 0.7271575927734375, 'eval_runtime': 4.0855, 'eval_samples_per_second': 2.937, 'eval_steps_per_second': 0.49, 'epoch': 2.0}


 42%|████▏     | 25/60 [00:51<03:29,  5.99s/it]

{'loss': 0.8329, 'learning_rate': 1.1666666666666668e-05, 'epoch': 2.08}


 43%|████▎     | 26/60 [00:52<02:27,  4.35s/it]

{'loss': 0.7254, 'learning_rate': 1.1333333333333334e-05, 'epoch': 2.17}


 45%|████▌     | 27/60 [00:52<01:46,  3.22s/it]

{'loss': 0.73, 'learning_rate': 1.1000000000000001e-05, 'epoch': 2.25}


 47%|████▋     | 28/60 [00:53<01:17,  2.43s/it]

{'loss': 0.6515, 'learning_rate': 1.0666666666666667e-05, 'epoch': 2.33}


 48%|████▊     | 29/60 [00:54<00:58,  1.87s/it]

{'loss': 0.7709, 'learning_rate': 1.0333333333333335e-05, 'epoch': 2.42}


 50%|█████     | 30/60 [00:54<00:43,  1.46s/it]

{'loss': 0.741, 'learning_rate': 1e-05, 'epoch': 2.5}


 52%|█████▏    | 31/60 [00:55<00:34,  1.18s/it]

{'loss': 0.7195, 'learning_rate': 9.666666666666667e-06, 'epoch': 2.58}


 53%|█████▎    | 32/60 [00:55<00:27,  1.00it/s]

{'loss': 0.7575, 'learning_rate': 9.333333333333334e-06, 'epoch': 2.67}


 55%|█████▌    | 33/60 [00:56<00:23,  1.16it/s]

{'loss': 0.6374, 'learning_rate': 9e-06, 'epoch': 2.75}


 57%|█████▋    | 34/60 [00:56<00:19,  1.32it/s]

{'loss': 0.6527, 'learning_rate': 8.666666666666668e-06, 'epoch': 2.83}


 58%|█████▊    | 35/60 [00:57<00:17,  1.41it/s]

{'loss': 0.6743, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.92}


 60%|██████    | 36/60 [00:57<00:16,  1.49it/s]

{'loss': 0.7407, 'learning_rate': 8.000000000000001e-06, 'epoch': 3.0}


                                               
 60%|██████    | 36/60 [01:16<00:16,  1.49it/s]

{'eval_loss': 0.7168386578559875, 'eval_runtime': 4.1825, 'eval_samples_per_second': 2.869, 'eval_steps_per_second': 0.478, 'epoch': 3.0}


 62%|██████▏   | 37/60 [01:17<02:25,  6.35s/it]

{'loss': 0.782, 'learning_rate': 7.666666666666667e-06, 'epoch': 3.08}


 63%|██████▎   | 38/60 [01:18<01:41,  4.63s/it]

{'loss': 0.8069, 'learning_rate': 7.333333333333333e-06, 'epoch': 3.17}


 65%|██████▌   | 39/60 [01:18<01:12,  3.43s/it]

{'loss': 0.7482, 'learning_rate': 7e-06, 'epoch': 3.25}


 67%|██████▋   | 40/60 [01:19<00:51,  2.59s/it]

{'loss': 0.6373, 'learning_rate': 6.666666666666667e-06, 'epoch': 3.33}


 68%|██████▊   | 41/60 [01:20<00:37,  2.00s/it]

{'loss': 0.7137, 'learning_rate': 6.333333333333333e-06, 'epoch': 3.42}


 70%|███████   | 42/60 [01:20<00:28,  1.58s/it]

{'loss': 0.7217, 'learning_rate': 6e-06, 'epoch': 3.5}


 72%|███████▏  | 43/60 [01:21<00:21,  1.28s/it]

{'loss': 0.6666, 'learning_rate': 5.666666666666667e-06, 'epoch': 3.58}


 73%|███████▎  | 44/60 [01:21<00:17,  1.07s/it]

{'loss': 0.8221, 'learning_rate': 5.333333333333334e-06, 'epoch': 3.67}


 75%|███████▌  | 45/60 [01:22<00:14,  1.07it/s]

{'loss': 0.6367, 'learning_rate': 5e-06, 'epoch': 3.75}


 77%|███████▋  | 46/60 [01:23<00:11,  1.20it/s]

{'loss': 0.6161, 'learning_rate': 4.666666666666667e-06, 'epoch': 3.83}


 78%|███████▊  | 47/60 [01:23<00:09,  1.31it/s]

{'loss': 0.6651, 'learning_rate': 4.333333333333334e-06, 'epoch': 3.92}


 80%|████████  | 48/60 [01:24<00:08,  1.37it/s]

{'loss': 0.7697, 'learning_rate': 4.000000000000001e-06, 'epoch': 4.0}


                                               
 80%|████████  | 48/60 [01:42<00:08,  1.37it/s]

{'eval_loss': 0.7099848389625549, 'eval_runtime': 3.9243, 'eval_samples_per_second': 3.058, 'eval_steps_per_second': 0.51, 'epoch': 4.0}


 82%|████████▏ | 49/60 [01:42<01:07,  6.10s/it]

{'loss': 0.8004, 'learning_rate': 3.6666666666666666e-06, 'epoch': 4.08}


 83%|████████▎ | 50/60 [01:43<00:44,  4.43s/it]

{'loss': 0.8062, 'learning_rate': 3.3333333333333333e-06, 'epoch': 4.17}


 85%|████████▌ | 51/60 [01:43<00:29,  3.25s/it]

{'loss': 0.6965, 'learning_rate': 3e-06, 'epoch': 4.25}


 87%|████████▋ | 52/60 [01:44<00:19,  2.43s/it]

{'loss': 0.6102, 'learning_rate': 2.666666666666667e-06, 'epoch': 4.33}


 88%|████████▊ | 53/60 [01:44<00:13,  1.87s/it]

{'loss': 0.7549, 'learning_rate': 2.3333333333333336e-06, 'epoch': 4.42}


 90%|█████████ | 54/60 [01:45<00:08,  1.49s/it]

{'loss': 0.7021, 'learning_rate': 2.0000000000000003e-06, 'epoch': 4.5}


 92%|█████████▏| 55/60 [01:46<00:06,  1.22s/it]

{'loss': 0.717, 'learning_rate': 1.6666666666666667e-06, 'epoch': 4.58}


 93%|█████████▎| 56/60 [01:46<00:04,  1.04s/it]

{'loss': 0.7023, 'learning_rate': 1.3333333333333334e-06, 'epoch': 4.67}


 95%|█████████▌| 57/60 [01:47<00:02,  1.10it/s]

{'loss': 0.6489, 'learning_rate': 1.0000000000000002e-06, 'epoch': 4.75}


 97%|█████████▋| 58/60 [01:47<00:01,  1.26it/s]

{'loss': 0.6614, 'learning_rate': 6.666666666666667e-07, 'epoch': 4.83}


 98%|█████████▊| 59/60 [01:48<00:00,  1.40it/s]

{'loss': 0.6609, 'learning_rate': 3.3333333333333335e-07, 'epoch': 4.92}


100%|██████████| 60/60 [01:49<00:00,  1.50it/s]

{'loss': 0.8036, 'learning_rate': 0.0, 'epoch': 5.0}


                                               
100%|██████████| 60/60 [01:52<00:00,  1.88s/it]

{'eval_loss': 0.7075120806694031, 'eval_runtime': 3.9299, 'eval_samples_per_second': 3.054, 'eval_steps_per_second': 0.509, 'epoch': 5.0}
{'train_runtime': 112.9944, 'train_samples_per_second': 8.496, 'train_steps_per_second': 0.531, 'train_loss': 0.7251347482204438, 'epoch': 5.0}





TrainOutput(global_step=60, training_loss=0.7251347482204438, metrics={'train_runtime': 112.9944, 'train_samples_per_second': 8.496, 'train_steps_per_second': 0.531, 'train_loss': 0.7251347482204438, 'epoch': 5.0})

## 7. Evaluate PEFT Model

In [35]:
# Generate predictions

predictions = trainer.predict(test_dataset)
logits = predictions.predictions
predicted_labels = logits.argmax(axis=-1)
print(predicted_labels)

100%|██████████| 12/12 [00:02<00:00,  4.08it/s]

[0 1 1 0 1 0 0 1 0 1 0 0 0 1 1 0 1 0 1 1 1 1 1 1 0 0 0 0 1 0 1 1 1 0 0 0 0
 1 0 1 1 0 0 1 1 1 0 1 0 0 0 0 1 1 1 0 0 1 1 0 1 0 1 0 0 1 1 0 0 1 0 0 0 0
 1 1 0 1 0 1 1 0 1 1 1 1 1 0 1 1 0 0 0 1 1 0 0 0 1 0 1 1 1 1 1 1 0 0 0 0 0
 0 1 1 0 0 0 0 1 1 0 1 0 0 1 0 1 1 0 0 1 0 1 0 1 0 1 0 1 1 0 1 0 1 1 0 1 1
 0 0 1 0 0 0 1 0 1 0 1 1 0 1 1 0 1 0 0 0 1 0 1 0 0 0 0 1 1 1 1 1 0 0 1 1 0
 1 0 0 0 1 0 0]





In [36]:
def calculate_accuracy(true_labels, predicted_labels):

    assert len(true_labels) == len(predicted_labels), "Arrays must have the same length"
    correct_predictions = np.sum(true_labels == predicted_labels)
    accuracy = correct_predictions / len(true_labels)

    return accuracy

test_labels = []
for data in test_dataset:
    sequence, label = data
    for single_label in label:
        test_labels.append(single_label)

accuracy = calculate_accuracy(test_labels, predicted_labels)
print(f"Accuracy: {accuracy:.2f}")

Accuracy: 0.58


As we can see, the PEFT model achieves similar performance to the baseline model, demonstrating the effectiveness of PEFT in adapting pre-trained models to specific tasks with limited computational resources.

With PEFT, we only train 442,368 parameters, which is 0.49% of the total parameters in the model. This is a significant reduction in computational resources compared to training the entire model from scratch.

We can improve the results by using a larger dataset, fine-tuning the model for more epochs or changing the hyperparameters (rank, learning rate, etc.).