# PEFT 库 LoRA 实战 - OpenAI Whisper-large-v2

本教程使用 LoRA 在`OpenAI Whisper-large-v2`模型上实现`语音识别(ASR)`任务的微调训练。

我们还结合了`int8` 量化进一步降低训练过程资源开销，同时保证了精度几乎不受影响。

## 全局参数设置

In [1]:
model_name_or_path = "openai/whisper-large-v2"
language = "Chinese (China)"
language_abbr = "zh-CN"
task = "transcribe"
dataset_name = "mozilla-foundation/common_voice_11_0"

batch_size=64

## 下载数据集 Common Voice

Common Voice 11.0 数据集包含许多不同语言的录音，总时长达数小时。

本教程以中文数据为例，展示如何使用 LoRA 在 Whisper-large-v2 上进行微调训练。

首先，初始化一个DatasetDict结构，并将训练集（将训练+验证拆分为训练集）和测试集拆分好，按照中文数据集构建配置加载到内存中：

In [5]:
from datasets import load_dataset
from datasets import load_dataset, DatasetDict
from datasets import config
from datasets import config
config.HF_DATASETS_CACHE = 'D:\cache\huggingface\datasets'

common_voice = DatasetDict()

common_voice["train"] = load_dataset(dataset_name, language_abbr, split="train+validation", cache_dir='D:\cache\huggingface\datasets')
common_voice["test"] = load_dataset(dataset_name, language_abbr, split="test", cache_dir='D:\cache\huggingface\datasets')
common_voice["train"][0]

{'client_id': '95368aab163e0387e4fd4991b4f2d8ccfbd4364bf656c860230501fd27dcedf087773e4695a6cf5de9c4f1d406d582283190d065cdfa36b0e2b060cffaca977e',
 'path': 'C:\\Users\\17972\\.cache\\huggingface\\datasets\\downloads\\extracted\\fb5fccc6e0e7604d5611e4748bfd3bb73c081c4fbe2c69cdf2d65cb406bddad9\\zh-CN_train_0/common_voice_zh-CN_33211332.mp3',
 'audio': {'path': 'C:\\Users\\17972\\.cache\\huggingface\\datasets\\downloads\\extracted\\fb5fccc6e0e7604d5611e4748bfd3bb73c081c4fbe2c69cdf2d65cb406bddad9\\zh-CN_train_0/common_voice_zh-CN_33211332.mp3',
  'array': array([-9.09494702e-13, -2.50111043e-12, -2.04636308e-12, ...,
          1.21667417e-05,  3.23003815e-06, -2.43064278e-07]),
  'sampling_rate': 48000},
 'sentence': '性喜温暖润湿气候且耐寒。',
 'up_votes': 2,
 'down_votes': 0,
 'age': '',
 'gender': '',
 'accent': '',
 'locale': 'zh-CN',
 'segment': ''}

## 预处理训练数据集


In [6]:
from transformers import AutoFeatureExtractor, AutoTokenizer, AutoProcessor

feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path, language=language, task=task)

processor = AutoProcessor.from_pretrained(
    model_name_or_path, language=language, task=task)



#### 移除数据集中不必要的字段

