# MindNLP-bigbird_pegasus模型微调
基础模型：google/bigbird-pegasus-large-arxiv
tokenizer：google/bigbird-pegasus-large-arxiv
微调数据集：databricks/databricks-dolly-15k
硬件：Ascend910B1
环境
| Software    | Version                     |
| ----------- | --------------------------- |
| MindSpore   | MindSpore 2.4.0             |
| MindSpore   | MindSpore 0.4.1             |
| CANN        | 8.0                         |
| Python      | Python 3.9                  |
| OS platform | Ubuntu 5.4.0-42-generic     |

## instruction
BigBird-Pegasus 是基于 BigBird 和 Pegasus 的混合模型，结合了两者的优势，专为处理长文本序列设计。BigBird 是一种基于 Transformer 的模型，通过稀疏注意力机制处理长序列，降低计算复杂度。Pegasus 是专为文本摘要设计的模型，通过自监督预训练任务（GSG）提升摘要生成能力。BigBird-Pegasus 结合了 BigBird 的长序列处理能力和 Pegasus 的摘要生成能力，适用于长文本摘要任务，如学术论文和长文档摘要。
Databricks Dolly 15k 是由 Databricks 发布的高质量指令微调数据集，包含约 15,000 条人工生成的指令-响应对，用于训练和评估对话模型。是专门为NLP模型微调设计的数据集。
## train loss

对比微调训练的loss变化

| epoch | mindnlp+mindspore | transformer+torch（4060） |
| ----- | ----------------- | ------------------------- |
| 1     | 2.0958            | 8.7301                    |
| 2     | 1.969             | 8.1557                    |
| 3     | 1.8755            | 7.7516                    |
| 4     | 1.8264            | 7.5017                    |
| 5     | 1.7349            | 7.2614                    |
| 6     | 1.678             | 7.0559                    |
| 7     | 1.6937            | 6.8405                    |
| 8     | 1.654             | 6.7297                    |
| 9     | 1.6365            | 6.7136                    |
| 10    | 1.7003            | 6.6279                    |

## eval loss

对比评估得分

| epoch | mindnlp+mindspore  | transformer+torch（4060） |
| ----- | ------------------ | ------------------------- |
| 1     | 2.1257965564727783 | 6.3235931396484375        |

**首先运行以下脚本配置环境**

In [None]:
# 在Ascend910B1环境需要额外安装以下
# !pip install mindnlp
# !pip install mindspore==2.4
# !export LD_PRELOAD=$LD_PRELOAD:/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/torch.libs/libgomp-74ff64e9.so.1.0.0
# !yum install libsndfile

Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
Collecting mindnlp
  Downloading http://mirrors.aliyun.com/pypi/packages/0f/a8/5a072852d28a51417b5e330b32e6ae5f26b491ef01a15ba968e77f785e69/mindnlp-0.4.0-py3-none-any.whl (8.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting datasets (from mindnlp)
  Downloading http://mirrors.aliyun.com/pypi/packages/4c/37/22ef7675bef4ffe9577b937ddca2e22791534cbbe11c30714972a91532dc/datasets-3.3.2-py3-none-any.whl (485 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting evaluate (from mindnlp)
  Downloading http://mirrors.aliyun.com/pypi/packages/a2/e7/cbca9e2d2590eb9b5aa8f7ebabe1beb1498f9462d2ecede5c9fd9735faaf/evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m1

## 导入库
注意这里曾经导入了多个Tokenizer进行过测试。
要设置mindspore工作环境为Ascend。

In [2]:
import os
from mindnlp.transformers import (
    BigBirdPegasusForCausalLM, 
    PegasusTokenizer,
    AutoTokenizer
)
from datasets import load_dataset, DatasetDict
from mindspore.dataset import GeneratorDataset
from mindnlp.engine import Trainer, TrainingArguments
import mindspore as ms
# 设置运行模式和设备
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend")

  from .autonotebook import tqdm as notebook_tqdm
Building prefix dict from the default dictionary ...
Dumping model to file cache /tmp/jieba.cache
Loading model cost 1.375 seconds.
Prefix dict has been built successfully.


## 处理数据集
这里为了快速多次微调，数据集经过处理后保存到本地。需要注意的是这里使用BigBirdPegasusForCausalLM，使用的是语言模型，需要将数据集进行处理。

In [3]:
# 定义数据集保存路径
dataset_path = "./processed_dataset"
# 检查是否存在处理好的数据集
if os.path.exists(dataset_path):
    dataset = DatasetDict.load_from_disk(dataset_path)
    train_dataset = dataset["train"]
    eval_dataset = dataset["eval"]
