0 导入包

In [22]:
# 加载数据集与处理数据集
from datasets import load_dataset, DatasetDict
from datasets import Audio

# 加载模型
from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration

# 训练
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer


# 评估指标
import evaluate
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

# 计算资源
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

In [23]:
model_id = "openai/whisper-small"

1 加载模型

In [24]:
# 查看特征提取器所支持的语言类型
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
TO_LANGUAGE_CODE

{'english': 'en',
 'chinese': 'zh',
 'german': 'de',
 'spanish': 'es',
 'russian': 'ru',
 'korean': 'ko',
 'french': 'fr',
 'japanese': 'ja',
 'portuguese': 'pt',
 'turkish': 'tr',
 'polish': 'pl',
 'catalan': 'ca',
 'dutch': 'nl',
 'arabic': 'ar',
 'swedish': 'sv',
 'italian': 'it',
 'indonesian': 'id',
 'hindi': 'hi',
 'finnish': 'fi',
 'vietnamese': 'vi',
 'hebrew': 'he',
 'ukrainian': 'uk',
 'greek': 'el',
 'malay': 'ms',
 'czech': 'cs',
 'romanian': 'ro',
 'danish': 'da',
 'hungarian': 'hu',
 'tamil': 'ta',
 'norwegian': 'no',
 'thai': 'th',
 'urdu': 'ur',
 'croatian': 'hr',
 'bulgarian': 'bg',
 'lithuanian': 'lt',
 'latin': 'la',
 'maori': 'mi',
 'malayalam': 'ml',
 'welsh': 'cy',
 'slovak': 'sk',
 'telugu': 'te',
 'persian': 'fa',
 'latvian': 'lv',
 'bengali': 'bn',
 'serbian': 'sr',
 'azerbaijani': 'az',
 'slovenian': 'sl',
 'kannada': 'kn',
 'estonian': 'et',
 'macedonian': 'mk',
 'breton': 'br',
 'basque': 'eu',
 'icelandic': 'is',
 'armenian': 'hy',
 'nepali': 'ne',
 'mongol

In [25]:
# 加载特征提取器
# language为训练数据语音类型
# 语音识别task="transcribe"，语音翻译task= "translate"
processor = WhisperProcessor.from_pretrained(model_id, language="chinese", task="transcribe")


model = WhisperForConditionalGeneration.from_pretrained(
    model_id,
    # load_in_8bit=True,
    
    # 参数与显存
    device_map={"": 0}, # k代表参数名前缀，空代表所有模型参数 v代表在哪张gpu上运行

    )

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2 加载数据集

In [26]:
common_voice_test = load_dataset("mozilla-foundation/common_voice_13_0", "zh-CN", split="test") # 超过10个g大小
print(len(common_voice_test)) # 显示长度

common_voice = DatasetDict()
common_voice["train"] = common_voice_test.shuffle(seed=42).select(range(3000)) # 随机选择前n条数据
common_voice["test"] = common_voice_test.shuffle(seed=42).select(range(1000)) # 随机选择前n条数据
common_voice

10624


DatasetDict({
    train: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 1000
    })
})

3 预处理 筛选列

In [27]:
# 筛选需要留下的列# 筛选需要留下的列
common_voice = common_voice.select_columns(["audio", "sentence"]) # 筛选需要留下的列# 筛选需要留下的列

3 预处理 更改采样率

In [28]:
# 更改采样率与特征提取器一致
sampling_rate = processor.feature_extractor.sampling_rate
print(sampling_rate)
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=sampling_rate))
common_voice['train'][0]

16000


{'audio': {'path': 'C:\\Users\\Administrator\\.cache\\huggingface\\datasets\\downloads\\extracted\\db1f345dfb6cd6276cbcdb1894bb538813007b78eaec0c3c146d06a22e625c55\\zh-CN_test_0/common_voice_zh-CN_22207680.mp3',
  'array': array([ 2.27373675e-13,  1.59161573e-12,  9.09494702e-13, ...,
          2.58021191e-06, -2.86191607e-06,  6.12053009e-06]),
  'sampling_rate': 16000},
 'sentence': '可分为接腰型和连腰型两大类。'}

3 预处理 processor = feature + tokenizer

In [29]:
common_voice = common_voice.map(
    lambda x : # for i in dataset , x=dataset[i]
    processor(audio=x['audio']["array"],sampling_rate=x['audio']["sampling_rate"],text=x["sentence"],),
    num_proc=1
)
common_voice

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence', 'input_features', 'labels'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['audio', 'sentence', 'input_features', 'labels'],
        num_rows: 1000
    })
})

3 预处理 留下少于30s的数据

