Skip to content

怎么用方言数据集lora微调训练paraformer模型 #2731

@aobenhaikai

Description

@aobenhaikai

Prepare_funasr_data.py数据集预处理生成了
{"key": "2_009_004_split_032", "source": "/root/autodl-tmp/.autodl/data/temp_split_audio/2_009_004_split_032.wav", "source_len": 15.0, "target": "12 438.2562946400315 O2 所有的脸我都觉得都很一般 438.3465203389006 441.7149464300125 O1 然后我我看到你的第一眼我也没啥子感觉 442.344029842798 449.1184760662173 O1 真的可能还有个原因我比较礼貌就是我刚刚认识的人我不会一直盯到别个看那是我一个习惯 449.92298", "target_len": 183}

项目目录
Config.yaml
Train.py
Run.sh

Config.yaml

模型配置

model: iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch
model_path: /root/paraformer/model
max_seq_len: 512
device: cuda

数据集配置

dataset_conf:
  train:
    manifest_filepath: /root/autodl-tmp/.autodl/data/train.jsonl
    data_type: audio
    tokenizer_conf:
      type: WordpieceTokenizer
      vocab_file: /root/paraformer/model/vocab.txt
    preprocessor_conf:
      type: AudioPreprocessor
      sample_rate: 16000
      normalize: true
  valid:
    manifest_filepath: /root/autodl-tmp/.autodl/data/dev.jsonl
    data_type: audio
    tokenizer_conf:
      type: WordpieceTokenizer
      vocab_file: /root/paraformer/model/vocab.txt
    preprocessor_conf:
      type: AudioPreprocessor
      sample_rate: 16000
      normalize: true
  batch_size: 16
  num_workers: 4
  pin_memory: true
  shuffle: true

LoRA配置

lora_config:
  r: 8  # LoRA秩
  lora_alpha: 32  # LoRA缩放因子
  target_modules: ["q_proj", "v_proj"]  # 目标模块
  lora_dropout: 0.05
  bias: "none"
  task_type: "CAUSAL_LM"

训练配置

train_conf:
  max_epoch: 10
  log_interval: 10
  eval_interval: 1
  save_interval: 1
  gradient_accumulation_steps: 1
  use_fp16: true
  grad_clip: 5.0
  find_unused_parameters: true

优化器配置

optim: adam
optim_conf:
  lr: 1e-4
  betas: [0.9, 0.98]
  eps: 1e-9
  weight_decay: 0.01

学习率调度器配置

scheduler: warmuplr
scheduler_conf:
  warmup_steps: 2000
  lr_decay_rate: 0.5
  total_steps: -1

输出配置

output_dir: /root/paraformer/outputs/lora_paraformer
tensorboard_dir: /root/paraformer/outputs/lora_paraformer/tensorboard
log_level: INFO
seed: 42

分布式训练配置

use_fsdp: false
backend: nccl

Train.py
#!/usr/bin/env python3

-- encoding: utf-8 --

import os
import sys
import torch
import torch.nn as nn
import hydra
import logging
import time
import argparse
from io import BytesIO
from contextlib import nullcontext
import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.algorithms.join import Join
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from tensorboardX import SummaryWriter
from funasr.train_utils.average_nbest_models import average_checkpoints
from funasr.register import tables
from funasr.optimizers import optim_classes
from funasr.train_utils.trainer import Trainer
from funasr.schedulers import scheduler_classes
from funasr.train_utils.initialize import initialize
from funasr.download.download_model_from_hub import download_model
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.misc import prepare_model_dir
from funasr.train_utils.model_summary import model_summary
from funasr import AutoModel
import functools
from typing import Optional, Dict, Union, List

