In [None]:
!pip install sacremoses
!pip install evaluate
!pip install sacrebleu
!pip install optimum[onnxruntime]

In [None]:
import torch
from transformers import (
    MarianMTModel,
    MarianTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import numpy as np
import evaluate

In [None]:
class AudioTranslationDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.source_texts = [pair[0] for pair in texts]  # 英文文本
        self.target_texts = [pair[1] for pair in texts]  # 中文文本
        self.max_length = max_length

    def __len__(self):
        return len(self.source_texts)

    def __getitem__(self, idx):
        source_text = str(self.source_texts[idx])
        target_text = str(self.target_texts[idx])

        # 编码源文本(英文)
        source_encoding = self.tokenizer(
            source_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # 编码目标文本(中文)
        with self.tokenizer.as_target_tokenizer():
            target_encoding = self.tokenizer(
                target_text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

        return {
            'input_ids': source_encoding['input_ids'].flatten(),
            'attention_mask': source_encoding['attention_mask'].flatten(),
            'labels': target_encoding['input_ids'].flatten()
        }

def load_corpus(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    # 处理训练集和验证集中的换行符
    def process_pairs(pairs):
        return [(src.replace('\n', '<nl>'), tgt.replace('\n', '<nl>'))
                for src, tgt in pairs]

    train_pairs = process_pairs(data['train'])
    val_pairs = process_pairs(data['validation'])

    return train_pairs, val_pairs

In [None]:
def compute_metrics(eval_preds):
    """计算评估指标"""
    metric = evaluate.load("sacrebleu")  # 使用 evaluate.load 替代 load_metric
    predictions, labels = eval_preds

    # 解码预测结果
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # 解码真实标签
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # 计算BLEU分数
    result = metric.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])

    # 添加生成的样本展示
    if len(decoded_preds) > 2:
        print("\n===== 翻译样本 =====")
        for pred, label in list(zip(decoded_preds, decoded_labels))[:2]:
            print(f"\n预测: {pred}")
            print(f"实际: {label}")

    return {"bleu": result["score"]}

def train_model(model, train_dataset, val_dataset, output_dir, num_epochs=10):
    """使用Seq2SeqTrainer训练模型"""

    # 定义训练参数
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        warmup_steps=100,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=50,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=3,
        predict_with_generate=True,
        fp16=torch.cuda.is_available(),  # 如果可用则使用混合精度训练
        lr_scheduler_type="cosine",
        load_best_model_at_end=True,
        metric_for_best_model="bleu",
        greater_is_better=True,
        report_to=["none"],
        push_to_hub=True,
        hub_model_id="opus-mt-en-zh-finetuned-audio-product",
        hub_token=hf_token,
        hub_strategy="end"
    )

    # 创建数据整理器
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        padding=True,
        return_tensors="pt"
    )

    # 创建训练器
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    # 开始训练
    trainer.train()
    return trainer

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# 加载模型和分词器
global tokenizer  # 使compute_metrics可以访问
model_name = "Helsinki-NLP/opus-mt-en-zh"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name).to(device)

# 加载数据集
train_pairs, val_pairs = load_corpus('translation_dataset.json')

# 创建数据集
train_dataset = AudioTranslationDataset(train_pairs, tokenizer)
val_dataset = AudioTranslationDataset(val_pairs, tokenizer)

# 训练模型
output_dir = './model/mt_en_zh'
trainer = train_model(model, train_dataset, val_dataset, output_dir, 20)
trainer.save_model(output_dir)

In [None]:
def export_to_onnx(model, tokenizer, model_name, onnx_dir):
    """导出模型为ONNX格式"""
    print("Exporting model to ONNX format...")
    model.eval()

    try:
        from optimum.onnxruntime import ORTModelForSeq2SeqLM, ORTQuantizer
        from optimum.onnxruntime.configuration import AutoQuantizationConfig
        import onnxruntime as ort
        import os

        # 获取可用的执行提供程序
        available_providers = ort.get_available_providers()
        print(f"Available providers: {available_providers}")

        # 选择合适的执行提供程序
        if 'CUDAExecutionProvider' in available_providers:
            provider = 'CUDAExecutionProvider'
        elif 'AzureExecutionProvider' in available_providers:
            provider = 'AzureExecutionProvider'
        else:
            provider = 'CPUExecutionProvider'

        print(f"Using provider: {provider}")

        # 创建输出目录
        os.makedirs(onnx_dir, exist_ok=True)

        # 导出并优化模型
        ort_model = ORTModelForSeq2SeqLM.from_pretrained(
            model_id=model_name,
            export=True,
            use_cache=True,
            use_io_binding=True,
            use_merged=True,
            provider=provider
        )

        # 保存到指定目录
        ort_model.save_pretrained(onnx_dir)
        print(f"Model successfully exported to: {onnx_dir}")

        encoder_quantizer = ORTQuantizer.from_pretrained(onnx_dir, file_name="encoder_model.onnx")

        # Create decoder quantizer
        # decoder_quantizer = ORTQuantizer.from_pretrained(onnx_dir, file_name="decoder_model.onnx")

        # Create decoder with past key values quantizer
        # decoder_wp_quantizer = ORTQuantizer.from_pretrained(onnx_dir, file_name="decoder_with_past_model.onnx")

        decoder_merge_quantizer = ORTQuantizer.from_pretrained(onnx_dir, file_name="decoder_model_merged.onnx")

        # Create Quantizer list
        # quantizer = [encoder_quantizer, decoder_quantizer, decoder_wp_quantizer,decoder_merge_quantizer]
        quantizer = [encoder_quantizer,decoder_merge_quantizer]
        dqconfig = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
        quantized_dir = onnx_dir

        for q in quantizer:
            q.quantize(save_dir=quantized_dir,quantization_config=dqconfig)  # doctest: +IGNORE_RESULT

    except Exception as e:
        print(f"Error exporting model to ONNX: {str(e)}")
        raise

In [None]:
# 导出为ONNX格式
onnx_dir = f'{output_dir}/onnx'
export_to_onnx(trainer.model, tokenizer, output_dir, onnx_dir)