# NLLB模型微调

NLLB模型微调是基于Google Colab的应用。旨在在NLLB模型的基础上加入更多针对领域性语言翻译的支持，同时在原本所支持200种语言的基础上，添加对更多种语言的支持。

In [None]:
#@title **从谷歌网盘选择用于训练的数据集**

# @markdown <br/>从网盘目录中选择要用于训练的数据集(.csv/.tsv），单击选中文件，点击'Select'按钮以确认。</font>
# @markdown <br/>若希望从本地上传数据集，则跳过此步执行下一单元格。</font>
# @markdown <br/>若到这一步才上传数据集到谷歌网盘，则重复执行本单元格以刷新文件列表。</font>
!pip install geemap
from google.colab import drive
from google.colab import files
import os
import logging
from IPython.display import clear_output
import geemap

clear_output()
drive.mount('/drive')

print('Google Drive is mounted，please select file')
print('谷歌云盘挂载完毕，请选择要训练的数据集')

from ipytree import Tree, Node
import ipywidgets as widgets
from ipywidgets import interactive
# import os
from google.colab import output
output.enable_custom_widget_manager()
use_drive = True
global drive_dir
drive_dir = []

def file_tree():
    # create widgets as a simple file browser
    full_widget = widgets.HBox()
    left_widget = widgets.VBox()
    right_widget = widgets.VBox()

    path_widget = widgets.Text()
    path_widget.layout.min_width = '300px'
    select_widget = widgets.Button(
      description='Select', button_style='primary', tooltip='Select current media file.'
      )
    drive_url = widgets.Output()

    right_widget.children = [select_widget]
    full_widget.children = [left_widget]

    tree_widget = widgets.Output()
    tree_widget.layout.max_width = '300px'
    tree_widget.overflow = 'auto'

    left_widget.children = [path_widget,tree_widget]

    # init file tree
    my_tree = Tree(multiple_selection=False)
    my_tree_dict = {}
    media_names = []

    def select_file(b):
        drive_dir.append(path_widget.value)
        # full_widget.disabled = True
        # clear_output()
        print('已选择数据集，可以继续选择或执行下个单元格')
    #     if (out_file not in my_tree_dict.keys()) and (out_dir in my_tree_dict.keys()):
    #         node = Node(os.path.basename(out_file))
    #         my_tree_dict[out_file] = node
    #         parent_node = my_tree_dict[out_dir]
    #         parent_node.add_node(node)

    select_widget.on_click(select_file)

    def handle_file_click(event):
        if event['new']:
            cur_node = event['owner']
            for key in my_tree_dict.keys():
                if (cur_node is my_tree_dict[key]) and (os.path.isfile(key)):
                    if key.lower().endswith(('.csv', '.tsv')):
                        try:
                            with open(key) as f:
                                path_widget.value = key
                                path_widget.disabled = False
                                select_widget.disabled = False
                                full_widget.children = [left_widget, right_widget]
                        except Exception as e:
                            path_widget.value = key
                            path_widget.disabled = True
                            select_widget.disabled = True

                            return

    def handle_folder_click(event):
        if event['new']:
            full_widget.children = [left_widget]

    # redirect cwd to default drive root path and add nodes
    my_dir = '/drive/MyDrive'
    my_root_name = my_dir.split('/')[-1]
    my_root_node = Node(my_root_name)
    my_tree_dict[my_dir] = my_root_node
    my_tree.add_node(my_root_node)
    my_root_node.observe(handle_folder_click, 'selected')

    for root, d_names, f_names in os.walk(my_dir):
        folders = root.split('/')
        for folder in folders:
            if folder.startswith('.'):
                continue
        for d_name in d_names:
            if d_name.startswith('.'):
                d_names.remove(d_name)
        for f_name in f_names:
            if f_name.lower().endswith('.csv'):
                media_names.append(f_name)
            if f_name.lower().endswith('.tsv'):
                media_names.append(f_name)

        d_names.sort()
        f_names.sort()
        media_names.sort()
        keys = my_tree_dict.keys()

        if root not in my_tree_dict.keys():
          # print(f'root name is {root}') # folder path
          name = root.split('/')[-1] # folder name
          # print(f'folder name is {name}')
          dir_name = os.path.dirname(root) # parent path of folder
          # print(f'dir name is {dir_name}')
          parent_node = my_tree_dict[dir_name]
          node = Node(name)
          my_tree_dict[root] = node
          parent_node.add_node(node)
          node.observe(handle_folder_click, 'selected')

        if len(media_names) > 0:
              parent_node = my_tree_dict[root] # parent folders
              # print(parent_node)
              parent_node.opened = False
              for f_name in media_names:
                  node = Node(f_name)
                  node.icon = 'file'
                  full_path = os.path.join(root, f_name)
                  # print(full_path)
                  my_tree_dict[full_path] = node
                  parent_node.add_node(node)
                  node.observe(handle_file_click, 'selected')
        media_names.clear()

    with tree_widget:
      tree_widget.clear_output()
      display(my_tree)

    return full_widget


