# 🚀 FunctionGemma 交互式训练 Notebook

本 Notebook 提供可视化的模型微调环境，支持：
- 📊 实时训练指标可视化
- 🎯 交互式参数配置
- 📈 Loss 曲线动态绘制
- 🔍 训练样本质量检查
- 💾 模型导出与推理测试

## 1. 环境初始化

In [None]:
# 检查 GPU 可用性
import torch
import sys
from pathlib import Path

# 添加项目根目录到路径
project_root = Path().absolute().parent
sys.path.insert(0, str(project_root))

print(f"🔥 PyTorch 版本: {torch.__version__}")
print(f"🎮 CUDA 可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"📺 GPU: {torch.cuda.get_device_name(0)}")
    print(f"💾 GPU 显存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ 警告: 未检测到 GPU，训练将非常慢")

## 2. 导入依赖

In [None]:
import json
import logging
from datetime import datetime
from typing import Optional, List, Dict, Any

import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, HTML, clear_output
from tqdm.notebook import tqdm
import ipywidgets as widgets

# 项目模块
from src.utils.config_loader import load_config, print_config
from src.data_engine.converter import DataConverter
from src.data_engine.formatter import FunctionGemmaFormatter
from src.training.trainer import FunctionGemmaTrainer
from src.training.callbacks import (
    WandbCallback,
    SampleGenerationCallback,
    EarlyStoppingCallback
)

# 设置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

print("✅ 所有依赖导入成功！")

## 3. 交互式参数配置

使用下面的控件配置训练参数：

In [None]:
# 创建交互式配置控件
config_widgets = {
    'model_name': widgets.Text(
        value='google/functiongemma-270m-it',
        description='模型名称:',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='500px')
    ),
    'max_seq_length': widgets.IntSlider(
        value=2048, min=512, max=8192, step=512,
        description='最大序列长度:',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='500px')
    ),
    'lora_rank': widgets.IntSlider(
        value=16, min=4, max=64, step=4,
        description='LoRA Rank:',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='500px')
    ),
    'lora_alpha': widgets.IntSlider(
        value=16, min=4, max=64, step=4,
        description='LoRA Alpha:',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='500px')
    ),
    'epochs': widgets.IntSlider(
        value=3, min=1, max=10, step=1,
        description='训练轮数:',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='500px')
    ),
    'batch_size': widgets.IntSlider(
        value=4, min=1, max=16, step=1,
        description='Batch Size:',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='500px')
    ),
    'learning_rate': widgets.FloatLogSlider(
        value=2e-4, base=10, min=-5, max=-3, step=0.1,
        description='学习率:',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='500px')
    ),
    'gradient_accumulation': widgets.IntSlider(
        value=4, min=1, max=16, step=1,
        description='梯度累积:',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='500px')
    ),
    'data_path': widgets.Text(
        value='data/processed/train.jsonl',
        description='数据路径:',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='500px')
    ),
    'output_dir': widgets.Text(
        value=f'outputs/models/experiment_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
        description='输出目录:',
        style={'description_width': '150px'},
        layout=widgets.Layout(width='500px')
    ),
}

# 显示配置控件
print("🎛️ 训练参数配置")
print("=" * 50)
for widget in config_widgets.values():
    display(widget)

## 4. 数据加载与可视化

In [None]:
def load_and_visualize_data(data_path: str):
    data_path = Path(data_path)
    
    if not data_path.exists():
        print(f"⚠️ 数据文件不存在: {data_path}")
        print("正在创建示例数据...")
        create_sample_data(data_path)
    
    converter = DataConverter()
    dataset = converter.load_dataset(str(data_path))
    df = pd.DataFrame(dataset)
    
    print(f"\n📊 数据概览")
    print(f"总样本数: {len(df)}")
    print(f"列名: {list(df.columns)}")
    
    print("\n📋 前 3 条样本:")
    for i in range(min(3, len(df))):
        print(f"\n样本 {i+1}:")
        for col in df.columns:
            value = df.iloc[i][col]
            if isinstance(value, str) and len(value) > 200:
                value = value[:200] + "..."
            print(f"  {col}: {value}")
    
    return dataset, df

def create_sample_data(output_path: Path):
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    sample_data = [
        {"text": "Sample training text 1"},
        {"text": "Sample training text 2"},
        {"text": "Sample training text 3"},
    ]
    
    expanded_data = []
    for i in range(100):
        for sample in sample_data:
            expanded_data.append(sample)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for item in expanded_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    print(f"✅ 已创建示例数据: {output_path}")

# 加载数据
data_path = config_widgets['data_path'].value
dataset, df = load_and_visualize_data(data_path)

## 5. 数据统计分析

In [None]:
# 数据统计可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

if 'text' in df.columns:
    text_lengths = df['text'].str.len()
    axes[0].hist(text_lengths, bins=30, edgecolor='black', alpha=0.7)
    axes[0].set_xlabel('Text Length')
    axes[0].set_ylabel('Frequency')
    axes[0].set_title('Distribution of Text Length')
    axes[0].axvline(text_lengths.mean(), color='red', linestyle='--')

if 'tool_name' in df.columns:
    tool_counts = df['tool_name'].value_counts()
    tool_counts.plot(kind='bar', ax=axes[1])
    axes[1].set_title('Tool Distribution')

plt.tight_layout()
plt.show()

print("\n📈 数据统计")
print("=" * 50)