In [7]:
common_voice = common_voice.remove_columns(
    ["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]
)

In [8]:
common_voice["train"][0]

{'audio': {'path': 'C:\\Users\\17972\\.cache\\huggingface\\datasets\\downloads\\extracted\\fb5fccc6e0e7604d5611e4748bfd3bb73c081c4fbe2c69cdf2d65cb406bddad9\\zh-CN_train_0/common_voice_zh-CN_33211332.mp3',
  'array': array([-9.09494702e-13, -2.50111043e-12, -2.04636308e-12, ...,
          1.21667417e-05,  3.23003815e-06, -2.43064278e-07]),
  'sampling_rate': 48000},
 'sentence': '性喜温暖润湿气候且耐寒。'}

#### 降采样音频数据

查看`common_voice` 数据集介绍，你会发现其音频是以48kHz的采样率进行采样的.

而`Whisper`模型是在16kHZ的音频输入上预训练的，因此我们需要将音频输入降采样以匹配模型预训练时使用的采样率。

通过在音频列上使用`cast_column`方法，并将`sampling_rate`设置为16kHz来对音频进行降采样。

下次调用时，音频输入将实时重新取样：

In [9]:
from datasets import Audio

common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

In [10]:
common_voice["train"][0]

{'audio': {'path': 'C:\\Users\\17972\\.cache\\huggingface\\datasets\\downloads\\extracted\\fb5fccc6e0e7604d5611e4748bfd3bb73c081c4fbe2c69cdf2d65cb406bddad9\\zh-CN_train_0/common_voice_zh-CN_33211332.mp3',
  'array': array([ 5.82076609e-11, -2.91038305e-11, -5.82076609e-11, ...,
         -5.96660539e-06,  2.71383760e-05,  1.29687833e-05]),
  'sampling_rate': 16000},
 'sentence': '性喜温暖润湿气候且耐寒。'}

### 整合以上数据处理为一个函数

该数据预处理函数应该包括：
- 通过加载音频列将音频输入重新采样为16kHZ。
- 使用特征提取器从音频数组计算输入特征。
- 将句子列标记化为输入标签。

In [11]:
def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

In [13]:
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"])

Map:   0%|          | 0/39637 [00:00<?, ? examples/s]

Map: 100%|██████████| 39637/39637 [19:22<00:00, 34.10 examples/s]  
Map: 100%|██████████| 10581/10581 [05:24<00:00, 32.62 examples/s]


创建一个`DataCollator`类来将每个批次中的`attention_mask`填充到最大长度，并用`-100`替换填充值，以便在损失函数中被忽略。

然后初始化数据收集器的实例：

In [14]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

## 训练模型

In [15]:
from transformers import AutoModelForSpeechSeq2Seq

model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")

config.json: 100%|██████████| 1.99k/1.99k [00:00<?, ?B/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
model.safetensors:  58%|█████▊    | 3.61G/6.17G [02:58<05:23, 7.94MB/s]  Error while downloading from https://cdn-lfs.huggingface.co/repos/1b/17/1b172f71c4dcb8f01bca81700e3d8c876e60bf12d7a74336f4aff73e094829b0/57a1ba2a82c093cabff2541409ae778c97145378b9ddfa722763cb1cb8f9020b?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1706536584&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwNjUzNjU4NH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8xYi8xNy8xYjE3MmY3MWM0ZGNiOGYwMWJjYTgxNzAwZTNkOGM4NzZlNjBiZjEyZDdhNzQzMzZmNGFmZjczZTA5NDg

bin d:\anaconda3\envs\llm-test\lib\site-packages\bitsandbytes\libbitsandbytes_cuda121.dll


generation_config.json: 100%|██████████| 4.29k/4.29k [00:00<?, ?B/s]


In [16]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

为了准备模型进行int8量化，使用 `prepare_model_for_int8_training` 函数来处理模型：
- 将所有非int8模块转换为完全精度（fp32）以保持稳定性
- 在输入嵌入层上添加前向钩子，计算输入隐藏状态的梯度
- 启用渐变检查点以进行更高效的内存训练

In [17]:
from peft import prepare_model_for_int8_training

model = prepare_model_for_int8_training(model)



In [18]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

config = LoraConfig(
    r=8,
    lora_alpha=64,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none")

In [19]:
model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 3,932,160 || all params: 1,547,237,120 || trainable%: 0.25414074863974306


### 演示需要，只训练了100 steps。建议同学改为默认的 3个 epochs 完整训练一个中文语音识别模型。

In [24]:
from transformers import Seq2SeqTrainingArguments
import os

# 设置序列到序列模型训练的参数
training_args = Seq2SeqTrainingArguments(
    output_dir="models/whisper-large-v2-asr-int8",  # 指定模型输出和保存的目录
    per_device_train_batch_size=batch_size,  # 每个设备上的训练批量大小
    gradient_accumulation_steps=1,  # 梯度累积步数，在每次优化器步骤之前累积的更新步数
    learning_rate=1e-3,  # 学习率
    warmup_steps=50,  # 在训练初期增加学习率的步数，有助于稳定训练
    # max_steps=100, # 训练总步数
    num_train_epochs=3,  # 训练的总轮数
    evaluation_strategy="epoch",  # 设置评估策略，这里是在每个epoch结束时进行评估
    fp16=True,  # 启用混合精度训练，可以提高训练速度，同时减少内存使用
    per_device_eval_batch_size=batch_size,  # 每个设备上的评估批量大小
    generation_max_length=128,  # 生成任务的最大长度
    logging_steps=25,  # 指定日志记录的步骤，用于跟踪训练进度
    remove_unused_columns=False,  # 是否删除不使用的列，以减少数据处理开销
    label_names=["labels"],  # 指定标签列的名称，用于训练过程中
)

#### 训练过程保存状态的回调，长时期训练建议使用

In [25]:
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers import Seq2SeqTrainer, TrainerCallback, Seq2SeqTrainingArguments, TrainerState, TrainerControl

class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: Seq2SeqTrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control

In [26]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["test"],
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
    callbacks=[SavePeftModelCallback],
)
model.config.use_cache = False

In [28]:
trainer.train()

 27%|██▋       | 500/1860 [2:20:01<6:20:51, 16.80s/it]
 10%|█         | 191/1860 [36:27<5:18:36, 11.45s/it]
  1%|▏         | 25/1860 [04:44<5:42:09, 11.19s/it]

{'loss': 0.3132, 'learning_rate': 0.0005, 'epoch': 0.04}


  3%|▎         | 50/1860 [09:25<5:36:08, 11.14s/it]

{'loss': 0.2698, 'learning_rate': 0.001, 'epoch': 0.08}


  4%|▍         | 75/1860 [14:04<5:34:07, 11.23s/it]

{'loss': 0.3156, 'learning_rate': 0.0009861878453038674, 'epoch': 0.12}


  5%|▌         | 100/1860 [18:44<5:27:50, 11.18s/it]

{'loss': 0.3173, 'learning_rate': 0.0009723756906077348, 'epoch': 0.16}


  7%|▋         | 125/1860 [23:25<5:27:02, 11.31s/it]

{'loss': 0.32, 'learning_rate': 0.0009585635359116023, 'epoch': 0.2}


  8%|▊         | 150/1860 [28:05<5:18:23, 11.17s/it]

{'loss': 0.3224, 'learning_rate': 0.0009447513812154696, 'epoch': 0.24}


  9%|▉         | 175/1860 [32:48<5:17:14, 11.30s/it]

{'loss': 0.3361, 'learning_rate': 0.000930939226519337, 'epoch': 0.28}


 11%|█         | 200/1860 [37:31<5:11:04, 11.24s/it]

{'loss': 0.2991, 'learning_rate': 0.0009171270718232044, 'epoch': 0.32}


 12%|█▏        | 225/1860 [42:11<5:07:03, 11.27s/it]

{'loss': 0.3268, 'learning_rate': 0.0009033149171270718, 'epoch': 0.36}


 13%|█▎        | 250/1860 [46:52<5:01:07, 11.22s/it]

{'loss': 0.3196, 'learning_rate': 0.0008895027624309392, 'epoch': 0.4}


 15%|█▍        | 275/1860 [51:33<4:53:04, 11.09s/it]

{'loss': 0.3295, 'learning_rate': 0.0008756906077348066, 'epoch': 0.44}


 16%|█▌        | 300/1860 [56:08<4:46:42, 11.03s/it]

{'loss': 0.3273, 'learning_rate': 0.0008618784530386741, 'epoch': 0.48}


 17%|█▋        | 325/1860 [1:00:41<4:38:41, 10.89s/it]

{'loss': 0.3084, 'learning_rate': 0.0008480662983425415, 'epoch': 0.52}


 19%|█▉        | 350/1860 [1:05:16<4:36:40, 10.99s/it]

{'loss': 0.3346, 'learning_rate': 0.0008342541436464089, 'epoch': 0.56}


 20%|██        | 375/1860 [1:09:49<4:31:02, 10.95s/it]

{'loss': 0.3257, 'learning_rate': 0.0008204419889502763, 'epoch': 0.6}


 22%|██▏       | 400/1860 [1:14:28<4:33:54, 11.26s/it]

{'loss': 0.332, 'learning_rate': 0.0008066298342541437, 'epoch': 0.65}


 23%|██▎       | 425/1860 [1:19:11<4:29:09, 11.25s/it]

{'loss': 0.3324, 'learning_rate': 0.0007928176795580111, 'epoch': 0.69}


 24%|██▍       | 450/1860 [1:23:52<4:24:01, 11.24s/it]

{'loss': 0.3042, 'learning_rate': 0.0007790055248618785, 'epoch': 0.73}


 26%|██▌       | 475/1860 [1:28:33<4:17:43, 11.17s/it]

{'loss': 0.2899, 'learning_rate': 0.0007651933701657459, 'epoch': 0.77}


 27%|██▋       | 500/1860 [1:33:13<4:14:39, 11.24s/it]Checkpoint destination directory models/whisper-large-v2-asr-int8\checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.


{'loss': 0.3302, 'learning_rate': 0.0007513812154696133, 'epoch': 0.81}


 28%|██▊       | 525/1860 [1:37:54<4:08:24, 11.16s/it]

{'loss': 0.303, 'learning_rate': 0.0007375690607734806, 'epoch': 0.85}


 30%|██▉       | 550/1860 [1:42:34<4:04:35, 11.20s/it]

{'loss': 0.3222, 'learning_rate': 0.0007237569060773481, 'epoch': 0.89}


 31%|███       | 575/1860 [1:47:14<3:59:45, 11.19s/it]

{'loss': 0.3188, 'learning_rate': 0.0007099447513812155, 'epoch': 0.93}


 32%|███▏      | 600/1860 [1:51:55<3:57:42, 11.32s/it]

{'loss': 0.3184, 'learning_rate': 0.0006961325966850829, 'epoch': 0.97}


                                                      
 33%|███▎      | 620/1860 [2:22:00<2:46:08,  8.04s/it]

{'eval_loss': 0.26249778270721436, 'eval_runtime': 1592.2585, 'eval_samples_per_second': 6.645, 'eval_steps_per_second': 0.104, 'epoch': 1.0}


 34%|███▎      | 625/1860 [2:23:03<43:13:25, 126.00s/it] 

{'loss': 0.3244, 'learning_rate': 0.0006823204419889503, 'epoch': 1.01}


 35%|███▍      | 650/1860 [2:27:54<3:55:51, 11.70s/it]  

{'loss': 0.2272, 'learning_rate': 0.0006685082872928176, 'epoch': 1.05}


 36%|███▋      | 675/1860 [2:32:43<3:48:09, 11.55s/it]

{'loss': 0.243, 'learning_rate': 0.0006546961325966851, 'epoch': 1.09}


 38%|███▊      | 700/1860 [2:37:30<3:41:42, 11.47s/it]

{'loss': 0.2614, 'learning_rate': 0.0006408839779005525, 'epoch': 1.13}


 39%|███▉      | 725/1860 [2:42:18<3:33:57, 11.31s/it]

{'loss': 0.2381, 'learning_rate': 0.0006270718232044199, 'epoch': 1.17}


 40%|████      | 750/1860 [2:46:57<3:25:12, 11.09s/it]

{'loss': 0.2707, 'learning_rate': 0.0006132596685082873, 'epoch': 1.21}


 42%|████▏     | 775/1860 [2:51:35<3:21:51, 11.16s/it]

{'loss': 0.2479, 'learning_rate': 0.0005994475138121546, 'epoch': 1.25}


 43%|████▎     | 800/1860 [2:56:12<3:16:30, 11.12s/it]

{'loss': 0.2387, 'learning_rate': 0.000585635359116022, 'epoch': 1.29}


 44%|████▍     | 825/1860 [3:00:49<3:11:13, 11.09s/it]

{'loss': 0.2806, 'learning_rate': 0.0005718232044198896, 'epoch': 1.33}


 46%|████▌     | 850/1860 [3:05:27<3:06:06, 11.06s/it]

{'loss': 0.2505, 'learning_rate': 0.000558011049723757, 'epoch': 1.37}


 47%|████▋     | 875/1860 [3:10:05<3:02:26, 11.11s/it]

{'loss': 0.2406, 'learning_rate': 0.0005441988950276244, 'epoch': 1.41}


 48%|████▊     | 900/1860 [3:14:42<2:57:17, 11.08s/it]

{'loss': 0.2298, 'learning_rate': 0.0005303867403314917, 'epoch': 1.45}


 50%|████▉     | 925/1860 [3:19:19<2:53:17, 11.12s/it]

{'loss': 0.2327, 'learning_rate': 0.0005165745856353591, 'epoch': 1.49}


 51%|█████     | 950/1860 [3:23:58<2:49:35, 11.18s/it]

{'loss': 0.2472, 'learning_rate': 0.0005027624309392266, 'epoch': 1.53}


 52%|█████▏    | 975/1860 [3:28:35<2:43:31, 11.09s/it]

{'loss': 0.2366, 'learning_rate': 0.0004889502762430939, 'epoch': 1.57}


 54%|█████▍    | 1000/1860 [3:33:14<2:39:10, 11.11s/it]

{'loss': 0.2463, 'learning_rate': 0.00047513812154696136, 'epoch': 1.61}


 55%|█████▌    | 1025/1860 [3:37:52<2:35:06, 11.15s/it]

{'loss': 0.2506, 'learning_rate': 0.00046132596685082873, 'epoch': 1.65}


 56%|█████▋    | 1050/1860 [3:42:31<2:30:33, 11.15s/it]

{'loss': 0.2732, 'learning_rate': 0.00044751381215469617, 'epoch': 1.69}


 58%|█████▊    | 1075/1860 [3:47:09<2:25:38, 11.13s/it]

{'loss': 0.276, 'learning_rate': 0.00043370165745856354, 'epoch': 1.73}


 59%|█████▉    | 1100/1860 [3:51:47<2:21:17, 11.15s/it]

{'loss': 0.2655, 'learning_rate': 0.0004198895027624309, 'epoch': 1.77}


 60%|██████    | 1125/1860 [3:56:24<2:15:39, 11.07s/it]

{'loss': 0.2523, 'learning_rate': 0.00040607734806629835, 'epoch': 1.81}


 62%|██████▏   | 1150/1860 [4:01:02<2:11:44, 11.13s/it]

{'loss': 0.2565, 'learning_rate': 0.00039226519337016573, 'epoch': 1.85}


 63%|██████▎   | 1175/1860 [4:05:40<2:07:11, 11.14s/it]

{'loss': 0.271, 'learning_rate': 0.0003784530386740331, 'epoch': 1.9}


 65%|██████▍   | 1200/1860 [4:10:18<2:02:19, 11.12s/it]

{'loss': 0.2327, 'learning_rate': 0.0003646408839779006, 'epoch': 1.94}


 66%|██████▌   | 1225/1860 [4:14:57<1:57:48, 11.13s/it]

{'loss': 0.2422, 'learning_rate': 0.000350828729281768, 'epoch': 1.98}


                                                       
 67%|██████▋   | 1240/1860 [4:43:07<1:22:17,  7.96s/it]

{'eval_loss': 0.25436481833457947, 'eval_runtime': 1534.6165, 'eval_samples_per_second': 6.895, 'eval_steps_per_second': 0.108, 'epoch': 2.0}


 67%|██████▋   | 1250/1860 [4:45:02<5:01:31, 29.66s/it]  

{'loss': 0.2237, 'learning_rate': 0.0003370165745856354, 'epoch': 2.02}


 69%|██████▊   | 1275/1860 [4:49:40<1:48:41, 11.15s/it]

{'loss': 0.18, 'learning_rate': 0.0003232044198895028, 'epoch': 2.06}


 70%|██████▉   | 1300/1860 [4:54:18<1:43:21, 11.07s/it]

{'loss': 0.1752, 'learning_rate': 0.00030939226519337016, 'epoch': 2.1}


 71%|███████   | 1325/1860 [4:58:57<1:39:44, 11.19s/it]

{'loss': 0.1906, 'learning_rate': 0.0002955801104972376, 'epoch': 2.14}


 73%|███████▎  | 1350/1860 [5:03:33<1:34:11, 11.08s/it]

{'loss': 0.1895, 'learning_rate': 0.00028176795580110497, 'epoch': 2.18}


 74%|███████▍  | 1375/1860 [5:08:11<1:29:32, 11.08s/it]

{'loss': 0.1948, 'learning_rate': 0.00026795580110497235, 'epoch': 2.22}


 75%|███████▌  | 1400/1860 [5:12:49<1:25:22, 11.14s/it]

{'loss': 0.1825, 'learning_rate': 0.0002541436464088398, 'epoch': 2.26}


 77%|███████▋  | 1425/1860 [5:17:27<1:20:31, 11.11s/it]

{'loss': 0.1697, 'learning_rate': 0.00024033149171270719, 'epoch': 2.3}


 78%|███████▊  | 1450/1860 [5:22:04<1:15:40, 11.07s/it]

{'loss': 0.172, 'learning_rate': 0.0002265193370165746, 'epoch': 2.34}


 79%|███████▉  | 1475/1860 [5:26:42<1:11:11, 11.09s/it]

{'loss': 0.1787, 'learning_rate': 0.000212707182320442, 'epoch': 2.38}


 81%|████████  | 1500/1860 [5:31:20<1:06:58, 11.16s/it]

{'loss': 0.2128, 'learning_rate': 0.0001988950276243094, 'epoch': 2.42}


 82%|████████▏ | 1525/1860 [5:35:58<1:02:00, 11.11s/it]

{'loss': 0.1743, 'learning_rate': 0.0001850828729281768, 'epoch': 2.46}


 83%|████████▎ | 1550/1860 [5:40:36<57:09, 11.06s/it]  

{'loss': 0.1595, 'learning_rate': 0.0001712707182320442, 'epoch': 2.5}


 85%|████████▍ | 1575/1860 [5:45:14<52:48, 11.12s/it]

{'loss': 0.1781, 'learning_rate': 0.0001574585635359116, 'epoch': 2.54}


 86%|████████▌ | 1600/1860 [5:49:52<48:19, 11.15s/it]

{'loss': 0.1841, 'learning_rate': 0.000143646408839779, 'epoch': 2.58}


 87%|████████▋ | 1625/1860 [5:54:29<43:25, 11.09s/it]

{'loss': 0.1701, 'learning_rate': 0.00012983425414364643, 'epoch': 2.62}


 89%|████████▊ | 1650/1860 [5:59:07<38:41, 11.06s/it]

{'loss': 0.1742, 'learning_rate': 0.0001160220994475138, 'epoch': 2.66}


 90%|█████████ | 1675/1860 [6:03:45<34:17, 11.12s/it]

{'loss': 0.1621, 'learning_rate': 0.00010220994475138122, 'epoch': 2.7}


 91%|█████████▏| 1700/1860 [6:08:23<29:37, 11.11s/it]

{'loss': 0.1998, 'learning_rate': 8.839779005524861e-05, 'epoch': 2.74}


 93%|█████████▎| 1725/1860 [6:13:00<24:53, 11.06s/it]

{'loss': 0.1768, 'learning_rate': 7.458563535911603e-05, 'epoch': 2.78}


 94%|█████████▍| 1750/1860 [6:17:38<20:20, 11.10s/it]

{'loss': 0.1939, 'learning_rate': 6.0773480662983424e-05, 'epoch': 2.82}


 95%|█████████▌| 1775/1860 [6:22:16<15:49, 11.17s/it]

{'loss': 0.185, 'learning_rate': 4.696132596685083e-05, 'epoch': 2.86}


 97%|█████████▋| 1800/1860 [6:26:54<11:08, 11.14s/it]

{'loss': 0.1638, 'learning_rate': 3.3149171270718233e-05, 'epoch': 2.9}


 98%|█████████▊| 1825/1860 [6:31:32<06:28, 11.09s/it]

{'loss': 0.1545, 'learning_rate': 1.9337016574585635e-05, 'epoch': 2.94}


 99%|█████████▉| 1850/1860 [6:36:09<01:51, 11.13s/it]

{'loss': 0.1663, 'learning_rate': 5.524861878453038e-06, 'epoch': 2.98}


                                                     
100%|██████████| 1860/1860 [7:03:29<00:00, 13.66s/it]

{'eval_loss': 0.2585037350654602, 'eval_runtime': 1540.5227, 'eval_samples_per_second': 6.868, 'eval_steps_per_second': 0.108, 'epoch': 3.0}
{'train_runtime': 25409.7802, 'train_samples_per_second': 4.68, 'train_steps_per_second': 0.073, 'train_loss': 0.24917860319537502, 'epoch': 3.0}





TrainOutput(global_step=1860, training_loss=0.24917860319537502, metrics={'train_runtime': 25409.7802, 'train_samples_per_second': 4.68, 'train_steps_per_second': 0.073, 'train_loss': 0.24917860319537502, 'epoch': 3.0})

### 保存 LoRA 模型

In [29]:
model.save_pretrained("models/whisper-large-v2-asr-int8")

### 使用 Pipiline 加载 LoRA 模型，实现自动语音识别任务

In [30]:
test_audio = "data/audio/test_zh.flac"

In [31]:
from transformers import AutomaticSpeechRecognitionPipeline

pipeline = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)

forced_decoder_ids = processor.get_decoder_prompt_ids(language="chinese", task=task)

The model 'PeftModel' is not supported for . Supported models are ['Pop2PianoForConditionalGeneration', 'SeamlessM4TForSpeechToText', 'SeamlessM4Tv2ForSpeechToText', 'SpeechEncoderDecoderModel', 'Speech2TextForConditionalGeneration', 'SpeechT5ForSpeechToText', 'WhisperForConditionalGeneration', 'Data2VecAudioForCTC', 'HubertForCTC', 'MCTCTForCTC', 'SEWForCTC', 'SEWDForCTC', 'UniSpeechForCTC', 'UniSpeechSatForCTC', 'Wav2Vec2ForCTC', 'Wav2Vec2ConformerForCTC', 'WavLMForCTC'].


In [32]:
with torch.cuda.amp.autocast():
    text = pipeline(test_audio, generate_kwargs={"forced_decoder_ids": forced_decoder_ids}, max_new_tokens=255)["text"]

In [33]:
text

'这是一段测试用于Whisper Large V2模型的自动语音识别测试。'

#### Homework 1: 为中文语料的训练过程增加过程评估，观察 Train Loss 和 Validation Loss 变化；
#### Homework 2: LoRA 模型训练完成后，使用测试集进行完整的模型评估

## 评估模型

In [35]:
import evaluate

# 词错误率（WER）是评估ASR模型常用的指标。从 Evaluate加载 WER 指标
metric = evaluate.load("wer")

In [36]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import gc

eval_dataloader = DataLoader(common_voice["test"], batch_size=8, collate_fn=data_collator)

model.eval()

PeftModel(
  (base_model): LoraModel(
    (model): WhisperForConditionalGeneration(
      (model): WhisperModel(
        (encoder): WhisperEncoder(
          (conv1): Conv1d(80, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
          (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
          (embed_positions): Embedding(1500, 1280)
          (layers): ModuleList(
            (0-31): 32 x WhisperEncoderLayer(
              (self_attn): WhisperSdpaAttention(
                (k_proj): Linear8bitLt(in_features=1280, out_features=1280, bias=False)
                (v_proj): lora.Linear8bitLt(
                  (base_layer): Linear8bitLt(in_features=1280, out_features=1280, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=1280, out_features=8, bias=False)
                  )
            

In [37]:
for step, batch in enumerate(tqdm(eval_dataloader)):
    with torch.cuda.amp.autocast():
        with torch.no_grad():
            generated_tokens = (
                model.generate(
                    input_features=batch["input_features"].to("cuda"),
                    decoder_input_ids=batch["labels"][:, :4].to("cuda"),
                    max_new_tokens=255,
                )
                .cpu()
                .numpy()
            )
            labels = batch["labels"].cpu().numpy()
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
            metric.add_batch(
                predictions=decoded_preds,
                references=decoded_labels,
            )
    del generated_tokens, labels, batch
    gc.collect()

100%|██████████| 1323/1323 [1:55:20<00:00,  5.23s/it] 


In [38]:
wer = 100 * metric.compute()
print(f"{wer=}")

wer=56.029106029106025