tree= file_tree()
tree

In [None]:
#@title **从本地上传数据集(可多选）**
# @markdown 若已选择谷歌盘中的数据集，则跳过此步执行下一单元格。</font>

from google.colab import files
use_drive = False
uploaded = files.upload()
file_names = []
file_names.append(list(uploaded.keys())[0])
print('已选择数据集，可以执行下个单元格')

In [None]:
#@title **划分数据集**
# @markdown 随机划分数据集为训练集，验证机和测试集，如果数据量少于10000，按8：1：1划分，如果数据量大于10000，则验证机和测试集都设置为1000，其余数据都设为训练集。</font>

import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
# Enable model load
from huggingface_hub.utils import _runtime
_runtime._is_google_colab = False
import sys
import warnings
warnings.filterwarnings("ignore")

file_basenames = []

if use_drive:
    output_dir = os.path.dirname(drive_dir[0])
    try:
        file_names = drive_dir
        for i in range(len(file_names)):
          file_basenames.append(file_names[i].split('.')[0])
        # print(file_name)
        output_dir = os.path.dirname(drive_dir[0])
    except Exception as e:
            print(f'error: {e}')
else:
    sys.path.append('/drive/content')
    if not os.path.exists(file_names[0]):
      raise ValueError(f"No {file_names[0]} found in current path.")
    else:
        try:
            for i in range(len(file_names)):
              file_basenames.append(Path(file_names[i]).stem)
            output_dir = Path(file_names[0]).parent.resolve()
            # print(file_basename)
            # print(output_dir)
        except Exception as e:
            print(f'error: {e}')

trans_df = pd.DataFrame()
for i in range(len(file_names)):
  trans_df = trans_df.append(pd.read_csv(file_names[i], sep='\t'))
print(trans_df.shape)
print(trans_df.columns)
# 首先，检查数据量大小
if len(trans_df) < 10000:
    # 如果数据量少于10000，按照8:1:1的比例随机划分
    df_train, df_temp = train_test_split(trans_df, test_size=0.2, random_state=42) # 先分割出20%的数据作为临时数据集（包含验证集和测试集）
    df_dev, df_test = train_test_split(df_temp, test_size=0.5, random_state=42) # 再将临时数据集均分为验证集和测试集
else:
    # 如果数据量大于或等于10000，验证集和测试集各为1000，其余为训练集
    # 先从总数据中分割出2000作为临时数据集
    df_train, df_temp = train_test_split(trans_df, test_size=2000, random_state=42)
    # 再将这2000条数据均分为验证集和测试集
    df_dev, df_test = train_test_split(df_temp, test_size=0.5, random_state=42)

trans_df.sample(10)

In [None]:
#@title **登陆huggingface**
!huggingface-cli login

In [None]:
#@title **选择模型和要训练的语言**
!pip install transformers sacremoses tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm.auto import tqdm, trange
import re
import sys
import unicodedata
from sacremoses import MosesPunctNormalizer

model_name = "yonyou-sg/nllb-200-distilled-1.3B" # @param ["yonyou-sg/nllb-200-distilled-600M","yonyou-sg/nllb-200-distilled-1.3B"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
src_lang_index = "tyv" #@param {type:"string"}
src_lang = "zho_Hans" #@param ["ace_Arab", "ace_Latn", "acm_Arab", "acq_Arab", "aeb_Arab", "afr_Latn", "ajp_Arab", "aka_Latn", "amh_Ethi", "apc_Arab", "arb_Arab", "ars_Arab", "ary_Arab", "arz_Arab", "asm_Beng", "ast_Latn", "awa_Deva", "ayr_Latn", "azb_Arab", "azj_Latn", "bak_Cyrl", "bam_Latn", "ban_Latn", "bel_Cyrl", "bem_Latn", "ben_Beng", "bho_Deva", "bjn_Arab", "bjn_Latn", "bod_Tibt", "bos_Latn", "bug_Latn", "bul_Cyrl", "cat_Latn", "ceb_Latn", "ces_Latn", "cjk_Latn", "ckb_Arab", "crh_Latn", "cym_Latn", "dan_Latn", "deu_Latn", "dik_Latn", "dyu_Latn", "dzo_Tibt", "ell_Grek", "eng_Latn", "epo_Latn", "est_Latn", "eus_Latn", "ewe_Latn", "fao_Latn", "pes_Arab", "fij_Latn", "fin_Latn", "fon_Latn", "fra_Latn", "fur_Latn", "fuv_Latn", "gla_Latn", "gle_Latn", "glg_Latn", "grn_Latn", "guj_Gujr", "hat_Latn", "hau_Latn", "heb_Hebr", "hin_Deva", "hne_Deva", "hrv_Latn", "hun_Latn", "hye_Armn", "ibo_Latn", "ilo_Latn", "ind_Latn", "isl_Latn", "ita_Latn", "jav_Latn", "jpn_Jpan", "kab_Latn", "kac_Latn", "kam_Latn", "kan_Knda", "kas_Arab", "kas_Deva", "kat_Geor", "knc_Arab", "knc_Latn", "kaz_Cyrl", "kbp_Latn", "kea_Latn", "khm_Khmr", "kik_Latn", "kin_Latn", "kir_Cyrl", "kmb_Latn", "kon_Latn", "kor_Hang", "kmr_Latn", "lao_Laoo", "lvs_Latn", "lij_Latn", "lim_Latn", "lin_Latn", "lit_Latn", "lmo_Latn", "ltg_Latn", "ltz_Latn", "lua_Latn", "lug_Latn", "luo_Latn", "lus_Latn", "mag_Deva", "mai_Deva", "mal_Mlym", "mar_Deva", "min_Latn", "mkd_Cyrl", "plt_Latn", "mlt_Latn", "mni_Beng", "khk_Cyrl", "mos_Latn", "mri_Latn", "zsm_Latn", "mya_Mymr", "nld_Latn", "nno_Latn", "nob_Latn", "npi_Deva", "nso_Latn", "nus_Latn", "nya_Latn", "oci_Latn", "gaz_Latn", "ory_Orya", "pag_Latn", "pan_Guru", "pap_Latn", "pol_Latn", "por_Latn", "prs_Arab", "pbt_Arab", "quy_Latn", "ron_Latn", "run_Latn", "rus_Cyrl", "sag_Latn", "san_Deva", "sat_Beng", "scn_Latn", "shn_Mymr", "sin_Sinh", "slk_Latn", "slv_Latn", "smo_Latn", "sna_Latn", "snd_Arab", "som_Latn", "sot_Latn", "spa_Latn", "als_Latn", "srd_Latn", "srp_Cyrl", "ssw_Latn", "sun_Latn", "swe_Latn", "swh_Latn", "szl_Latn", "tam_Taml", "tat_Cyrl", "tel_Telu", "tgk_Cyrl", "tgl_Latn", "tha_Thai", "tir_Ethi", "taq_Latn", "taq_Tfng", "tpi_Latn", "tsn_Latn", "tso_Latn", "tuk_Latn", "tum_Latn", "tur_Latn", "twi_Latn", "tzm_Tfng", "uig_Arab", "ukr_Cyrl", "umb_Latn", "urd_Arab", "uzn_Latn", "vec_Latn", "vie_Latn", "war_Latn", "wol_Latn", "xho_Latn", "ydd_Hebr", "yor_Latn", "yue_Hant", "zho_Hans", "zho_Hant", "zul_Latn"]
target_lang_index = "ru" #@param {type:"string"}
target_lang = "dan_Latn" #@param ["ace_Arab", "ace_Latn", "acm_Arab", "acq_Arab", "aeb_Arab", "afr_Latn", "ajp_Arab", "aka_Latn", "amh_Ethi", "apc_Arab", "arb_Arab", "ars_Arab", "ary_Arab", "arz_Arab", "asm_Beng", "ast_Latn", "awa_Deva", "ayr_Latn", "azb_Arab", "azj_Latn", "bak_Cyrl", "bam_Latn", "ban_Latn", "bel_Cyrl", "bem_Latn", "ben_Beng", "bho_Deva", "bjn_Arab", "bjn_Latn", "bod_Tibt", "bos_Latn", "bug_Latn", "bul_Cyrl", "cat_Latn", "ceb_Latn", "ces_Latn", "cjk_Latn", "ckb_Arab", "crh_Latn", "cym_Latn", "dan_Latn", "deu_Latn", "dik_Latn", "dyu_Latn", "dzo_Tibt", "ell_Grek", "eng_Latn", "epo_Latn", "est_Latn", "eus_Latn", "ewe_Latn", "fao_Latn", "pes_Arab", "fij_Latn", "fin_Latn", "fon_Latn", "fra_Latn", "fur_Latn", "fuv_Latn", "gla_Latn", "gle_Latn", "glg_Latn", "grn_Latn", "guj_Gujr", "hat_Latn", "hau_Latn", "heb_Hebr", "hin_Deva", "hne_Deva", "hrv_Latn", "hun_Latn", "hye_Armn", "ibo_Latn", "ilo_Latn", "ind_Latn", "isl_Latn", "ita_Latn", "jav_Latn", "jpn_Jpan", "kab_Latn", "kac_Latn", "kam_Latn", "kan_Knda", "kas_Arab", "kas_Deva", "kat_Geor", "knc_Arab", "knc_Latn", "kaz_Cyrl", "kbp_Latn", "kea_Latn", "khm_Khmr", "kik_Latn", "kin_Latn", "kir_Cyrl", "kmb_Latn", "kon_Latn", "kor_Hang", "kmr_Latn", "lao_Laoo", "lvs_Latn", "lij_Latn", "lim_Latn", "lin_Latn", "lit_Latn", "lmo_Latn", "ltg_Latn", "ltz_Latn", "lua_Latn", "lug_Latn", "luo_Latn", "lus_Latn", "mag_Deva", "mai_Deva", "mal_Mlym", "mar_Deva", "min_Latn", "mkd_Cyrl", "plt_Latn", "mlt_Latn", "mni_Beng", "khk_Cyrl", "mos_Latn", "mri_Latn", "zsm_Latn", "mya_Mymr", "nld_Latn", "nno_Latn", "nob_Latn", "npi_Deva", "nso_Latn", "nus_Latn", "nya_Latn", "oci_Latn", "gaz_Latn", "ory_Orya", "pag_Latn", "pan_Guru", "pap_Latn", "pol_Latn", "por_Latn", "prs_Arab", "pbt_Arab", "quy_Latn", "ron_Latn", "run_Latn", "rus_Cyrl", "sag_Latn", "san_Deva", "sat_Beng", "scn_Latn", "shn_Mymr", "sin_Sinh", "slk_Latn", "slv_Latn", "smo_Latn", "sna_Latn", "snd_Arab", "som_Latn", "sot_Latn", "spa_Latn", "als_Latn", "srd_Latn", "srp_Cyrl", "ssw_Latn", "sun_Latn", "swe_Latn", "swh_Latn", "szl_Latn", "tam_Taml", "tat_Cyrl", "tel_Telu", "tgk_Cyrl", "tgl_Latn", "tha_Thai", "tir_Ethi", "taq_Latn", "taq_Tfng", "tpi_Latn", "tsn_Latn", "tso_Latn", "tuk_Latn", "tum_Latn", "tur_Latn", "twi_Latn", "tzm_Tfng", "uig_Arab", "ukr_Cyrl", "umb_Latn", "urd_Arab", "uzn_Latn", "vec_Latn", "vie_Latn", "war_Latn", "wol_Latn", "xho_Latn", "ydd_Hebr", "yor_Latn", "yue_Hant", "zho_Hans", "zho_Hant", "zul_Latn"]
MODEL_SAVE_PATH = f'/content/models/nllb-{src_lang}-{target_lang}'

def word_tokenize(text):
    """
    Split a text into words, numbers, and punctuation marks
    (for languages where words are separated by spaces)
    """
    return re.findall('(\w+|[^\w\s])', text)

smpl = df_train.sample(10000, random_state=1)
smpl['src_lang_tokens'] = smpl[src_lang_index].apply(tokenizer.tokenize)
smpl['target_lang_tokens'] = smpl[target_lang_index].apply(tokenizer.tokenize)
smpl['src_lang_words'] = smpl[src_lang_index].apply(word_tokenize)
smpl['target_lang_words'] = smpl[target_lang_index].apply(word_tokenize)
stats = smpl[
    ['src_lang_tokens', 'target_lang_tokens', 'src_lang_words', 'target_lang_words']
].applymap(len).describe()
print("原始语言token转化比：",stats['src_lang_tokens']['mean'] / stats['src_lang_words']['mean'])
print("目标语言token转化比：",stats['target_lang_tokens']['mean'] / stats['target_lang_words']['mean'])
texts_with_unk = [
    text for text in tqdm(trans_df[src_lang_index])
    if tokenizer.unk_token_id in tokenizer(text).input_ids
]
print("未知符号数量：",len(texts_with_unk))


mpn = MosesPunctNormalizer(lang="en")
mpn.substitutions = [
    (re.compile(r), sub) for r, sub in mpn.substitutions
]

def get_non_printing_char_replacer(replace_by: str = " "):
    non_printable_map = {
        ord(c): replace_by
        for c in (chr(i) for i in range(sys.maxunicode + 1))
        # same as \p{C} in perl
        # see https://www.unicode.org/reports/tr44/#General_Category_Values
        if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
    }

    def replace_non_printing_char(line) -> str:
        return line.translate(non_printable_map)

    return replace_non_printing_char

replace_nonprint = get_non_printing_char_replacer(" ")

def preproc(text):
    clean = mpn.normalize(text)
    clean = replace_nonprint(clean)
    # replace 𝓕𝔯𝔞𝔫𝔠𝔢𝔰𝔠𝔞 by Francesca
    clean = unicodedata.normalize("NFKC", clean)
    return clean
texts_with_unk_normed = [
    text for text in tqdm(texts_with_unk)
    if tokenizer.unk_token_id in tokenizer(preproc(text)).input_ids
]
print("处理掉非标准标点符号后，未知符号数量：",len(texts_with_unk_normed))
stats

In [None]:
#@title **扩展词汇表（可选）【暂时别用，会造成CUDA的异常】**
# @markdown <br/>在上面处理掉非标准标点符号后，如果还有未知符号，那么就可以考虑扩展词汇表，我们会筛选出出现频率大于3的字符，来扩充词汇表。</font>
# @markdown <br/><br/>我们将文本转储到一个纯文本文件中，并在此文件上训练一个新的句子分词器模型，以便将其标记添加到现有的NLLB分词器中。Sentencepiece是训练分词器的流行算法之一。</font>
# @markdown <br/><br/>在训练了一个新的分词器之后，我用它执行了一个“外科手术”：从标准NLLB分词器中提取出sentencepiece模型，并用新的分词器来丰富在原始NLLB分词器中缺少的词条（基于sentencepiece仓库中的示例）。</font>
# @markdown <br/><br/>最后我们需要更新神经网络权重，为新添加的标记添加新的嵌入。在 NLLB 中，标记嵌入位于名为 shared 的参数中。它既用于编码器和解码器输入嵌入，也用于预测下一个令牌分布的最后一个解码器层。</font>
from collections import Counter
import sentencepiece as spm
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
all_texts = trans_df[src_lang_index].dropna().tolist() + trans_df[target_lang_index].dropna().tolist()
all_text_normalized = [preproc(t) for t in tqdm(all_texts)]
chars_cnt = Counter(c for t in all_text_normalized for c in t)
required_chars = ''.join([
    k for k, v in chars_cnt.most_common()
    if v >= 4 and k not in ' '
])
# 我们将文本转储到一个纯文本文件中，并在此文件上训练一个新的句子分词器模型，以便将其标记添加到现有的NLLB分词器中。Sentencepiece是训练分词器的流行算法之一。
all_texts_file = 'all_texts_plain.txt'
SPM_PREFIX = 'spm_new_text_16k'
with open(all_texts_file, 'w') as f:
    for i, text in enumerate(all_texts):
        print(text, file=f)

spm.SentencePieceTrainer.train(
    input=all_texts_file,
    model_prefix=SPM_PREFIX,
    vocab_size=2**14,  # 16K
    character_coverage = 1,
    num_threads=16,
    train_extremely_large_corpus=False,
    add_dummy_prefix=False,
    max_sentencepiece_length=128,
    max_sentence_length=4192*4,
    pad_id=0,
    eos_id=1,
    unk_id=2,
    bos_id=-1,
    required_chars=required_chars,
)
# 读取NLLB分词器和新训练的分词器模型
tokenizer = NllbTokenizer.from_pretrained(model_name)
sp_trained = spm.SentencePieceProcessor(model_file=f'{SPM_PREFIX}.model')
added_spm = sp_pb2_model.ModelProto()
added_spm.ParseFromString(sp_trained.serialized_model_proto())
old_spm = sp_pb2_model.ModelProto()
old_spm.ParseFromString(tokenizer.sp_model.serialized_model_proto())

# 将缺失的tokens加入NLLB分词器模型
nllb_tokens_set = {p.piece for p in old_spm.pieces}
prev_min_score = old_spm.pieces[-1].score
for p in added_spm.pieces:
    piece = p.piece
    if piece not in nllb_tokens_set:
        new_p = sp_pb2_model.ModelProto().SentencePiece()
        new_p.piece = piece
        # for all new tokens, I'll set a lower score (priority)
        new_p.score = p.score + prev_min_score
        old_spm.pieces.append(new_p)

# 保存结果
NEW_SPM_NAME = 'spm_nllb_new_268k.model'
with open(NEW_SPM_NAME, 'wb') as f:
    f.write(old_spm.SerializeToString())