@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
    if kwargs.get("debug", False):
        import pdb
        pdb.set_trace()
    assert "model" in kwargs
   
    # 检查模型路径是否存在
    model_config = kwargs["model"]
    if isinstance(model_config, dict) and "model_path" in model_config:
        model_path = model_config["model_path"]
        if os.path.exists(model_path):
            logging.info(f"使用本地模型路径: {model_path}")
            # 如果存在本地模型路径,不需要下载
            if "model_conf" not in kwargs:
                kwargs["model_conf"] = model_config
        else:
            logging.warning(f"本地模型路径不存在: {model_path}")
            logging.info("尝试从模型中心下载...")
            if "model_conf" not in kwargs:
                # 使用顶级model_name参数进行下载
                if "model_name" in kwargs:
                    model_name = kwargs["model_name"]
                    logging.info(f"使用顶级model_name参数: {model_name}")
                elif isinstance(model_config, dict) and "model" in model_config:
                    model_name = model_config["model"]
                    logging.info(f"从model配置中提取模型名称: {model_name}")
                else:
                    raise ValueError("无法获取模型名称,请在配置文件中设置model_name")
               
                # 创建一个新的kwargs副本,只包含必要的信息
                download_kwargs = {
                    "model": model_name,
                    "hub": kwargs.get("hub", "ms"),
                    "is_training": kwargs.get("is_training", True)
                }
                # 下载模型
                downloaded_kwargs = download_model(**download_kwargs)
                # 合并下载结果
                kwargs.update(downloaded_kwargs)
                # 确保model_conf被设置
                if "model_conf" not in kwargs:
                    kwargs["model_conf"] = downloaded_kwargs.get("model_conf", {})
    else:
        logging.info("使用模型名称从模型中心下载...")
        if "model_conf" not in kwargs:
            # 使用顶级model_name参数进行下载
            if "model_name" in kwargs:
                model_name = kwargs["model_name"]
                logging.info(f"使用顶级model_name参数: {model_name}")
            elif isinstance(model_config, dict) and "model" in model_config:
                model_name = model_config["model"]
                logging.info(f"从model配置中提取模型名称: {model_name}")
            else:
                raise ValueError("无法获取模型名称,请在配置文件中设置model_name")
           
            # 创建一个新的kwargs副本,只包含必要的信息
            download_kwargs = {
                "model": model_name,
                "hub": kwargs.get("hub", "ms"),
                "is_training": kwargs.get("is_training", True)
            }
            # 下载模型
            downloaded_kwargs = download_model(**download_kwargs)
            # 合并下载结果
            kwargs.update(downloaded_kwargs)
            # 确保model_conf被设置
            if "model_conf" not in kwargs:
                kwargs["model_conf"] = downloaded_kwargs.get("model_conf", {})
   
    main(**kwargs)

