In [None]:
!pip uninstall -y mlflow
!pip install -U "mlflow==3.2.0"



[0mCollecting mlflow==3.2.0
  Downloading mlflow-3.2.0-py3-none-any.whl.metadata (29 kB)
Collecting mlflow-skinny==3.2.0 (from mlflow==3.2.0)
  Downloading mlflow_skinny-3.2.0-py3-none-any.whl.metadata (30 kB)
Collecting mlflow-tracing==3.2.0 (from mlflow==3.2.0)
  Downloading mlflow_tracing-3.2.0-py3-none-any.whl.metadata (19 kB)
Collecting docker<8,>=4.0.0 (from mlflow==3.2.0)
  Downloading docker-7.1.0-py3-none-any.whl.metadata (3.8 kB)
Collecting graphene<4 (from mlflow==3.2.0)
  Downloading graphene-3.4.3-py2.py3-none-any.whl.metadata (6.9 kB)
Collecting gunicorn<24 (from mlflow==3.2.0)
  Downloading gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)
Collecting databricks-sdk<1,>=0.20.0 (from mlflow-skinny==3.2.0->mlflow==3.2.0)
  Downloading databricks_sdk-0.65.0-py3-none-any.whl.metadata (39 kB)
Collecting graphql-core<3.3,>=3.1 (from graphene<4->mlflow==3.2.0)
  Downloading graphql_core-3.2.6-py3-none-any.whl.metadata (11 kB)
