# Finetune Llama-3 with LLaMA Factory

Please use a **free** Tesla T4 Colab GPU to run this!

Project homepage: https://github.com/hiyouga/LLaMA-Factory

## Install Dependencies

In [None]:
#!rm -f /content/drive/MyDrive/mlflow/mlflow.db
#清空实验结果clean all result

In [None]:
# 0. 【手工】在 Colab 菜单点：Runtime → Restart runtime
#    （确保彻底重启，清掉所有旧的进程和线程池）

# 1. 切到 /content，删除旧的 LLaMA‑Factory 源码
%cd /content
!rm -rf LLaMA-Factory

# 2. 拉最新 main 分支
!git clone https://github.com/hiyouga/LLaMA-Factory.git

# 3. 进到源码目录
%cd LLaMA-Factory




/content
Cloning into 'LLaMA-Factory'...
remote: Enumerating objects: 24763, done.[K
remote: Counting objects: 100% (121/121), done.[K
remote: Compressing objects: 100% (93/93), done.[K
remote: Total 24763 (delta 94), reused 28 (delta 28), pack-reused 24642 (from 3)[K
Receiving objects: 100% (24763/24763), 54.79 MiB | 39.15 MiB/s, done.
Resolving deltas: 100% (17815/17815), done.
/content/LLaMA-Factory


In [None]:

!pip install -e .[torch,bitsandbytes,mlflow]

# 6. 再次彻底重启 Runtime（一定要重启一次，让新的包和补丁生效）
#    Colab 菜单：Runtime → Restart runtime

Obtaining file:///content/LLaMA-Factory
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: llamafactory
  Building editable for llamafactory (pyproject.toml) ... [?25l[?25hdone
  Created wheel for llamafactory: filename=llamafactory-0.9.4.dev0-0.editable-py3-none-any.whl size=28206 sha256=b1077b734a5f769d3e334108d62bb282e95f798774f9236dcf8b30523e1014d2
  Stored in directory: /tmp/pip-ephem-wheel-cache-jtp23asw/wheels/68/8b/5e/52f9888e6a91a2651260d603137c052b925af896da6e32a3f7
Successfully built llamafactory
Installing collected packages: llamafactory
  Attempting uninstall: llamafactory
    Found existing installation: llamafactory 0.9.4.dev0
    Uninstalling llamafactory-0.9.4.dev0:
      Successfully uninstalled llamafactory-0.9.4.dev0
Succe

In [None]:
#替换logging
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
!cp "/content/drive/MyDrive/logging.py" /content/LLaMA-Factory/src/llamafactory/extras/logging.py


Mounted at /content/drive


In [None]:
!pip uninstall -y mlflow
!pip install -U "mlflow==3.2.0"


Found existing installation: mlflow 3.2.0
Uninstalling mlflow-3.2.0:
  Successfully uninstalled mlflow-3.2.0
Collecting mlflow==3.2.0
  Using cached mlflow-3.2.0-py3-none-any.whl.metadata (29 kB)
Using cached mlflow-3.2.0-py3-none-any.whl (25.8 MB)
Installing collected packages: mlflow
Successfully installed mlflow-3.2.0


In [None]:
# ========================================
# MLflow Ultimate Setup - 终极一体化启动代码
# Ultimate MLflow setup code for unified experiment management
# Copy to Colab, run once per session - 每次Colab重启运行一次即可
# All experiments in one place - 所有实验统一管理
# ========================================

import os
import subprocess
import time
import shutil
from pathlib import Path

print("🚀 MLflow Ultimate Setup - Starting...")
print("="*60)

# ===== 1. Environment Setup - 环境准备 =====
# Mount Google Drive - 挂载Google Drive
from google.colab import drive, output
if not os.path.exists('/content/drive/MyDrive'):
    drive.mount('/content/drive')
    print("✅ Google Drive mounted successfully")
else:
    print("✅ Google Drive already mounted")

# Install MLflow if needed - 安装MLflow（如果需要）
try:
    import mlflow
    from mlflow.tracking import MlflowClient
    print(f"✅ MLflow installed (v{mlflow.__version__})")
except ImportError:
    print("📦 Installing MLflow...")
    os.system("pip install -q mlflow")
    import mlflow
    from mlflow.tracking import MlflowClient
    print("✅ MLflow installation completed")

# ===== 2. Unified Path Configuration - 统一路径配置 =====
# Use your existing path structure - 使用你现有的路径结构
MLFLOW_HOME = "/content/drive/MyDrive/mlflow"
DB_PATH = f"{MLFLOW_HOME}/mlflow (3).db"
ARTIFACTS_PATH = f"{MLFLOW_HOME}/artifacts"

# Set environment variables - 设置环境变量
os.environ["MLFLOW_TRACKING_URI"] = f"sqlite:///{DB_PATH}"
os.environ["MLFLOW_ARTIFACT_ROOT"] = ARTIFACTS_PATH

print(f"📁 MLflow data directory: {MLFLOW_HOME}")
print(f"📄 Database path: {DB_PATH}")
print(f"📦 Artifacts path: {ARTIFACTS_PATH}")

# Create directories if not exist - 创建目录（如果不存在）
Path(MLFLOW_HOME).mkdir(parents=True, exist_ok=True)
Path(ARTIFACTS_PATH).mkdir(parents=True, exist_ok=True)

# ===== 3. Clean Up Old Processes - 清理旧进程 =====
print("🔄 Cleaning up old MLflow processes...")
os.system("pkill -f 'mlflow' || true")
time.sleep(3)

# ===== 4. Database Setup - 数据库设置 =====
def setup_database():
    """Initialize MLflow database - 初始化MLflow数据库"""
    print("🔧 Setting up MLflow database...")

    # Initialize database if not exists - 如果数据库不存在则初始化
    if not os.path.exists(DB_PATH):
        print("🆕 First time setup - initializing database...")
        result = subprocess.run([
            "mlflow", "db", "upgrade", f"sqlite:///{DB_PATH}"
        ], capture_output=True, text=True)

        if result.returncode != 0:
            print(f"❌ Database initialization failed: {result.stderr}")
            return False
        print("✅ Database initialized successfully")
    else:
        print("✅ Using existing database")
        # Upgrade database schema if needed - 如果需要则升级数据库架构
        subprocess.run([
            "mlflow", "db", "upgrade", f"sqlite:///{DB_PATH}"
        ], capture_output=True, text=True)

    return True

# ===== 5. MLflow Server Startup - MLflow服务器启动 =====
def start_mlflow_server():
    """Start MLflow tracking server - 启动MLflow跟踪服务器"""
    print("🚀 Starting MLflow server...")

    # Start server process - 启动服务器进程
    server_proc = subprocess.Popen([
        "mlflow", "server",
        "--backend-store-uri", f"sqlite:///{DB_PATH}",
        "--default-artifact-root", ARTIFACTS_PATH,
        "--host", "0.0.0.0",
        "--port", "5000"
    ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

    # Wait for server to start - 等待服务器启动
    print("⏳ Waiting for server to start...")
    time.sleep(8)

    return server_proc

# ===== 6. Experiment Statistics - 实验统计 =====
def show_experiment_stats():
    """Display current experiment statistics - 显示当前实验统计"""
    try:
        # Connect to database - 连接数据库
        mlflow.set_tracking_uri(f"sqlite:///{DB_PATH}")
        client = MlflowClient()

        # Get all experiments - 获取所有实验
        experiments = client.search_experiments()
        total_runs = 0

        print("\n📊 Current Experiment Statistics:")
        print("-" * 60)

        for exp in experiments:
            runs = client.search_runs(exp.experiment_id)
            total_runs += len(runs)
            status = "🟢" if len(runs) > 0 else "⚪"
            print(f"{status} {exp.name}: {len(runs)} runs")

            # Show latest run info - 显示最新run信息
            if runs:
                latest_run = runs[0]  # Most recent run - 最新的run
                run_name = latest_run.data.tags.get('mlflow.runName', 'unnamed')
                print(f"    └─ Latest: {run_name} ({latest_run.info.run_id[:8]})")

        print("-" * 60)
        print(f"📈 Total: {len(experiments)} experiments, {total_runs} runs")

        if total_runs == 0:
            print("💡 Tip: Start your first experiment!")
        else:
            print("🎉 All historical experiments loaded successfully!")

    except Exception as e:
        print(f"⚠️  Failed to get statistics: {e}")

# ===== 7. Client Setup - 客户端设置 =====
def setup_mlflow_client():
    """Setup MLflow client for immediate use - 设置MLflow客户端供立即使用"""
    # Set tracking URI - 设置跟踪URI
    mlflow.set_tracking_uri(f"sqlite:///{DB_PATH}")

    # Set default experiment - 设置默认实验
    mlflow.set_experiment("lora")

    print("✅ MLflow client configured")
    print(f"✅ Global Tracking URI: {mlflow.get_tracking_uri()}")

# ===== 8. Main Execution - 主执行流程 =====
def main():
    """Main setup process - 主要设置流程"""

    # Setup database - 设置数据库
    if not setup_database():
        print("❌ Database setup failed, stopping execution")
        return None

    # Setup MLflow client - 设置MLflow客户端
    setup_mlflow_client()

    # Start server - 启动服务器
    server_proc = start_mlflow_server()

    # Embed UI in Colab - 在Colab中嵌入UI
    print("🎯 Embedding MLflow UI in Colab...")
    output.serve_kernel_port_as_window(5000)

    # Show statistics - 显示统计信息
    show_experiment_stats()

    # Display usage guide - 显示使用指南
    print("\n" + "="*60)
    print("🎉 MLflow Setup Complete!")
    print("="*60)
    print("🌐 Web UI: Click the window above to view all experiments")
    print(f"📍 Tracking URI: sqlite:///{DB_PATH}")
    print(f"📁 Data Location: {MLFLOW_HOME}")

    print("\n💡 Usage Instructions:")
    print("1. All new experiments automatically save to database")
    print("2. Run this code block once per Colab session")
    print("3. Data permanently saved in Google Drive")

    print("\n🔧 Use in training code:")
    print("   import mlflow")
    print(f'   mlflow.set_tracking_uri("sqlite:///{DB_PATH}")')
    print("   # Then use mlflow.log_param(), mlflow.log_metric(), etc.")

    print("\n📝 LLaMA Factory configuration:")
    print('   "report_to": "mlflow",')
    print(f'   "mlflow_tracking_uri": "sqlite:///{DB_PATH}"')

    return server_proc

# ===== 9. Utility Functions - 实用函数 =====
def quick_test():
    """Quick test of MLflow functionality - MLflow功能快速测试"""
    print("\n🧪 Quick MLflow functionality test...")

    mlflow.set_tracking_uri(f"sqlite:///{DB_PATH}")

    with mlflow.start_run(run_name="connectivity_test"):
        mlflow.log_param("test_param", "hello_world")
        mlflow.log_metric("test_metric", 0.95)
        mlflow.set_tag("source", "quick_test")

    print("✅ Test completed! Refresh UI to see new test run")

def restart_server():
    """Restart MLflow server - 重启MLflow服务器"""
    print("🔄 Restarting MLflow server...")
    os.system("pkill -f 'mlflow' || true")
    time.sleep(2)
    main()

def show_config():
    """Display current configuration - 显示当前配置"""
    print("\n⚙️  Current Configuration:")
    print(f"📍 Tracking URI: sqlite:///{DB_PATH}")
    print(f"📁 Artifacts Root: {ARTIFACTS_PATH}")
    print(f"🌐 Web UI: http://localhost:5000")

# ===== 10. Execute Main Setup - 执行主要设置 =====
if __name__ == "__main__":
    server_process = main()

    if server_process:
        print(f"\n✅ MLflow server process ID: {server_process.pid}")

        # Display utility commands - 显示实用命令
        print("\n" + "🔥" * 40)
        print("💡 Utility Commands:")
        print("• quick_test()     - Test MLflow functionality")
        print("• restart_server() - Restart MLflow server")
        print("• show_config()    - Display configuration")
        print("🔥" * 40)
        print("\n🎯 Ready for experiments! All data unified in one location.")
    else:
        print("❌ Setup failed. Please check error messages above.")

🚀 MLflow Ultimate Setup - Starting...
✅ Google Drive already mounted
✅ MLflow installed (v3.2.0)
📁 MLflow data directory: /content/drive/MyDrive/mlflow
📄 Database path: /content/drive/MyDrive/mlflow/mlflow (3).db
📦 Artifacts path: /content/drive/MyDrive/mlflow/artifacts
🔄 Cleaning up old MLflow processes...
🔧 Setting up MLflow database...
✅ Using existing database


2025/09/12 04:09:07 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/09/12 04:09:07 INFO mlflow.store.db.utils: Updating database tables
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.


✅ MLflow client configured
✅ Global Tracking URI: sqlite:////content/drive/MyDrive/mlflow/mlflow (3).db
🚀 Starting MLflow server...
⏳ Waiting for server to start...
🎯 Embedding MLflow UI in Colab...
Try `serve_kernel_port_as_iframe` instead. [0m


<IPython.core.display.Javascript object>


📊 Current Experiment Statistics:
------------------------------------------------------------
🟢 ultrachat_vblora_experiment: 14 runs
    └─ Latest: vblora_stingning_ultra4 (05a9b9c4)
🟢 vblora_docvqa: 4 runs
    └─ Latest: comprehensive_training (19661f1a)
🟢 vblora_docvqa_old: 29 runs
    └─ Latest: vblora_last (22f1d571)
🟢 lora: 61 runs
    └─ Latest: overfittingr8 (6091c35b)
------------------------------------------------------------
📈 Total: 4 experiments, 108 runs
🎉 All historical experiments loaded successfully!

🎉 MLflow Setup Complete!
🌐 Web UI: Click the window above to view all experiments
📍 Tracking URI: sqlite:////content/drive/MyDrive/mlflow/mlflow (3).db
📁 Data Location: /content/drive/MyDrive/mlflow

💡 Usage Instructions:
1. All new experiments automatically save to database
2. Run this code block once per Colab session
3. Data permanently saved in Google Drive

🔧 Use in training code:
   import mlflow
   mlflow.set_tracking_uri("sqlite:////content/drive/MyDrive/mlflow/m

### Check GPU environment

In [None]:
import torch
try:
  assert torch.cuda.is_available() is True
except AssertionError:
  print("Please set up a GPU before using LLaMA Factory: https://medium.com/mlearning-ai/training-yolov4-on-google-colab-316f8fff99c6")

## Update Identity Dataset

In [None]:
import json

%cd /content/LLaMA-Factory/

NAME = "Llama-3"
AUTHOR = "LLaMA Factory"

with open("data/identity.json", "r", encoding="utf-8") as f:
  dataset = json.load(f)

for sample in dataset:
  sample["output"] = sample["output"].replace("{{"+ "name" + "}}", NAME).replace("{{"+ "author" + "}}", AUTHOR)

with open("data/identity.json", "w", encoding="utf-8") as f:
  json.dump(dataset, f, indent=2, ensure_ascii=False)

/content/LLaMA-Factory


In [None]:
import json
import os
from datasets import load_dataset
import base64
from PIL import Image
import io

# 切换到LLaMA Factory目录
%cd /content/LLaMA-Factory/

# 创建数据目录
os.makedirs("data/docvqa", exist_ok=True)

print("开始下载DocVQA validation数据集...")

# 下载DocVQA验证集
# 指定 config_name="DocVQA"
dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")


print(f"数据集加载完成，共有 {len(dataset)} 个样本")

# 准备转换为LLaMA Factory支持的格式
def process_dataset(dataset):
    """将DocVQA数据集转换为LLaMA Factory的sharegpt格式"""
    processed_data = []

    for idx, sample in enumerate(dataset):
        # 获取图像
        image = sample['image']

        # 保存图像到本地
        image_path = f"data/docvqa/image_{idx}.png"
        image.save(image_path)

        # 构建对话格式
        conversation = {
            "conversations": [
                {
                    "from": "human",
                    "value": f"<image>{sample['question']}"
                },
                {
                    "from": "gpt",
                    "value": ", ".join(sample['answers'])  # DocVQA的答案是列表格式
                }
            ],
            "images": [image_path]
        }

        processed_data.append(conversation)

        # 每处理100个样本显示进度
        if (idx + 1) % 100 == 0:
            print(f"已处理 {idx + 1} 个样本...")

    return processed_data

# 处理数据集
print("开始处理数据集...")
processed_data = process_dataset(dataset)

# 保存处理后的数据
output_file = "data/docvqa/docvqa_validation.json"
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(processed_data, f, ensure_ascii=False, indent=2)

print(f"数据集处理完成，保存到 {output_file}")

# 更新dataset_info.json
dataset_info_path = "data/dataset_info.json"

# 读取现有的dataset_info.json
if os.path.exists(dataset_info_path):
    with open(dataset_info_path, 'r', encoding='utf-8') as f:
        dataset_info = json.load(f)
else:
    dataset_info = {}

# 添加DocVQA数据集信息
dataset_info["docvqa_validation"] = {
    "file_name": "docvqa/docvqa_validation.json",
    "formatting": "sharegpt",
    "columns": {
        "messages": "conversations",
        "images": "images"
    }
}

# 保存更新后的dataset_info.json
with open(dataset_info_path, 'w', encoding='utf-8') as f:
    json.dump(dataset_info, f, ensure_ascii=False, indent=2)

print("dataset_info.json 已更新")
print(f"数据集名称: docvqa_validation")
print(f"总样本数: {len(processed_data)}")
print("现在可以在LLaMA Factory UI中使用 'docvqa_validation' 作为数据集名称进行训练")

/content/LLaMA-Factory
开始下载DocVQA validation数据集...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

validation-00000-of-00006.parquet:   0%|          | 0.00/115M [00:00<?, ?B/s]

validation-00001-of-00006.parquet:   0%|          | 0.00/160M [00:00<?, ?B/s]

validation-00002-of-00006.parquet:   0%|          | 0.00/184M [00:00<?, ?B/s]

validation-00003-of-00006.parquet:   0%|          | 0.00/178M [00:00<?, ?B/s]

validation-00004-of-00006.parquet:   0%|          | 0.00/206M [00:00<?, ?B/s]

validation-00005-of-00006.parquet:   0%|          | 0.00/212M [00:00<?, ?B/s]

test-00000-of-00006.parquet:   0%|          | 0.00/139M [00:00<?, ?B/s]

test-00001-of-00006.parquet:   0%|          | 0.00/161M [00:00<?, ?B/s]

test-00002-of-00006.parquet:   0%|          | 0.00/179M [00:00<?, ?B/s]

test-00003-of-00006.parquet:   0%|          | 0.00/189M [00:00<?, ?B/s]

test-00004-of-00006.parquet:   0%|          | 0.00/211M [00:00<?, ?B/s]

test-00005-of-00006.parquet:   0%|          | 0.00/228M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/5349 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5188 [00:00<?, ? examples/s]

数据集加载完成，共有 5349 个样本
开始处理数据集...
已处理 100 个样本...
已处理 200 个样本...
已处理 300 个样本...
已处理 400 个样本...
已处理 500 个样本...
已处理 600 个样本...
已处理 700 个样本...
已处理 800 个样本...
已处理 900 个样本...
已处理 1000 个样本...
已处理 1100 个样本...
已处理 1200 个样本...
已处理 1300 个样本...
已处理 1400 个样本...
已处理 1500 个样本...
已处理 1600 个样本...
已处理 1700 个样本...
已处理 1800 个样本...
已处理 1900 个样本...
已处理 2000 个样本...
已处理 2100 个样本...
已处理 2200 个样本...
已处理 2300 个样本...
已处理 2400 个样本...
已处理 2500 个样本...
已处理 2600 个样本...
已处理 2700 个样本...
已处理 2800 个样本...
已处理 2900 个样本...
已处理 3000 个样本...
已处理 3100 个样本...
已处理 3200 个样本...
已处理 3300 个样本...
已处理 3400 个样本...
已处理 3500 个样本...
已处理 3600 个样本...
已处理 3700 个样本...
已处理 3800 个样本...
已处理 3900 个样本...
已处理 4000 个样本...
已处理 4100 个样本...
已处理 4200 个样本...
已处理 4300 个样本...
已处理 4400 个样本...
已处理 4500 个样本...
已处理 4600 个样本...
已处理 4700 个样本...
已处理 4800 个样本...
已处理 4900 个样本...
已处理 5000 个样本...
已处理 5100 个样本...
已处理 5200 个样本...
已处理 5300 个样本...
数据集处理完成，保存到 data/docvqa/docvqa_validation.json
dataset_info.json 已更新
数据集名称: docvqa_validation
总样本数: 5349
现在可以在LLaMA Factory UI中使用 

In [None]:
import json
import os
from datasets import load_dataset
from PIL import Image
from google.colab import drive
import shutil

# 挂载Google Drive
if not os.path.exists('/content/drive/MyDrive'):
    drive.mount('/content/drive')
    print("✅ Google Drive挂载成功")
else:
    print("✅ Google Drive已经挂载")

# 定义持久化路径
DRIVE_DATA_DIR = "/content/drive/MyDrive/llama_factory_data/docvqa"
LOCAL_DATA_DIR = "/content/LLaMA-Factory/data/docvqa"

def download_and_process_docvqa():
    """下载并处理DocVQA数据集，保存到Google Drive"""
    print("🔄 开始下载和处理DocVQA数据集...")

    # 创建Drive目录
    os.makedirs(DRIVE_DATA_DIR, exist_ok=True)
    os.makedirs(f"{DRIVE_DATA_DIR}/images", exist_ok=True)

    # 下载数据集
    print("📥 下载DocVQA validation数据集...")
    dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
    print(f"数据集加载完成，共有 {len(dataset)} 个样本")

    # 处理数据集
    def process_dataset(dataset):
        """将DocVQA数据集转换为LLaMA Factory的sharegpt格式"""
        processed_data = []

        for idx, sample in enumerate(dataset):
            # 获取图像
            image = sample['image']

            # 保存图像到Drive（使用相对路径）
            image_filename = f"image_{idx}.png"
            image_path_drive = f"{DRIVE_DATA_DIR}/images/{image_filename}"
            image_path_relative = f"data/docvqa/images/{image_filename}"  # LLaMA Factory中的相对路径

            # 保存图像
            image.save(image_path_drive)

            # 构建对话格式
            conversation = {
                "conversations": [
                    {
                        "from": "human",
                        "value": f"<image>{sample['question']}"
                    },
                    {
                        "from": "gpt",
                        "value": ", ".join(sample['answers'])  # DocVQA的答案是列表格式
                    }
                ],
                "images": [image_path_relative]  # 使用相对路径
            }

            processed_data.append(conversation)

            # 每处理100个样本显示进度
            if (idx + 1) % 100 == 0:
                print(f"已处理 {idx + 1} 个样本...")

        return processed_data

    print("🔄 开始处理数据集...")
    processed_data = process_dataset(dataset)

    # 保存处理后的数据到Drive
    output_file = f"{DRIVE_DATA_DIR}/docvqa_validation.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(processed_data, f, ensure_ascii=False, indent=2)

    print(f"✅ 数据集处理完成，保存到 {output_file}")
    print(f"📊 总样本数: {len(processed_data)}")
    print(f"🖼️ 图像保存位置: {DRIVE_DATA_DIR}/images/")

    # 保存元信息
    meta_info = {
        "total_samples": len(processed_data),
        "dataset_name": "docvqa_validation",
        "source": "lmms-lab/DocVQA",
        "split": "validation",
        "processed_date": str(pd.Timestamp.now()) if 'pd' in globals() else "unknown",
        "format": "sharegpt"
    }

    with open(f"{DRIVE_DATA_DIR}/meta_info.json", 'w', encoding='utf-8') as f:
        json.dump(meta_info, f, ensure_ascii=False, indent=2)

    return processed_data

def load_docvqa_from_drive():
    """从Google Drive加载已处理的DocVQA数据集"""
    print("📂 从Google Drive加载DocVQA数据集...")

    # 检查文件是否存在
    json_file = f"{DRIVE_DATA_DIR}/docvqa_validation.json"
    images_dir = f"{DRIVE_DATA_DIR}/images"

    if not os.path.exists(json_file):
        print("❌ 数据文件不存在，需要先处理数据集")
        return None

    if not os.path.exists(images_dir):
        print("❌ 图像目录不存在，需要先处理数据集")
        return None

    # 读取数据
    with open(json_file, 'r', encoding='utf-8') as f:
        processed_data = json.load(f)

    # 读取元信息
    meta_file = f"{DRIVE_DATA_DIR}/meta_info.json"
    if os.path.exists(meta_file):
        with open(meta_file, 'r', encoding='utf-8') as f:
            meta_info = json.load(f)
        print(f"📊 数据集信息: {meta_info}")

    print(f"✅ 成功加载 {len(processed_data)} 个样本")
    return processed_data

def setup_docvqa_for_llamafactory():
    """为LLaMA Factory设置DocVQA数据集"""
    print("🔧 为LLaMA Factory设置DocVQA数据集...")

    # 切换到LLaMA Factory目录
    os.chdir('/content/LLaMA-Factory/')

    # 创建本地数据目录
    os.makedirs(LOCAL_DATA_DIR, exist_ok=True)
    os.makedirs(f"{LOCAL_DATA_DIR}/images", exist_ok=True)

    # 从Drive复制数据到本地
    print("📋 复制数据文件...")
    shutil.copy2(f"{DRIVE_DATA_DIR}/docvqa_validation.json", f"{LOCAL_DATA_DIR}/docvqa_validation.json")

    # 复制图像（创建软链接更高效）
    print("🔗 创建图像软链接...")
    if os.path.exists(f"{LOCAL_DATA_DIR}/images"):
        shutil.rmtree(f"{LOCAL_DATA_DIR}/images")
    os.symlink(f"{DRIVE_DATA_DIR}/images", f"{LOCAL_DATA_DIR}/images")

    # 更新dataset_info.json
    dataset_info_path = "data/dataset_info.json"

    # 读取现有的dataset_info.json
    if os.path.exists(dataset_info_path):
        with open(dataset_info_path, 'r', encoding='utf-8') as f:
            dataset_info = json.load(f)
    else:
        dataset_info = {}

    # 添加DocVQA数据集信息
    dataset_info["docvqa_validation"] = {
        "file_name": "docvqa/docvqa_validation.json",
        "formatting": "sharegpt",
        "columns": {
            "messages": "conversations",
            "images": "images"
        }
    }

    # 保存更新后的dataset_info.json
    with open(dataset_info_path, 'w', encoding='utf-8') as f:
        json.dump(dataset_info, f, ensure_ascii=False, indent=2)

    print("✅ dataset_info.json 已更新")
    print("✅ DocVQA数据集设置完成")
    print("🚀 现在可以在LLaMA Factory UI中使用 'docvqa_validation' 作为数据集名称进行训练")

# 主函数：智能处理流程
def prepare_docvqa_dataset():
    """智能准备DocVQA数据集"""
    print("🎯 智能DocVQA数据集准备流程")

    # 检查Drive中是否已有处理好的数据
    if os.path.exists(f"{DRIVE_DATA_DIR}/docvqa_validation.json"):
        print("✅ 发现已处理的数据集，直接从Drive加载")
        data = load_docvqa_from_drive()
        if data:
            setup_docvqa_for_llamafactory()
            return True

    # 如果没有，则下载并处理
    print("🔄 未发现已处理的数据集，开始下载和处理...")
    try:
        processed_data = download_and_process_docvqa()
        if processed_data:
            setup_docvqa_for_llamafactory()
            return True
    except Exception as e:
        print(f"❌ 数据处理失败: {e}")
        return False

    return False

# 运行主函数
if __name__ == "__main__":
    success = prepare_docvqa_dataset()
    if success:
        print("\n🎉 DocVQA数据集准备完成！")
        print("💡 下次运行时会直接从Google Drive加载，无需重新下载处理")
    else:
        print("\n❌ DocVQA数据集准备失败，请检查错误信息")

✅ Google Drive已经挂载
🎯 智能DocVQA数据集准备流程
🔄 未发现已处理的数据集，开始下载和处理...
🔄 开始下载和处理DocVQA数据集...
📥 下载DocVQA validation数据集...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

validation-00000-of-00006.parquet:   0%|          | 0.00/115M [00:00<?, ?B/s]

validation-00001-of-00006.parquet:   0%|          | 0.00/160M [00:00<?, ?B/s]

validation-00002-of-00006.parquet:   0%|          | 0.00/184M [00:00<?, ?B/s]

validation-00003-of-00006.parquet:   0%|          | 0.00/178M [00:00<?, ?B/s]

validation-00004-of-00006.parquet:   0%|          | 0.00/206M [00:00<?, ?B/s]

validation-00005-of-00006.parquet:   0%|          | 0.00/212M [00:00<?, ?B/s]

test-00000-of-00006.parquet:   0%|          | 0.00/139M [00:00<?, ?B/s]

test-00001-of-00006.parquet:   0%|          | 0.00/161M [00:00<?, ?B/s]

test-00002-of-00006.parquet:   0%|          | 0.00/179M [00:00<?, ?B/s]

test-00003-of-00006.parquet:   0%|          | 0.00/189M [00:00<?, ?B/s]

test-00004-of-00006.parquet:   0%|          | 0.00/211M [00:00<?, ?B/s]

test-00005-of-00006.parquet:   0%|          | 0.00/228M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/5349 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5188 [00:00<?, ? examples/s]

数据集加载完成，共有 5349 个样本
🔄 开始处理数据集...
已处理 100 个样本...
已处理 200 个样本...
已处理 300 个样本...
已处理 400 个样本...
已处理 500 个样本...
已处理 600 个样本...
已处理 700 个样本...
已处理 800 个样本...
已处理 900 个样本...
已处理 1000 个样本...
已处理 1100 个样本...
已处理 1200 个样本...
已处理 1300 个样本...
已处理 1400 个样本...
已处理 1500 个样本...
已处理 1600 个样本...
已处理 1700 个样本...
已处理 1800 个样本...
已处理 1900 个样本...
已处理 2000 个样本...
已处理 2100 个样本...
已处理 2200 个样本...
已处理 2300 个样本...
已处理 2400 个样本...
已处理 2500 个样本...
已处理 2600 个样本...
已处理 2700 个样本...
已处理 2800 个样本...
已处理 2900 个样本...
已处理 3000 个样本...
已处理 3100 个样本...
已处理 3200 个样本...
已处理 3300 个样本...
已处理 3400 个样本...
已处理 3500 个样本...
已处理 3600 个样本...
已处理 3700 个样本...
已处理 3800 个样本...
已处理 3900 个样本...
已处理 4000 个样本...
已处理 4100 个样本...
已处理 4200 个样本...
已处理 4300 个样本...
已处理 4400 个样本...
已处理 4500 个样本...
已处理 4600 个样本...
已处理 4700 个样本...
已处理 4800 个样本...
已处理 4900 个样本...
已处理 5000 个样本...
已处理 5100 个样本...
已处理 5200 个样本...
已处理 5300 个样本...
✅ 数据集处理完成，保存到 /content/drive/MyDrive/llama_factory_data/docvqa/docvqa_validation.json
📊 总样本数: 5349
🖼️ 图像保存位置: /content/drive/MyD

In [None]:
# 快速加载DocVQA数据集脚本
# 适用于数据已经在Google Drive中处理好的情况

import json
import os
import shutil
from google.colab import drive

def quick_setup_docvqa():
    """快速设置DocVQA数据集（从Google Drive）"""
    print("⚡ 快速DocVQA数据集设置...")

    # 1. 挂载Google Drive
    if not os.path.exists('/content/drive/MyDrive'):
        drive.mount('/content/drive')
        print("✅ Google Drive挂载成功")
    else:
        print("✅ Google Drive已经挂载")

    # 2. 定义路径
    DRIVE_DATA_DIR = "/content/drive/MyDrive/llama_factory_data/docvqa"
    LOCAL_DATA_DIR = "/content/LLaMA-Factory/data/docvqa"

    # 3. 检查Drive中的数据
    if not os.path.exists(f"{DRIVE_DATA_DIR}/docvqa_validation.json"):
        print("❌ 在Google Drive中未找到处理好的DocVQA数据集")
        print("💡 请先运行完整的数据处理脚本")
        return False

    # 4. 切换到LLaMA Factory目录
    os.chdir('/content/LLaMA-Factory/')

    # 5. 创建本地数据目录结构
    os.makedirs(LOCAL_DATA_DIR, exist_ok=True)
    os.makedirs("data/docvqa/images", exist_ok=True)

    # 6. 复制数据文件
    print("📋 复制数据文件...")
    shutil.copy2(f"{DRIVE_DATA_DIR}/docvqa_validation.json", f"{LOCAL_DATA_DIR}/docvqa_validation.json")

    # 7. 创建图像软链接（更高效）
    print("🔗 链接图像目录...")
    image_link_path = f"{LOCAL_DATA_DIR}/images"
    if os.path.islink(image_link_path):
        os.unlink(image_link_path)
    elif os.path.exists(image_link_path):
        shutil.rmtree(image_link_path)

    os.symlink(f"{DRIVE_DATA_DIR}/images", image_link_path)

    # 8. 更新dataset_info.json
    print("📝 更新dataset_info.json...")
    dataset_info_path = "data/dataset_info.json"

    if os.path.exists(dataset_info_path):
        with open(dataset_info_path, 'r', encoding='utf-8') as f:
            dataset_info = json.load(f)
    else:
        dataset_info = {}

    dataset_info["docvqa_validation"] = {
        "file_name": "docvqa/docvqa_validation.json",
        "formatting": "sharegpt",
        "columns": {
            "messages": "conversations",
            "images": "images"
        }
    }

    with open(dataset_info_path, 'w', encoding='utf-8') as f:
        json.dump(dataset_info, f, ensure_ascii=False, indent=2)

    # 9. 验证设置
    with open(f"{LOCAL_DATA_DIR}/docvqa_validation.json", 'r', encoding='utf-8') as f:
        data = json.load(f)

    # 10. 显示统计信息
    print("✅ DocVQA数据集快速设置完成！")
    print(f"📊 总样本数: {len(data)}")
    print(f"📁 数据文件: {LOCAL_DATA_DIR}/docvqa_validation.json")
    print(f"🖼️ 图像目录: {LOCAL_DATA_DIR}/images")
    print("🚀 现在可以在LLaMA Factory UI中使用 'docvqa_validation' 数据集")

    # 11. 显示示例数据
    if len(data) > 0:
        print("\n📝 数据示例:")
        sample = data[0]
        print(f"  问题: {sample['conversations'][0]['value']}")
        print(f"  答案: {sample['conversations'][1]['value']}")
        print(f"  图像: {sample['images'][0]}")

    return True

# 直接运行
if __name__ == "__main__":
    success = quick_setup_docvqa()
    if not success:
        print("\n💡 如果是第一次使用，请运行完整的数据处理脚本：")

⚡ 快速DocVQA数据集设置...
✅ Google Drive已经挂载
📋 复制数据文件...
🔗 链接图像目录...
📝 更新dataset_info.json...
✅ DocVQA数据集快速设置完成！
📊 总样本数: 5349
📁 数据文件: /content/LLaMA-Factory/data/docvqa/docvqa_validation.json
🖼️ 图像目录: /content/LLaMA-Factory/data/docvqa/images
🚀 现在可以在LLaMA Factory UI中使用 'docvqa_validation' 数据集

📝 数据示例:
  问题: <image>What is the ‘actual’ value per 1000, during the year 1975?
  答案: 0.28
  图像: data/docvqa/images/image_0.png


In [None]:
import json
import os
from datasets import load_dataset
from PIL import Image
from google.colab import drive
import shutil
import pandas as pd

# 挂载Google Drive
if not os.path.exists('/content/drive/MyDrive'):
    drive.mount('/content/drive')
    print("✅ Google Drive挂载成功")
else:
    print("✅ Google Drive已经挂载")

# 定义持久化路径
DRIVE_DATA_DIR = "/content/drive/MyDrive/llama_factory_data/infographicvqa"
LOCAL_DATA_DIR = "/content/LLaMA-Factory/data/infographicvqa"

def download_and_process_infographicvqa():
    """下载并处理InfographicVQA数据集，保存到Google Drive"""
    print("🔄 开始下载和处理InfographicVQA数据集...")

    # 创建Drive目录
    os.makedirs(DRIVE_DATA_DIR, exist_ok=True)
    os.makedirs(f"{DRIVE_DATA_DIR}/images", exist_ok=True)

    # 下载数据集
    print("📥 下载InfographicVQA validation数据集...")
    dataset = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation")
    print(f"数据集加载完成，共有 {len(dataset)} 个样本")

    # 限制到前535条
    max_samples = min(535, len(dataset))
    dataset = dataset.select(range(max_samples))
    print(f"使用前 {max_samples} 个样本进行处理")

    # 处理数据集
    def process_dataset(dataset):
        """将InfographicVQA数据集转换为LLaMA Factory的sharegpt格式"""
        processed_data = []

        for idx, sample in enumerate(dataset):
            # 获取图像
            image = sample['image']

            # 保存图像到Drive（使用相对路径）
            image_filename = f"infographic_{idx}.png"
            image_path_drive = f"{DRIVE_DATA_DIR}/images/{image_filename}"
            image_path_relative = f"data/infographicvqa/images/{image_filename}"  # LLaMA Factory中的相对路径

            # 保存图像
            image.save(image_path_drive)

            # 构建对话格式
            conversation = {
                "conversations": [
                    {
                        "from": "human",
                        "value": f"<image>{sample['question']}"
                    },
                    {
                        "from": "gpt",
                        "value": ", ".join(sample['answers']) if isinstance(sample['answers'], list) else str(sample['answers'])
                    }
                ],
                "images": [image_path_relative]  # 使用相对路径
            }

            processed_data.append(conversation)

            # 每处理50个样本显示进度
            if (idx + 1) % 50 == 0:
                print(f"已处理 {idx + 1} 个样本...")

        return processed_data

    print("🔄 开始处理数据集...")
    processed_data = process_dataset(dataset)

    # 保存处理后的数据到Drive
    output_file = f"{DRIVE_DATA_DIR}/infographicvqa_validation.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(processed_data, f, ensure_ascii=False, indent=2)

    print(f"✅ 数据集处理完成，保存到 {output_file}")
    print(f"📊 总样本数: {len(processed_data)}")
    print(f"🖼️ 图像保存位置: {DRIVE_DATA_DIR}/images/")

    # 保存元信息
    meta_info = {
        "total_samples": len(processed_data),
        "dataset_name": "infographicvqa_validation",
        "source": "lmms-lab/DocVQA",
        "subset": "InfographicVQA",
        "split": "validation",
        "max_samples": max_samples,
        "processed_date": str(pd.Timestamp.now()),
        "format": "sharegpt"
    }

    with open(f"{DRIVE_DATA_DIR}/meta_info.json", 'w', encoding='utf-8') as f:
        json.dump(meta_info, f, ensure_ascii=False, indent=2)

    return processed_data

def load_infographicvqa_from_drive():
    """从Google Drive加载已处理的InfographicVQA数据集"""
    print("📂 从Google Drive加载InfographicVQA数据集...")

    # 检查文件是否存在
    json_file = f"{DRIVE_DATA_DIR}/infographicvqa_validation.json"
    images_dir = f"{DRIVE_DATA_DIR}/images"

    if not os.path.exists(json_file):
        print("❌ 数据文件不存在，需要先处理数据集")
        return None

    if not os.path.exists(images_dir):
        print("❌ 图像目录不存在，需要先处理数据集")
        return None

    # 读取数据
    with open(json_file, 'r', encoding='utf-8') as f:
        processed_data = json.load(f)

    # 读取元信息
    meta_file = f"{DRIVE_DATA_DIR}/meta_info.json"
    if os.path.exists(meta_file):
        with open(meta_file, 'r', encoding='utf-8') as f:
            meta_info = json.load(f)
        print(f"📊 数据集信息: {meta_info}")

    print(f"✅ 成功加载 {len(processed_data)} 个样本")
    return processed_data

def setup_infographicvqa_for_llamafactory():
    """为LLaMA Factory设置InfographicVQA数据集"""
    print("🔧 为LLaMA Factory设置InfographicVQA数据集...")

    # 切换到LLaMA Factory目录
    os.chdir('/content/LLaMA-Factory/')

    # 创建本地数据目录
    os.makedirs(LOCAL_DATA_DIR, exist_ok=True)
    os.makedirs(f"{LOCAL_DATA_DIR}/images", exist_ok=True)

    # 从Drive复制数据到本地
    print("📋 复制数据文件...")
    shutil.copy2(f"{DRIVE_DATA_DIR}/infographicvqa_validation.json", f"{LOCAL_DATA_DIR}/infographicvqa_validation.json")

    # 复制图像（创建软链接更高效）
    print("🔗 创建图像软链接...")
    if os.path.exists(f"{LOCAL_DATA_DIR}/images"):
        shutil.rmtree(f"{LOCAL_DATA_DIR}/images")
    os.symlink(f"{DRIVE_DATA_DIR}/images", f"{LOCAL_DATA_DIR}/images")

    # 更新dataset_info.json
    dataset_info_path = "data/dataset_info.json"

    # 读取现有的dataset_info.json
    if os.path.exists(dataset_info_path):
        with open(dataset_info_path, 'r', encoding='utf-8') as f:
            dataset_info = json.load(f)
    else:
        dataset_info = {}

    # 添加InfographicVQA数据集信息
    dataset_info["infographicvqa_validation"] = {
        "file_name": "infographicvqa/infographicvqa_validation.json",
        "formatting": "sharegpt",
        "columns": {
            "messages": "conversations",
            "images": "images"
        }
    }

    # 保存更新后的dataset_info.json
    with open(dataset_info_path, 'w', encoding='utf-8') as f:
        json.dump(dataset_info, f, ensure_ascii=False, indent=2)

    print("✅ dataset_info.json 已更新")
    print("✅ InfographicVQA数据集设置完成")
    print("🚀 现在可以在LLaMA Factory UI中使用 'infographicvqa_validation' 作为数据集名称进行训练")

# 主函数：智能处理流程
def prepare_infographicvqa_dataset():
    """智能准备InfographicVQA数据集"""
    print("🎯 智能InfographicVQA数据集准备流程")

    # 检查Drive中是否已有处理好的数据
    if os.path.exists(f"{DRIVE_DATA_DIR}/infographicvqa_validation.json"):
        print("✅ 发现已处理的数据集，直接从Drive加载")
        data = load_infographicvqa_from_drive()
        if data:
            setup_infographicvqa_for_llamafactory()
            return True

    # 如果没有，则下载并处理
    print("🔄 未发现已处理的数据集，开始下载和处理...")
    try:
        processed_data = download_and_process_infographicvqa()
        if processed_data:
            setup_infographicvqa_for_llamafactory()
            return True
    except Exception as e:
        print(f"❌ 数据处理失败: {e}")
        return False

    return False

# 运行主函数
if __name__ == "__main__":
    success = prepare_infographicvqa_dataset()
    if success:
        print("\n🎉 InfographicVQA数据集准备完成！")
        print("💡 下次运行时会直接从Google Drive加载，无需重新下载处理")
        print("📌 数据集名称: infographicvqa_validation")
        print("📌 样本数量: 535 (最大)")
    else:
        print("\n❌ InfographicVQA数据集准备失败，请检查错误信息")

✅ Google Drive已经挂载
🎯 智能InfographicVQA数据集准备流程
✅ 发现已处理的数据集，直接从Drive加载
📂 从Google Drive加载InfographicVQA数据集...
📊 数据集信息: {'total_samples': 535, 'dataset_name': 'infographicvqa_validation', 'source': 'lmms-lab/DocVQA', 'subset': 'InfographicVQA', 'split': 'validation', 'max_samples': 535, 'processed_date': '2025-07-31 02:35:12.295270', 'format': 'sharegpt'}
✅ 成功加载 535 个样本
🔧 为LLaMA Factory设置InfographicVQA数据集...
📋 复制数据文件...
🔗 创建图像软链接...


OSError: Cannot call rmtree on a symbolic link

In [None]:
# 快速加载InfographicVQA数据集脚本
# 适用于数据已经在Google Drive中处理好的情况

import json
import os
import shutil
from google.colab import drive

def quick_setup_infographicvqa():
    """快速设置InfographicVQA数据集（从Google Drive）"""
    print("⚡ 快速InfographicVQA数据集设置...")

    # 1. 挂载Google Drive
    if not os.path.exists('/content/drive/MyDrive'):
        drive.mount('/content/drive')
        print("✅ Google Drive挂载成功")
    else:
        print("✅ Google Drive已经挂载")

    # 2. 定义路径
    DRIVE_DATA_DIR = "/content/drive/MyDrive/llama_factory_data/infographicvqa"
    LOCAL_DATA_DIR = "/content/LLaMA-Factory/data/infographicvqa"

    # 3. 检查Drive中的数据
    if not os.path.exists(f"{DRIVE_DATA_DIR}/infographicvqa_validation.json"):
        print("❌ 在Google Drive中未找到处理好的InfographicVQA数据集")
        print("💡 请先运行完整的数据处理脚本")
        return False

    # 4. 切换到LLaMA Factory目录
    os.chdir('/content/LLaMA-Factory/')

    # 5. 创建本地数据目录结构
    os.makedirs(LOCAL_DATA_DIR, exist_ok=True)
    os.makedirs("data/infographicvqa/images", exist_ok=True)

    # 6. 复制数据文件
    print("📋 复制数据文件...")
    shutil.copy2(f"{DRIVE_DATA_DIR}/infographicvqa_validation.json", f"{LOCAL_DATA_DIR}/infographicvqa_validation.json")

    # 7. 创建图像软链接（更高效）
    print("🔗 链接图像目录...")
    image_link_path = f"{LOCAL_DATA_DIR}/images"
    if os.path.islink(image_link_path):
        os.unlink(image_link_path)
    elif os.path.exists(image_link_path):
        shutil.rmtree(image_link_path)

    os.symlink(f"{DRIVE_DATA_DIR}/images", image_link_path)

    # 8. 更新dataset_info.json
    print("📝 更新dataset_info.json...")
    dataset_info_path = "data/dataset_info.json"

    if os.path.exists(dataset_info_path):
        with open(dataset_info_path, 'r', encoding='utf-8') as f:
            dataset_info = json.load(f)
    else:
        dataset_info = {}

    dataset_info["infographicvqa_validation"] = {
        "file_name": "infographicvqa/infographicvqa_validation.json",
        "formatting": "sharegpt",
        "columns": {
            "messages": "conversations",
            "images": "images"
        }
    }

    with open(dataset_info_path, 'w', encoding='utf-8') as f:
        json.dump(dataset_info, f, ensure_ascii=False, indent=2)

    # 9. 验证设置
    with open(f"{LOCAL_DATA_DIR}/infographicvqa_validation.json", 'r', encoding='utf-8') as f:
        data = json.load(f)

    # 10. 加载元信息
    meta_file = f"{DRIVE_DATA_DIR}/meta_info.json"
    if os.path.exists(meta_file):
        with open(meta_file, 'r', encoding='utf-8') as f:
            meta_info = json.load(f)
        print(f"📊 数据集元信息: {meta_info}")

    # 11. 显示统计信息
    print("✅ InfographicVQA数据集快速设置完成！")
    print(f"📊 总样本数: {len(data)}")
    print(f"📁 数据文件: {LOCAL_DATA_DIR}/infographicvqa_validation.json")
    print(f"🖼️ 图像目录: {LOCAL_DATA_DIR}/images")
    print("🚀 现在可以在LLaMA Factory UI中使用 'infographicvqa_validation' 数据集")

    # 12. 显示示例数据
    if len(data) > 0:
        print("\n📝 数据示例:")
        sample = data[0]
        print(f"  问题: {sample['conversations'][0]['value']}")
        print(f"  答案: {sample['conversations'][1]['value']}")
        print(f"  图像: {sample['images'][0]}")

        # 显示更多统计信息
        print(f"\n📈 数据集统计:")
        questions = [item['conversations'][0]['value'] for item in data]
        answers = [item['conversations'][1]['value'] for item in data]
        avg_question_length = sum(len(q.split()) for q in questions) / len(questions)
        avg_answer_length = sum(len(a.split()) for a in answers) / len(answers)
        print(f"  平均问题长度: {avg_question_length:.1f} 个词")
        print(f"  平均答案长度: {avg_answer_length:.1f} 个词")

    return True

# 直接运行
if __name__ == "__main__":
    success = quick_setup_infographicvqa()
    if not success:
        print("\n💡 如果是第一次使用，请运行完整的数据处理脚本")
        print("📋 完整脚本将下载并处理InfographicVQA数据集的前535条数据")
    else:
        print("\n🎯 设置完成，可以开始训练了！")
        print("💡 下次可以直接运行此快速脚本，无需重新处理数据")

⚡ 快速InfographicVQA数据集设置...
✅ Google Drive已经挂载
📋 复制数据文件...
🔗 链接图像目录...
📝 更新dataset_info.json...
📊 数据集元信息: {'total_samples': 535, 'dataset_name': 'infographicvqa_validation', 'source': 'lmms-lab/DocVQA', 'subset': 'InfographicVQA', 'split': 'validation', 'max_samples': 535, 'processed_date': '2025-07-31 02:35:12.295270', 'format': 'sharegpt'}
✅ InfographicVQA数据集快速设置完成！
📊 总样本数: 535
📁 数据文件: /content/LLaMA-Factory/data/infographicvqa/infographicvqa_validation.json
🖼️ 图像目录: /content/LLaMA-Factory/data/infographicvqa/images
🚀 现在可以在LLaMA Factory UI中使用 'infographicvqa_validation' 数据集

📝 数据示例:
  问题: <image>Which social platform has heavy female audience?
  答案: pinterest
  图像: data/infographicvqa/images/infographic_0.png

📈 数据集统计:
  平均问题长度: 10.9 个词
  平均答案长度: 2.3 个词

🎯 设置完成，可以开始训练了！
💡 下次可以直接运行此快速脚本，无需重新处理数据


## Fine-tune model via LLaMA Board

In [None]:

import os, shutil, pathlib, subprocess

repo_saves  = '/content/LLaMA-Factory/saves'
drive_saves = '/content/drive/MyDrive/llama_saves'

# 确保目标目录存在
os.makedirs(drive_saves, exist_ok=True)

# --- ❶ 只要本地 saves 不是符号链接，就先增量同步到 Drive ---
if os.path.exists(repo_saves) and not os.path.islink(repo_saves):
    print("🔄 Sync local → Drive ...")
    subprocess.run(["rsync", "-a", "--update", f"{repo_saves}/", f"{drive_saves}/"], check=True)

    # 同步完再安全删除本地目录
    shutil.rmtree(repo_saves)

# --- ❷ 重建符号链接（避免陈旧） ---
if os.path.islink(repo_saves):
    os.unlink(repo_saves)

os.symlink(drive_saves, repo_saves, target_is_directory=True)
print("✅ symlink OK → 写入 saves/ 将实时保存到 Drive")


✅ symlink OK → 写入 saves/ 将实时保存到 Drive


In [None]:
pip install rouge_chinese


Collecting rouge_chinese
  Downloading rouge_chinese-1.0.3-py3-none-any.whl.metadata (7.6 kB)
Downloading rouge_chinese-1.0.3-py3-none-any.whl (21 kB)
Installing collected packages: rouge_chinese
Successfully installed rouge_chinese-1.0.3


In [None]:
#跑一半继续
import mlflow
mlflow.set_tracking_uri("sqlite:////content/drive/MyDrive/mlflow/mlflow.db")
if mlflow.active_run() is not None:
    mlflow.end_run()
mlflow.start_run(run_id="9f5605f56b0f400daa3d7f5d13b8aeca")  # ← 旧 run ID


<ActiveRun: >

In [None]:
%cd /content/LLaMA-Factory/
!GRADIO_SHARE=1 llamafactory-cli webui

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[INFO|tokenization_utils_base.py:2336] 2025-09-12 04:44:29,062 >> Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[INFO|video_processing_utils.py:720] 2025-09-12 04:44:29,271 >> loading configuration file preprocessor_config.json from cache at /root/.cache/huggingface/hub/models--Qwen--Qwen2-VL-2B/snapshots/d3a53f2484fce9d62fff115a5ddfc833f873bfde/preprocessor_config.json
[INFO|video_processing_utils.py:764] 2025-09-12 04:44:29,271 >> Video processor Qwen2VLVideoProcessor {
  "crop_size": null,
  "data_format": "channels_first",
  "default_to_square": true,
  "device": null,
  "do_center_crop": null,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_pad": null,
  "do_rescale": true,
  "do_resize": true,
  "do_sample_frames": false,
  "fps": null,
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_std": [
    0.26862954,


In [None]:
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from peft import PeftModel, PeftConfig
import torch

base = "Qwen/Qwen2-VL-2B-Instruct"
adapter = "/content/drive/MyDrive/vblora_save/comprehensive_vblora/final_model"   # 指向含 adapter_config.json 的目录

# 1) 加载 Qwen2-VL 基座
processor = AutoProcessor.from_pretrained(base, trust_remote_code=True)
model_base = Qwen2VLForConditionalGeneration.from_pretrained(
    base, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
)

# 2) 不带适配器的输出
prompt = "用一句话解释 VB-LoRA 的核心思想。"
inputs = processor(text=prompt, return_tensors="pt").to(model_base.device)
with torch.inference_mode():
    out_base = model_base.generate(**inputs, max_new_tokens=64, do_sample=False)
print("BASE:", processor.batch_decode(out_base, skip_special_tokens=True)[0])

# 3) 挂载 VB-LoRA 适配器
peft_cfg = PeftConfig.from_pretrained(adapter)
assert peft_cfg.peft_type.upper() == "VBLORA", peft_cfg   # 简单断言类型
model = PeftModel.from_pretrained(model_base, adapter)
model.eval()

print("Loaded adapters:", list(getattr(model, "peft_config", {}).keys()))

# 4) 带适配器的输出（应与 BASE 有明显差异）
with torch.inference_mode():
    out_vb = model.generate(**inputs, max_new_tokens=64, do_sample=False)
print("VB-LORA:", processor.batch_decode(out_vb, skip_special_tokens=True)[0])


preprocessor_config.json:   0%|          | 0.00/347 [00:00<?, ?B/s]

The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


chat_template.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/429M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/272 [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


BASE: 用一句话解释 VB-LoRA 的核心思想。VB-LoRA 是一种基于虚拟机的网络协议，它允许在虚拟机之间进行数据传输，从而实现跨平台的通信。核心思想是通过虚拟机的虚拟机（VM）来实现数据的传输，从而实现跨平台的通信。VB-LoRA 的核心思想是通过


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Loaded adapters: ['default']
VB-LORA: 用一句话解释 VB-LoRA 的核心思想。 VB-LoRA 的核心思想是“以用户为中心，以用户需求为导向，以用户体验为核心，以用户满意度为标准，以用户反馈为依据，以用户意见为依据，以用户建议为依据，以用户需求为依据，以用户反馈为依据，以用户意见为依据


In [None]:
#克隆
# ==== 0) 基本设置 ====
import os, shutil, tempfile
import mlflow
from mlflow.tracking import MlflowClient

# 若你有自定义 tracking URI（比如放在 Google Drive 的 sqlite），在这里设：
# mlflow.set_tracking_uri("file:/content/drive/MyDrive/mlflow")

OLD_RUN_ID = "5e360ce1579b4a28b3c6fef84a0ae3e8"       # ← 换成你刚才那条 run 的 ID
NEW_RUN_NAME = "overfittingr16"  # 新 run 的名字
EXTRA_PARAMS = {                  # 这就是你要补记到 MLflow 的参数
    "lora_rank": "12",
    "lora_alpha": "16",
    "lora_dropout": "0",
}

client = MlflowClient()

# 读取旧 run
old_run = client.get_run(OLD_RUN_ID)
exp_id = old_run.info.experiment_id

# ==== 1) 新建 run，并打上关联 tag ====
new_run = client.create_run(
    experiment_id=exp_id,
    tags={**old_run.data.tags, "clone_of": OLD_RUN_ID}
)
new_id = new_run.info.run_id

# 设定新 run 的名字（UI 左侧好看）
client.set_tag(new_id, "mlflow.runName", NEW_RUN_NAME)

# ==== 2) 复制 params（再追加 LoRA 的参数）====
# 注意：MLflow 的 param 一旦写入同键不能改值；我们只新增键
for k, v in old_run.data.params.items():
    client.log_param(new_id, k, v)
for k, v in EXTRA_PARAMS.items():
    client.log_param(new_id, k, v)

# ==== 3) 复制 metrics 全量历史（含 step & timestamp）====
# 这样平行坐标/曲线都会完整重现
for key in old_run.data.metrics.keys():
    for m in client.get_metric_history(OLD_RUN_ID, key):
        client.log_metric(
            run_id=new_id,
            key=key,
            value=m.value,
            timestamp=m.timestamp,
            step=m.step,
        )

# ==== 4) 复制 artifacts（若旧 run 的 artifacts 目录仍存在）====
def _download_and_log_artifacts(src_run_id, dst_run_id):
    tmpdir = tempfile.mkdtemp()
    try:
        # 下载旧 run 的所有 artifacts 到本地临时目录
        client.download_artifacts(src_run_id, "", tmpdir)
        # 上传到新 run
        client.log_artifacts(dst_run_id, tmpdir)
    except Exception as e:
        print(f"⚠️ 复制 artifacts 失败（可能旧目录已丢失）：{e}")
    finally:
        shutil.rmtree(tmpdir, ignore_errors=True)

_download_and_log_artifacts(OLD_RUN_ID, new_id)

print(f"✅ Clone 完成，新 run: {new_id}")
print("小贴士：在新 run 的 Artifacts 里也能看到 adapter_config.json，那里有 r/alpha/dropout 的真值。")


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

⚠️ 复制 artifacts 失败（可能旧目录已丢失）：The following failures occurred while downloading one or more artifacts from /content/LLaMA-Factory/mlruns/1/5e360ce1579b4a28b3c6fef84a0ae3e8/artifacts:
##### File  #####
[Errno 2] No such file or directory: '/content/LLaMA-Factory/mlruns/1/5e360ce1579b4a28b3c6fef84a0ae3e8/artifacts/.'
✅ Clone 完成，新 run: fe00769c33074b2da9da0789a15d368e
小贴士：在新 run 的 Artifacts 里也能看到 adapter_config.json，那里有 r/alpha/dropout 的真值。


In [None]:
from llamafactory.hparams.finetuning_args import FinetuningArguments
import inspect, textwrap

# 打印定义，确认字段是否存在
src = inspect.getsource(FinetuningArguments)
print(textwrap.dedent(src))



@dataclass
class FinetuningArguments(
    SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
):
    r"""Arguments pertaining to which techniques we are going to fine-tuning with."""

    pure_bf16: bool = field(
        default=False,
        metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
    )
    stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
        default="sft",
        metadata={"help": "Which stage will be performed in training."},
    )
    finetuning_type: Literal["lora", "freeze", "full"] = field(
        default="lora",
        metadata={"help": "Which fine-tuning method to use."},
    )
    use_llama_pro: bool = field(
        default=False,
        metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
    )
    use_adam_mini: bool = field(
        default=False,
        metadata={"help": "Whether or not to 

## Fine-tune model via Command Line

It takes ~30min for training.

In [None]:
import json

args = dict(
  stage="sft",                                               # do supervised fine-tuning
  do_train=True,
  model_name_or_path="unsloth/llama-3-8b-Instruct-bnb-4bit", # use bnb-4bit-quantized Llama-3-8B-Instruct model
  dataset="identity,alpaca_en_demo",                         # use alpaca and identity datasets
  template="llama3",                                         # use llama3 prompt template
  finetuning_type="lora",                                    # use LoRA adapters to save memory
  lora_target="all",                                         # attach LoRA adapters to all linear layers
  output_dir="llama3_lora",                                  # the path to save LoRA adapters
  per_device_train_batch_size=2,                             # the micro batch size
  gradient_accumulation_steps=4,                             # the gradient accumulation steps
  lr_scheduler_type="cosine",                                # use cosine learning rate scheduler
  logging_steps=5,                                           # log every 5 steps
  warmup_ratio=0.1,                                          # use warmup scheduler
  save_steps=1000,                                           # save checkpoint every 1000 steps
  learning_rate=5e-5,                                        # the learning rate
  num_train_epochs=3.0,                                      # the epochs of training
  max_samples=500,                                           # use 500 examples in each dataset
  max_grad_norm=1.0,                                         # clip gradient norm to 1.0
  loraplus_lr_ratio=16.0,                                    # use LoRA+ algorithm with lambda=16.0
  fp16=True,                                                 # use float16 mixed precision training
  report_to="none",                                          # disable wandb logging
)

json.dump(args, open("train_llama3.json", "w", encoding="utf-8"), indent=2)

%cd /content/LLaMA-Factory/

!llamafactory-cli train train_llama3.json

## Infer the fine-tuned model

In [None]:
from llamafactory.chat import ChatModel
from llamafactory.extras.misc import torch_gc

%cd /content/LLaMA-Factory/

args = dict(
  model_name_or_path="unsloth/llama-3-8b-Instruct-bnb-4bit", # use bnb-4bit-quantized Llama-3-8B-Instruct model
  adapter_name_or_path="llama3_lora",                        # load the saved LoRA adapters
  template="llama3",                                         # same to the one in training
  finetuning_type="lora",                                    # same to the one in training
)
chat_model = ChatModel(args)

messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
  query = input("\nUser: ")
  if query.strip() == "exit":
    break
  if query.strip() == "clear":
    messages = []
    torch_gc()
    print("History has been removed.")
    continue

  messages.append({"role": "user", "content": query})
  print("Assistant: ", end="", flush=True)

  response = ""
  for new_text in chat_model.stream_chat(messages):
    print(new_text, end="", flush=True)
    response += new_text
  print()
  messages.append({"role": "assistant", "content": response})

torch_gc()

## Merge the LoRA adapter and optionally upload model

NOTE: the Colab free version has merely 12GB RAM, where merging LoRA of a 8B model needs at least 18GB RAM, thus you **cannot** perform it in the free version.

In [None]:
!huggingface-cli login

In [None]:
import json

args = dict(
  model_name_or_path="meta-llama/Meta-Llama-3-8B-Instruct", # use official non-quantized Llama-3-8B-Instruct model
  adapter_name_or_path="llama3_lora",                       # load the saved LoRA adapters
  template="llama3",                                        # same to the one in training
  finetuning_type="lora",                                   # same to the one in training
  export_dir="llama3_lora_merged",                          # the path to save the merged model
  export_size=2,                                            # the file shard size (in GB) of the merged model
  export_device="cpu",                                      # the device used in export, can be chosen from `cpu` and `auto`
  # export_hub_model_id="your_id/your_model",               # the Hugging Face hub ID to upload model
)

json.dump(args, open("merge_llama3.json", "w", encoding="utf-8"), indent=2)

%cd /content/LLaMA-Factory/

!llamafactory-cli export merge_llama3.json