In [30]:
# 计算时长
common_voice = common_voice.map(
    lambda x : # for i in dataset , x=dataset[i]
    {"time" : len(x['audio']["array"]) / x['audio']["sampling_rate"]},
    remove_columns=['audio', 'sentence'], # 删除列
    num_proc=1
)
common_voice

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels', 'time'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['input_features', 'labels', 'time'],
        num_rows: 1000
    })
})

In [31]:
# 筛选
common_voice= common_voice.filter(
    lambda x: 
    x["time"] < 30.0, 
)
common_voice

Filter:   0%|          | 0/3000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels', 'time'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['input_features', 'labels', 'time'],
        num_rows: 1000
    })
})

In [32]:
common_voice = common_voice.remove_columns(['time'])
common_voice

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 1000
    })
})

4 设置数据预处理器

In [33]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    def __call__(
        self, 
        dataset: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        
        # 总功能：输入一个dataset，输出{"input_features":torch,"input_ids":torch}。这个是最终输入进模型的数据。

        '''处理音频数据input_features
        1处理成列表:[{"input_features":"[长度为100]"},{"input_features":"[长度为80]"},...] 
        2填充长度 并转换成torch:{'input_values': tensor([ [长度为100],[长度为80+后面全部填充为0.00],...])
        '''
        input_features = [{"input_features": x["input_features"][0]} for x in dataset]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")


        '''处理文本token数据label
        1处理成列表:[{'input_ids': [1, 2, 3, 4]}, {'input_ids': [1, 2, 3]}]
        2转换成torch:
        {'input_ids': tensor([[1, 2, 3, 4], [1, 2, 3, 0]]), 
         'attention_mask': tensor([[1, 1, 1, 1],[1, 1, 1, 0]])}
        3用-100填充mask为0的向量tensor([[1,2,3,4],[1,2,3,-100]])
        4删去bos特殊词汇tensor([[2,3,4],[2,3,-100]])
        
        '''
        label_features = [{"input_ids": x["labels"]} for x in dataset]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        # 如果在之前分词时添加了 bos 词元，那就剪切掉，因为之后还会加上的
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch
    
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

5 设置训练完后的评估指标

In [34]:
'''评估指标原理
以单词为标准计算错误率
错误种类：
0在位置相同的情况下
1应替换 S 
2应添加 I
3应删除 D

例子：
the	cat	sat	on	the	mat  答案
the	cat	sit	on	the	     预测
✅  ✅ S   ✅  ✅ I

单词错误率wer = 错误个数/句子单词个数
         = (1 + 1 + 0)/6
         = 0.333

(也有用字母为标准计算错误率的，但用单词更加严格，也要求系统更好理解上下文)
'''

import evaluate

x ='一段语音' 
y = "the cat sat on the mat" # 答案
out = "the cat sit on the" # 预测

wer_metric = evaluate.load("wer")
wer = wer_metric.compute(references=[y], predictions=[out])
print(wer)

0.3333333333333333


In [35]:
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()
metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # 用 pad_token_id 替换 -100
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # 我们希望在计算指标时不要组合起词元
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # 计算普通的 WER
    wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)

    # 计算标准化的 WER
    pred_str_norm = [normalizer(pred) for pred in pred_str]
    label_str_norm = [normalizer(label) for label in label_str]
    # 过滤，从而在评估时只计算 reference 非空的样本
    pred_str_norm = [
        pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0
    ]
    label_str_norm = [
        label_str_norm[i]
        for i in range(len(label_str_norm))
        if len(label_str_norm[i]) > 0
    ]

    wer = 100 * metric.compute(predictions=pred_str_norm, references=label_str_norm)

    return {"wer_ortho": wer_ortho, "wer": wer}

6 设置训练参数

In [38]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,

    args=Seq2SeqTrainingArguments(
        # ——————模型保存————————
        output_dir="whisper-fted-common_voice",  # 模型保存路径
        overwrite_output_dir=True, # 同名是否覆盖
        save_strategy="epoch", # 以步数为单位保存
        # save_steps=100, # 多少步保存一次

        # ——————训练量——————————
        # max_steps=500,
        num_train_epochs=10,
        per_device_train_batch_size=12, # 批次
        gradient_accumulation_steps=4, # 累计多少次更新
        gradient_checkpointing=True, # 保存激活值减少显存占用

        # ——————训练精度————————
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),

        # ——————更新参数—————————
        warmup_steps=50,
        learning_rate=1e-5,
        lr_scheduler_type="constant_with_warmup",

        # ——————评估参数————————
        eval_steps=500,
        metric_for_best_model="wer",
        greater_is_better=False, # 代表指标越低越好
        evaluation_strategy="steps", # 以步数为单位评估
        bf16_full_eval=True, # 评估时候用bf16
        per_device_eval_batch_size=4, # 评估批次
        predict_with_generate=True, # 评估时候使用生成模式
        generation_max_length=225,

        # ——————log参数————————
        logging_steps=25,

        # ——————种子————————————

    ),

)

