In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BertTokenizer, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup,
    AlbertModel, AlbertConfig, PreTrainedModel, modeling_outputs
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, f1_score, precision_recall_fscore_support
from tqdm import tqdm
import json
import os
import random
import warnings
import optuna
import multiprocessing as mp
import copy
from datetime import datetime
import time
import math
warnings.filterwarnings('ignore')

In [2]:
# 设置随机种子以确保结果可重现
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
# 数据增强函数
def augment_text(text, label, augment_prob=0.3):
    augmented_texts = [text]
    augmented_labels = [label]

    # 对少数类别进行更多增强
    if label in [3, 4]: 
        augment_prob *= 2

    # 随机插入空格
    if random.random() < augment_prob:
        words = text.split()
        if len(words) > 1:
            idx = random.randint(0, len(words))
            words.insert(idx, ' ')
            augmented_texts.append(' '.join(words))
            augmented_labels.append(label)

    # 随机删除字符（
    if label == 4 and random.random() < augment_prob and len(text) > 5:
        idx = random.randint(0, len(text)-1)
        new_text = text[:idx] + text[idx+1:]
        augmented_texts.append(new_text)
        augmented_labels.append(label)

    # 随机替换标点符号
    if label == 4 and random.random() < augment_prob:
        punctuations = ['!', '?', '.', ',', ';', ':', '/', '\\', '|']
        new_text = text
        for _ in range(random.randint(1, 3)):
            if len(new_text) > 0:
                idx = random.randint(0, len(new_text)-1)
                new_text = new_text[:idx] + random.choice(punctuations) + new_text[idx+1:]
        augmented_texts.append(new_text)
        augmented_labels.append(label)

    return augmented_texts, augmented_labels

# 定义数据集类
class LanguageDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        # 对文本进行编码
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

In [None]:
def load_dataset(file_path):
    texts = []
    labels = []
    label_map = {'chs': 0, 'cht': 1, 'en': 2, 'sybol': 3, 'error': 4}

    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    # 处理训练集数据
    for line in data['train']:
        try:
            text, label = line.strip().split('\t')
            if label in label_map:
                texts.append(text)
                labels.append(label_map[label])
        except Exception as e:
            print(f"Error processing: {line}:  {str(e)}")
            continue

    # 处理验证集数据
    val_texts = []
    val_labels = []
    for line in data['validation']:
        try:
            text, label = line.strip().split('\t')
            if label in label_map:
                val_texts.append(text)
                val_labels.append(label_map[label])
        except Exception as e:
            print(f"Error processing: {line}:  {str(e)}")
            continue

    return texts, labels, val_texts, val_labels

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.class_weights = class_weights
        self.label_smoothing = label_smoothing

    def forward(self, inputs, targets):
        # 应用标签平滑
        if self.label_smoothing > 0:
            num_classes = inputs.size(-1)
            targets_one_hot = F.one_hot(targets, num_classes).float()
            targets_one_hot = targets_one_hot * (1 - self.label_smoothing) + \
                             self.label_smoothing / num_classes
            ce_loss = F.cross_entropy(inputs, targets_one_hot, weight=self.class_weights, reduction='none')
        else:
            ce_loss = F.cross_entropy(inputs, targets, weight=self.class_weights, reduction='none')

        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

