Prepare_funasr_data.py数据集预处理生成了
{"key": "2_009_004_split_032", "source": "/root/autodl-tmp/.autodl/data/temp_split_audio/2_009_004_split_032.wav", "source_len": 15.0, "target": "12 438.2562946400315 O2 所有的脸我都觉得都很一般 438.3465203389006 441.7149464300125 O1 然后我我看到你的第一眼我也没啥子感觉 442.344029842798 449.1184760662173 O1 真的可能还有个原因我比较礼貌就是我刚刚认识的人我不会一直盯到别个看那是我一个习惯 449.92298", "target_len": 183}
项目目录
Config.yaml
Train.py
Run.sh
Config.yaml
模型配置
model: iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch
model_path: /root/paraformer/model
max_seq_len: 512
device: cuda
数据集配置
dataset_conf:
train:
manifest_filepath: /root/autodl-tmp/.autodl/data/train.jsonl
data_type: audio
tokenizer_conf:
type: WordpieceTokenizer
vocab_file: /root/paraformer/model/vocab.txt
preprocessor_conf:
type: AudioPreprocessor
sample_rate: 16000
normalize: true
valid:
manifest_filepath: /root/autodl-tmp/.autodl/data/dev.jsonl
data_type: audio
tokenizer_conf:
type: WordpieceTokenizer
vocab_file: /root/paraformer/model/vocab.txt
preprocessor_conf:
type: AudioPreprocessor
sample_rate: 16000
normalize: true
batch_size: 16
num_workers: 4
pin_memory: true
shuffle: true
LoRA配置
lora_config:
r: 8 # LoRA秩
lora_alpha: 32 # LoRA缩放因子
target_modules: ["q_proj", "v_proj"] # 目标模块
lora_dropout: 0.05
bias: "none"
task_type: "CAUSAL_LM"
训练配置
train_conf:
max_epoch: 10
log_interval: 10
eval_interval: 1
save_interval: 1
gradient_accumulation_steps: 1
use_fp16: true
grad_clip: 5.0
find_unused_parameters: true
优化器配置
optim: adam
optim_conf:
lr: 1e-4
betas: [0.9, 0.98]
eps: 1e-9
weight_decay: 0.01
学习率调度器配置
scheduler: warmuplr
scheduler_conf:
warmup_steps: 2000
lr_decay_rate: 0.5
total_steps: -1
输出配置
output_dir: /root/paraformer/outputs/lora_paraformer
tensorboard_dir: /root/paraformer/outputs/lora_paraformer/tensorboard
log_level: INFO
seed: 42
分布式训练配置
use_fsdp: false
backend: nccl
Train.py
#!/usr/bin/env python3
-- encoding: utf-8 --
import os
import sys
import torch
import torch.nn as nn
import hydra
import logging
import time
import argparse
from io import BytesIO
from contextlib import nullcontext
import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.algorithms.join import Join
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from tensorboardX import SummaryWriter
from funasr.train_utils.average_nbest_models import average_checkpoints
from funasr.register import tables
from funasr.optimizers import optim_classes
from funasr.train_utils.trainer import Trainer
from funasr.schedulers import scheduler_classes
from funasr.train_utils.initialize import initialize
from funasr.download.download_model_from_hub import download_model
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.misc import prepare_model_dir
from funasr.train_utils.model_summary import model_summary
from funasr import AutoModel
import functools
from typing import Optional, Dict, Union, List
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
if kwargs.get("debug", False):
import pdb
pdb.set_trace()
assert "model" in kwargs
# 检查模型路径是否存在
model_config = kwargs["model"]
if isinstance(model_config, dict) and "model_path" in model_config:
model_path = model_config["model_path"]
if os.path.exists(model_path):
logging.info(f"使用本地模型路径: {model_path}")
# 如果存在本地模型路径,不需要下载
if "model_conf" not in kwargs:
kwargs["model_conf"] = model_config
else:
logging.warning(f"本地模型路径不存在: {model_path}")
logging.info("尝试从模型中心下载...")
if "model_conf" not in kwargs:
# 使用顶级model_name参数进行下载
if "model_name" in kwargs:
model_name = kwargs["model_name"]
logging.info(f"使用顶级model_name参数: {model_name}")
elif isinstance(model_config, dict) and "model" in model_config:
model_name = model_config["model"]
logging.info(f"从model配置中提取模型名称: {model_name}")
else:
raise ValueError("无法获取模型名称,请在配置文件中设置model_name")
# 创建一个新的kwargs副本,只包含必要的信息
download_kwargs = {
"model": model_name,
"hub": kwargs.get("hub", "ms"),
"is_training": kwargs.get("is_training", True)
}
# 下载模型
downloaded_kwargs = download_model(**download_kwargs)
# 合并下载结果
kwargs.update(downloaded_kwargs)
# 确保model_conf被设置
if "model_conf" not in kwargs:
kwargs["model_conf"] = downloaded_kwargs.get("model_conf", {})
else:
logging.info("使用模型名称从模型中心下载...")
if "model_conf" not in kwargs:
# 使用顶级model_name参数进行下载
if "model_name" in kwargs:
model_name = kwargs["model_name"]
logging.info(f"使用顶级model_name参数: {model_name}")
elif isinstance(model_config, dict) and "model" in model_config:
model_name = model_config["model"]
logging.info(f"从model配置中提取模型名称: {model_name}")
else:
raise ValueError("无法获取模型名称,请在配置文件中设置model_name")
# 创建一个新的kwargs副本,只包含必要的信息
download_kwargs = {
"model": model_name,
"hub": kwargs.get("hub", "ms"),
"is_training": kwargs.get("is_training", True)
}
# 下载模型
downloaded_kwargs = download_model(**download_kwargs)
# 合并下载结果
kwargs.update(downloaded_kwargs)
# 确保model_conf被设置
if "model_conf" not in kwargs:
kwargs["model_conf"] = downloaded_kwargs.get("model_conf", {})
main(**kwargs)
def main(**kwargs):
# set random seed
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
# open tf32
torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if local_rank == 0:
tables.print()
# Check if we are using DDP or FSDP
use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
use_fsdp = kwargs.get("use_fsdp", False)
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
torch.cuda.set_device(local_rank)
logging.info("Build model, frontend, tokenizer")
device = kwargs.get("device", "cuda")
# 准备模型配置
model_name = kwargs.get("model")
model_path = kwargs.get("model_path")
max_seq_len = kwargs.get("max_seq_len", 512)
# 检查模型路径是否存在,如果不存在则自动下载
if model_path and not os.path.exists(model_path) or (os.path.exists(model_path) and not os.listdir(model_path)):
logging.info(f"模型路径 {model_path} 不存在或为空,尝试从模型中心下载...")
try:
from funasr.download.download_model_from_hub import download_model
download_kwargs = {
"model": model_name,
"hub": "ms",
"is_training": True,
"output_dir": model_path
}
download_model(** download_kwargs)
logging.info(f"模型下载成功,保存到 {model_path}")
except Exception as e:
logging.warning(f"模型自动下载失败: {str(e)},将尝试直接从模型中心加载")
# 构建模型配置
model_kwargs = {
"model": model_name,
"model_path": model_path,
"device": "cpu", # 先在CPU上构建,然后再移动到GPU
"max_seq_len": max_seq_len
}
logging.info(f"使用模型配置: {model_kwargs}")
model = AutoModel(**model_kwargs)
# save config.yaml
if (
(use_ddp or use_fsdp)
and dist.get_rank() == 0
or not (use_ddp or use_fsdp)
and local_rank == 0
):
prepare_model_dir(**kwargs)
# parse kwargs
kwargs = model.kwargs
kwargs["device"] = device
tokenizer = kwargs["tokenizer"]
frontend = kwargs["frontend"]
model = model.model
del kwargs["model"]
# freeze_param
freeze_param = kwargs.get("freeze_param", None)
if freeze_param is not None:
if "," in freeze_param:
freeze_param = eval(freeze_param)
if not isinstance(freeze_param, (list, tuple)):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
# LoRA configuration
lora_config = kwargs.get("lora_config", None)
if lora_config is not None:
logging.info("Applying LoRA configuration")
# Add LoRA adapters to the model
from funasr.models.lora import add_lora
model = add_lora(model,** lora_config)
# Mark only LoRA parameters as trainable
mark_only_lora_as_trainable(model)
# Log trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
logging.info(f"Trainable params: {trainable_params} ({trainable_params/total_params:.4f}% of total)")
if local_rank == 0:
logging.info(f"{model_summary(model)}")
if use_ddp:
model = model.cuda(local_rank)
model = DDP(
model,
device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get(
"find_unused_parameters", False
),
)
elif use_fsdp:
def custom_auto_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
min_num_params: int = int(1e5),
) -> bool:
# Calculate number of parameters in the module
unwrapped_params = sum(p.numel() for p in module.parameters() if not p.requires_grad)
is_large = unwrapped_params >= min_num_params
requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
return is_large and requires_grad_uniform
# Configure a custom min_num_params
my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
torch.cuda.set_device(local_rank)
model = FSDP(
model,
auto_wrap_policy=custom_auto_wrap_policy,
mixed_precision=None,
device_id=torch.cuda.current_device(),
)
else:
model = model.to(device=kwargs.get("device", "cuda"))
kwargs["device"] = next(model.parameters()).device
# optim
logging.info("Build optim")
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
# Filter parameters for optimizer (only trainable ones)
trainable_params = [p for p in model.parameters() if p.requires_grad]
optim = optim_class(trainable_params,** kwargs.get("optim_conf"))
# scheduler
logging.info("Build scheduler")
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
# dataset
logging.info("Build dataloader")
dataloader_class = tables.dataloader_classes.get(
kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")
)
dataloader = dataloader_class(**kwargs)
trainer = Trainer(
local_rank=local_rank,
use_ddp=use_ddp,
use_fsdp=use_fsdp,
device=kwargs["device"],
output_dir=kwargs.get("output_dir", "./exp"),** kwargs.get("train_conf"),
)
scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
trainer.resume_checkpoint(
model=model,
optim=optim,
scheduler=scheduler,
scaler=scaler,
)
tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
os.makedirs(tensorboard_dir, exist_ok=True)
try:
writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
except:
writer = None
dataloader_tr, dataloader_val = None, None
for epoch in range(trainer.start_epoch, trainer.max_epoch):
time1 = time.perf_counter()
for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
time_slice_i = time.perf_counter()
dataloader_tr, dataloader_val = dataloader.build_iter(
epoch, data_split_i=data_split_i, start_step=trainer.start_step
)
trainer.train_epoch(
model=model,
optim=optim,
scheduler=scheduler,
scaler=scaler,
dataloader_train=dataloader_tr,
dataloader_val=dataloader_val,
epoch=epoch,
writer=writer,
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
start_step=trainer.start_step,
)
trainer.start_step = 0
device = next(model.parameters()).device
if device.type == "cuda":
with torch.cuda.device(device):
torch.cuda.empty_cache()
time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
logging.info(
f"rank: {local_rank}, "
f"time_escaped_epoch: {time_escaped:.3f} hours, "
f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {dataloader.data_split_num-data_split_i} slices, {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours, "
f"epoch: {trainer.max_epoch - epoch} epochs, {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
)
trainer.start_data_split_i = 0
trainer.validate_epoch(
model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
)
scheduler.step()
trainer.step_in_epoch = 0
# Save LoRA adapters separately if configured
if lora_config is not None and (trainer.rank == 0 or trainer.rank == -1):
lora_save_path = os.path.join(trainer.output_dir, f"epoch_{epoch+1}lora_adapters.bin")
# For DDP, need to get the base model
if use_ddp:
lora_state_dict = {k: v for k, v in model.module.state_dict().items() if "lora" in k}
else:
lora_state_dict = {k: v for k, v in model.state_dict().items() if "lora_" in k}
torch.save(lora_state_dict, lora_save_path)
logging.info(f"Saved LoRA adapters to {lora_save_path}")
trainer.save_checkpoint(
epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
)
time2 = time.perf_counter()
time_escaped = (time2 - time1) / 3600.0
logging.info(
f"rank: {local_rank}, "
f"time_escaped_epoch: {time_escaped:.3f} hours, "
f"estimated to finish {trainer.max_epoch} "
f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
)
trainer.train_acc_avg = 0.0
trainer.train_loss_avg = 0.0
if trainer.rank == 0:
average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
trainer.close()
if name == "main":
main_hydra()
Run.sh
#!/bin/bash
set -e
==============================================
Paraformer LoRA 微调一键启动脚本
==============================================
配置参数
DATA_DIR="/root/autodl-tmp/.autodl/data"
AUDIO_DIR="${DATA_DIR}/audio_files"
TEXT_DIR="${DATA_DIR}/text_files"
OUTPUT_DIR="/root/paraformer/outputs/lora_paraformer"
CONFIG_PATH="/root/paraformer/config.yaml"
TRAIN_SCRIPT="/root/paraformer/train.py"
PREPARE_SCRIPT="/root/paraformer/prepare_funasr_data.py"
MODEL_DIR="/root/paraformer/model"
MODEL_NAME="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
颜色定义
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # 无颜色
检查命令是否存在
check_command() {
if ! command -v $1 &> /dev/null; then
echo -e "${RED}错误: $1 命令未找到,请安装${NC}"
exit 1
fi
}
检查目录是否存在
check_directory() {
if [ ! -d "$1" ]; then
echo -e "${RED}错误: 目录 $1 不存在${NC}"
exit 1
fi
}
检查文件是否存在
check_file() {
if [ ! -f "$1" ]; then
echo -e "${RED}错误: 文件 $1 不存在${NC}"
exit 1
fi
}
检查Python环境
check_python_env() {
echo -e "\n${GREEN}=== 检查Python环境 ==="
check_command "python3"
check_command "pip3"
# 检查必要的Python包
required_packages=("torch" "soundfile" "tqdm" "hydra-core" "tensorboardX" "funasr")
for pkg in "${required_packages[@]}"; do
if ! python3 -c "import $pkg" &> /dev/null; then
echo -e "${YELLOW}警告: $pkg 未安装,正在安装...${NC}"
pip3 install $pkg
fi
done
echo -e "${GREEN}Python环境检查完成${NC}"
}
检查CUDA环境
check_cuda_env() {
echo -e "\n${GREEN}=== 检查CUDA环境 ==="
if ! python3 -c "import torch; print(torch.cuda.is_available())" | grep -q "True"; then
echo -e "${RED}错误: CUDA不可用,请检查您的GPU环境${NC}"
exit 1
fi
NUM_GPUS=$(python3 -c "import torch; print(torch.cuda.device_count())")
echo -e "检测到 $NUM_GPUS 个GPU"
echo -e "${GREEN}CUDA环境检查完成${NC}"
}
检查数据目录结构
check_data_dir() {
echo -e "\n${GREEN}=== 检查数据目录结构 ==="
check_directory "$DATA_DIR"
check_directory "$AUDIO_DIR"
check_directory "$TEXT_DIR"
# 检查音频文件
audio_files=$(find "$AUDIO_DIR" -type f ( -name ".wav" -o -name ".flac" -o -name ".mp3" -o -name ".m4a" -o -name ".ogg" ) | wc -l)
if [ $audio_files -eq 0 ]; then
echo -e "${RED}错误: 在 $AUDIO_DIR 目录下未找到音频文件${NC}"
echo -e "支持的音频格式: .wav, .flac, .mp3, .m4a, .ogg"
exit 1
fi
echo -e "找到音频文件数量: $audio_files"
# 检查文本文件
text_files=$(find "$TEXT_DIR" -type f -name ".txt" | wc -l)
if [ $text_files -eq 0 ]; then
echo -e "${RED}错误: 在 $TEXT_DIR 目录下未找到文本文件(.txt)${NC}"
exit 1
fi
echo -e "找到文本文件数量: $text_files"
echo -e "${GREEN}数据目录结构检查完成${NC}"
}
检查配置文件
check_config() {
echo -e "\n${GREEN}=== 检查配置文件 ==="
check_file "$CONFIG_PATH"
check_file "$TRAIN_SCRIPT"
check_file "$PREPARE_SCRIPT"
# 检查模型目录
if [ ! -d "$MODEL_DIR" ]; then
echo -e "${YELLOW}警告: 模型目录 $MODEL_DIR 不存在,将创建并在训练时自动下载${NC}"
mkdir -p "$MODEL_DIR"
fi
echo -e "${GREEN}配置文件检查完成${NC}"
}
准备数据
prepare_data() {
echo -e "\n${GREEN}=== 开始数据准备 ==="
echo -e "数据目录: $DATA_DIR"
echo -e "音频目录: $AUDIO_DIR"
echo -e "文本目录: $TEXT_DIR"
# 调用数据准备脚本,添加--split参数来生成训练集/开发集/测试集
python3 "$PREPARE_SCRIPT"
--audio-dir "$AUDIO_DIR"
--text-dir "$TEXT_DIR"
--output-dir "$DATA_DIR"
--split
# 检查生成的文件
TRAIN_FILE="${DATA_DIR}/train.jsonl"
DEV_FILE="${DATA_DIR}/dev.jsonl"
TEST_FILE="${DATA_DIR}/test.jsonl"
if [ ! -f "$TRAIN_FILE" ] || [ ! -f "$DEV_FILE" ] || [ ! -f "$TEST_FILE" ]; then
echo -e "${RED}错误: 数据准备失败,未生成必要的JSONL文件${NC}"
exit 1
fi
train_samples=$(wc -l < "$TRAIN_FILE")
dev_samples=$(wc -l < "$DEV_FILE")
test_samples=$(wc -l < "$TEST_FILE")
echo -e "训练集样本数: $train_samples"
echo -e "开发集样本数: $dev_samples"
echo -e "测试集样本数: $test_samples"
echo -e "${GREEN}数据准备完成${NC}"
}
开始训练
start_training() {
echo -e "\n${GREEN}=== 开始训练 ==="
echo -e "输出目录: $OUTPUT_DIR"
echo -e "配置文件: $CONFIG_PATH"
echo -e "模型名称: $MODEL_NAME"
echo -e "模型目录: $MODEL_DIR"
# 创建输出目录
mkdir -p "$OUTPUT_DIR"
# 训练命令
TRAIN_CMD="python3 $TRAIN_SCRIPT
--config-path $(dirname "$CONFIG_PATH")
--config-name $(basename "$CONFIG_PATH" .yaml)
output_dir=$OUTPUT_DIR
model.device=cuda"
echo -e "训练命令: $TRAIN_CMD"
# 检查GPU数量
if [ $NUM_GPUS -gt 1 ]; then
echo -e "使用多GPU分布式训练($NUM_GPUS 个GPU)"
torchrun --nproc_per_node=$NUM_GPUS $TRAIN_SCRIPT
--config-path $(dirname "$CONFIG_PATH")
--config-name $(basename "$CONFIG_PATH" .yaml)
output_dir=$OUTPUT_DIR
model.device=cuda
else
echo -e "使用单GPU训练"
$TRAIN_CMD
fi
}
训练完成
finish_training() {
echo -e "\n${GREEN}=== 训练完成 ==="
echo -e "训练结果保存在: ${OUTPUT_DIR}"
# 检查生成的模型文件
lora_files=$(find "$OUTPUT_DIR" -name "lora_adapters.bin" | wc -l)
if [ $lora_files -eq 0 ]; then
echo -e "${YELLOW}警告: 未找到LoRA适配器文件,训练可能未完成${NC}"
else
echo -e "生成的LoRA适配器文件: ${OUTPUT_DIR}/epoch__lora_adapters.bin"
fi
echo -e "\n${GREEN}训练流程结束${NC}"
echo -e "使用方法:"
echo -e "1. 应用LoRA适配器: 将生成的lora_adapters.bin文件与预训练模型结合使用"
echo -e "2. 推理: 使用FunASR的推理接口加载微调后的模型"
echo -e "3. 评估: 使用验证集评估模型性能"
}
主函数
main() {
clear
echo -e "${GREEN}========================================"
echo -e " Paraformer LoRA 微调"
echo -e "========================================"
echo -e " 数据目录: $DATA_DIR"
echo -e " 输出目录: $OUTPUT_DIR"
echo -e "========================================${NC}"
# 执行检查
check_python_env
check_cuda_env
check_data_dir
check_config
# 确认开始
echo -e "\n${YELLOW}所有检查已通过,准备开始训练..."
read -p "是否继续?(y/n) " -n 1 -r
echo -e "\n${NC}"
if [[ $REPLY =~ ^[Yy]$ ]]; then
prepare_data
start_training
finish_training
else
echo -e "${YELLOW}训练已取消${NC}"
exit 0
fi
}
显示帮助信息
show_help() {
echo -e "${GREEN}Paraformer LoRA 微调一键启动脚本${NC}"
echo -e "用法: $0 [选项]"
echo -e "选项:"
echo -e " --help 显示帮助信息"
echo -e " --prepare 仅执行数据准备"
echo -e " --train 仅执行训练(需先准备数据)"
echo -e " --check 仅执行环境检查"
}
解析命令行参数
parse_args() {
case "$1" in
--help)
show_help
exit 0
;;
--prepare)
check_python_env
check_data_dir
check_config
prepare_data
exit 0
;;
--train)
check_python_env
check_cuda_env
check_config
start_training
finish_training
exit 0
;;
--check)
check_python_env
check_cuda_env
check_data_dir
check_config
echo -e "\n${GREEN}所有检查通过!${NC}"
exit 0
;;
"")
main
;;
*)
echo -e "${RED}错误: 未知选项 $1${NC}"
show_help
exit 1
;;
esac
}
开始执行
parse_args "$@"
Prepare_funasr_data.py数据集预处理生成了
{"key": "2_009_004_split_032", "source": "/root/autodl-tmp/.autodl/data/temp_split_audio/2_009_004_split_032.wav", "source_len": 15.0, "target": "12 438.2562946400315 O2 所有的脸我都觉得都很一般 438.3465203389006 441.7149464300125 O1 然后我我看到你的第一眼我也没啥子感觉 442.344029842798 449.1184760662173 O1 真的可能还有个原因我比较礼貌就是我刚刚认识的人我不会一直盯到别个看那是我一个习惯 449.92298", "target_len": 183}
项目目录
Config.yaml
Train.py
Run.sh
Config.yaml
模型配置
model: iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch
model_path: /root/paraformer/model
max_seq_len: 512
device: cuda
数据集配置
dataset_conf:
train:
manifest_filepath: /root/autodl-tmp/.autodl/data/train.jsonl
data_type: audio
tokenizer_conf:
type: WordpieceTokenizer
vocab_file: /root/paraformer/model/vocab.txt
preprocessor_conf:
type: AudioPreprocessor
sample_rate: 16000
normalize: true
valid:
manifest_filepath: /root/autodl-tmp/.autodl/data/dev.jsonl
data_type: audio
tokenizer_conf:
type: WordpieceTokenizer
vocab_file: /root/paraformer/model/vocab.txt
preprocessor_conf:
type: AudioPreprocessor
sample_rate: 16000
normalize: true
batch_size: 16
num_workers: 4
pin_memory: true
shuffle: true
LoRA配置
lora_config:
r: 8 # LoRA秩
lora_alpha: 32 # LoRA缩放因子
target_modules: ["q_proj", "v_proj"] # 目标模块
lora_dropout: 0.05
bias: "none"
task_type: "CAUSAL_LM"
训练配置
train_conf:
max_epoch: 10
log_interval: 10
eval_interval: 1
save_interval: 1
gradient_accumulation_steps: 1
use_fp16: true
grad_clip: 5.0
find_unused_parameters: true
优化器配置
optim: adam
optim_conf:
lr: 1e-4
betas: [0.9, 0.98]
eps: 1e-9
weight_decay: 0.01
学习率调度器配置
scheduler: warmuplr
scheduler_conf:
warmup_steps: 2000
lr_decay_rate: 0.5
total_steps: -1
输出配置
output_dir: /root/paraformer/outputs/lora_paraformer
tensorboard_dir: /root/paraformer/outputs/lora_paraformer/tensorboard
log_level: INFO
seed: 42
分布式训练配置
use_fsdp: false
backend: nccl
Train.py
#!/usr/bin/env python3
-- encoding: utf-8 --
import os
import sys
import torch
import torch.nn as nn
import hydra
import logging
import time
import argparse
from io import BytesIO
from contextlib import nullcontext
import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.algorithms.join import Join
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from tensorboardX import SummaryWriter
from funasr.train_utils.average_nbest_models import average_checkpoints
from funasr.register import tables
from funasr.optimizers import optim_classes
from funasr.train_utils.trainer import Trainer
from funasr.schedulers import scheduler_classes
from funasr.train_utils.initialize import initialize
from funasr.download.download_model_from_hub import download_model
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.misc import prepare_model_dir
from funasr.train_utils.model_summary import model_summary
from funasr import AutoModel
import functools
from typing import Optional, Dict, Union, List
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
if kwargs.get("debug", False):
import pdb
pdb.set_trace()
assert "model" in kwargs
# 检查模型路径是否存在
model_config = kwargs["model"]
if isinstance(model_config, dict) and "model_path" in model_config:
model_path = model_config["model_path"]
if os.path.exists(model_path):
logging.info(f"使用本地模型路径: {model_path}")
# 如果存在本地模型路径,不需要下载
if "model_conf" not in kwargs:
kwargs["model_conf"] = model_config
else:
logging.warning(f"本地模型路径不存在: {model_path}")
logging.info("尝试从模型中心下载...")
if "model_conf" not in kwargs:
# 使用顶级model_name参数进行下载
if "model_name" in kwargs:
model_name = kwargs["model_name"]
logging.info(f"使用顶级model_name参数: {model_name}")
elif isinstance(model_config, dict) and "model" in model_config:
model_name = model_config["model"]
logging.info(f"从model配置中提取模型名称: {model_name}")
else:
raise ValueError("无法获取模型名称,请在配置文件中设置model_name")
# 创建一个新的kwargs副本,只包含必要的信息
download_kwargs = {
"model": model_name,
"hub": kwargs.get("hub", "ms"),
"is_training": kwargs.get("is_training", True)
}
# 下载模型
downloaded_kwargs = download_model(**download_kwargs)
# 合并下载结果
kwargs.update(downloaded_kwargs)
# 确保model_conf被设置
if "model_conf" not in kwargs:
kwargs["model_conf"] = downloaded_kwargs.get("model_conf", {})
else:
logging.info("使用模型名称从模型中心下载...")
if "model_conf" not in kwargs:
# 使用顶级model_name参数进行下载
if "model_name" in kwargs:
model_name = kwargs["model_name"]
logging.info(f"使用顶级model_name参数: {model_name}")
elif isinstance(model_config, dict) and "model" in model_config:
model_name = model_config["model"]
logging.info(f"从model配置中提取模型名称: {model_name}")
else:
raise ValueError("无法获取模型名称,请在配置文件中设置model_name")
# 创建一个新的kwargs副本,只包含必要的信息
download_kwargs = {
"model": model_name,
"hub": kwargs.get("hub", "ms"),
"is_training": kwargs.get("is_training", True)
}
# 下载模型
downloaded_kwargs = download_model(**download_kwargs)
# 合并下载结果
kwargs.update(downloaded_kwargs)
# 确保model_conf被设置
if "model_conf" not in kwargs:
kwargs["model_conf"] = downloaded_kwargs.get("model_conf", {})
main(**kwargs)
def main(**kwargs):
# set random seed
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
# open tf32
torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if local_rank == 0:
tables.print()
# Check if we are using DDP or FSDP
use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
use_fsdp = kwargs.get("use_fsdp", False)
if use_ddp or use_fsdp:
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
torch.cuda.set_device(local_rank)
logging.info("Build model, frontend, tokenizer")
device = kwargs.get("device", "cuda")
# 准备模型配置
model_name = kwargs.get("model")
model_path = kwargs.get("model_path")
max_seq_len = kwargs.get("max_seq_len", 512)
# 检查模型路径是否存在,如果不存在则自动下载
if model_path and not os.path.exists(model_path) or (os.path.exists(model_path) and not os.listdir(model_path)):
logging.info(f"模型路径 {model_path} 不存在或为空,尝试从模型中心下载...")
try:
from funasr.download.download_model_from_hub import download_model
download_kwargs = {
"model": model_name,
"hub": "ms",
"is_training": True,
"output_dir": model_path
}
download_model(** download_kwargs)
logging.info(f"模型下载成功,保存到 {model_path}")
except Exception as e:
logging.warning(f"模型自动下载失败: {str(e)},将尝试直接从模型中心加载")
# 构建模型配置
model_kwargs = {
"model": model_name,
"model_path": model_path,
"device": "cpu", # 先在CPU上构建,然后再移动到GPU
"max_seq_len": max_seq_len
}
logging.info(f"使用模型配置: {model_kwargs}")
model = AutoModel(**model_kwargs)
# save config.yaml
if (
(use_ddp or use_fsdp)
and dist.get_rank() == 0
or not (use_ddp or use_fsdp)
and local_rank == 0
):
prepare_model_dir(**kwargs)
# parse kwargs
kwargs = model.kwargs
kwargs["device"] = device
tokenizer = kwargs["tokenizer"]
frontend = kwargs["frontend"]
model = model.model
del kwargs["model"]
# freeze_param
freeze_param = kwargs.get("freeze_param", None)
if freeze_param is not None:
if "," in freeze_param:
freeze_param = eval(freeze_param)
if not isinstance(freeze_param, (list, tuple)):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
# LoRA configuration
lora_config = kwargs.get("lora_config", None)
if lora_config is not None:
logging.info("Applying LoRA configuration")
# Add LoRA adapters to the model
from funasr.models.lora import add_lora
model = add_lora(model,** lora_config)
# Mark only LoRA parameters as trainable
mark_only_lora_as_trainable(model)
# Log trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
logging.info(f"Trainable params: {trainable_params} ({trainable_params/total_params:.4f}% of total)")
if local_rank == 0:
logging.info(f"{model_summary(model)}")
if use_ddp:
model = model.cuda(local_rank)
model = DDP(
model,
device_ids=[local_rank],
find_unused_parameters=kwargs.get("train_conf", {}).get(
"find_unused_parameters", False
),
)
elif use_fsdp:
def custom_auto_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
min_num_params: int = int(1e5),
) -> bool:
# Calculate number of parameters in the module
unwrapped_params = sum(p.numel() for p in module.parameters() if not p.requires_grad)
is_large = unwrapped_params >= min_num_params
requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
return is_large and requires_grad_uniform
# Configure a custom
min_num_paramsmy_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
torch.cuda.set_device(local_rank)
model = FSDP(
model,
auto_wrap_policy=custom_auto_wrap_policy,
mixed_precision=None,
device_id=torch.cuda.current_device(),
)
else:
model = model.to(device=kwargs.get("device", "cuda"))
kwargs["device"] = next(model.parameters()).device
# optim
logging.info("Build optim")
optim = kwargs.get("optim", "adam")
assert optim in optim_classes
optim_class = optim_classes.get(optim)
# Filter parameters for optimizer (only trainable ones)
trainable_params = [p for p in model.parameters() if p.requires_grad]
optim = optim_class(trainable_params,** kwargs.get("optim_conf"))
# scheduler
logging.info("Build scheduler")
scheduler = kwargs.get("scheduler", "warmuplr")
assert scheduler in scheduler_classes
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
# dataset
logging.info("Build dataloader")
dataloader_class = tables.dataloader_classes.get(
kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")
)
dataloader = dataloader_class(**kwargs)
trainer = Trainer(
local_rank=local_rank,
use_ddp=use_ddp,
use_fsdp=use_fsdp,
device=kwargs["device"],
output_dir=kwargs.get("output_dir", "./exp"),** kwargs.get("train_conf"),
)
scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
trainer.resume_checkpoint(
model=model,
optim=optim,
scheduler=scheduler,
scaler=scaler,
)
tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
os.makedirs(tensorboard_dir, exist_ok=True)
try:
writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
except:
writer = None
dataloader_tr, dataloader_val = None, None
for epoch in range(trainer.start_epoch, trainer.max_epoch):
time1 = time.perf_counter()
for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
time_slice_i = time.perf_counter()
dataloader_tr, dataloader_val = dataloader.build_iter(
epoch, data_split_i=data_split_i, start_step=trainer.start_step
)
trainer.train_epoch(
model=model,
optim=optim,
scheduler=scheduler,
scaler=scaler,
dataloader_train=dataloader_tr,
dataloader_val=dataloader_val,
epoch=epoch,
writer=writer,
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
start_step=trainer.start_step,
)
trainer.start_step = 0
device = next(model.parameters()).device
if device.type == "cuda":
with torch.cuda.device(device):
torch.cuda.empty_cache()
time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
logging.info(
f"rank: {local_rank}, "
f"time_escaped_epoch: {time_escaped:.3f} hours, "
f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {dataloader.data_split_num-data_split_i} slices, {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours, "
f"epoch: {trainer.max_epoch - epoch} epochs, {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
)
trainer.start_data_split_i = 0
trainer.validate_epoch(
model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
)
scheduler.step()
trainer.step_in_epoch = 0
# Save LoRA adapters separately if configured
if lora_config is not None and (trainer.rank == 0 or trainer.rank == -1):
lora_save_path = os.path.join(trainer.output_dir, f"epoch_{epoch+1}lora_adapters.bin")
# For DDP, need to get the base model
if use_ddp:
lora_state_dict = {k: v for k, v in model.module.state_dict().items() if "lora" in k}
else:
lora_state_dict = {k: v for k, v in model.state_dict().items() if "lora_" in k}
torch.save(lora_state_dict, lora_save_path)
logging.info(f"Saved LoRA adapters to {lora_save_path}")
trainer.save_checkpoint(
epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
)
time2 = time.perf_counter()
time_escaped = (time2 - time1) / 3600.0
logging.info(
f"rank: {local_rank}, "
f"time_escaped_epoch: {time_escaped:.3f} hours, "
f"estimated to finish {trainer.max_epoch} "
f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
)
trainer.train_acc_avg = 0.0
trainer.train_loss_avg = 0.0
if trainer.rank == 0:
average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
trainer.close()
if name == "main":
main_hydra()
Run.sh
#!/bin/bash
set -e
==============================================
Paraformer LoRA 微调一键启动脚本
==============================================
配置参数
DATA_DIR="/root/autodl-tmp/.autodl/data"
AUDIO_DIR="${DATA_DIR}/audio_files"
TEXT_DIR="${DATA_DIR}/text_files"
OUTPUT_DIR="/root/paraformer/outputs/lora_paraformer"
CONFIG_PATH="/root/paraformer/config.yaml"
TRAIN_SCRIPT="/root/paraformer/train.py"
PREPARE_SCRIPT="/root/paraformer/prepare_funasr_data.py"
MODEL_DIR="/root/paraformer/model"
MODEL_NAME="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
颜色定义
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # 无颜色
检查命令是否存在
check_command() {$1 命令未找到,请安装$ {NC}"
if ! command -v $1 &> /dev/null; then
echo -e "${RED}错误:
exit 1
fi
}
检查目录是否存在
check_directory() {$1 不存在$ {NC}"
if [ ! -d "$1" ]; then
echo -e "${RED}错误: 目录
exit 1
fi
}
检查文件是否存在
check_file() {$1 不存在$ {NC}"
if [ ! -f "$1" ]; then
echo -e "${RED}错误: 文件
exit 1
fi
}
检查Python环境
check_python_env() {
$pkg 未安装,正在安装...$ {NC}"
echo -e "\n${GREEN}=== 检查Python环境 ==="
check_command "python3"
check_command "pip3"
# 检查必要的Python包
required_packages=("torch" "soundfile" "tqdm" "hydra-core" "tensorboardX" "funasr")
for pkg in "${required_packages[@]}"; do
if ! python3 -c "import $pkg" &> /dev/null; then
echo -e "${YELLOW}警告:
pip3 install $pkg
fi
done
echo -e "${GREEN}Python环境检查完成${NC}"
}
检查CUDA环境
check_cuda_env() {
echo -e "\n${GREEN}=== 检查CUDA环境 ==="
if ! python3 -c "import torch; print(torch.cuda.is_available())" | grep -q "True"; then
echo -e "${RED}错误: CUDA不可用,请检查您的GPU环境${NC}"
exit 1
fi
NUM_GPUS=$(python3 -c "import torch; print(torch.cuda.device_count())")
echo -e "检测到 $NUM_GPUS 个GPU"
echo -e "${GREEN}CUDA环境检查完成${NC}"
}
检查数据目录结构
check_data_dir() {
$TEXT_DIR 目录下未找到文本文件(.txt)$ {NC}"
echo -e "\n${GREEN}=== 检查数据目录结构 ==="
check_directory "$DATA_DIR"
check_directory "$AUDIO_DIR"
check_directory "$TEXT_DIR"
# 检查音频文件
audio_files=$(find "$AUDIO_DIR" -type f ( -name ".wav" -o -name ".flac" -o -name ".mp3" -o -name ".m4a" -o -name ".ogg" ) | wc -l)
if [ $audio_files -eq 0 ]; then
echo -e "${RED}错误: 在 $AUDIO_DIR 目录下未找到音频文件${NC}"
echo -e "支持的音频格式: .wav, .flac, .mp3, .m4a, .ogg"
exit 1
fi
echo -e "找到音频文件数量: $audio_files"
# 检查文本文件
text_files=$(find "$TEXT_DIR" -type f -name ".txt" | wc -l)
if [ $text_files -eq 0 ]; then
echo -e "${RED}错误: 在
exit 1
fi
echo -e "找到文本文件数量: $text_files"
echo -e "${GREEN}数据目录结构检查完成${NC}"
}
检查配置文件
check_config() {
$MODEL_DIR 不存在,将创建并在训练时自动下载$ {NC}"
echo -e "\n${GREEN}=== 检查配置文件 ==="
check_file "$CONFIG_PATH"
check_file "$TRAIN_SCRIPT"
check_file "$PREPARE_SCRIPT"
# 检查模型目录
if [ ! -d "$MODEL_DIR" ]; then
echo -e "${YELLOW}警告: 模型目录
mkdir -p "$MODEL_DIR"
fi
echo -e "${GREEN}配置文件检查完成${NC}"
}
准备数据
prepare_data() {
echo -e "\n${GREEN}=== 开始数据准备 ==="
echo -e "数据目录: $DATA_DIR"
echo -e "音频目录: $AUDIO_DIR"
echo -e "文本目录: $TEXT_DIR"
# 调用数据准备脚本,添加--split参数来生成训练集/开发集/测试集
python3 "$PREPARE_SCRIPT"
--audio-dir "$AUDIO_DIR"
--text-dir "$TEXT_DIR"
--output-dir "$DATA_DIR"
--split
# 检查生成的文件
TRAIN_FILE="${DATA_DIR}/train.jsonl"
DEV_FILE="${DATA_DIR}/dev.jsonl"
TEST_FILE="${DATA_DIR}/test.jsonl"
if [ ! -f "$TRAIN_FILE" ] || [ ! -f "$DEV_FILE" ] || [ ! -f "$TEST_FILE" ]; then
echo -e "${RED}错误: 数据准备失败,未生成必要的JSONL文件${NC}"
exit 1
fi
train_samples=$(wc -l < "$TRAIN_FILE")
dev_samples=$(wc -l < "$DEV_FILE")
test_samples=$(wc -l < "$TEST_FILE")
echo -e "训练集样本数: $train_samples"
echo -e "开发集样本数: $dev_samples"
echo -e "测试集样本数: $test_samples"
echo -e "${GREEN}数据准备完成${NC}"
}
开始训练
start_training() {
echo -e "\n${GREEN}=== 开始训练 ==="
echo -e "输出目录: $OUTPUT_DIR"
echo -e "配置文件: $CONFIG_PATH"
echo -e "模型名称: $MODEL_NAME"
echo -e "模型目录: $MODEL_DIR"
# 创建输出目录
mkdir -p "$OUTPUT_DIR"
# 训练命令
TRAIN_CMD="python3 $TRAIN_SCRIPT
--config-path $(dirname "$CONFIG_PATH")
--config-name $(basename "$CONFIG_PATH" .yaml)
output_dir=$OUTPUT_DIR
model.device=cuda"
echo -e "训练命令: $TRAIN_CMD"
# 检查GPU数量
if [ $NUM_GPUS -gt 1 ]; then
echo -e "使用多GPU分布式训练($NUM_GPUS 个GPU)"
torchrun --nproc_per_node=$NUM_GPUS $TRAIN_SCRIPT
--config-path $(dirname "$CONFIG_PATH")
--config-name $(basename "$CONFIG_PATH" .yaml)
output_dir=$OUTPUT_DIR
model.device=cuda
else
echo -e "使用单GPU训练"
$TRAIN_CMD
fi
}
训练完成
finish_training() {
echo -e "\n${GREEN}=== 训练完成 ==="
echo -e "训练结果保存在: ${OUTPUT_DIR}"
# 检查生成的模型文件
lora_files=$(find "$OUTPUT_DIR" -name "lora_adapters.bin" | wc -l)
if [ $lora_files -eq 0 ]; then
echo -e "${YELLOW}警告: 未找到LoRA适配器文件,训练可能未完成${NC}"
else
echo -e "生成的LoRA适配器文件: ${OUTPUT_DIR}/epoch__lora_adapters.bin"
fi
echo -e "\n${GREEN}训练流程结束${NC}"
echo -e "使用方法:"
echo -e "1. 应用LoRA适配器: 将生成的lora_adapters.bin文件与预训练模型结合使用"
echo -e "2. 推理: 使用FunASR的推理接口加载微调后的模型"
echo -e "3. 评估: 使用验证集评估模型性能"
}
主函数
main() {
$REPLY =~ ^[Yy]$ ]]; then
clear
echo -e "${GREEN}========================================"
echo -e " Paraformer LoRA 微调"
echo -e "========================================"
echo -e " 数据目录: $DATA_DIR"
echo -e " 输出目录: $OUTPUT_DIR"
echo -e "========================================${NC}"
# 执行检查
check_python_env
check_cuda_env
check_data_dir
check_config
# 确认开始
echo -e "\n${YELLOW}所有检查已通过,准备开始训练..."
read -p "是否继续?(y/n) " -n 1 -r
echo -e "\n${NC}"
if [[
prepare_data
start_training
finish_training
else
echo -e "${YELLOW}训练已取消${NC}"
exit 0
fi
}
显示帮助信息
show_help() {
echo -e "${GREEN}Paraformer LoRA 微调一键启动脚本${NC}"
echo -e "用法: $0 [选项]"
echo -e "选项:"
echo -e " --help 显示帮助信息"
echo -e " --prepare 仅执行数据准备"
echo -e " --train 仅执行训练(需先准备数据)"
echo -e " --check 仅执行环境检查"
}
解析命令行参数
parse_args() {$1$ {NC}"
case "$1" in
--help)
show_help
exit 0
;;
--prepare)
check_python_env
check_data_dir
check_config
prepare_data
exit 0
;;
--train)
check_python_env
check_cuda_env
check_config
start_training
finish_training
exit 0
;;
--check)
check_python_env
check_cuda_env
check_data_dir
check_config
echo -e "\n${GREEN}所有检查通过!${NC}"
exit 0
;;
"")
main
;;
*)
echo -e "${RED}错误: 未知选项
show_help
exit 1
;;
esac
}
开始执行
parse_args "$@"