Collecting graphql-relay<3.3,>=3.1 (from graphene<4

In [None]:
# -*- coding: utf-8 -*-
"""
VB-LoRA DocVQA 训练 - MLflow集成版本
自动选择最优checkpoint + MLflow实验跟踪
"""

# ==============================================
# 安装依赖
# ==============================================
# %pip -q install -U "transformers>=4.53.0" "accelerate>=1.0.0" "trl>=0.9.6" \
#   "peft>=0.15.0" "bitsandbytes>=0.43.0" qwen-vl-utils pillow mlflow

import os, json, torch
from PIL import Image
from google.colab import drive
from math import floor

# ---------- 0) 挂载 Drive ----------
drive.mount('/content/drive')

# ========================================
# 🆕 MLflow 配置 - 连接现有数据库
# ========================================
import mlflow
import mlflow.pytorch
from mlflow.tracking import MlflowClient

# MLflow 设置 - 使用现有数据库
MLFLOW_HOME = "/content/drive/MyDrive/mlflow"
DB_PATH = f"{MLFLOW_HOME}/mlflow (3).db"

# 连接到现有 MLflow 数据库
mlflow.set_tracking_uri(f"sqlite:///{DB_PATH}")
mlflow.set_experiment("vblora_docvqa")

print(f"✅ 连接到现有 MLflow 数据库: {mlflow.get_tracking_uri()}")

# 显示现有实验统计
try:
    client = MlflowClient()
    experiments = client.search_experiments()
    total_runs = sum(len(client.search_runs(exp.experiment_id)) for exp in experiments)
    print(f"📊 现有实验: {len(experiments)} 个实验, {total_runs} 次运行")
    print("🔄 新的训练将添加到现有实验历史中")
except Exception as e:
    print(f"⚠️ 无法读取实验历史: {e}")

# ---------- 1) 路径配置 ----------
DRIVE_DATA_DIR = "/content/drive/MyDrive/llama_factory_data/docvqa"
JSON_PATH      = f"{DRIVE_DATA_DIR}/docvqa_validation.json"
IMAGES_ROOT    = f"{DRIVE_DATA_DIR}/images"

assert os.path.exists(JSON_PATH),  f"没找到 {JSON_PATH}"
assert os.path.exists(IMAGES_ROOT), f"没找到 {IMAGES_ROOT}"

print(f"✅ 数据路径确认完毕")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


2025/09/07 19:59:15 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/09/07 19:59:15 INFO mlflow.store.db.utils: Updating database tables
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.


✅ 连接到现有 MLflow 数据库: sqlite:////content/drive/MyDrive/mlflow/mlflow (3).db
📊 现有实验: 4 个实验, 95 次运行
🔄 新的训练将添加到现有实验历史中
✅ 数据路径确认完毕


In [None]:
print("🔧 正在安装缺失的依赖...")

# 安装 triton（NVIDIA GPU 加速库）
!pip install -q triton

# 重新安装 bitsandbytes（确保版本兼容）
!pip install -q --upgrade bitsandbytes

# 如果还有问题，试试这个
!pip install -q --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

print("✅ 依赖安装完成，重启运行时后重新运行")

🔧 正在安装缺失的依赖...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m27.5 MB/s[0m eta [36m0:00:00[0m
[?25h✅ 依赖安装完成，重启运行时后重新运行


In [None]:
# 安装/校验 mlflow
# 如果已经安装会很快；没有就会安装最新稳定版


import os, time, socket, subprocess
from google.colab import output

MLFLOW_HOME = "/content/drive/MyDrive/mlflow"
DB_PATH = f"{MLFLOW_HOME}/mlflow (3).db"
ARTIFACTS_PATH = f"{MLFLOW_HOME}/artifacts"
os.makedirs(MLFLOW_HOME, exist_ok=True)
os.makedirs(ARTIFACTS_PATH, exist_ok=True)

# 干净启动
!pkill -f "mlflow" || true

# 用 mlflow **ui** 更简单；并把路径写成 URI
cmd = [
    "mlflow", "ui",
    "--backend-store-uri", f"sqlite:///{DB_PATH}",
    "--host", "127.0.0.1",
    "--port", "5000"
]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)

# 等端口真正就绪（最多等 60 秒），失败就把日志打出来
ready = False
for _ in range(60):
    try:
        with socket.create_connection(("127.0.0.1", 5000), timeout=1):
            ready = True
            break
    except OSError:
        time.sleep(1)

if not ready:
    print("❌ MLflow UI 启动失败，日志如下：\n")
    try:
        print(proc.stdout.read())
    except Exception:
        pass
else:
    print("✅ MLflow UI 已启动在 127.0.0.1:5000，下面嵌入页面")
    output.serve_kernel_port_as_iframe(5000, height=720)
    # 如果 iframe 仍空白，可改用新窗口：output.serve_kernel_port_as_window(5000)


^C
✅ MLflow UI 已启动在 127.0.0.1:5000，下面嵌入页面


<IPython.core.display.Javascript object>

In [None]:
#改配置
from mlflow.tracking import MlflowClient
from urllib.parse import urlparse
from datetime import datetime
import mlflow, os, shutil

# ===== 配置区（改这里）=====
OLD_RUN_ID = "4ae110a507d14161a24465f39e61035e"          # 旧 run_id
OUT_DIR_GUESS = "/content/drive/MyDrive/vblora_save/vblora_highlr"  # 你训练时的输出目录（如果有）
CORRECT_PARAMS = {
    "learning_rate_logits": "3.5e-3",
    "learning_rate_vector": "1.2e-3",
    "gradient_accumulation_steps": "8",
    # ...按需补充...
}
COPY_METRICS = True
COPY_TAGS = True
UPLOAD_ALL_FROM_OUTDIR = False   # 若 True，会把 OUT_DIR_GUESS 下所有内容都当作 artifacts 上传（体积可能大）
# ==========================

client = MlflowClient()
ts = datetime.now().strftime("%Y-%m-%d")

def uri_to_path(uri: str) -> str:
    p = urlparse(uri)
    if p.scheme in ("file", ""):
        return p.path
    raise ValueError(f"不支持的 artifact_uri: {uri}")

# 1) 读旧 run
old_run = client.get_run(OLD_RUN_ID)
exp_id = old_run.info.experiment_id
old_name = old_run.data.tags.get("mlflow.runName") or OLD_RUN_ID
old_art_dir = uri_to_path(old_run.info.artifact_uri)  # 通常以 .../<exp_id>/<run_id>/artifacts 结尾

print(f"Old run: {OLD_RUN_ID} | name: {old_name}")
print("Old artifact dir:", old_art_dir, "| exists:", os.path.exists(old_art_dir))

# 2) 新建 run
new_run = client.create_run(
    experiment_id=exp_id,
    tags={"clone_of": OLD_RUN_ID, "note": f"Cloned with corrected params on {ts}"},
)
NEW_RUN_ID = new_run.info.run_id
client.set_tag(NEW_RUN_ID, "mlflow.runName", f"{old_name} (corrected {ts})")
print("New run:", NEW_RUN_ID)

# 3) 写入“更正参数”
for k, v in CORRECT_PARAMS.items():
    client.log_param(NEW_RUN_ID, k, v)

# 4) 复制/补回 artifacts
new_art_dir = uri_to_path(client.get_run(NEW_RUN_ID).info.artifact_uri)
os.makedirs(new_art_dir, exist_ok=True)

def copy_all(src_dir, dst_dir):
    copied = 0
    for entry in os.listdir(src_dir):
        s = os.path.join(src_dir, entry)
        d = os.path.join(dst_dir, entry)
        if os.path.isdir(s):
            shutil.copytree(s, d, dirs_exist_ok=True)
        else:
            shutil.copy2(s, d)
        copied += 1
    return copied

copied_source = None
if os.path.exists(old_art_dir):
    n = copy_all(old_art_dir, new_art_dir)
    copied_source = f"old_art_dir ({n} items)"
    print(f"✅ Copied ALL artifacts from old_art_dir: {n} items")
else:
    print("⚠️ 旧 run 的 artifacts 目录已不存在（多半因为重启清空了 /content）。")
    # 尝试从训练输出目录补救
    if OUT_DIR_GUESS and os.path.exists(OUT_DIR_GUESS):
        if UPLOAD_ALL_FROM_OUTDIR:
            n = copy_all(OUT_DIR_GUESS, new_art_dir)
            copied_source = f"OUT_DIR_GUESS ({n} items, FULL)"
            print(f"✅ Uploaded ALL files from OUT_DIR_GUESS: {n} items")
        else:
            # 只挑常见关键产物
            candidates = [
                "best_model", "training_summary.json",
                "checkpoint-1000", "checkpoint-900", "checkpoint-800",  # 如有
            ]
            n = 0
            for name in candidates:
                s = os.path.join(OUT_DIR_GUESS, name)
                d = os.path.join(new_art_dir, name)
                if os.path.exists(s):
                    if os.path.isdir(s):
                        shutil.copytree(s, d, dirs_exist_ok=True)
                    else:
                        shutil.copy2(s, d)
                    n += 1
            if n > 0:
                copied_source = f"OUT_DIR_GUESS (picked {n} items)"
                print(f"✅ Uploaded key artifacts from OUT_DIR_GUESS: {n} items")
            else:
                print("⚠️ OUT_DIR_GUESS 未找到可用文件；将只克隆参数/指标/标签。")
    else:
        print("ℹ️ 没提供或找不到 OUT_DIR_GUESS；将只克隆参数/指标/标签。")

if COPY_METRICS:
    # 兼容：data.metrics 是 dict[str, float]（最后值）
    metric_keys = list((old_run.data.metrics or {}).keys())
    copied = 0
    for key in metric_keys:
        try:
            # 优先取完整历史（若后台保留）
            hist = client.get_metric_history(OLD_RUN_ID, key)
            if hist:
                for pt in hist:
                    client.log_metric(
                        NEW_RUN_ID, key, pt.value,
                        step=getattr(pt, "step", None) or 0,
                        timestamp=getattr(pt, "timestamp", None)
                    )
            else:
                # 没有历史则至少回填最后值
                client.log_metric(NEW_RUN_ID, key, old_run.data.metrics[key])
            copied += 1
        except Exception as e:
            # 单个 metric 失败不影响整体
            print(f"⚠️ metric {key} 复制失败: {e}")
    print(f"📈 Copied metric history/values for {copied} keys.")


# 6) 结束新 run
client.set_terminated(NEW_RUN_ID, status="FINISHED")

print("\n🎉 Done.")
print("New artifacts dir:", new_art_dir)
print("Artifacts source:", copied_source or "none (only params/metrics/tags cloned)")
print(f"✅ New corrected run ready: {NEW_RUN_ID}")




MlflowException: Run '4ae110a507d14161a24465f39e61035e' not found

In [None]:
# 1. 先安装缺失的依赖
print("🔧 正在安装 qwen-vl-utils...")
!pip install -q qwen-vl-utils

# 2. 基础导入（不包含 qwen_vl_utils）
import os
import json
import torch
from PIL import Image
from google.colab import drive
from math import floor

# 基础科学计算库
import numpy as np
from sklearn.model_selection import train_test_split

# PyTorch 相关
from torch.utils.data import Subset

# Transformers 和相关库
from transformers import (
    Qwen2VLForConditionalGeneration,
    Qwen2VLProcessor,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    TrainerCallback,
)

# PEFT 相关
from peft import VBLoRAConfig, get_peft_model

# 优化器和调度器
from bitsandbytes.optim import Adam8bit, PagedAdamW8bit
from transformers.optimization import get_cosine_schedule_with_warmup

# MLflow 相关
import mlflow
import mlflow.pytorch
from mlflow.tracking import MlflowClient

# 其他工具
import shutil
import subprocess
import time

🔧 正在安装 qwen-vl-utils...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m47.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
DRIVE_DATA_DIR = "/content/drive/MyDrive/llama_factory_data/docvqa"
JSON_PATH = f"{DRIVE_DATA_DIR}/docvqa_validation.json"
IMAGES_ROOT = f"{DRIVE_DATA_DIR}/images"

print(f"📁 数据路径: {JSON_PATH}")
print(f"🖼️ 图像路径: {IMAGES_ROOT}")

# ---------- 2) 数据集类 ----------

def resize_image_if_needed(image, max_pixels=768*768):
    """限制图像最大像素数，保持宽高比"""
    width, height = image.size
    current_pixels = width * height

    if current_pixels <= max_pixels:
        return image

    # 计算缩放比例
    scale = (max_pixels / current_pixels) ** 0.5
    new_width = int(width * scale)
    new_height = int(height * scale)

    return image.resize((new_width, new_height), Image.Resampling.LANCZOS)

class DocVQADataset(torch.utils.data.Dataset):
    def __init__(self, json_path, images_root):
        with open(json_path, "r", encoding="utf-8") as f:
            self.samples = json.load(f)
        self.images_root = images_root

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ex = self.samples[idx]
        q = ex["conversations"][0]["value"]
        a = ex["conversations"][1]["value"]

        img_entry = ex["images"][0]
        cand = img_entry["path"] if isinstance(img_entry, dict) and "path" in img_entry else str(img_entry)
        cand = os.path.expanduser(cand)

        # 图片路径解析逻辑
        if os.path.isabs(cand) and os.path.exists(cand):
            img_path = cand
        else:
            base = os.path.basename(cand)
            p1 = os.path.join(self.images_root, base)
            p2 = os.path.join(self.images_root, cand)
            if os.path.exists(p1):
                img_path = p1
            elif os.path.exists(p2):
                img_path = p2
            else:
                raise FileNotFoundError(
                    f"找不到图片：\n  原始: {cand}\n  尝试1: {p1}\n  尝试2: {p2}\n"
                    f"请确认 {self.images_root} 下是否有对应文件名。"
                )

        # 🔥 关键修改：加载图像并限制尺寸
        image = Image.open(img_path).convert("RGB")
        image = resize_image_if_needed(image, max_pixels=768*768)

        chat = [
            {"role": "user", "content": [
                {"type": "image", "image": image},
                {"type": "text",  "text": q},
            ]},
            {"role": "assistant", "content": [{"type": "text", "text": a}]},
        ]
        return chat

# 创建数据集
full_ds = DocVQADataset(JSON_PATH, IMAGES_ROOT)

# 划分训练/验证集 - 随机划分版本
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
import numpy as np

# 设置随机种子确保可复现
np.random.seed(42)

# 创建所有样本的索引
n = len(full_ds)
all_indices = list(range(n))

# 随机划分：90% 训练，10% 验证
train_indices, eval_indices = train_test_split(
    all_indices,
    test_size=0.1,        # 10% 作为验证集
    random_state=42,      # 固定随机种子
    shuffle=True          # 确保打乱
)

# 创建数据集子集
train_ds = Subset(full_ds, train_indices)
eval_ds = Subset(full_ds, eval_indices)

print(f"✅ 数据就绪：Train {len(train_ds)}, Eval {len(eval_ds)}")
print(f"📊 数据划分比例：Train {len(train_ds)/n:.1%}, Eval {len(eval_ds)/n:.1%}")
print(f"🖼️ 图像尺寸已限制为最大 768x768 像素以避免OOM")

# ---------- 3) 模型配置 ----------
from transformers import (
    Qwen2VLForConditionalGeneration,
    Qwen2VLProcessor,
    BitsAndBytesConfig,
)

# 检查设备能力
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
compute_dtype = torch.bfloat16 if use_bf16 else torch.float16

# 量化配置
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
)

model_id = "Qwen/Qwen2-VL-2B-Instruct"

print(f"🚀 开始加载模型: {model_id}")

# 加载模型和处理器
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=compute_dtype,
    quantization_config=bnb_config
)
processor = Qwen2VLProcessor.from_pretrained(model_id)

print(f"✅ 模型加载完成")

# ---------- 4) VB-LoRA 配置和挂载 ----------
from peft import VBLoRAConfig, get_peft_model

# VB-LoRA 配置
peft_config = VBLoRAConfig(
    task_type="CAUSAL_LM",
    r=4,
    num_vectors=128,
    vector_length=128,
    topk=12,
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    vblora_dropout=0.0,
    save_only_topk_weights=False,
)

# 🔥 关键：在创建 Trainer 之前挂载 VB-LoRA 适配器
print("🔧 正在挂载 VB-LoRA 适配器...")
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print("✅ VB-LoRA 适配器挂载完成")

# 模型训练设置
model.train()
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

📁 数据路径: /content/drive/MyDrive/llama_factory_data/docvqa/docvqa_validation.json
🖼️ 图像路径: /content/drive/MyDrive/llama_factory_data/docvqa/images


`torch_dtype` is deprecated! Use `dtype` instead!


✅ 数据就绪：Train 4814, Eval 535
📊 数据划分比例：Train 90.0%, Eval 10.0%
🖼️ 图像尺寸已限制为最大 768x768 像素以避免OOM
🚀 开始加载模型: Qwen/Qwen2-VL-2B-Instruct


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/429M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/272 [00:00<?, ?B/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/347 [00:00<?, ?B/s]

The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

chat_template.json: 0.00B [00:00, ?B/s]

✅ 模型加载完成
🔧 正在挂载 VB-LoRA 适配器...
trainable params: 4,632,576 || all params: 2,213,618,176 || trainable%: 0.2093
✅ VB-LoRA 适配器挂载完成


In [None]:
# 1. 添加缺失的导入
# ========================================
from qwen_vl_utils import process_vision_info
from transformers import TrainerCallback
import shutil

# ========================================
# 2. 定义缺失的变量
# ========================================
# 检查是否有chat_template
HAS_TEMPLATE = bool(getattr(processor.tokenizer, "chat_template", None))
print(f"📝 Chat template 可用: {HAS_TEMPLATE}")

# ========================================
# 3. 修复函数名称不一致问题
# ========================================
# 将 expected_collate_fn 重命名为 fixed_collate_fn，或者修改调用处
def fixed_collate_fn(examples):
    """符合期望的collate函数：掩码所有<|im_end|>"""
    # 文本部分
    if HAS_TEMPLATE:
        texts = [processor.apply_chat_template(ex, tokenize=False) for ex in examples]
    else:
        texts = [ex[0]["content"][1]["text"] for ex in examples]

    # 图像部分
    image_inputs = [process_vision_info(ex)[0] for ex in examples]

    # 处理成张量
    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
    batch["input_ids"] = batch["input_ids"].to(torch.long)
    if "attention_mask" in batch:
        batch["attention_mask"] = batch["attention_mask"].to(torch.long)

    # 标签处理
    labels = batch["input_ids"].clone().to(torch.long)

    # 1. 屏蔽padding
    pad_id = processor.tokenizer.pad_token_id
    if pad_id is not None:
        labels[labels == pad_id] = -100

    # 2. 🔥 关键：屏蔽特殊token，包括所有<|im_end|>
    tokens_to_mask = [
        151655,  # <|image_pad|>
        151644,  # <|im_start|>
        151645,  # <|im_end|> ← 🔥 这就是你要的：掩码所有<|im_end|>
    ]
    for token_id in tokens_to_mask:
        labels[labels == token_id] = -100

    # 3. 只保留assistant的实际回答内容
    for i, (input_ids, text) in enumerate(zip(batch["input_ids"], texts)):
        try:
            # 找到assistant开始位置
            assistant_pattern = "<|im_start|>assistant\n"
            if assistant_pattern in text:
                pattern_tokens = processor.tokenizer.encode(assistant_pattern, add_special_tokens=False)

                for j in range(len(input_ids) - len(pattern_tokens) + 1):
                    if torch.equal(input_ids[j:j+len(pattern_tokens)], torch.tensor(pattern_tokens)):
                        # 从回答内容开始保留
                        answer_start = j + len(pattern_tokens)
                        labels[i, :answer_start] = -100

                        # 找到回答结束位置（第一个<|im_end|>）
                        im_end_positions = (input_ids[answer_start:] == 151645).nonzero(as_tuple=False).flatten()
                        if len(im_end_positions) > 0:
                            # 保留到<|im_end|>之前（不包括<|im_end|>）
                            first_im_end = answer_start + im_end_positions[0].item()
                            labels[i, first_im_end:] = -100
                        break
                else:
                    labels[i, :len(input_ids)//2] = -100
            else:
                labels[i, :len(input_ids)//2] = -100

        except Exception as e:
            print(f"处理样本{i}时出错: {e}")
            labels[i, :len(input_ids)//2] = -100

    batch["labels"] = labels
    return batch



# 验证修复效果
print("\n🧪 验证修复效果...")
test_batch = next(iter(torch.utils.data.DataLoader(train_ds, batch_size=1, collate_fn=fixed_collate_fn)))

input_ids = test_batch["input_ids"][0]
labels = test_batch["labels"][0]

# 检查所有<|im_end|>是否都被掩码
im_end_positions = (input_ids == 151645).nonzero(as_tuple=False).flatten()
print(f"发现 {len(im_end_positions)} 个 <|im_end|> token:")

all_masked = True
for pos in im_end_positions:
    label_value = labels[pos].item()
    status = "✅ 掩码" if label_value == -100 else "❌ 未掩码"
    print(f"  位置 {pos}: {status}")
    if label_value != -100:
        all_masked = False

if all_masked:
    print("\n🎉 修复成功！所有<|im_end|>都已被掩码")
    print("✅ 现在符合你的期望策略")
else:
    print("\n❌ 修复失败，仍有<|im_end|>未被掩码")

# 显示训练token统计
train_tokens = (labels != -100).sum().item()
total_tokens = len(labels)
print(f"\n📊 训练token统计: {train_tokens}/{total_tokens} ({train_tokens/total_tokens:.1%})")

print("\n💡 现在可以继续训练了:")
print("   - 所有<|im_end|>在input_ids中保留（模型能看到）")
print("   - 所有<|im_end|>在labels中掩码（不计算loss）")
print("   - 推理时<|im_end|>会作为停止符工作")

# 如果要重新开始训练
print("\n🚀 建议现在重新开始训练:")
if 'trainer_fixed' in locals():
    print("trainer_fixed.train()")
else:
    print("创建新的trainer并开始训练")

📝 Chat template 可用: True

🧪 验证修复效果...
发现 3 个 <|im_end|> token:
  位置 9: ✅ 掩码
  位置 773: ✅ 掩码
  位置 787: ✅ 掩码

🎉 修复成功！所有<|im_end|>都已被掩码
✅ 现在符合你的期望策略

📊 训练token统计: 9/789 (1.1%)

💡 现在可以继续训练了:
   - 所有<|im_end|>在input_ids中保留（模型能看到）
   - 所有<|im_end|>在labels中掩码（不计算loss）
   - 推理时<|im_end|>会作为停止符工作

🚀 建议现在重新开始训练:
创建新的trainer并开始训练


In [None]:
class MLflowCallback(TrainerCallback):
    def __init__(self):
        self.run_id = None
        self.best_eval_loss = float('inf')
        self.best_checkpoint = None
        self.first_eval_done = False  # 🆕 添加首次评估标志

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        """训练开始时启动 MLflow run"""
        # 启动新的 MLflow run
        run = mlflow.start_run(run_name=f"vblora_docvqa_{args.run_name}")
        self.run_id = run.info.run_id

        # 记录超参数
        mlflow.log_params({
            "model_name": model_id,
            "max_steps": args.max_steps,
            "per_device_train_batch_size": args.per_device_train_batch_size,
            "per_device_eval_batch_size": args.per_device_eval_batch_size,
            "gradient_accumulation_steps": args.gradient_accumulation_steps,
            "warmup_ratio": args.warmup_ratio,
            "learning_rate_logits": 3.2e-3,
            "learning_rate_vector": 1.6e-4,
            "use_bf16": use_bf16,
            "vblora_r": peft_config.r,
            "vblora_num_vectors": peft_config.num_vectors,
            "vblora_vector_length": peft_config.vector_length,
            "vblora_topk": peft_config.topk,
            "train_size": len(train_ds),
            "eval_size": len(eval_ds),
        })
        print(f"📊 MLflow Run 已启动: {self.run_id}")

    def on_log(self, args, state, control, logs=None, model=None, **kwargs):
        """每次日志记录时同步到 MLflow"""
        if logs and self.run_id:
            step = state.global_step

            # 🔄 修改：正常记录所有指标，包括 eval_loss
            for key, value in logs.items():
                if isinstance(value, (int, float)):
                    mlflow.log_metric(key, value, step=step)

            # 🔥 修复：只要进行过首次评估就记录 best_eval_loss
            if self.first_eval_done:
                mlflow.log_metric("best_eval_loss", self.best_eval_loss, step=step)

    def on_evaluate(self, args, state, control, logs=None, model=None, **kwargs):
        """评估完成后检查是否为最佳模型"""
        if logs and "eval_loss" in logs:
            current_eval_loss = logs["eval_loss"]

            # 🔥 修复：首次评估时，无论值多大都设为最佳
            if not self.first_eval_done:
                self.best_eval_loss = current_eval_loss
                self.first_eval_done = True
                checkpoint_dir = f"{args.output_dir}/checkpoint-{state.global_step}"
                if os.path.exists(checkpoint_dir):
                    self.best_checkpoint = checkpoint_dir

                print(f"🎯 首次评估! Step {state.global_step}, 设置初始 best_eval_loss = {current_eval_loss:.4f}")

                # 🆕 立即记录首次最佳值
                mlflow.log_metrics({
                    "best_eval_loss": self.best_eval_loss,
                    "best_step": state.global_step
                }, step=state.global_step)

            elif current_eval_loss < self.best_eval_loss:
                self.best_eval_loss = current_eval_loss
                # 记录当前checkpoint为最佳
                checkpoint_dir = f"{args.output_dir}/checkpoint-{state.global_step}"
                if os.path.exists(checkpoint_dir):
                    self.best_checkpoint = checkpoint_dir
                    print(f"🏆 新的最佳模型! Step {state.global_step}, Eval Loss: {current_eval_loss:.4f}")

                    # 🆕 立即记录最佳指标
                    mlflow.log_metrics({
                        "best_eval_loss": self.best_eval_loss,
                        "best_step": state.global_step
                    }, step=state.global_step)

    def on_train_end(self, args, state, control, model=None, **kwargs):
        """训练结束时保存最佳模型并结束 MLflow run"""
        if self.best_checkpoint and os.path.exists(self.best_checkpoint):
            # 复制最佳checkpoint到最终输出目录
            best_model_dir = f"{args.output_dir}/best_model"
            if os.path.exists(best_model_dir):
                shutil.rmtree(best_model_dir)
            shutil.copytree(self.best_checkpoint, best_model_dir)

            # 记录模型到 MLflow
            try:
                mlflow.pytorch.log_model(
                    pytorch_model=model,
                    artifact_path="best_model",
                    registered_model_name=f"vblora_docvqa_best"
                )
                print(f"✅ 最佳模型已保存到 MLflow: {best_model_dir}")
            except Exception as e:
                print(f"⚠️ MLflow 模型记录失败: {e}")

            # 🔥 修复：分别记录数值和字符串参数
            # 只记录数值型指标
            final_metrics = {
                "final_best_eval_loss": self.best_eval_loss,
                "total_steps": state.global_step,
                "training_completed_flag": 1.0,  # 用数值1.0表示完成
            }
            mlflow.log_metrics(final_metrics)

            # 单独记录字符串参数
            mlflow.log_param("best_checkpoint_path", self.best_checkpoint)
            mlflow.log_param("final_status", "completed")

            # 保存训练摘要
            summary = {
                "experiment_name": "vblora_docvqa",
                "model_id": model_id,
                "best_eval_loss": self.best_eval_loss,
                "total_steps": state.global_step,
                "best_model_path": best_model_dir,
                "run_id": self.run_id
            }

            summary_path = f"{args.output_dir}/training_summary.json"
            with open(summary_path, "w") as f:
                json.dump(summary, f, indent=2)
            mlflow.log_artifact(summary_path)

        # 结束 MLflow run
        mlflow.end_run()
        print(f"🎯 训练完成! MLflow Run: {self.run_id}")

In [None]:
try:
    while mlflow.active_run() is not None:
        mlflow.end_run()
        print("🧹 清理残留的 MLflow run")
except Exception as e:
    print(f"MLflow 清理: {e}")

In [None]:
# 关掉残留进程
!pkill -f "mlflow" || true

import os, subprocess, time
from google.colab import output

DB_PATH = "/content/drive/MyDrive/mlflow/mlflow.db"
ARTIFACTS_PATH = "/content/drive/MyDrive/mlflow/artifacts"
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
os.makedirs(ARTIFACTS_PATH, exist_ok=True)

# 用子进程列表方式启动，避免“—host”/空格解析问题
cmd = [
    "mlflow", "server",
    "--backend-store-uri", f"sqlite:///{DB_PATH}",
    "--default-artifact-root", ARTIFACTS_PATH,
    "--host", "127.0.0.1",          # 也可用 0.0.0.0；这里用 127.0.0.1 最稳
    "--port", "5000",
]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
time.sleep(3)

# 用 iframe 显示（官方建议）
output.serve_kernel_port_as_iframe(5000, height=720)



^C


<IPython.core.display.Javascript object>

In [None]:
# ---------- 6) 训练参数配置 ----------
from transformers import Trainer, TrainingArguments
# from torch.optim import AdamW   # ← 删掉或注释
from bitsandbytes.optim import Adam8bit, PagedAdamW8bit

from transformers.optimization import get_cosine_schedule_with_warmup

# 输出目录设置
root_out = "/content/drive/MyDrive/vblora_save"
exp_name = "vblora_medlr"
out_dir  = os.path.join(root_out, exp_name)
os.makedirs(out_dir, exist_ok=True)

# 🆕 修改训练参数 - 统一评估和保存频率
# 4. 修复 MLflow 报告设置
# ========================================
# 修复 TrainingArguments 中的 MLflow 设置
training_args_fixed = TrainingArguments(
    output_dir=out_dir,
    run_name=exp_name,
    max_steps=1000,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    logging_steps=20,

    # 🔥 关键修改：统一评估和保存频率，启用最佳模型保存
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=None,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    warmup_ratio=0.1,
    bf16=use_bf16,
    fp16=not use_bf16,
    tf32=True,
    remove_unused_columns=False,
    dataloader_num_workers=0,
    dataloader_persistent_workers=False,

    # 🔥 修复：MLflow 报告设置
    report_to=[],  # 禁用自动报告，使用自定义回调
)


In [None]:
print("🔧 创建修复后的 Trainer...")

# 清理之前可能存在的 MLflow runs
try:
    while mlflow.active_run() is not None:
        mlflow.end_run()
        print("🧹 清理残留的 MLflow run")
except Exception as e:
    print(f"MLflow 清理: {e}")

# 创建新的 MLflow 回调
mlflow_callback_fixed = MLflowCallback()

# 创建修复后的 trainer
trainer_fixed = Trainer(
    model=model,  # 使用已经挂载了 VB-LoRA 的模型
    args=training_args_fixed,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=fixed_collate_fn,  # ✅ 现在函数名一致了
    processing_class=processor,
    callbacks=[mlflow_callback_fixed],
)


🔧 创建修复后的 Trainer...


In [None]:
# ========================================
# VB-LoRA 双学习率优化器设置
# ========================================

def build_vblora_optimizer(m, lr_logits=3.2e-3, lr_vector=1.6e-4, wd=0.01, eps=1e-8):
    """为 VB-LoRA 构建双学习率优化器"""
    logits_params, bank_params, other = [], [], []

    for n, p in m.named_parameters():
        if not p.requires_grad:
            continue
        if "vblora_logits" in n:
            logits_params.append(p)
        elif "vector_bank" in n:
            bank_params.append(p)
        else:
            other.append(p)

    groups = []
    if logits_params:
        groups.append({"params": logits_params, "lr": lr_logits})
        print(f"📊 VB-LoRA logits 参数组: {len(logits_params)} 个参数, lr={lr_logits}")
    if bank_params:
        groups.append({"params": bank_params, "lr": lr_vector})
        print(f"📊 VB-LoRA bank 参数组: {len(bank_params)} 个参数, lr={lr_vector}")
    if other:
        groups.append({"params": other, "lr": lr_vector})
        print(f"📊 其他参数组: {len(other)} 个参数, lr={lr_vector}")

    return Adam8bit(groups, weight_decay=wd, eps=eps)
    # 如果显存更紧：return PagedAdamW8bit(groups, weight_decay=wd, eps=eps)

# ========================================
# 为 trainer_fixed 设置优化器和调度器
# ========================================
print("🔧 配置 VB-LoRA 双学习率优化器...")

# 设置优化器
trainer_fixed.optimizer = build_vblora_optimizer(
    trainer_fixed.model,
    lr_logits=3.2e-3,
    lr_vector=1.6e-4
)

# 设置学习率调度器
num_training_steps = training_args_fixed.max_steps
num_warmup_steps = int(num_training_steps * training_args_fixed.warmup_ratio)

trainer_fixed.lr_scheduler = get_cosine_schedule_with_warmup(
    trainer_fixed.optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

print(f"📈 学习率调度: {num_warmup_steps} warmup steps / {num_training_steps} total steps")
print("✅ 修复后的 Trainer 配置完成")

# ========================================
# 开始训练
# ========================================
print("\n🚀 开始训练...")
print("📊 MLflow 跟踪已启用，所有指标将自动记录")
print("🏆 最佳模型将根据 eval_loss 自动选择和保存")
print("💡 使用命令: trainer_fixed.train()")



🔧 配置 VB-LoRA 双学习率优化器...
📊 VB-LoRA logits 参数组: 392 个参数, lr=0.0032
📊 VB-LoRA bank 参数组: 1 个参数, lr=0.00016
📈 学习率调度: 100 warmup steps / 1000 total steps
✅ 修复后的 Trainer 配置完成

🚀 开始训练...
📊 MLflow 跟踪已启用，所有指标将自动记录
🏆 最佳模型将根据 eval_loss 自动选择和保存
💡 使用命令: trainer_fixed.train()


In [None]:
# train now
train_output = trainer_fixed.train()


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


📊 MLflow Run 已启动: 5b367f3254674e549ed1c2f22ff0ca08


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss,Validation Loss
100,0.8148,0.678492
200,0.4439,0.483465
300,0.4523,0.447745
400,0.4133,0.434761
500,0.4118,0.428733
600,0.3993,0.423604
700,0.3282,0.423337
800,0.3372,0.421534
900,0.4082,0.420803
1000,0.375,0.420302




🎯 训练完成! MLflow Run: 5b367f3254674e549ed1c2f22ff0ca08


In [None]:
import math, torch

def steps_for_epochs(train_ds, per_device_train_batch_size, grad_accum, epochs=4, world_size=None):
    if world_size is None:
        world_size = torch.cuda.device_count() if torch.cuda.is_available() else 1
    updates_per_epoch = math.ceil(len(train_ds) / (per_device_train_batch_size * world_size))
    steps_per_epoch   = math.ceil(updates_per_epoch / grad_accum)
    total_steps       = steps_per_epoch * epochs
    return steps_per_epoch, total_steps

# 按你现在的参数填：比如 train_bsz=2, grad_accum=4 或 8（二选一看你当前设定）
steps_ep, steps_4ep = steps_for_epochs(train_ds, per_device_train_batch_size=2, grad_accum=4, epochs=4)
print("每个 epoch 的 steps:", steps_ep)
print("4 个 epoch 的 max_steps:", steps_4ep)
print("想用 max_steps 等价于按 epoch 评估/保存，可设 eval_steps=save_steps=", steps_ep)


In [None]:
# ---------- 10) 简单推理测试 ----------
def chat_once(pil_image: Image.Image, question: str, max_new_tokens=128):
    """单次对话推理"""
    # 🔥 关键修改：推理时也限制图像尺寸
    pil_image = resize_image_if_needed(pil_image, max_pixels=768*768)

    conv = [{"role":"user","content":[{"type":"image","image":pil_image},{"type":"text","text":question}]}]
    text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[[pil_image]], return_tensors="pt").to(trainer.model.device)

    with torch.inference_mode():
        out = trainer.model.generate(**inputs, max_new_tokens=max_new_tokens)

    trimmed = [o[len(i):] for i,o in zip(inputs.input_ids, out)]
    return processor.batch_decode(trimmed, skip_special_tokens=True)[0]

In [None]:
# 🔧 修复的交互式 Chatbox

def create_fixed_interactive_chatbox():
    """创建修复后的交互式 chatbox"""

    # 1. 获取所有可用的 checkpoint
    checkpoints = []

    # 最佳模型
    best_model_path = f"{out_dir}/best_model"
    if os.path.exists(best_model_path):
        checkpoints.append(("🏆 Best Model", best_model_path))

    # 所有 checkpoint
    import glob
    checkpoint_pattern = f"{out_dir}/checkpoint-*"
    for ckpt_dir in sorted(glob.glob(checkpoint_pattern),
                          key=lambda x: int(x.split('-')[-1]) if x.split('-')[-1].isdigit() else 0):
        step = ckpt_dir.split('-')[-1]
        checkpoints.append((f"📍 Checkpoint-{step}", ckpt_dir))

    # 当前训练的模型
    checkpoints.append(("🔄 Current Model", "current"))

    if not checkpoints:
        checkpoints = [("🔄 Current Model", "current")]

    # 2. 创建 UI 组件
    import ipywidgets as widgets

    checkpoint_dropdown = widgets.Dropdown(
        options=checkpoints,
        value=checkpoints[0][1] if checkpoints else "current",
        description='选择模型:',
        style={'description_width': '80px'},
        layout=widgets.Layout(width='400px')
    )

    # 图片上传器
    image_upload = widgets.FileUpload(
        accept='image/*',
        multiple=False,
        description='上传图片',
        style={'description_width': '80px'},
        layout=widgets.Layout(width='300px')
    )

    # 问题输入框
    question_input = widgets.Textarea(
        placeholder='请输入您的问题...',
        description='问题:',
        style={'description_width': '80px'},
        layout=widgets.Layout(width='500px', height='80px')
    )

    # 推理按钮
    inference_button = widgets.Button(
        description='🚀 开始推理',
        button_style='primary',
        layout=widgets.Layout(width='120px')
    )

    # 输出区域
    output_area = widgets.Output()

    # 存储当前加载的模型
    current_model_info = {"model": None, "processor": None, "path": None}

    def load_checkpoint(checkpoint_path):
        """加载指定的 checkpoint"""
        try:
            if checkpoint_path == "current":
                # 使用当前训练的模型
                current_trainer = trainer_fixed if 'trainer_fixed' in globals() else trainer
                return current_trainer.model, processor, "current_training_model"

            # 加载保存的 checkpoint
            from peft import PeftModel

            # 重新加载基础模型
            base_model = Qwen2VLForConditionalGeneration.from_pretrained(
                model_id,
                device_map="auto",
                torch_dtype=compute_dtype,
                quantization_config=bnb_config
            )

            # 加载 PEFT 适配器
            peft_model = PeftModel.from_pretrained(base_model, checkpoint_path)
            peft_model.eval()

            return peft_model, processor, checkpoint_path

        except Exception as e:
            print(f"❌ 加载模型失败: {e}")
            return None, None, None

    def perform_inference(model, processor, image, question):
        """执行推理"""
        try:
            # 应用图像尺寸限制
            image = resize_image_if_needed(image, max_pixels=768*768)

            conv = [{"role":"user","content":[
                {"type":"image","image":image},
                {"type":"text","text":question}
            ]}]

            text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
            inputs = processor(text=[text], images=[[image]], return_tensors="pt").to(model.device)

            with torch.inference_mode():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=256,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=processor.tokenizer.eos_token_id
                )

            # 解码回答
            trimmed = [o[len(i):] for i, o in zip(inputs.input_ids, outputs)]
            answer = processor.batch_decode(trimmed, skip_special_tokens=True)[0]

            return answer.strip()

        except Exception as e:
            return f"❌ 推理失败: {e}"

    def on_inference_click(b):
        """推理按钮点击事件"""
        with output_area:
            from IPython.display import clear_output
            clear_output()

            # 检查输入
            if not image_upload.value:
                print("❌ 请先上传图片！")
                return

            if not question_input.value.strip():
                print("❌ 请输入问题！")
                return

            # 获取选择的 checkpoint
            selected_checkpoint = checkpoint_dropdown.value
            checkpoint_name = [name for name, path in checkpoints if path == selected_checkpoint][0]

            print(f"🔄 正在使用 {checkpoint_name} 进行推理...")

            # 加载模型（如果需要）
            if (current_model_info["path"] != selected_checkpoint or
                current_model_info["model"] is None):

                print("📦 正在加载模型...")
                model, proc, path = load_checkpoint(selected_checkpoint)

                if model is None:
                    print("❌ 模型加载失败！")
                    return

                current_model_info.update({
                    "model": model,
                    "processor": proc,
                    "path": path
                })
                print("✅ 模型加载完成")

            # 处理上传的图片
            try:
                uploaded_file = list(image_upload.value.values())[0]

                # 🔧 修复：正确处理上传的文件内容
                file_content = uploaded_file['content']

                # 检查文件内容
                if len(file_content) == 0:
                    print("❌ 上传的文件为空")
                    return

                # 使用BytesIO处理文件内容
                from io import BytesIO
                image_stream = BytesIO(file_content)

                # 打开并转换图片
                image = Image.open(image_stream).convert('RGB')
                image = resize_image_if_needed(image, max_pixels=768*768)
                print(f"✅ 图片加载完成，尺寸: {image.size}")

            except Exception as e:
                print(f"❌ 图片处理失败: {e}")
                print(f"   错误详情: {type(e).__name__}")

                # 尝试替代方案：使用验证集的一个样本
                try:
                    print("🔄 使用验证集样本代替...")
                    sample = eval_ds[0]
                    image = sample[0]["content"][0]["image"]
                    image = resize_image_if_needed(image, max_pixels=768*768)
                    print(f"✅ 使用验证集样本，尺寸: {image.size}")
                except Exception as e2:
                    print(f"❌ 备用方案也失败: {e2}")
                    return

            # 执行推理
            question = question_input.value.strip()
            print(f"❓ 问题: {question}")
            print("🤔 正在思考...")

            answer = perform_inference(
                current_model_info["model"],
                current_model_info["processor"],
                image,
                question
            )

            print(f"💡 回答: {answer}")

            # 记录到 MLflow
            try:
                import mlflow
                with mlflow.start_run(run_name=f"interactive_inference_{exp_name}"):
                    mlflow.log_param("checkpoint_used", checkpoint_name)
                    mlflow.log_param("question", question)
                    mlflow.log_param("answer", answer)
                    mlflow.log_param("image_size", f"{image.size[0]}x{image.size[1]}")
                print("📊 推理结果已记录到 MLflow")
            except Exception as e:
                print(f"⚠️ MLflow 记录失败: {e}")

    # 绑定事件
    inference_button.on_click(on_inference_click)

    # 布局
    ui = widgets.VBox([
        widgets.HTML("<h3>🎯 交互式模型测试 Chatbox</h3>"),
        widgets.HBox([
            widgets.VBox([
                checkpoint_dropdown,
                image_upload,
            ]),
            widgets.VBox([
                question_input,
                inference_button,
            ])
        ]),
        widgets.HTML("<hr>"),
        widgets.HTML("<h4>📤 推理结果:</h4>"),
        output_area
    ])

    return ui

# 🔧 简化版本的测试（如果ipywidgets不可用）
def simple_inference_test():
    """简化版本的推理测试"""
    print("🧪 使用训练集的一个样本进行推理测试...")

    # 获取一个测试样本
    sample = eval_ds[0]
    img = sample[0]["content"][0]["image"]
    q = sample[0]["content"][1]["text"]

    print(f"❓ 问题: {q}")

    # 应用图像尺寸限制
    img = resize_image_if_needed(img, max_pixels=768*768)
    print(f"🖼️ 图像尺寸: {img.size}")

    # 使用当前模型推理
    current_trainer = trainer_fixed if 'trainer_fixed' in globals() else trainer

    conv = [{"role":"user","content":[{"type":"image","image":img},{"type":"text","text":q}]}]
    text = processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[[img]], return_tensors="pt").to(current_trainer.model.device)

    with torch.inference_mode():
        out = current_trainer.model.generate(**inputs, max_new_tokens=128)

    trimmed = [o[len(i):] for i,o in zip(inputs.input_ids, out)]
    answer = processor.batch_decode(trimmed, skip_special_tokens=True)[0]

    print(f"💡 回答: {answer}")

    # 记录到 MLflow
    try:
        import mlflow
        with mlflow.start_run(run_name=f"simple_inference_test_{exp_name}"):
            mlflow.log_param("sample_question", q)
            mlflow.log_param("sample_answer", answer)
            mlflow.log_param("model_source", "current_training_model")
            mlflow.log_param("image_size", f"{img.size[0]}x{img.size[1]}")
    except Exception as e:
        print(f"⚠️ MLflow 记录失败: {e}")

# 尝试创建交互式界面，如果失败则使用简化版本
print("🎯 创建交互式测试环境...")

try:
    import ipywidgets as widgets
    from IPython.display import display

    chatbox_ui = create_fixed_interactive_chatbox()
    display(chatbox_ui)
    print("✅ 交互式 Chatbox 已启动！")
    print("💡 使用说明:")
    print("   1. 选择要测试的 checkpoint")
    print("   2. 上传测试图片")
    print("   3. 输入问题")
    print("   4. 点击 '🚀 开始推理' 按钮")
    print("   5. 查看结果并在 MLflow 中追踪")

except (ImportError, NameError) as e:
    print(f"⚠️ 交互式界面不可用: {e}")
    print("🔄 使用简化版本测试...")
    simple_inference_test()

print(f"\n🎉 所有步骤完成！")
print(f"📊 实验结果已记录到 MLflow")
print(f"🏆 模型保存在: {out_dir}")
if 'trainer_fixed' in globals():
    print("🚀 可以使用 trainer_fixed.train() 开始训练")
else:
    print("🚀 可以使用 trainer.train() 开始训练")