else:
    # 加载和处理数据集
    dataset = load_dataset("databricks/databricks-dolly-15k")
    print(dataset)

    def format_prompt(sample):
        instruction = f"### Instruction\n{sample['instruction']}"
        context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
        response = f"### Answer\n{sample['response']}"
        prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
        sample["prompt"] = prompt
        return sample

    dataset = dataset.map(format_prompt)
    dataset = dataset.remove_columns(['instruction', 'context', 'response', 'category'])
    train_dataset = dataset["train"].select(range(0, 40))
    eval_dataset = dataset["train"].select(range(40, 50))
    # print(train_dataset)
    # print(eval_dataset)
    # print(train_dataset[0])
    # 保存处理好的数据集
    dataset = DatasetDict({"train": train_dataset, "eval": eval_dataset})
    dataset.save_to_disk(dataset_path)

## 加载模型
在mindnlp中没有找到类似BigBirdPegasusTokenizer的类，所以使用AutoTokenizer。查阅mindnlp，发现有个例程还可以使用PegasusTokenizer，都进行了尝试。


In [4]:
model_name = "google/bigbird-pegasus-large-arxiv"
tokenizer_name = "google/bigbird-pegasus-large-arxiv"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# tokenizer = PegasusTokenizer.from_pretrained(tokenizer_name)
tokenizer.pad_token = tokenizer.eos_token 
model = BigBirdPegasusForCausalLM.from_pretrained(model_name)

BigBirdPegasusForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB




## 将数据集预处理为训练格式
这里在mindnlp中没有找到类似transformer中DataCollatorForLanguageModeling的工具，所以需要自己编写padding和truncation。
这里输出了处理过的数据集与torch的进行对比，保证获得的数据集是一样的。

In [5]:
class TextDataset:
    def __init__(self, data):
        self.data = data
    # 这里就是个padding和truncation截断的操作
    def __getitem__(self, index):
        index = int(index)
        text = self.data[index]["prompt"]
        inputs = tokenizer(text, padding='max_length', max_length=256, truncation=True)
        return (
            inputs["input_ids"], 
            inputs["attention_mask"],
            inputs["input_ids"]  # 添加labels
        )

    def __len__(self):
        return len(self.data)
train_dataset = GeneratorDataset(
    TextDataset(train_dataset),
    column_names=["input_ids", "attention_mask", "labels"],  # 添加labels
    shuffle=True
)
eval_dataset = GeneratorDataset(
    TextDataset(eval_dataset),
    column_names=["input_ids", "attention_mask", "labels"],  # 添加labels
    shuffle=False
)
print("train_dataset:", train_dataset)
print("eval_dataset:", eval_dataset)
for data in train_dataset.create_dict_iterator():
    print(data)
    break