def main(**kwargs):
    # set random seed
    set_all_random_seed(kwargs.get("seed", 0))
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    # open tf32
    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)

    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if local_rank == 0:
        tables.print()

    # Check if we are using DDP or FSDP
    use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
    use_fsdp = kwargs.get("use_fsdp", False)

    if use_ddp or use_fsdp:
        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
        torch.cuda.set_device(local_rank)

    logging.info("Build model, frontend, tokenizer")
    device = kwargs.get("device", "cuda")
   
    # 准备模型配置
    model_name = kwargs.get("model")
    model_path = kwargs.get("model_path")
    max_seq_len = kwargs.get("max_seq_len", 512)
   
    # 检查模型路径是否存在,如果不存在则自动下载
    if model_path and not os.path.exists(model_path) or (os.path.exists(model_path) and not os.listdir(model_path)):
        logging.info(f"模型路径 {model_path} 不存在或为空,尝试从模型中心下载...")
        try:
            from funasr.download.download_model_from_hub import download_model
            download_kwargs = {
                "model": model_name,
                "hub": "ms",
                "is_training": True,
                "output_dir": model_path
            }
            download_model(** download_kwargs)
            logging.info(f"模型下载成功,保存到 {model_path}")
        except Exception as e:
            logging.warning(f"模型自动下载失败: {str(e)},将尝试直接从模型中心加载")
   
    # 构建模型配置
    model_kwargs = {
        "model": model_name,
        "model_path": model_path,
        "device": "cpu",  # 先在CPU上构建,然后再移动到GPU
        "max_seq_len": max_seq_len
    }
   
    logging.info(f"使用模型配置: {model_kwargs}")
    model = AutoModel(**model_kwargs)

    # save config.yaml
    if (
        (use_ddp or use_fsdp)
        and dist.get_rank() == 0
        or not (use_ddp or use_fsdp)
        and local_rank == 0
    ):
        prepare_model_dir(**kwargs)

    # parse kwargs
    kwargs = model.kwargs
    kwargs["device"] = device
    tokenizer = kwargs["tokenizer"]
    frontend = kwargs["frontend"]
    model = model.model
    del kwargs["model"]

    # freeze_param
    freeze_param = kwargs.get("freeze_param", None)
    if freeze_param is not None:
        if "," in freeze_param:
            freeze_param = eval(freeze_param)
        if not isinstance(freeze_param, (list, tuple)):
            freeze_param = (freeze_param,)
        logging.info("freeze_param is not None: %s", freeze_param)
        for t in freeze_param:
            for k, p in model.named_parameters():
                if k.startswith(t + ".") or k == t:
                    logging.info(f"Setting {k}.requires_grad = False")
                    p.requires_grad = False

    # LoRA configuration
    lora_config = kwargs.get("lora_config", None)
    if lora_config is not None:
        logging.info("Applying LoRA configuration")
        # Add LoRA adapters to the model
        from funasr.models.lora import add_lora
        model = add_lora(model,** lora_config)
       
        # Mark only LoRA parameters as trainable
        mark_only_lora_as_trainable(model)
       
        # Log trainable parameters
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        logging.info(f"Trainable params: {trainable_params} ({trainable_params/total_params:.4f}% of total)")

    if local_rank == 0:
        logging.info(f"{model_summary(model)}")

    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(
            model,
            device_ids=[local_rank],
            find_unused_parameters=kwargs.get("train_conf", {}).get(
                "find_unused_parameters", False
            ),
        )
    elif use_fsdp:
        def custom_auto_wrap_policy(
            module: nn.Module,
            recurse: bool,
            nonwrapped_numel: int,
            min_num_params: int = int(1e5),
        ) -> bool:
            # Calculate number of parameters in the module
            unwrapped_params = sum(p.numel() for p in module.parameters() if not p.requires_grad)
            is_large = unwrapped_params >= min_num_params
            requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
            return is_large and requires_grad_uniform

        # Configure a custom min_num_params
        my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
        torch.cuda.set_device(local_rank)
        model = FSDP(
            model,
            auto_wrap_policy=custom_auto_wrap_policy,
            mixed_precision=None,
            device_id=torch.cuda.current_device(),
        )
    else:
        model = model.to(device=kwargs.get("device", "cuda"))

    kwargs["device"] = next(model.parameters()).device

    # optim
    logging.info("Build optim")
    optim = kwargs.get("optim", "adam")
    assert optim in optim_classes
    optim_class = optim_classes.get(optim)
   
    # Filter parameters for optimizer (only trainable ones)
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optim = optim_class(trainable_params,** kwargs.get("optim_conf"))

    # scheduler
    logging.info("Build scheduler")
    scheduler = kwargs.get("scheduler", "warmuplr")
    assert scheduler in scheduler_classes
    scheduler_class = scheduler_classes.get(scheduler)
    scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))

    # dataset
    logging.info("Build dataloader")
    dataloader_class = tables.dataloader_classes.get(
        kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")
    )
    dataloader = dataloader_class(**kwargs)

    trainer = Trainer(
        local_rank=local_rank,
        use_ddp=use_ddp,
        use_fsdp=use_fsdp,
        device=kwargs["device"],
        output_dir=kwargs.get("output_dir", "./exp"),** kwargs.get("train_conf"),
    )

    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
    scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler

    trainer.resume_checkpoint(
        model=model,
        optim=optim,
        scheduler=scheduler,
        scaler=scaler,
    )

    tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
    os.makedirs(tensorboard_dir, exist_ok=True)
    try:
        writer = SummaryWriter(tensorboard_dir)  # if trainer.rank == 0 else None
    except:
        writer = None

    dataloader_tr, dataloader_val = None, None
    for epoch in range(trainer.start_epoch, trainer.max_epoch):
        time1 = time.perf_counter()
        for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
            time_slice_i = time.perf_counter()
            dataloader_tr, dataloader_val = dataloader.build_iter(
                epoch, data_split_i=data_split_i, start_step=trainer.start_step
            )
            trainer.train_epoch(
                model=model,
                optim=optim,
                scheduler=scheduler,
                scaler=scaler,
                dataloader_train=dataloader_tr,
                dataloader_val=dataloader_val,
                epoch=epoch,
                writer=writer,
                data_split_i=data_split_i,
                data_split_num=dataloader.data_split_num,
                start_step=trainer.start_step,
            )
            trainer.start_step = 0
            device = next(model.parameters()).device
            if device.type == "cuda":
                with torch.cuda.device(device):
                    torch.cuda.empty_cache()
            time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
            logging.info(
                f"rank: {local_rank}, "
                f"time_escaped_epoch: {time_escaped:.3f} hours, "
                f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {dataloader.data_split_num-data_split_i} slices, {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours, "
                f"epoch: {trainer.max_epoch - epoch} epochs, {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
            )
        trainer.start_data_split_i = 0
        trainer.validate_epoch(
            model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
        )
        scheduler.step()
        trainer.step_in_epoch = 0
       
        # Save LoRA adapters separately if configured
        if lora_config is not None and (trainer.rank == 0 or trainer.rank == -1):
            lora_save_path = os.path.join(trainer.output_dir, f"epoch_{epoch+1}lora_adapters.bin")
            # For DDP, need to get the base model
            if use_ddp:
                lora_state_dict = {k: v for k, v in model.module.state_dict().items() if "lora
" in k}
            else:
                lora_state_dict = {k: v for k, v in model.state_dict().items() if "lora_" in k}
            torch.save(lora_state_dict, lora_save_path)
            logging.info(f"Saved LoRA adapters to {lora_save_path}")
       
        trainer.save_checkpoint(
            epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
        )
        time2 = time.perf_counter()
        time_escaped = (time2 - time1) / 3600.0
        logging.info(
            f"rank: {local_rank}, "
            f"time_escaped_epoch: {time_escaped:.3f} hours, "
            f"estimated to finish {trainer.max_epoch} "
            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
        )
        trainer.train_acc_avg = 0.0
        trainer.train_loss_avg = 0.0

    if trainer.rank == 0:
        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)

    trainer.close()

if name == "main":
    main_hydra()

Run.sh
#!/bin/bash
set -e

==============================================

Paraformer LoRA 微调一键启动脚本

==============================================

配置参数

DATA_DIR="/root/autodl-tmp/.autodl/data"
AUDIO_DIR="${DATA_DIR}/audio_files"
TEXT_DIR="${DATA_DIR}/text_files"
OUTPUT_DIR="/root/paraformer/outputs/lora_paraformer"
CONFIG_PATH="/root/paraformer/config.yaml"
TRAIN_SCRIPT="/root/paraformer/train.py"
PREPARE_SCRIPT="/root/paraformer/prepare_funasr_data.py"
MODEL_DIR="/root/paraformer/model"
MODEL_NAME="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"

颜色定义

RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # 无颜色

检查命令是否存在

check_command() {
    if ! command -v $1 &> /dev/null; then
        echo -e "${RED}错误: $1 命令未找到,请安装${NC}"
        exit 1
    fi
}

检查目录是否存在

check_directory() {
    if [ ! -d "$1" ]; then
        echo -e "${RED}错误: 目录 $1 不存在${NC}"
        exit 1
    fi
}

检查文件是否存在

check_file() {
    if [ ! -f "$1" ]; then
        echo -e "${RED}错误: 文件 $1 不存在${NC}"
        exit 1
    fi
}

检查Python环境

check_python_env() {
    echo -e "\n${GREEN}=== 检查Python环境 ==="
    check_command "python3"
    check_command "pip3"
   
    # 检查必要的Python包
    required_packages=("torch" "soundfile" "tqdm" "hydra-core" "tensorboardX" "funasr")
   
    for pkg in "${required_packages[@]}"; do
        if ! python3 -c "import $pkg" &> /dev/null; then
            echo -e "${YELLOW}警告: $pkg 未安装,正在安装...${NC}"
            pip3 install $pkg
        fi
    done
   
    echo -e "${GREEN}Python环境检查完成${NC}"
}

检查CUDA环境

check_cuda_env() {
    echo -e "\n${GREEN}=== 检查CUDA环境 ==="
    if ! python3 -c "import torch; print(torch.cuda.is_available())" | grep -q "True"; then
        echo -e "${RED}错误: CUDA不可用,请检查您的GPU环境${NC}"
        exit 1
    fi
   
    NUM_GPUS=$(python3 -c "import torch; print(torch.cuda.device_count())")
    echo -e "检测到 $NUM_GPUS 个GPU"
    echo -e "${GREEN}CUDA环境检查完成${NC}"
}

检查数据目录结构

check_data_dir() {
    echo -e "\n${GREEN}=== 检查数据目录结构 ==="
    check_directory "$DATA_DIR"
    check_directory "$AUDIO_DIR"
    check_directory "$TEXT_DIR"
   
    # 检查音频文件
    audio_files=$(find "$AUDIO_DIR" -type f ( -name ".wav" -o -name ".flac" -o -name ".mp3" -o -name ".m4a" -o -name ".ogg" ) | wc -l)
    if [ $audio_files -eq 0 ]; then
        echo -e "${RED}错误: 在 $AUDIO_DIR 目录下未找到音频文件${NC}"
        echo -e "支持的音频格式: .wav, .flac, .mp3, .m4a, .ogg"
        exit 1
    fi
    echo -e "找到音频文件数量: $audio_files"
   
    # 检查文本文件
    text_files=$(find "$TEXT_DIR" -type f -name "
.txt" | wc -l)
    if [ $text_files -eq 0 ]; then
        echo -e "${RED}错误: 在 $TEXT_DIR 目录下未找到文本文件(.txt)${NC}"
        exit 1
    fi
    echo -e "找到文本文件数量: $text_files"
   
    echo -e "${GREEN}数据目录结构检查完成${NC}"
}

检查配置文件

check_config() {
    echo -e "\n${GREEN}=== 检查配置文件 ==="
    check_file "$CONFIG_PATH"
    check_file "$TRAIN_SCRIPT"
    check_file "$PREPARE_SCRIPT"
   
    # 检查模型目录
    if [ ! -d "$MODEL_DIR" ]; then
        echo -e "${YELLOW}警告: 模型目录 $MODEL_DIR 不存在,将创建并在训练时自动下载${NC}"
        mkdir -p "$MODEL_DIR"
    fi
   
    echo -e "${GREEN}配置文件检查完成${NC}"
}

准备数据

prepare_data() {
    echo -e "\n${GREEN}=== 开始数据准备 ==="
    echo -e "数据目录: $DATA_DIR"
    echo -e "音频目录: $AUDIO_DIR"
    echo -e "文本目录: $TEXT_DIR"
   
    # 调用数据准备脚本,添加--split参数来生成训练集/开发集/测试集
    python3 "$PREPARE_SCRIPT"
        --audio-dir "$AUDIO_DIR"
        --text-dir "$TEXT_DIR"
        --output-dir "$DATA_DIR"
        --split
   
    # 检查生成的文件
    TRAIN_FILE="${DATA_DIR}/train.jsonl"
    DEV_FILE="${DATA_DIR}/dev.jsonl"
    TEST_FILE="${DATA_DIR}/test.jsonl"
   
    if [ ! -f "$TRAIN_FILE" ] || [ ! -f "$DEV_FILE" ] || [ ! -f "$TEST_FILE" ]; then
        echo -e "${RED}错误: 数据准备失败,未生成必要的JSONL文件${NC}"
        exit 1
    fi
   
    train_samples=$(wc -l < "$TRAIN_FILE")
    dev_samples=$(wc -l < "$DEV_FILE")
    test_samples=$(wc -l < "$TEST_FILE")
   
    echo -e "训练集样本数: $train_samples"
    echo -e "开发集样本数: $dev_samples"
    echo -e "测试集样本数: $test_samples"
    echo -e "${GREEN}数据准备完成${NC}"
}

开始训练

start_training() {
    echo -e "\n${GREEN}=== 开始训练 ==="
    echo -e "输出目录: $OUTPUT_DIR"
    echo -e "配置文件: $CONFIG_PATH"
    echo -e "模型名称: $MODEL_NAME"
    echo -e "模型目录: $MODEL_DIR"
   
    # 创建输出目录
    mkdir -p "$OUTPUT_DIR"
   
    # 训练命令
    TRAIN_CMD="python3 $TRAIN_SCRIPT
        --config-path $(dirname "$CONFIG_PATH")
        --config-name $(basename "$CONFIG_PATH" .yaml)
        output_dir=$OUTPUT_DIR
        model.device=cuda"
   
    echo -e "训练命令: $TRAIN_CMD"
   
    # 检查GPU数量
    if [ $NUM_GPUS -gt 1 ]; then
        echo -e "使用多GPU分布式训练($NUM_GPUS 个GPU)"
        torchrun --nproc_per_node=$NUM_GPUS $TRAIN_SCRIPT
            --config-path $(dirname "$CONFIG_PATH")
            --config-name $(basename "$CONFIG_PATH" .yaml)
            output_dir=$OUTPUT_DIR
            model.device=cuda
    else
        echo -e "使用单GPU训练"
        $TRAIN_CMD
    fi
}

训练完成

finish_training() {
    echo -e "\n${GREEN}=== 训练完成 ==="
    echo -e "训练结果保存在: ${OUTPUT_DIR}"
   
    # 检查生成的模型文件
    lora_files=$(find "$OUTPUT_DIR" -name "lora_adapters.bin" | wc -l)
    if [ $lora_files -eq 0 ]; then
        echo -e "${YELLOW}警告: 未找到LoRA适配器文件,训练可能未完成${NC}"
    else
        echo -e "生成的LoRA适配器文件: ${OUTPUT_DIR}/epoch_
_lora_adapters.bin"
    fi
   
    echo -e "\n${GREEN}训练流程结束${NC}"
    echo -e "使用方法:"
    echo -e "1. 应用LoRA适配器: 将生成的lora_adapters.bin文件与预训练模型结合使用"
    echo -e "2. 推理: 使用FunASR的推理接口加载微调后的模型"
    echo -e "3. 评估: 使用验证集评估模型性能"
}

主函数

main() {
    clear
    echo -e "${GREEN}========================================"
    echo -e "          Paraformer LoRA 微调"
    echo -e "========================================"
    echo -e "  数据目录: $DATA_DIR"
    echo -e "  输出目录: $OUTPUT_DIR"
    echo -e "========================================${NC}"
   
    # 执行检查
    check_python_env
    check_cuda_env
    check_data_dir
    check_config
   
    # 确认开始
    echo -e "\n${YELLOW}所有检查已通过,准备开始训练..."
    read -p "是否继续?(y/n) " -n 1 -r
    echo -e "\n${NC}"
   
    if [[ $REPLY =~ ^[Yy]$ ]]; then
        prepare_data
        start_training
        finish_training
    else
        echo -e "${YELLOW}训练已取消${NC}"
        exit 0
    fi
}

显示帮助信息

show_help() {
    echo -e "${GREEN}Paraformer LoRA 微调一键启动脚本${NC}"
    echo -e "用法: $0 [选项]"
    echo -e "选项:"
    echo -e "  --help      显示帮助信息"
    echo -e "  --prepare   仅执行数据准备"
    echo -e "  --train     仅执行训练(需先准备数据)"
    echo -e "  --check     仅执行环境检查"
}

解析命令行参数

parse_args() {
    case "$1" in
        --help)
            show_help
            exit 0
            ;;
        --prepare)
            check_python_env
            check_data_dir
            check_config
            prepare_data
            exit 0
            ;;
        --train)
            check_python_env
            check_cuda_env
            check_config
            start_training
            finish_training
            exit 0
            ;;
        --check)
            check_python_env
            check_cuda_env
            check_data_dir
            check_config
            echo -e "\n${GREEN}所有检查通过!${NC}"
            exit 0
            ;;
        "")
            main
            ;;
        *)
            echo -e "${RED}错误: 未知选项 $1${NC}"
            show_help
            exit 1
            ;;
    esac
}

开始执行

parse_args "$@"

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions