In [None]:
!pip install -q optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html
!pip install -q -e . -f https://storage.googleapis.com/libtpu-releases/index.html
!pip install -q trl peft
!pip install -q ipywidgets widgetsnbextension
!pip install -q datasets evaluate accelerate
!pip install -q nltk jieba evaluate rouge_score sacrebleu

!export PJRT_DEVICE=TPU

In [None]:
import torch
from transformers import AutoModelForCausalLM,  AutoTokenizer, DataCollatorForLanguageModeling
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
from optimum.tpu import fsdp_v2
import os
from transformers import TrainerCallback
import torch_xla.core.xla_model as xm
import json
from datasets import Dataset
import evaluate
import jieba
import numpy as np
from accelerate.utils import extract_model_from_parallel

os.environ["TPU_NAME"] = "v3-8"
try:
    os.environ.pop('TPU_PROCESS_ADDRESSES')
except:
    pass

In [None]:
def load_corpus(file_path):
    import random
    random.seed(42)
    """加载语料"""
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    def wrap_prompt(src, tgt, lang):
        # 构建统一的系统提示
        system_prompt = f"""<|im_start|>system
你是专注于xxx领域的资深翻译专家，专门负责将英文文档精准翻译成 {lang}。

## 翻译原则：

### 1. 术语准确性
- 严格使用xxx行业标准术语
- 保持技术参数的精确性和专业性
- 遵循xxx领域的权威表达方式

### 2. 句式结构与技术逻辑忠实性
- 英文常见被动语态，在{lang}中需转换为符合习惯的**主动语态或自然表述**，避免生硬直译
- 对复杂英文长句，按{lang}习惯进行合理拆分或重组，确保**逻辑清晰、易于理解**，同时**不丢失任何细节**
- 逻确保技术逻辑连接词翻译准确，因果关系、条件关系明确

**请逐句审阅，严格遵循以上所有规范。**<|im_end|>"""

        user_prompt = f"<|im_start|>user\n### 请将以下文本准确翻译成{lang}：\n{src}<|im_end|>\n\n"

        # 构建助手回复
        assistant_response = f"{tgt}<|im_end|>"

        return {
                "systemprompt": system_prompt,
                "userprompt": user_prompt,
                "prompt": f"{system_prompt}\n{user_prompt}",
                "completion": assistant_response,
            }

    def process_pairs(pairs, shuffle=True):
        # 构建ChatML格式的数据，按语言类型分组
        zhcn_samples = []
        zhtw_samples = []

        for src, tgt, *tgt2 in pairs:
            zhcn_samples.append(wrap_prompt(src, tgt, "简体中文"))
            if len(tgt2) > 0:
                zhtw_samples.append(wrap_prompt(src, tgt2[0], "繁体中文"))

        # 按语言类型分组，避免在同一序列中混合不同语言
        processed = []
        if shuffle:
            random.shuffle(zhcn_samples)
            random.shuffle(zhtw_samples)
        # 先添加简体中文样本
        processed.extend(zhcn_samples)
        # 再添加繁体中文样本
        processed.extend(zhtw_samples)

        return processed
    
    train_data = process_pairs(data['train'])
    val_data = process_pairs(data['validation'],False)
    
    return train_data, val_data

In [None]:
def pack_sequences(examples, tokenizer, max_length=768):
    import random
    random.seed(42)

    # 第一步：遍历所有语料，编码并按语言分组
    chs_prompts = []  # 简体中文prompts
    cht_prompts = []  # 繁体中文prompts
    chs_full_system_prompt = None  # 简体中文完整system prompt
    cht_full_system_prompt = None  # 繁体中文完整system prompt
    chs_short_system_prompt = None  # 简体中文简短system prompt
    cht_short_system_prompt = None  # 繁体中文简短system prompt

    def extract_target_language(user_prompt):
        if "简体中文" in user_prompt:
            return "简体中文"
        else:
            return "繁体中文"

    def create_short_system_prompt(language):
        """创建简短版本的系统提示词"""
        if language == "简体中文":
            return f"\n<|im_start|>system\n你是专注于xxx领域的资深翻译专家，专门负责将英文文档精准翻译成简体中文。<|im_end|>"
        else:
            return f"\n<|im_start|>system\n你是专注于xxx领域的资深翻译专家，专门负责将英文文档精准翻译成繁体中文。<|im_end|>"

    # 遍历所有语料，编码user prompt、completion，并拼接为完整格式
    for example in examples:
        target_language = extract_target_language(example["userprompt"])

        # 编码user prompt和completion
        user_prompt_tokens = tokenizer("\n" + example["userprompt"],
                                     add_special_tokens=False,
                                     return_tensors="pt")["input_ids"][0]
        completion_tokens = tokenizer(example["completion"],
                                    add_special_tokens=False,
                                    return_tensors="pt")["input_ids"][0]

        # 拼接为完整的prompt_tokens（不包含system prompt）
        prompt_tokens = torch.cat([user_prompt_tokens, completion_tokens])

        # 按语言保存到对应数组
        if target_language == "简体中文":
            chs_prompts.append(prompt_tokens.tolist())
            # 保存完整和简短system prompt（只需要保存一次）
            if chs_full_system_prompt is None:
                chs_full_system_prompt = tokenizer(example["systemprompt"],
                                                 add_special_tokens=False,
                                                 return_tensors="pt")["input_ids"][0].tolist()
                # 创建并编码简短版本
                short_prompt = create_short_system_prompt(target_language)
                chs_short_system_prompt = tokenizer(short_prompt,
                                                  add_special_tokens=False,
                                                  return_tensors="pt")["input_ids"][0].tolist()
        else:
            cht_prompts.append(prompt_tokens.tolist())
            # 保存完整和简短system prompt（只需要保存一次）
            if cht_full_system_prompt is None:
                cht_full_system_prompt = tokenizer(example["systemprompt"],
                                                 add_special_tokens=False,
                                                 return_tensors="pt")["input_ids"][0].tolist()
                # 创建并编码简短版本
                short_prompt = create_short_system_prompt(target_language)
                cht_short_system_prompt = tokenizer(short_prompt,
                                                  add_special_tokens=False,
                                                  return_tensors="pt")["input_ids"][0].tolist()

    # 第二步：按元素长度进行升序排序
    sorted_chs_prompts = sorted(chs_prompts, key=len)
    sorted_cht_prompts = sorted(cht_prompts, key=len)

    # 第三步：按当前逻辑遍历原始prompts，合并对应的system prompt和翻译prompt
    packed_inputs = []
    current_tokens = []
    current_language = None

    def get_system_prompts_for_language(language):
        """获取指定语言的完整和简短系统提示词"""
        if language == "简体中文":
            return chs_full_system_prompt, chs_short_system_prompt
        else:
            return cht_full_system_prompt, cht_short_system_prompt

    def get_sorted_prompts_for_language(language):
        if language == "简体中文":
            return sorted_chs_prompts
        else:
            return sorted_cht_prompts

    def pad_with_cached_prompts(current_tokens, padding_length, language):
        """使用缓存的prompt进行填充，避免过度使用pad_token_id"""
        sorted_prompts = get_sorted_prompts_for_language(language)

        # 由于sorted_prompts已按长度排序，使用二分查找找到合适的范围
        # 找到第一个长度大于padding_length的索引
        suitable_end_idx = 0
        for i, prompt in enumerate(sorted_prompts):
            if len(prompt) > padding_length:
                break
            suitable_end_idx = i + 1

        # 获取所有合适长度的prompts
        suitable_prompts = sorted_prompts[:suitable_end_idx].copy()

        while padding_length > 0 and suitable_prompts:
            # 随机选择一个合适的prompt
            selected_prompt = random.choice(suitable_prompts)

            current_tokens.extend(selected_prompt)
            padding_length -= len(selected_prompt)

            # 移除已使用的prompt，避免重复使用
            suitable_prompts.remove(selected_prompt)

            # 更新suitable_prompts，移除现在长度超过剩余padding_length的prompts
            suitable_prompts = [p for p in suitable_prompts if len(p) <= padding_length]

        # 如果还有剩余长度，用pad_token_id填充
        if padding_length > 0:
            current_tokens.extend([tokenizer.pad_token_id] * padding_length)

        return current_tokens

    # 周期性系统提示词策略配置
    full_system_interval = 4  # 每4个样本使用一次完整系统提示词
    chs_index = 0
    cht_index = 0
    samples_in_current_sequence = 0  # 当前序列中的样本数量

    # 按原始顺序遍历examples，使用已处理的prompts
    for i, example in enumerate(examples):
        target_language = extract_target_language(example["userprompt"])

        # 获取对应语言的完整和简短系统提示词
        full_system_prompt, short_system_prompt = get_system_prompts_for_language(target_language)

        # 使用已处理的prompt数据
        if target_language == "简体中文":
            processed_prompt_tokens = chs_prompts[chs_index]
            chs_index += 1
        else:
            processed_prompt_tokens = cht_prompts[cht_index]
            cht_index += 1

        # 检查是否需要开始新序列（语言变化或长度超限）
        is_first_in_sequence = False
        if current_language is None or current_language != target_language:
            is_first_in_sequence = True
            current_language = target_language
            samples_in_current_sequence = 0

        # 决定使用哪种系统提示词
        if is_first_in_sequence or samples_in_current_sequence % full_system_interval == 0:
            # 使用完整系统提示词（序列开始或周期性强化）
            system_prompt_to_use = full_system_prompt
        else:
            # 使用简短系统提示词
            system_prompt_to_use = short_system_prompt

        # 构建tokens
        if is_first_in_sequence:
            # 第一个样本包含系统提示词
            tokens = system_prompt_to_use + processed_prompt_tokens
            samples_in_current_sequence = 1
        else:
            # 检查是否需要添加系统提示词（周期性或长度限制）
            if samples_in_current_sequence % full_system_interval == 0:
                # 周期性添加完整系统提示词
                tokens = system_prompt_to_use + processed_prompt_tokens
            else:
                # 只添加简短系统提示词
                tokens = system_prompt_to_use + processed_prompt_tokens

            # 检查长度是否超限
            if (len(current_tokens) + len(tokens) > max_length):
                is_first_in_sequence = True
                # 超长时重新开始，使用完整系统提示词
                tokens = full_system_prompt + processed_prompt_tokens
                samples_in_current_sequence = 1
            else:
                samples_in_current_sequence += 1

        # 检查是否需要开始新的序列
        if is_first_in_sequence:
            # 保存当前序列（如果有内容）
            if len(current_tokens) > 0:
                # 第四步：使用改进的填充策略
                padding_length = max_length - len(current_tokens)
                if padding_length > 0:
                    current_tokens = pad_with_cached_prompts(current_tokens, padding_length, current_language)
                text = tokenizer.decode(torch.tensor(current_tokens))
                packed_inputs.append({"text": text})

            # 开始新的序列
            current_tokens = tokens.copy()
        else:
            # 添加到当前序列
            current_tokens.extend(tokens)

    # 处理最后一个序列
    if len(current_tokens) > 0:
        padding_length = max_length - len(current_tokens)
        if padding_length > 0:
            current_tokens = pad_with_cached_prompts(current_tokens, padding_length, current_language)
        text = tokenizer.decode(torch.tensor(current_tokens))
        packed_inputs.append({"text": text})

    return packed_inputs

In [6]:
def storage_ptr(tensor):
    return tensor.untyped_storage().data_ptr() if tensor.is_cuda or tensor.device.type == 'cpu' else 0

def storage_size(tensor):
    return tensor.untyped_storage().nbytes() if tensor.is_cuda or tensor.device.type == 'cpu' else 0

def check_weights_on_cpu_with_storage(state_dict):
    for name, tensor in state_dict.items():
        # print(f"Parameter '{name}' 的设备: {tensor.device}")
        # 检查设备是否为 CPU
        if tensor.device.type != 'cpu':
            return False, f"Parameter '{name}' is on device {tensor.device}, not CPU."
        # 检查存储指针和大小是否有效
        if storage_ptr(tensor) == 0:
            return False, f"Parameter '{name}' has invalid storage pointer (0)."
        if storage_size(tensor) == 0:
            return False, f"Parameter '{name}' has invalid storage size (0 bytes)."
    return True, "All weights are on CPU with valid storage."

In [None]:
import shutil
import time

# 修改自定义检查点保存回调
# 修改自定义检查点保存回调
class CustomCheckpointCallback(TrainerCallback):
    """自定义检查点保存回调，根据评估损失决定是否保存模型"""
    
    def __init__(self, patience=3, tokenizer=None):
        self.best_score = float('inf')  # 初始化为无穷大,因为要最小化损失
        self.patience = patience
        self.no_improvement_count = 0
        self.tokenizer = tokenizer
        self.best_step = None  # 记录最佳步骤
        self.best_state_dict = None  # 用于缓存最佳模型的权重
    
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if not metrics:
            return
        
        # 获取当前评估损失
        current_loss = metrics.get("eval_loss", float('inf'))
        # 获取当前步骤
        current_step = state.global_step
        
        print(f"\n当前步骤: {current_step}, 当前评估损失: {current_loss}, 最佳损失: {self.best_score}")
        
        # 如果当前损失小于最佳损失,可能需要保存模型
        if current_loss < self.best_score:
            print(f"发现更好的模型 (Loss: {current_loss} < {self.best_score})")

            # 更新最佳分数和重置耐心计数器
            self.best_score = current_loss
            self.no_improvement_count = 0
            self.best_step = current_step
            
            # 获取模型
            model = kwargs.get("model")
            # 清理之前的缓存（如果有）
            if self.best_state_dict is not None:
                del self.best_state_dict
                # 强制进行垃圾回收
                import gc
                gc.collect()
            
            # 将当前模型权重缓存到内存中
            self._cache_model_weights(model)

        else:
            # 如果没有改进，增加计数器
            self.no_improvement_count += 1
            print(f"没有改进，当前耐心计数: {self.no_improvement_count}/{self.patience}")
            
            # 如果超过耐心值，可以提前停止训练
            if self.no_improvement_count >= self.patience:
                print(f"已经 {self.patience} 次评估没有改进，提前停止训练")
                control.should_training_stop = True
    
    def on_train_end(self, args, state, control, **kwargs):
        """训练结束时将最佳模型权重保存到磁盘"""
        if self.best_state_dict is not None:
            print(f"\n训练结束，保存最佳模型 (步骤: {self.best_step}, 损失: {self.best_score})")
                
        try:
            output_dir = "./model"
            
            # 如果目标目录已存在则先删除
            if os.path.exists(output_dir):
                print(f"删除现有目标目录: {output_dir}")
                shutil.rmtree(output_dir)
            
            # 创建输出目录
            os.makedirs(output_dir, exist_ok=True)
            
            # 保存模型权重
            print(f"开始保存模型权重到: {output_dir}")
            
            # 使用PeftModel的save_pretrained方法保存权重
            self._save_cached_weights(kwargs.get("model"), output_dir)
            
        except Exception as e:
            print(f"保存最佳模型时出错: {str(e)}")
            
        finally:
            # 清理内存
            del self.best_state_dict
            import gc
            gc.collect()
    
    def _cache_model_weights(self, model):
        """将模型权重缓存到内存中而不是保存到磁盘"""
        try:
            # 确保我们有模型实例
            if model is None:
                print("错误：无法获取模型实例")
                return
                
            print("正在将模型权重缓存到CPU内存...")
            
            # 记录开始时间
            extract_start_time = time.time()
            # 提取原始模型（非并行包装）
            unwrap = extract_model_from_parallel(model, recursive=True)

            state_dict = unwrap.state_dict()
            to_cpu = {}
            size = 0
            for name, param in unwrap.named_parameters():
                # 检查是否存在
                if name in state_dict:
                    del state_dict[name]
                else:
                    print(f"参数 {name} 不存在于state_dict中")

                if  "lora_" in name  or  param.requires_grad:
                    to_cpu[name] = param.data
                    size += param.data.numel() * param.data.element_size()
            # 打印state_dict剩余的参数
            if len(state_dict) > 0:
                print(f"state_dict剩余的参数: {state_dict.keys()}")

            self.best_state_dict = xm._maybe_convert_to_cpu(to_cpu)
            convert_end_time = time.time()
            print(f"TPU权重转移到CPU耗时: {convert_end_time - extract_start_time:.2f}秒，大小: {size / (1024 * 1024):.2f}MB")
            
        except Exception as e:
            print(f"缓存模型权重时出错: {str(e)}")
            self.best_state_dict = None
    
    def _save_cached_weights(self, model, output_dir):
        try:
            # 确保我们有模型实例
            if model is None:
                print("错误：无法获取模型实例")
                return

            # 记录模型提取时间
            start_time = time.time()
            unwrap = extract_model_from_parallel(model,recursive=True)
            """将缓存的模型权重保存到磁盘"""
            if self.best_state_dict is None:
                print("从TPU获取模型状态字典...")
                state_dict = xm._maybe_convert_to_cpu(unwrap.state_dict())
            else:
                print("使用缓存的模型状态字典...")
                state_dict = self.best_state_dict
            dict_end_time = time.time()
            print(f"获取/准备模型状态字典耗时: {dict_end_time - start_time:.2f}秒")
        
            # 记录保存时间
            save_start_time = time.time()
            unwrap.save_pretrained(
                    output_dir,
                    save_function=xm.save,
                    state_dict=state_dict,
            )
            save_end_time = time.time()
            print(f"save_pretrained调用耗时: {save_end_time - save_start_time:.2f}秒")
            
            # 总耗时
            total_time = time.time() - start_time
            print(f"模型检查点已成功保存到 {output_dir}，总耗时: {total_time:.2f}秒")

        except Exception as e:
            print(f"保存LoRA权重时出错: {str(e)}")

In [None]:
from typing import Any, Union

class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
    def __init__(
        self,
        *args,
        mlm: bool = False,
        ignore_index: int = -100,
        **kwargs,
    ):
        super().__init__(*args, mlm=mlm, **kwargs)
        self.ignore_index = ignore_index
        
        # 定义标记用于定位assistant回复
        self.assistant_start = "<|im_start|>assistant\n"
        self.im_end = "<|im_end|>"
        
        # 获取标记的token ids
        self.assistant_start_tokens = self.tokenizer.encode(self.assistant_start, add_special_tokens=False)
        self.im_end_tokens = self.tokenizer.encode(self.im_end, add_special_tokens=False)
        print(f"assistant_start_tokens: {self.assistant_start_tokens}")
        print(f"im_end_tokens: {self.im_end_tokens}")
    

    def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
        batch = super().torch_call(examples)

        for i in range(len(examples)):
            # 寻找assistant回复的开始位置
            start_token = self.assistant_start_tokens[0]
            start_indices = np.where(batch["labels"][i] == start_token)[0]
            end_token = self.im_end_tokens[0]
            prev_end_idx = 0
            
            for start_idx in start_indices:
                if start_idx < prev_end_idx:
                    continue

                # 验证完整的assistant_start_tokens序列
                if (self.assistant_start_tokens == 
                    batch["labels"][i][start_idx:start_idx + len(self.assistant_start_tokens)].tolist()):
                    start_pos = start_idx + len(self.assistant_start_tokens)
                    
                    # 从start_pos开始寻找im_end标记
                    end_indices = np.where(batch["labels"][i][start_pos:] == end_token)[0]
                    
                    for relative_end_idx in end_indices:
                        end_idx = start_pos + relative_end_idx
                        # 验证完整的im_end_tokens序列
                        if (self.im_end_tokens == 
                            batch["labels"][i][end_idx:end_idx + len(self.im_end_tokens)].tolist()):
                            # 只保留assistant回复部分的标签，其他部分设为ignore_index
                            batch["labels"][i, prev_end_idx:start_pos] = self.ignore_index
                            prev_end_idx = end_idx+len(self.im_end_tokens)
                            break
            if prev_end_idx < len(batch["labels"][i]):
                batch["labels"][i, prev_end_idx:] = self.ignore_index

        return batch

In [None]:
model_id = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)

bleu = evaluate.load("sacrebleu")
rouge = evaluate.load("rouge")
    
# 修改评估指标计算函数
def compute_metrics(eval_preds):

    # SFTTrainer的eval_preds是一个EvalPrediction对象
    # 其中包含predictions和label_ids
    predictions =eval_preds.predictions
    labels = eval_preds.label_ids
    
    decoded_preds = []
    decoded_labels = []
    
    for item1, item2 in zip(predictions, labels):
        item1 = np.argmax(item1, axis=-1)
        
        for seq1, seq2 in zip(item1, item2):
            # 只取labels中非-100的位置
            valid_positions = seq2 != -100
            valid_pred = seq1[valid_positions]
            valid_label = seq2[valid_positions]
            
            # 解码有效位置的token
            pred_text = tokenizer.decode(valid_pred, skip_special_tokens=True)
            label_text = tokenizer.decode(valid_label, skip_special_tokens=True)
            
            if pred_text.strip() and label_text.strip():  # 只添加非空文本
                decoded_preds.append(pred_text)
                decoded_labels.append(label_text)
    
    # 对中文文本进行分词
    decoded_preds = [' '.join(jieba.cut(pred)) for pred in decoded_preds]
    decoded_labels = [' '.join(jieba.cut(label)) for label in decoded_labels]
    
    # 计算BLEU分数
    bleu_score = bleu.compute(predictions=decoded_preds, references=[[ref] for ref in decoded_labels])
    print(f"bleu_score: {bleu_score}")
    # 计算ROUGE分数
    rouge_scores = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    print(f"rouge_scores: {rouge_scores}")
    
    return {
        'bleu': bleu_score['score'],
        'rouge1': round(float(rouge_scores['rouge1']), 4),
        'rouge2': round(float(rouge_scores['rouge2']), 4),
        'rougeL': round(float(rouge_scores['rougeL']), 4),
    }


In [None]:
def train():
    fsdp_v2.use_fsdp_v2()
    
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
    
    fsdp_training_args = {
        "fsdp": "full_shard auto_wrap",  # 启用全分片策略
        "fsdp_config": {
            "transformer_layer_cls_to_wrap": ["Qwen2DecoderLayer"],
            "xla": True,
            "xla_fsdp_v2": True,
            "xla_fsdp_grad_ckpt": True 
        }
    }
    
    lora_config = LoraConfig(
            r=128,
            lora_alpha=256,
            lora_dropout=0.05,
            bias="none",
            target_modules=[
                "q_proj", 
                "v_proj",
                "o_proj", 
                "gate_proj",
                "up_proj",
                "down_proj"
            ],
            modules_to_save=["lm_head"],
            task_type="CAUSAL_LM",
    )
    
    train_data, val_data = load_corpus('./translation_dataset.json')
    train_dataset = Dataset.from_list(pack_sequences(train_data, tokenizer, max_length=1024))
    val_dataset = Dataset.from_list(pack_sequences(val_data, tokenizer, max_length=1024))
    
    training_args = SFTConfig(
        max_seq_length=1024,
        # packing=True,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        num_train_epochs=10,
        evaluation_strategy="steps",
        eval_steps=10,
        eval_do_concat_batches=False,
        save_strategy="no",
        lr_scheduler_type="cosine",
        learning_rate=2e-5,
        warmup_ratio=0.1,
        weight_decay=0.01,
        metric_for_best_model="eval_loss",
        # metric_for_best_model="bleu",   # 使用BLEU分数作为最佳模型选择标准
        # greater_is_better=True,         # BLEU分数越高越好
        output_dir="./tmp/output",
        optim="adamw_torch_xla",
        logging_steps=1,
        dataloader_drop_last=True,  # Required by FSDP v2
        **fsdp_training_args,
    )

    checkpoint_callback = CustomCheckpointCallback(
        patience=3,
        tokenizer=tokenizer
    )
    
    completion_collator = DataCollatorForCompletionOnlyLM(
        tokenizer=tokenizer,
        mlm=False,
    )
    
    # 在trainer初始化时添加回调
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        args=training_args,
        processing_class=tokenizer,
        peft_config=lora_config,
        data_collator=completion_collator,
        # compute_metrics=compute_metrics,  # 添加评估指标计算函数
        callbacks=[checkpoint_callback]  # 添加自定义检查点回调
    )
    
    trainer.train()
    return trainer

In [None]:
import torch_xla.runtime as xr
xr.initialize_cache('xla_cache', readonly=False)

trainer = train()