## 6. 训练可视化回调类

In [None]:
from transformers import TrainerCallback
from collections import defaultdict

class JupyterVisualizationCallback(TrainerCallback):
    def __init__(self, update_steps: int = 10):
        super().__init__()
        self.update_steps = update_steps
        self.metrics_history = defaultdict(list)
        self.fig = None
        self.axes = None
    
    def setup_plot(self):
        plt.ion()
        self.fig, self.axes = plt.subplots(2, 2, figsize=(14, 10))
        self.fig.suptitle('Training Metrics (Real-time)', fontsize=14)
        plt.show()
    
    def on_train_begin(self, args, state, control, **kwargs):
        print("\n🚀 训练开始！")
        self.setup_plot()
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return
        step = state.global_step
        for key, value in logs.items():
            if isinstance(value, (int, float)):
                self.metrics_history[key].append((step, value))
        if step % self.update_steps == 0:
            self.update_plots()
    
    def update_plots(self):
        if self.axes is None:
            return
        for ax in self.axes.flat:
            ax.clear()
        
        if 'loss' in self.metrics_history:
            steps, losses = zip(*self.metrics_history['loss'])
            self.axes[0, 0].plot(steps, losses, 'b-', linewidth=2)
            self.axes[0, 0].set_title('Training Loss')
            self.axes[0, 0].grid(True, alpha=0.3)
        
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
        plt.pause(0.01)
    
    def on_train_end(self, args, state, control, **kwargs):
        print("\n✅ 训练完成！")
        plt.ioff()
        plt.savefig('training_metrics.png', dpi=150)

print("✅ 可视化回调类已定义")

## 7. 开始训练

运行下面的单元格开始训练：

In [None]:
from omegaconf import OmegaConf

config_dict = {
    'model': {
        'name': config_widgets['model_name'].value,
        'max_seq_length': config_widgets['max_seq_length'].value,
        'dtype': 'bfloat16',
        'lora': {
            'enabled': True,
            'rank': config_widgets['lora_rank'].value,
            'alpha': config_widgets['lora_alpha'].value,
            'target_modules': ["q_proj", "k_proj", "v_proj"],
        },
    },
    'training': {
        'epochs': config_widgets['epochs'].value,
        'per_device_train_batch_size': config_widgets['batch_size'].value,
        'learning_rate': config_widgets['learning_rate'].value,
    },
    'logging': {
        'output_dir': config_widgets['output_dir'].value,
        'wandb': {'enabled': False}
    }
}

config = OmegaConf.create(config_dict)
print("🔧 训练配置")
print(OmegaConf.to_yaml(config))

In [None]:
# 创建训练器
trainer = FunctionGemmaTrainer(config)

print("\n📥 加载模型...")
trainer.load_model()
print("✅ 模型加载完成")

callbacks = [JupyterVisualizationCallback(update_steps=5)]

print("\n🎯 开始训练...")
train_result = trainer.train(
    train_dataset=dataset,
    output_dir=config.logging.output_dir,
    callbacks=callbacks
)
print(f"\n✅ 训练完成！")

## 8. 保存模型

In [None]:
output_dir = config.logging.output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)

print(f"💾 保存模型到: {output_dir}")
trainer.save_model(output_dir)

config_save_path = Path(output_dir) / 'training_config.yaml'
OmegaConf.save(config, config_save_path)
print(f"✅ 配置已保存")

## 9. 推理测试

In [None]:
# 交互式推理
inference_widget = widgets.Textarea(
    value='查询北京天气',
    description='输入:',
    layout=widgets.Layout(width='100%', height='80px')
)

run_button = widgets.Button(
    description='运行推理',
    button_style='success',
    icon='play'
)

output_area = widgets.Output()

def on_run_button_clicked(b):
    with output_area:
        clear_output()
        print("🤖 正在推理...\n")
        prompt = inference_widget.value
        try:
            result = trainer.inference(prompt, max_new_tokens=128)
            print(f"输入: {prompt}\n")
            print(f"输出: {result}")
        except Exception as e:
            print(f"❌ 推理失败: {e}")

run_button.on_click(on_run_button_clicked)

display(widgets.VBox([
    widgets.HTML("<h3>🎯 模型推理测试</h3>"),
    inference_widget,
    run_button,
    output_area
]))

## 10. 批量推理测试

In [None]:
test_prompts = [
    "查询北京天气",
    "把背景改成蓝色",
    "创建一个名字叫张三的用户",
]

print("🧪 批量推理测试\n")
for i, prompt in enumerate(test_prompts, 1):
    print(f"\n测试 {i}/{len(test_prompts)}")
    result = trainer.inference(prompt)
    print(f"输入: {prompt}")
    print(f"输出: {result[:200]}...")
    print("-" * 60)

## 11. 模型导出

In [None]:
from src.utils.export import export_model

export_format = widgets.Dropdown(
    options=['pytorch', 'gguf'],
    value='pytorch',
    description='格式:'
)

export_button = widgets.Button(
    description='导出模型',
    button_style='primary'
)

export_output = widgets.Output()

def on_export_clicked(b):
    with export_output:
        clear_output()
        print(f"📦 导出模型...")

export_button.on_click(on_export_clicked)

display(widgets.VBox([
    widgets.HTML("<h3>📦 模型导出</h3>"),
    export_format,
    export_button,
    export_output
]))