train_dataset: <mindspore.dataset.engine.datasets_user_defined.GeneratorDataset object at 0xffff404b6430>
eval_dataset: <mindspore.dataset.engine.datasets_user_defined.GeneratorDataset object at 0xffff45782430>
{'input_ids': Tensor(shape=[256], dtype=Int64, value= [  110, 63444, 26323,   463,   117,   114,   110, 84040,  5551, 41676,   152,   110, 63444, 30058,   222, 22600,   108,   114,   110, 84040,  5551, 41676,   117,   142, 
  8091, 41676,   120,   117,   263,   112, 37525,   523,   108,   120,   117,   108,   112,  1910,   523,   190,   203, 31059,  2274,   143,   544,  1613,   113,   109, 
 12091,   250, 10008, 44069,   143, 10209,   116,   158,   113,   523,   138,   129, 53136,   141,   109, 41676,   134,   291, 10269,   107,   182,   117,   114,   711, 
   113,   109, 41676,  1001,   131,   116,  4224,   113, 67669,  7775,   122, 30671,   143, 84040,  2928,   250, 10879,   108,   895, 44069,   143,  6388,   158, 11213, 
   114,  1934, 28593,   197,  6306, 44069,   143, 11753

## 配置trainer并train
这里参数要与torch的训练参数一致，记录当前训练的loss变换然后对比

In [6]:
EPOCHS = 10
BATCH_SIZE = 4
# 定义训练参数
training_args = TrainingArguments(
    output_dir='./MindsporeBigBirdFinetune',
    overwrite_output_dir=True,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    
    save_steps=500,                  # Save checkpoint every 500 steps
    save_total_limit=2,              # Keep only the last 2 checkpoints
    logging_dir="./logs",            # Directory for logs
    logging_steps=100,               # Log every 100 steps
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    eval_steps=500,                  # Evaluation frequency
    learning_rate=5e-5,
    weight_decay=0.01,               # Weight decay
)

# 创建trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=None
)
trainer.train()

  0%|          | 0/100 [00:00<?, ?it/s]

\

  1%|          | 1/100 [00:28<47:21, 28.70s/it]

|

 10%|█         | 10/100 [00:38<01:38,  1.10s/it]

{'loss': 2.0958, 'learning_rate': 4.5e-05, 'epoch': 1.0}



  0%|          | 0/3 [00:00<?, ?it/s][A
 67%|██████▋   | 2/3 [00:01<00:00,  1.63it/s][A

/

                                                
 10%|█         | 10/100 [00:43<01:38,  1.10s/it]
100%|██████████| 3/3 [00:04<00:00,  1.63it/s][A
                                             [A

{'eval_loss': 2.592344045639038, 'eval_runtime': 4.9288, 'eval_samples_per_second': 0.609, 'eval_steps_per_second': 0.203, 'epoch': 1.0}


 20%|██        | 20/100 [00:50<01:04,  1.24it/s]

{'loss': 1.969, 'learning_rate': 4e-05, 'epoch': 2.0}



  0%|          | 0/3 [00:00<?, ?it/s][A
                                                
 20%|██        | 20/100 [00:50<01:04,  1.24it/s]
100%|██████████| 3/3 [00:00<00:00, 19.53it/s][A
                                             [A

{'eval_loss': 2.486072063446045, 'eval_runtime': 0.2738, 'eval_samples_per_second': 10.956, 'eval_steps_per_second': 3.652, 'epoch': 2.0}


 30%|███       | 30/100 [00:57<00:46,  1.50it/s]

{'loss': 1.8755, 'learning_rate': 3.5e-05, 'epoch': 3.0}



  0%|          | 0/3 [00:00<?, ?it/s][A
                                                
 30%|███       | 30/100 [00:57<00:46,  1.50it/s]
100%|██████████| 3/3 [00:00<00:00, 22.78it/s][A
                                             [A

{'eval_loss': 2.367415189743042, 'eval_runtime': 0.2442, 'eval_samples_per_second': 12.283, 'eval_steps_per_second': 4.094, 'epoch': 3.0}


 40%|████      | 40/100 [01:04<00:39,  1.54it/s]

{'loss': 1.8264, 'learning_rate': 3e-05, 'epoch': 4.0}



  0%|          | 0/3 [00:00<?, ?it/s][A
                                                
 40%|████      | 40/100 [01:04<00:39,  1.54it/s]
100%|██████████| 3/3 [00:00<00:00, 24.96it/s][A
                                             [A

{'eval_loss': 2.3535046577453613, 'eval_runtime': 0.241, 'eval_samples_per_second': 12.45, 'eval_steps_per_second': 4.15, 'epoch': 4.0}


 50%|█████     | 50/100 [01:11<00:34,  1.45it/s]

{'loss': 1.7349, 'learning_rate': 2.5e-05, 'epoch': 5.0}



  0%|          | 0/3 [00:00<?, ?it/s][A
                                                
 50%|█████     | 50/100 [01:11<00:34,  1.45it/s]
100%|██████████| 3/3 [00:00<00:00, 22.24it/s][A
                                             [A

{'eval_loss': 2.2972629070281982, 'eval_runtime': 0.2457, 'eval_samples_per_second': 12.21, 'eval_steps_per_second': 4.07, 'epoch': 5.0}


 60%|██████    | 60/100 [01:18<00:24,  1.61it/s]

{'loss': 1.678, 'learning_rate': 2e-05, 'epoch': 6.0}



  0%|          | 0/3 [00:00<?, ?it/s][A
                                                
 60%|██████    | 60/100 [01:18<00:24,  1.61it/s]
100%|██████████| 3/3 [00:00<00:00, 24.91it/s][A
                                             [A

{'eval_loss': 2.195664882659912, 'eval_runtime': 0.2324, 'eval_samples_per_second': 12.91, 'eval_steps_per_second': 4.303, 'epoch': 6.0}


 70%|███████   | 70/100 [01:25<00:20,  1.44it/s]

{'loss': 1.6937, 'learning_rate': 1.5e-05, 'epoch': 7.0}



  0%|          | 0/3 [00:00<?, ?it/s][A
                                                
 70%|███████   | 70/100 [01:25<00:20,  1.44it/s]
100%|██████████| 3/3 [00:00<00:00, 21.99it/s][A
                                             [A

{'eval_loss': 2.1624794006347656, 'eval_runtime': 0.2587, 'eval_samples_per_second': 11.596, 'eval_steps_per_second': 3.865, 'epoch': 7.0}


 80%|████████  | 80/100 [01:32<00:13,  1.48it/s]

{'loss': 1.654, 'learning_rate': 1e-05, 'epoch': 8.0}



  0%|          | 0/3 [00:00<?, ?it/s][A
                                                
 80%|████████  | 80/100 [01:32<00:13,  1.48it/s]
100%|██████████| 3/3 [00:00<00:00, 23.14it/s][A
                                             [A

{'eval_loss': 2.159714460372925, 'eval_runtime': 0.2363, 'eval_samples_per_second': 12.696, 'eval_steps_per_second': 4.232, 'epoch': 8.0}


 90%|█████████ | 90/100 [01:39<00:06,  1.51it/s]

{'loss': 1.6365, 'learning_rate': 5e-06, 'epoch': 9.0}



  0%|          | 0/3 [00:00<?, ?it/s][A
                                                
 90%|█████████ | 90/100 [01:39<00:06,  1.51it/s]
100%|██████████| 3/3 [00:00<00:00, 22.68it/s][A
                                             [A

{'eval_loss': 2.1347262859344482, 'eval_runtime': 0.2604, 'eval_samples_per_second': 11.523, 'eval_steps_per_second': 3.841, 'epoch': 9.0}


100%|██████████| 100/100 [01:46<00:00,  1.52it/s]

{'loss': 1.7003, 'learning_rate': 0.0, 'epoch': 10.0}



  0%|          | 0/3 [00:00<?, ?it/s][A
                                                 
100%|██████████| 100/100 [01:46<00:00,  1.52it/s]
100%|██████████| 3/3 [00:00<00:00, 21.63it/s][A
100%|██████████| 100/100 [01:46<00:00,  1.06s/it]

{'eval_loss': 2.1257965564727783, 'eval_runtime': 0.2557, 'eval_samples_per_second': 11.733, 'eval_steps_per_second': 3.911, 'epoch': 10.0}
{'train_runtime': 106.4446, 'train_samples_per_second': 3.758, 'train_steps_per_second': 0.939, 'train_loss': 1.7863994789123536, 'epoch': 10.0}





TrainOutput(global_step=100, training_loss=1.7863994789123536, metrics={'train_runtime': 106.4446, 'train_samples_per_second': 3.758, 'train_steps_per_second': 0.939, 'train_loss': 1.7863994789123536, 'epoch': 10.0})

## 查看评估结果

In [7]:
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

100%|██████████| 3/3 [00:00<00:00, 15.78it/s]

Evaluation results: {'eval_loss': 2.1257965564727783, 'eval_runtime': 0.3007, 'eval_samples_per_second': 9.977, 'eval_steps_per_second': 3.326, 'epoch': 10.0}





## 保存微调结果

In [8]:
model.save_pretrained("./mindNLPModelBigbirdPegasusFinetune")
tokenizer.save_pretrained("./mindNLPTokenizerBigbirdPegasusFinetune")

Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file instead.
Non-default generation parameters: {'max_length': 256, 'num_beams': 5, 'length_penalty': 0.8}


('./mindNLPTokenizerBigbirdPegasusFinetune/tokenizer_config.json',
 './mindNLPTokenizerBigbirdPegasusFinetune/special_tokens_map.json',
 './mindNLPTokenizerBigbirdPegasusFinetune/spiece.model',
 './mindNLPTokenizerBigbirdPegasusFinetune/added_tokens.json',
 './mindNLPTokenizerBigbirdPegasusFinetune/tokenizer.json')

## 使用微调模型进行测试
虽然loss不断下降并且比torch的更好。但是由于两个都是短暂微调训练，可以看到语言模型实际效果并不好，输出结果不解其意。

In [9]:
fine_tuned_model = BigBirdPegasusForCausalLM.from_pretrained("./mindNLPModelBigbirdPegasusFinetune")
fine_tuned_tokenizer = PegasusTokenizer.from_pretrained("./mindNLPTokenizerBigbirdPegasusFinetune")
inputs = "Hello, my dog is cute"
input_tokens = fine_tuned_tokenizer(inputs, return_tensors="ms")
outputs = fine_tuned_model(**input_tokens)
logits = outputs.logits
# 使用 argmax 获取预测的 token ID
from mindspore import ops
predicted_token_ids = ops.argmax(logits, dim=-1)  # 在最后一个维度（vocab_size）上取 argmax
# 解码生成的文本
generated_text = fine_tuned_tokenizer.decode(predicted_token_ids[0].asnumpy().tolist(), skip_special_tokens=True)
print(generated_text)

in,, have a but