# 加载tokenizers
tokenizer_old = NllbTokenizer.from_pretrained(model_name)
tokenizer = NllbTokenizer.from_pretrained(model_name, vocab_file=NEW_SPM_NAME)
print('原始版本分词器词汇量：',len(tokenizer_old),'新版本分词器词汇量：',len(tokenizer))
added_vocab = set(tokenizer.get_vocab()).difference(set(tokenizer_old.get_vocab()))
print('新增的词汇量：',len(added_vocab))

# 加载并调整embedding层大小
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))

# 初始化新的embedding层
for t in tqdm(added_vocab):
    tt = tokenizer_old(t, add_special_tokens=False).input_ids
    if len(tt) == 0:
        tt = [tokenizer_old.unk_token_id]
    idx = tokenizer.convert_tokens_to_ids(t)
    model.model.shared.weight.data[idx] = model.model.shared.weight.data[tt].mean(0)

In [None]:
#@title **训练模型**
from transformers.optimization import Adafactor
from transformers import get_constant_schedule_with_warmup
import random
import torch
import gc
import numpy as np
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
model.cuda()
optimizer = Adafactor(
    [p for p in model.parameters() if p.requires_grad],
    scale_parameter=False,
    relative_step=False,
    lr=1e-4,
    clip_threshold=1.0,
    weight_decay=1e-3,
)
scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=1000)
LANGS = [(src_lang_index, src_lang), (target_lang_index, target_lang)]

def get_batch_pairs(batch_size, data=df_train):
    (l1, long1), (l2, long2) = random.sample(LANGS, 2)
    xx, yy = [], []
    for _ in range(batch_size):
        item = data.iloc[random.randint(0, len(data)-1)]
        xx.append(preproc(item[l1]))
        yy.append(preproc(item[l2]))
    return xx, yy, long1, long2

print(get_batch_pairs(1))
def cleanup():
    """使用垃圾回收释放GPU显存"""
    gc.collect()
    torch.cuda.empty_cache()

batch_size = 8  #@param {type:"integer"}
training_steps = 60000  #@param {type:"integer"}
max_length = 128
losses = []

model.train()
x, y, loss = None, None, None
cleanup()