In [None]:
class AlbertClassifier(PreTrainedModel):
    config_class = AlbertConfig

    def __init__(self, config, dropout_rate=0.2):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.albert = AlbertModel(config)
        self.dropout = nn.Dropout(dropout_rate)

        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),

            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),

            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),

            nn.Linear(128, config.num_labels)
        )

        # 添加注意力池化层
        self.attention_pool = nn.MultiheadAttention(
            embed_dim=config.hidden_size,
            num_heads=8,
            dropout=dropout_rate,
            batch_first=True
        )

        # 初始化权重
        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        sequence_output = outputs.last_hidden_state
        batch_size, seq_len, hidden_size = sequence_output.shape
        
        query = sequence_output[:, :1, :]  # [CLS]作为查询
        key = sequence_output
        value = sequence_output
        
        # 拆分多头
        query = query.view(batch_size, 1, 8, hidden_size // 8).transpose(1, 2)
        key = key.view(batch_size, seq_len, 8, hidden_size // 8).transpose(1, 2)
        value = value.view(batch_size, seq_len, 8, hidden_size // 8).transpose(1, 2)
        
        # 计算注意力分数
        attn_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(hidden_size // 8)
        
        # 应用注意力掩码
        if attention_mask is not None:
            attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
            attn_scores = attn_scores.masked_fill(~attn_mask, float('-inf'))
        
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, value)
        
        # 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, 1, hidden_size
        )
        
        pooled_output = attn_output.squeeze(1)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return modeling_outputs.SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
def get_class_weights(labels, device):
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(labels),
        y=labels
    )
    return torch.FloatTensor(class_weights).to(device)

def detailed_evaluation(model, val_loader, device, class_weights=None):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    all_losses = []

    # 使用Focal Loss进行评估
    criterion = FocalLoss(gamma=2.0, class_weights=class_weights, label_smoothing=0.1)

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="评估中"):
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )

                # 计算损失
                loss = criterion(outputs.logits, labels)
                all_losses.append(loss.item())

                # 获取预测和概率
                probs = F.softmax(outputs.logits, dim=1)
                preds = torch.argmax(outputs.logits, dim=1)

                all_probs.extend(probs.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
            except Exception as e:
                print(f"评估错误: {str(e)}")
                continue

    # 计算各种指标
    f1_weighted = f1_score(all_labels, all_preds, average='weighted')
    f1_macro = f1_score(all_labels, all_preds, average='macro')
    precision, recall, f1_per_class, support = precision_recall_fscore_support(
        all_labels, all_preds, average=None
    )

    # 计算每个类别的置信度分布
    label_names = ['chs', 'cht', 'en', 'symbol', 'error']
    confidence_stats = {}

    for i, label_name in enumerate(label_names):
        mask = np.array(all_labels) == i
        if mask.sum() > 0:
            confidence = np.array(all_probs)[mask, i]
            confidence_stats[label_name] = {
                'mean_confidence': confidence.mean(),
                'std_confidence': confidence.std(),
                'min_confidence': confidence.min(),
                'max_confidence': confidence.max(),
                'precision': precision[i],
                'recall': recall[i],
                'f1': f1_per_class[i],
                'support': support[i]
            }

    # 打印详细报告
    print(f"加权F1分数: {f1_weighted:.4f}")
    print(f"宏平均F1分数: {f1_macro:.4f}")
    print(f"平均损失: {np.mean(all_losses):.4f}")
    print(f"{'类别':<15} {'精确率':<8} {'召回率':<8} {'F1分数':<8} {'支持数':<8} {'平均置信度':<10}")

    for label_name, stats in confidence_stats.items():
        print(f"{label_name:<15} {stats['precision']:<8.3f} {stats['recall']:<8.3f} "
              f"{stats['f1']:<8.3f} {stats['support']:<8} {stats['mean_confidence']:<10.3f}")

    print("\n混淆矩阵:")
    print(confusion_matrix(all_labels, all_preds))

    print("\n分类报告:")
    print(classification_report(all_labels, all_preds, target_names=label_names))

    return f1_weighted, f1_macro, confidence_stats

def evaluate_model(model, val_loader, criterion, device, return_predictions=False):
    """简化的评估函数，用于训练过程中的快速评估"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="评估中", leave=False):
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )

                loss = criterion(outputs.logits, labels)
                total_loss += loss.item()

                _, predicted = torch.max(outputs.logits, 1)
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
            except Exception as e:
                print(f"评估错误: {str(e)}")
                continue

    # 计算F1分数
    f1 = f1_score(all_labels, all_preds, average='weighted')
    avg_loss = total_loss / len(val_loader) if len(val_loader) > 0 else 0

    if return_predictions:
        return avg_loss, f1, all_preds, all_labels
    return avg_loss, f1

def progressive_training(model, train_loader, val_loader, device, config):
    """渐进式训练策略：先训练分类头，再微调整个模型"""
    # 第一阶段：冻结预训练层，只训练分类头
    print("\n第一阶段：训练分类头（冻结ALBERT层）")
    for param in model.albert.parameters():
        param.requires_grad = False

    # 只训练分类头，使用较大学习率
    stage1_results = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_epochs=5,
        learning_rate=1e-3,
        warmup_ratio=0.1,
        scheduler_type='linear',
        accumulation_steps=config.get('accumulation_steps', 1),
        max_patience=3,
        min_delta=1e-4,
        use_focal_loss=True,
        stage_name="分类头训练"
    )

    # 第二阶段：解冻所有层，使用较小学习率微调
    print("\n第二阶段：微调整个模型（解冻所有层）")
    for param in model.albert.parameters():
        param.requires_grad = True

    # 使用分层学习率
    stage2_results = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_epochs=config.get('num_epochs', 20),
        learning_rate=config.get('learning_rate', 2e-5),
        warmup_ratio=config.get('warmup_ratio', 0.15),
        scheduler_type=config.get('scheduler_type', 'cosine'),
        accumulation_steps=config.get('accumulation_steps', 1),
        max_patience=config.get('max_patience', 5),
        min_delta=config.get('min_delta', 1e-4),
        use_focal_loss=True,
        use_layered_lr=True,
        stage_name="全模型微调"
    )

    return {
        'stage1': stage1_results,
        'stage2': stage2_results,
        'best_f1': stage2_results['best_f1'],
        'best_epoch': stage2_results['best_epoch']
    }

def train_model(model, train_loader, val_loader, device,
                         num_epochs=20, learning_rate=2e-5, warmup_ratio=0.1,
                         scheduler_type='cosine', accumulation_steps=1,
                         max_patience=5, min_delta=1e-4, use_focal_loss=True,
                         use_layered_lr=False, stage_name="训练"):
    print(f"\n开始{stage_name}阶段...")
    # 分层学习率配置
    no_decay = ['bias', 'LayerNorm.weight']

    if use_layered_lr:
        # 为不同层设置不同的学习率
        optimizer_grouped_parameters = [
            # ALBERT层使用较小学习率
            {
                'params': [p for n, p in model.albert.named_parameters()
                          if not any(nd in n for nd in no_decay)],
                'weight_decay': 0.01,
                'lr': learning_rate * 0.1  # ALBERT层使用1/10的学习率
            },
            {
                'params': [p for n, p in model.albert.named_parameters()
                          if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0,
                'lr': learning_rate * 0.1
            },
            # 分类头使用正常学习率
            {
                'params': [p for n, p in model.classifier.named_parameters()
                          if not any(nd in n for nd in no_decay)],
                'weight_decay': 0.01,
                'lr': learning_rate
            },
            {
                'params': [p for n, p in model.classifier.named_parameters()
                          if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0,
                'lr': learning_rate
            },
            # 注意力池化层
            {
                'params': [p for n, p in model.attention_pool.named_parameters()
                          if not any(nd in n for nd in no_decay)],
                'weight_decay': 0.01,
                'lr': learning_rate
            },
            {
                'params': [p for n, p in model.attention_pool.named_parameters()
                          if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0,
                'lr': learning_rate
            }
        ]
    else:
        # 标准配置
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in model.named_parameters()
                          if not any(nd in n for nd in no_decay)],
                'weight_decay': 0.01,
            },
            {
                'params': [p for n, p in model.named_parameters()
                          if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0,
            }
        ]

    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-8)

    # 计算总训练步数
    num_training_steps = len(train_loader) * num_epochs // accumulation_steps
    num_warmup_steps = int(num_training_steps * warmup_ratio)

    # 学习率调度器
    if scheduler_type == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )
    else:
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )

    # 获取类别权重
    all_labels = []
    for batch in train_loader:
        all_labels.extend(batch['label'].numpy())
    class_weights = get_class_weights(all_labels, device)

    # 使用Focal Loss或标准CrossEntropyLoss
    if use_focal_loss:
        criterion = FocalLoss(
            gamma=2.0,
            class_weights=class_weights,
            label_smoothing=0.1
        )
        print(f"使用Focal Loss (gamma=2.0, label_smoothing=0.1)")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print(f"使用标准CrossEntropyLoss")

    # 训练状态跟踪
    best_f1 = 0.0
    best_loss = float('inf')
    patience_counter = 0
    best_epoch = 0

    # 历史记录
    train_loss_history = []
    val_loss_history = []
    val_f1_history = []
    lr_history = []

    print(f"开始训练，总共 {num_epochs} 个epoch")
    print(f"训练步数: {num_training_steps}, 预热步数: {num_warmup_steps}")
    print(f"梯度累积步数: {accumulation_steps}")

    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        total_loss = 0
        optimizer.zero_grad()

        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}')

        for step, batch in enumerate(progress_bar):
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss / accumulation_steps  # 梯度累积
                loss.backward()

                total_loss += loss.item() * accumulation_steps

                # 梯度累积
                if (step + 1) % accumulation_steps == 0:
                    # 梯度裁剪
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

                # 更新进度条
                current_lr = scheduler.get_last_lr()[0] if scheduler.get_last_lr() else learning_rate
                progress_bar.set_postfix({
                    'loss': f'{loss.item() * accumulation_steps:.4f}',
                    'lr': f'{current_lr:.2e}'
                })

            except Exception as e:
                print(f"Error in training batch: {str(e)}")
                continue

        # 记录训练损失
        avg_train_loss = total_loss / len(train_loader)
        train_loss_history.append(avg_train_loss)
        lr_history.append(scheduler.get_last_lr()[0] if scheduler.get_last_lr() else learning_rate)

        # 验证阶段
        val_loss, val_f1 = evaluate_model(model, val_loader, criterion, device)
        val_loss_history.append(val_loss)
        val_f1_history.append(val_f1)

        print(f'Epoch {epoch + 1}:')
        print(f'训练损失: {avg_train_loss:.4f}')
        print(f'验证损失: {val_loss:.4f}')
        print(f'验证F1: {val_f1:.4f}')
        print(f'学习率: {lr_history[-1]:.2e}')

        # 保存最佳模型（基于F1分数）
        improved = False
        if val_f1 > best_f1 + min_delta:
            best_f1 = val_f1
            best_loss = val_loss
            best_epoch = epoch
            patience_counter = 0
            improved = True

            # 保存模型
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_f1': best_f1,
                'best_loss': best_loss,
                'epoch': epoch,
                'train_loss_history': train_loss_history,
                'val_loss_history': val_loss_history,
                'val_f1_history': val_f1_history,
                'lr_history': lr_history
            }, 'best_classifier_optimized.pt')
            print(f'新的最佳模型已保存 (F1: {best_f1:.4f})')
        else:
            patience_counter += 1
            print(f'未改善 (耐心计数: {patience_counter}/{max_patience})')

        # 早停检查
        if patience_counter >= max_patience:
            print(f'早停触发！在第 {epoch + 1} 个epoch停止训练')
            print(f'最佳模型在第 {best_epoch + 1} 个epoch，F1分数: {best_f1:.4f}')
            break

    # 加载最佳模型
    print('训练完成，加载最佳模型...')
    checkpoint = torch.load('best_classifier_optimized.pt', weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])

    # 最终评估
    print('\n使用最佳模型进行最终评估:')
    final_val_loss, final_val_f1 = evaluate_model(model, val_loader, criterion, device)
    print(f'最终验证损失: {final_val_loss:.4f}')
    print(f'最终验证F1: {final_val_f1:.4f}')

    return {
        'best_f1': best_f1,
        'best_loss': best_loss,
        'best_epoch': best_epoch,
        'train_loss_history': train_loss_history,
        'val_loss_history': val_loss_history,
        'val_f1_history': val_f1_history,
        'lr_history': lr_history
    }

In [None]:
def hyperparameter_search(train_texts, train_labels, val_texts, val_labels,
                         tokenizer, model_config, device, n_trials=20, n_jobs=2):
    """兼容性包装函数，调用并发版本的超参数搜索"""
    return hyperparameter_search_parallel(
        train_texts, train_labels, val_texts, val_labels,
        tokenizer, model_config, n_trials, n_jobs=n_jobs, use_gpu_parallel=True
    )

def train_model_fast(model, train_loader, val_loader, device, config, trial_number=0, show_progress=True):
    """快速训练函数，专门用于超参数搜索，显示详细进度"""
    
    if show_progress:
        print(f"Trial {trial_number} - 开始训练")
        print(f"超参数配置:")
        for key, value in config.items():
            print(f"{key}: {value}")
        print(f"设备: {device}")
        print(f"训练批次数: {len(train_loader)}")
        print(f"验证批次数: {len(val_loader)}")
    
    # 优化器
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=0.01)
    
    # 学习率调度器
    num_training_steps = len(train_loader) * config['num_epochs']
    num_warmup_steps = int(num_training_steps * config['warmup_ratio'])
    
    if config['scheduler_type'] == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
        )
    else:
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
        )
    
    # 损失函数
    all_labels = []
    for batch in train_loader:
        all_labels.extend(batch['label'].numpy())
    class_weights = get_class_weights(all_labels, device)
    
    criterion = FocalLoss(
        gamma=config.get('focal_gamma', 2.0), 
        class_weights=class_weights, 
        label_smoothing=config.get('label_smoothing', 0.1)
    )
    
    best_f1 = 0.0
    patience_counter = 0
    max_patience = 2
    
    if show_progress:
        print(f"\n开始训练 {config['num_epochs']} 个epochs...")
    
    for epoch in range(config['num_epochs']):
        # 训练阶段
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch in train_loader:
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                
                optimizer.zero_grad()
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = criterion(outputs.logits, labels)
                loss.backward()
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()
                
                total_loss += loss.item()
                num_batches += 1
                    
            except Exception as e:
                if show_progress:
                    print(f"训练批次错误: {e}")
                continue
        
        # 验证阶段
        model.eval()
        val_preds, val_labels = [], []
        val_loss = 0
        val_batches = 0
        
        with torch.no_grad():
            for batch in val_loader:
                try:
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['label'].to(device)
                    
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                    loss = criterion(outputs.logits, labels)
                    preds = torch.argmax(outputs.logits, dim=1)
                    
                    val_preds.extend(preds.cpu().numpy())
                    val_labels.extend(labels.cpu().numpy())
                    val_loss += loss.item()
                    val_batches += 1
                    
                except Exception as e:
                    if show_progress:
                        print(f"验证批次错误: {e}")
                    continue
        
        # 计算指标
        if len(val_preds) > 0:
            f1 = f1_score(val_labels, val_preds, average='weighted')
        else:
            f1 = 0.0
        
        avg_train_loss = total_loss / max(num_batches, 1)
        avg_val_loss = val_loss / max(val_batches, 1)
        
        if show_progress:
            print(f"Trial {trial_number} - Epoch {epoch+1}/{config['num_epochs']}:")
            print(f"训练损失: {avg_train_loss:.4f}")
            print(f"验证损失: {avg_val_loss:.4f}")
            print(f"验证F1: {f1:.4f}")
            print(f"当前最佳F1: {best_f1:.4f}")
            print(f"学习率: {scheduler.get_last_lr()[0]:.2e}")
        
        # 早停检查
        if f1 > best_f1:
            best_f1 = f1
            patience_counter = 0
            if show_progress:
                print(f"新的最佳F1分数!")
        else:
            patience_counter += 1
            if show_progress:
                print(f"耐心计数: {patience_counter}/{max_patience}")
        
        if patience_counter >= max_patience:
            if show_progress:
                print(f"早停触发，停止训练")
            break
    
    if show_progress:
        print(f"\nTrial {trial_number} 完成!")
        print(f"最终F1分数: {best_f1:.4f}")
        print(f"实际训练epochs: {epoch+1}")
    
    return best_f1

def objective_function(trial, train_texts, train_labels, val_texts, val_labels,
                      tokenizer, model_config, use_gpu_parallel=True, show_progress=True):
    try:
        # 设备选择
        if use_gpu_parallel and torch.cuda.is_available():
            gpu_count = torch.cuda.device_count()
            gpu_id = trial.number % gpu_count
            device = torch.device(f'cuda:{gpu_id}')
        else:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # 超参数建议
        config = {
            'learning_rate': trial.suggest_float('learning_rate', 1e-5, 5e-5, log=True),
            'batch_size': trial.suggest_categorical('batch_size', [256, 512]),
            'dropout_rate': trial.suggest_float('dropout_rate', 0.1, 0.3),
            'warmup_ratio': trial.suggest_float('warmup_ratio', 0.05, 0.2),
            'num_epochs': trial.suggest_int('num_epochs', 6, 12),  # 快速搜索
            'scheduler_type': trial.suggest_categorical('scheduler_type', ['linear', 'cosine']),
            'focal_gamma': trial.suggest_float('focal_gamma', 1.0, 3.0),
            'label_smoothing': trial.suggest_float('label_smoothing', 0.0, 0.2)
        }

        if show_progress:
            print(f"Trial {trial.number} 超参数:")
            for key, value in config.items():
                print(f"{key}: {value}")

        # 创建模型
        model_config_copy = copy.deepcopy(model_config)
        model = AlbertClassifier(model_config_copy, dropout_rate=config['dropout_rate'])
        model = model.to(device)
        model.albert = AlbertModel.from_pretrained("voidful/albert_chinese_tiny", config=model_config_copy).to(device)

        # 创建数据集
        train_dataset = LanguageDataset(train_texts, train_labels, tokenizer)
        val_dataset = LanguageDataset(val_texts, val_labels, tokenizer)

        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=config['batch_size'])

        # 训练
        start_time = time.time()
        best_f1 = train_model_fast(model, train_loader, val_loader, device, config,
                                  trial_number=trial.number, show_progress=show_progress)
        training_time = time.time() - start_time

        if show_progress:
            print(f"\nTrial {trial.number} 训练时间: {training_time:.1f}秒")
            print(f"Trial {trial.number} 最终F1分数: {best_f1:.4f}")

        # 清理内存
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return best_f1

    except Exception as e:
        print(f"Trial {trial.number} 失败: {str(e)}")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return 0.0

def hyperparameter_search_parallel(train_texts, train_labels, val_texts, val_labels,
        tokenizer, model_config, n_trials=30, n_jobs=4, use_gpu_parallel=True, show_progress=True):
    """并发超参数搜索主函数"""
    print("并发超参数搜索")
    # 设置随机种子
    set_seed(42)

    # 检查资源
    max_workers = min(n_jobs, mp.cpu_count())
    print(f"CPU核心数: {mp.cpu_count()}")
    print(f"使用并发数: {max_workers}")

    if torch.cuda.is_available():
        print(f"可用GPU数: {torch.cuda.device_count()}")
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"GPU并行: {'启用' if use_gpu_parallel else '禁用'}")
    else:
        print("GPU: 不可用")
        use_gpu_parallel = False

    # 创建objective函数
    def objective(trial):
        return objective_function(
            trial, train_texts, train_labels, val_texts, val_labels,
            tokenizer, model_config, use_gpu_parallel, show_progress
        )

    # 创建study
    study = optuna.create_study(
        direction='maximize',
        sampler=optuna.samplers.TPESampler(n_startup_trials=8),
        pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=2)
    )

    print(f"\n开始搜索 {n_trials} 个试验...")
    print(f"并发进程数: {max_workers}")
    print(f"目标: 最大化F1分数")

    # 执行搜索
    start_time = time.time()
    study.optimize(objective, n_trials=n_trials, n_jobs=max_workers)
    search_time = time.time() - start_time

    # 结果分析
    print("搜索结果")
    print(f"最佳F1分数: {study.best_value:.4f}")
    print(f"总搜索时间: {search_time:.1f}秒 ({search_time/60:.1f}分钟)")
    print(f"平均每个trial: {search_time/len(study.trials):.1f}秒")

    print(f"\n最佳超参数:")
    for key, value in study.best_params.items():
        print(f"{key}: {value}")

    print(f"\n搜索统计:")
    completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    failed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.FAIL]

    print(f"完成试验数: {len(completed_trials)}")
    print(f"剪枝试验数: {len(pruned_trials)}")
    print(f"失败试验数: {len(failed_trials)}")

    if len(completed_trials) > 1:
        f1_scores = [t.value for t in completed_trials if t.value is not None]
        print(f"F1分数范围: {min(f1_scores):.4f} - {max(f1_scores):.4f}")
        print(f"F1分数均值: {np.mean(f1_scores):.4f}")
        print(f"F1分数标准差: {np.std(f1_scores):.4f}")

    # 保存结果
    print({
        'timestamp': datetime.now().isoformat(),
        'best_params': study.best_params,
        'best_value': study.best_value,
        'search_time_seconds': search_time,
        'n_trials': n_trials,
        'n_jobs': max_workers,
        'use_gpu_parallel': use_gpu_parallel,
        'statistics': {
            'completed': len(completed_trials),
            'pruned': len(pruned_trials),
            'failed': len(failed_trials),
            'f1_scores': f1_scores if len(completed_trials) > 1 else []
        },
        'data_info': {
            'train_samples': len(train_texts),
            'val_samples': len(val_texts),
        }
    })

    return study.best_params, study.best_value

In [None]:
# 设置随机种子
set_seed(42)
# 训练配置
config = {
    'model_name': "voidful/albert_chinese_tiny",
    'max_length': 128,
    'batch_size': 256,
    'num_epochs': 25,
    'learning_rate': 4.8402352659412e-05,  # 提高学习率
    'warmup_ratio': 0.07065493676186665,   # 增加warmup比例
    'scheduler_type': 'cosine',  # 'linear' or 'cosine'
    'accumulation_steps': 1,
    'max_patience': 5,
    'min_delta': 1e-4,
    'dropout_rate': 0.17387168694938498,    # 降低dropout
    'train_ratio': 0.8,
    'val_ratio': 0.1,
    'test_ratio': 0.1,
    'use_weighted_sampler': True,
    'data_augmentation': True,
    'use_hyperparameter_search': False,  # 是否进行超参数搜索
    'use_progressive_training': True,    # 是否使用渐进式训练
    'hyperparameter_trials': 15         # 超参数搜索试验次数
}

# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载tokenizer和模型配置
print(f"\n加载模型: {config['model_name']}")
tokenizer = BertTokenizer.from_pretrained(config['model_name'])

model_config = AlbertConfig.from_pretrained(
    config['model_name'],
    num_labels=5,
    id2label={0: "chs", 1: "cht", 2: "en", 3: "symbol", 4: "error"},
    label2id={"chs": 0, "cht": 1, "en": 2, "symbol": 3, "error": 4}
)

# 创建模型
model = AlbertClassifier(model_config, dropout_rate=config['dropout_rate'])
model = model.to(device)
model.albert = AlbertModel.from_pretrained(config['model_name'], config=model_config).to(device)

In [None]:
# 加载和分割数据
print("\n加载数据...")
# 加载划分好的数据集
train_texts, train_labels, val_texts, val_labels = load_dataset('./dataset.json')

# 创建数据集
train_dataset = LanguageDataset(train_texts, train_labels, tokenizer)
val_dataset = LanguageDataset(val_texts, val_labels, tokenizer)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [None]:
# 超参数搜索（可选）
if config['use_hyperparameter_search']:
    print("执行超参数搜索")
    best_params, best_score = hyperparameter_search(
        train_texts=train_texts,
        train_labels=train_labels,
        val_texts=val_texts,
        val_labels=val_labels,
        tokenizer=tokenizer,
        model_config=model_config,
        device=device,
        n_trials=config['hyperparameter_trials']
    )

    # 更新配置为最佳参数
    for key, value in best_params.items():
        if key in config:
            config[key] = value

    print(f"\n使用最佳超参数重新创建模型...")
    # 重新创建模型
    model = AlbertClassifier(model_config, dropout_rate=config['dropout_rate'])
    model = model.to(device)
    model.albert = AlbertModel.from_pretrained(config['model_name'], config=model_config).to(device)

# 开始训练
print("开始主要训练流程")
if config['use_progressive_training']:
    # 使用渐进式训练
    training_results = progressive_training(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        config=config
    )
else:
    # 使用标准训练
    training_results = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_epochs=config['num_epochs'],
        learning_rate=config['learning_rate'],
        warmup_ratio=config['warmup_ratio'],
        scheduler_type=config['scheduler_type'],
        accumulation_steps=config['accumulation_steps'],
        max_patience=config['max_patience'],
        min_delta=config['min_delta'],
        use_focal_loss=True,
        use_layered_lr=True,
        stage_name="标准训练"
    )

print("执行评估")
# 获取类别权重用于评估
all_labels = []
for batch in train_loader:
    all_labels.extend(batch['label'].numpy())
class_weights = get_class_weights(all_labels, device)

# 执行评估
f1_weighted, f1_macro, confidence_stats = detailed_evaluation(
    model=model,
    val_loader=val_loader,
    device=device,
    class_weights=class_weights
)

print(f"\n最终训练结果:")
print(f"最佳F1分数: {training_results['best_f1']:.4f}")
print(f"最佳epoch: {training_results['best_epoch']}")
print(f"加权F1分数: {f1_weighted:.4f}")
print(f"宏平均F1分数: {f1_macro:.4f}")

In [None]:
output_dir='albert_tiny_chinese_classifier'
os.makedirs(output_dir, exist_ok=True)
# 保存模型和配置
model.save_pretrained(output_dir)
# 保存分词器
tokenizer.save_pretrained(output_dir)
print(f"模型已保存到 {output_dir}")

导出模型为onnx格式

In [None]:
!pip install optimum[onnxruntime] onnxscript

In [None]:
# 导出为ONNX格式
dummy_input = {
    'input_ids': torch.randint(0, tokenizer.vocab_size, (1, 128), dtype=torch.long).to(device),
    'attention_mask': torch.ones(1, 128, dtype=torch.long).to(device)
}
torch.onnx.export(
    model,
    (dummy_input['input_ids'], dummy_input['attention_mask']),
    output_dir+'/model.onnx',
    input_names=['input_ids', 'attention_mask'],
    output_names=['logits'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'sequence_length'},
        'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
        'logits': {0: 'batch_size'}
    },
    opset_version=14,
    do_constant_folding=True,
    export_params=True,
    verbose=False
)

量化onnx模型

In [None]:
from optimum.onnxruntime import  ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig
import onnxruntime as ort
import os

# 创建量化器
quantizer = ORTQuantizer.from_pretrained(output_dir,file_name='model.onnx')
# 配置量化参数
qconfig = AutoQuantizationConfig.avx512_vnni(
    is_static=False,
    per_channel=False
)
# 量化模型
quantizer.quantize(
    save_dir=output_dir,
    quantization_config=qconfig
)