7 开始训练

In [39]:
trainer.train()

  0%|          | 0/620 [00:00<?, ?it/s]

{'loss': 1.7557, 'grad_norm': 15.48286247253418, 'learning_rate': 5e-06, 'epoch': 0.4}
{'loss': 0.775, 'grad_norm': 9.534825325012207, 'learning_rate': 1e-05, 'epoch': 0.8}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'loss': 0.4764, 'grad_norm': 4.3815155029296875, 'learning_rate': 1e-05, 'epoch': 1.2}
{'loss': 0.2813, 'grad_norm': 3.6703641414642334, 'learning_rate': 1e-05, 'epoch': 1.6}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'loss': 0.2807, 'grad_norm': 3.958953380584717, 'learning_rate': 1e-05, 'epoch': 2.0}




{'loss': 0.1543, 'grad_norm': 2.4052417278289795, 'learning_rate': 1e-05, 'epoch': 2.4}
{'loss': 0.1395, 'grad_norm': 3.3555829524993896, 'learning_rate': 1e-05, 'epoch': 2.8}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'loss': 0.0974, 'grad_norm': 1.369393229484558, 'learning_rate': 1e-05, 'epoch': 3.2}
{'loss': 0.0637, 'grad_norm': 1.9268174171447754, 'learning_rate': 1e-05, 'epoch': 3.6}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'loss': 0.0734, 'grad_norm': 1.2431974411010742, 'learning_rate': 1e-05, 'epoch': 4.0}




{'loss': 0.0265, 'grad_norm': 1.0082679986953735, 'learning_rate': 1e-05, 'epoch': 4.4}
{'loss': 0.0325, 'grad_norm': 1.2117432355880737, 'learning_rate': 1e-05, 'epoch': 4.8}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'loss': 0.0252, 'grad_norm': 0.8787457346916199, 'learning_rate': 1e-05, 'epoch': 5.2}
{'loss': 0.0159, 'grad_norm': 0.7080553770065308, 'learning_rate': 1e-05, 'epoch': 5.6}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'loss': 0.0151, 'grad_norm': 0.9923376441001892, 'learning_rate': 1e-05, 'epoch': 6.0}




{'loss': 0.0073, 'grad_norm': 0.35907936096191406, 'learning_rate': 1e-05, 'epoch': 6.4}
{'loss': 0.0119, 'grad_norm': 1.2927442789077759, 'learning_rate': 1e-05, 'epoch': 6.8}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'loss': 0.0074, 'grad_norm': 0.14668883383274078, 'learning_rate': 1e-05, 'epoch': 7.2}
{'loss': 0.0067, 'grad_norm': 0.5330799221992493, 'learning_rate': 1e-05, 'epoch': 7.6}
{'loss': 0.0059, 'grad_norm': 0.3225308656692505, 'learning_rate': 1e-05, 'epoch': 8.0}


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.


  0%|          | 0/250 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'eval_loss': 0.004121796693652868, 'eval_wer_ortho': 113.99999999999999, 'eval_wer': 630.527289546716, 'eval_runtime': 235.8729, 'eval_samples_per_second': 4.24, 'eval_steps_per_second': 1.06, 'epoch': 8.0}




{'loss': 0.005, 'grad_norm': 0.37355801463127136, 'learning_rate': 1e-05, 'epoch': 8.4}
{'loss': 0.0046, 'grad_norm': 0.2525150179862976, 'learning_rate': 1e-05, 'epoch': 8.8}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'loss': 0.0039, 'grad_norm': 0.2607267498970032, 'learning_rate': 1e-05, 'epoch': 9.2}
{'loss': 0.0032, 'grad_norm': 0.08447516709566116, 'learning_rate': 1e-05, 'epoch': 9.6}


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


{'train_runtime': 2343.8049, 'train_samples_per_second': 12.8, 'train_steps_per_second': 0.265, 'train_loss': 0.17221177180688227, 'epoch': 9.92}


TrainOutput(global_step=620, training_loss=0.17221177180688227, metrics={'train_runtime': 2343.8049, 'train_samples_per_second': 12.8, 'train_steps_per_second': 0.265, 'total_flos': 8.5883015135232e+18, 'train_loss': 0.17221177180688227, 'epoch': 9.92})