tq = trange(len(losses), training_steps)
for i in tq:
    xx, yy, lang1, lang2 = get_batch_pairs(batch_size)
    try:
        tokenizer.src_lang = lang1
        x = tokenizer(xx, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(model.device)
        tokenizer.src_lang = lang2
        y = tokenizer(yy, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(model.device)
        # -100 is a magic value ignored in the loss function
        # because we don't want the model to learn to predict padding ids
        y.input_ids[y.input_ids == tokenizer.pad_token_id] = -100

        loss = model(**x, labels=y.input_ids).loss
        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

    except RuntimeError as e:  # usually, it is out-of-memory
        optimizer.zero_grad(set_to_none=True)
        x, y, loss = None, None, None
        cleanup()
        print('error', max(len(s) for s in xx + yy), e)
        continue

    if i % 1000 == 0:
        # each 1000 steps, I report average loss at these steps
        print('Epoch的批次：',i,'平均损失函数：',np.mean(losses[-1000:]))

    if i % 1000 == 0 and i > 0:
        model.save_pretrained(MODEL_SAVE_PATH)
        tokenizer.save_pretrained(MODEL_SAVE_PATH)

In [None]:
#@title **模型评估**
# @markdown <br/> 机器翻译质量的两个最受欢迎的自动指标是 BLEU 和 ChrF++。它们都计算翻译和参考文本之间的相似性百分比。但是，它们对相似性的定义略有不同;例如，BLEU 只奖励全字匹配，而 ChrF++ 即使只有单词部分匹配，也会给出正分（例如，ChrF++ 会将翻译“течёт холод”视为与引用“несёт холодом”的相似度约为 40%，而 BLEU 将报告零相似度）。</font>
# @markdown <br/> BLEU输出的各个部分含义如下：</font>
# @markdown <br/> 总BLEU分数：例如，BLEU = 1.88或BLEU = 3.97，表示整体的翻译质量评分。</font>
# @markdown <br/> n-gram精确度：例如，8.1/2.4/1.3/0.5，分别对应1-gram、2-gram、3-gram和4-gram的精确度百分比。</font>
# @markdown <br/> BP（brevity penalty）：如果机器翻译的长度小于参考翻译的长度，会应用惩罚因子以避免过短的翻译。</font>
# @markdown <br/> ratio：翻译长度与参考长度的比率。</font>
# @markdown <br/> hyp_len和ref_len：分别是翻译文本和参考文本的长度。</font>
# @markdown <br/> chrF++输出的是一个总分，例如chrF2++ = 14.58或chrF2++ = 20.17，表示翻译质量的评分，分数越高翻译质量越好。</font>
!pip install sacrebleu
cleanup()
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
# 假设 'model' 和 'tokenizer' 可能已经被定义
try
    model
    tokenizer
    # 如果执行到这里，说明 'model' 和 'tokenizer' 已经被定义
    print("model and tokenizer are already defined.")
except NameError:
    # 如果出现 NameError 异常，说明 'model' 和 'tokenizer' 尚未定义
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
    model_load_name = MODEL_SAVE_PATH  # 请替换为你的模型名称
    model = AutoModelForSeq2SeqLM.from_pretrained(model_load_name).cuda()
    tokenizer = AutoTokenizer.from_pretrained(model_load_name)
    print("model and tokenizer have been defined.")

import sacrebleu
def translate(text, src_lang, tgt_lang, a=32, b=3, max_input_length=1024, num_beams=4, **kwargs):
    """Turn a text or a list of texts into a list of translations"""
    tokenizer.src_lang = src_lang
    tokenizer.tgt_lang = tgt_lang
    inputs = tokenizer(
        text, return_tensors='pt', padding=True, truncation=True,
        max_length=max_input_length
    )
    model.eval() # turn off training mode
    result = model.generate(
        **inputs.to(model.device),
        forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
        max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
        num_beams=num_beams, **kwargs
    )
    return tokenizer.batch_decode(result, skip_special_tokens=True)


def batched_translate(texts, batch_size=8, **kwargs):
    """Translate texts in batches of similar length"""
    idxs, texts2 = zip(*sorted(enumerate(texts), key=lambda p: len(p[1]), reverse=True))
    results = []
    for i in trange(0, len(texts2), batch_size):
        results.extend(translate(texts2[i: i+batch_size], **kwargs))
    return [p for i, p in sorted(zip(idxs, results))]

target_lang_column1 = target_lang + "_translated"
target_lang_column2 = src_lang + "_translated"
texts_to_translate = df_test[src_lang_index].tolist()
translated_texts = batched_translate(texts_to_translate, src_lang=src_lang, tgt_lang=target_lang)
df_test[target_lang_column1] = translated_texts
texts_to_translate = df_test[target_lang_index].tolist()
translated_texts = batched_translate(texts_to_translate, src_lang=target_lang, tgt_lang=src_lang)
df_test[target_lang_column2] = translated_texts

bleu_calc = sacrebleu.BLEU()
chrf_calc = sacrebleu.CHRF(word_order=2)  # this metric is called ChrF++

print(bleu_calc.corpus_score(df_test[target_lang_column2].tolist(), [df_test[src_lang_index].tolist()]))
print(chrf_calc.corpus_score(df_test[target_lang_column2].tolist(), [df_test[src_lang_index].tolist()]))
print(bleu_calc.corpus_score(df_test[target_lang_column1].tolist(), [df_test[target_lang_index].tolist()]))
print(chrf_calc.corpus_score(df_test[target_lang_column1].tolist(), [df_test[target_lang_index].tolist()]))

In [None]:
#@title **模型发布**
new_model_name = "test" #@param {type:"string"}
model.push_to_hub("yonyou-sg/"+new_model_name)
tokenizer.push_to_hub("yonyou-sg/"+new_model_name)