# transformer改成conformer，增强模型能力，注意力改成四头

# 下载数据

## https://drive.google.com/file/d/1zHjG3F8msz9LBPhp_N5kp_O6G9F2Y5w9/view?usp=drive_link

In [1]:
# !gdown --id '1zHjG3F8msz9LBPhp_N5kp_O6G9F2Y5w9' --output Dataset.zip
# !unzip Dataset.zip

# 导入包

In [1]:
import os
import json
import torch
import random
import math
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from datetime import datetime
import matplotlib.pyplot as plt


# 模型

In [2]:
import torch.nn.functional as F
class LightweightConformerBlock(nn.Module):
    def __init__(self, d_model=80, conv_kernel_size=31, dropout=0.1):
        super().__init__()
        
        # 前馈网络部分 (简化版)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        
        # 双头自注意力 (修改这里: num_heads=2)
        self.self_attn = nn.MultiheadAttention(
            d_model, num_heads=2, dropout=dropout, batch_first=True  # num_heads=2
        )
        self.norm2 = nn.LayerNorm(d_model)
        
        # 轻量级卷积模块
        self.conv_module = LightweightConvModule(d_model, conv_kernel_size, dropout)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 前馈网络 + 残差
        residual = x
        x = self.norm1(x)
        x = self.ffn(x)
        x = self.dropout(x)
        x = residual + x
        
        # 自注意力 + 残差
        residual = x
        x = self.norm2(x)
        x, _ = self.self_attn(x, x, x)
        x = self.dropout(x)
        x = residual + x
        
        # 卷积模块 + 残差
        residual = x
        x = self.norm3(x)
        x = self.conv_module(x)
        x = self.dropout(x)
        x = residual + x
        
        return x

class LightweightConvModule(nn.Module):
    def __init__(self, d_model, kernel_size=31, dropout=0.1):
        super().__init__()
        assert kernel_size % 2 == 1, "Kernel size should be odd"
        
        # 简化版卷积模块
        self.pointwise_conv1 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.depthwise_conv = nn.Conv1d(
            d_model, d_model, kernel_size=kernel_size,
            padding=(kernel_size-1)//2, groups=d_model
        )
        self.batch_norm = nn.BatchNorm1d(d_model)
        self.activation = nn.ReLU()
        self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # 输入形状: (batch, seq_len, d_model)
        residual = x
        
        # 转置用于卷积
        x = x.transpose(1, 2)  # (batch, d_model, seq_len)
        
        # 点式卷积
        x = self.pointwise_conv1(x)
        
        # 深度可分离卷积
        x = self.depthwise_conv(x)
        x = self.batch_norm(x)
        x = self.activation(x)
        
        # 点式卷积
        x = self.pointwise_conv2(x)
        x = self.dropout(x)
        
        # 恢复维度
        x = x.transpose(1, 2)  # (batch, seq_len, d_model)
        
        return x + residual

class LightweightConformerClassifier(nn.Module):
    def __init__(self, d_model=80, n_spks=600, dropout=0.1, num_layers=2):
        """轻量级Conformer分类器（单头注意力）
        
        Args:
            d_model: 特征维度
            n_spks: 说话人数量
            dropout: dropout率
            num_layers: Conformer层数
        """
        super().__init__()
        
        # 输入投影层
        self.prenet = nn.Linear(40, d_model)
        self.prenet_dropout = nn.Dropout(dropout)
        
        # Conformer层 (双头注意力)
        self.conformer_layers = nn.ModuleList([
            LightweightConformerBlock(
                d_model=d_model,
                conv_kernel_size=31,
                dropout=dropout
            ) for _ in range(num_layers)
        ])
        
        # 输出层
        self.output_layer = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, n_spks),
        )

    def forward(self, mels):
        """
        Args:
            mels: (batch_size, seq_len, 40)
        Returns:
            out: (batch_size, n_spks)
        """
        # 输入投影
        x = self.prenet(mels)  # (batch, seq_len, d_model)
        x = self.prenet_dropout(x)
        
        # Conformer编码
        for layer in self.conformer_layers:
            x = layer(x)
        
        # 全局平均池化
        x = x.mean(dim=1)  # (batch, d_model)
        
        # 分类输出
        out = self.output_layer(x)  # (batch, n_spks)
        
        return out

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EnhancedConformerBlock(nn.Module):
    def __init__(self, d_model=128, nhead=4, conv_kernel_size=31, dropout=0.1, expansion_factor=4):
        """
        增强的Conformer块
        
        Args:
            d_model: 特征维度
            nhead: 注意力头数 (增加到4头)
            conv_kernel_size: 卷积核大小
            dropout: dropout率
            expansion_factor: 前馈网络扩展因子
        """
        super().__init__()
        self.d_model = d_model
        dim_feedforward = d_model * expansion_factor
        
        # 第一部分：前馈网络 + 层归一化 + 残差
        self.ffn1 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, dim_feedforward),
            nn.SiLU(),  # 使用SiLU激活函数，比ReLU更平滑
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )
        
        # 第二部分：多头自注意力 + 层归一化 + 残差
        self.self_attn = nn.MultiheadAttention(
            d_model, 
            nhead, 
            dropout=dropout, 
            batch_first=True
        )
        self.norm_attn = nn.LayerNorm(d_model)
        
        # 第三部分：增强的卷积模块
        self.conv_module = EnhancedConvModule(
            d_model, 
            kernel_size=conv_kernel_size, 
            dropout=dropout
        )
        self.norm_conv = nn.LayerNorm(d_model)
        
        # 第四部分：第二个前馈网络 + 层归一化 + 残差
        self.ffn2 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, dim_feedforward),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 第一部分：前馈网络 + 残差
        residual = x
        x = self.ffn1(x) + residual
        
        # 第二部分：自注意力 + 残差
        residual = x
        x = self.norm_attn(x)
        x_attn, _ = self.self_attn(x, x, x)
        x = residual + self.dropout(x_attn)
        
        # 第三部分：卷积模块 + 残差
        residual = x
        x = self.norm_conv(x)
        x_conv = self.conv_module(x)
        x = residual + self.dropout(x_conv)
        
        # 第四部分：前馈网络 + 残差
        residual = x
        x = self.ffn2(x) + residual
        
        return x

class EnhancedConvModule(nn.Module):
    def __init__(self, d_model, kernel_size=31, dropout=0.1, expansion_factor=2):
        """
        增强的卷积模块
        """
        super().__init__()
        assert kernel_size % 2 == 1, "Kernel size should be odd"
        
        self.conv_module = nn.Sequential(
            nn.LayerNorm(d_model),
            # 第一个点式卷积，扩展维度
            nn.Conv1d(d_model, d_model * expansion_factor, kernel_size=1),
            nn.GLU(dim=1),  # 门控线性单元，减少计算量
            # 深度可分离卷积
            nn.Conv1d(
                d_model, d_model, kernel_size=kernel_size,
                padding=(kernel_size-1)//2, groups=d_model,
                bias=False  # 在BatchNorm后不需要bias
            ),
            nn.BatchNorm1d(d_model),
            nn.SiLU(),
            nn.Dropout(dropout),
            # 第二个点式卷积，恢复维度
            nn.Conv1d(d_model, d_model, kernel_size=1),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        residual = x
        # 调整维度用于卷积
        x = x.transpose(1, 2)  # (batch, d_model, seq_len)
        x = self.conv_module(x)
        x = x.transpose(1, 2)  # 恢复维度
        return x + residual

class EnhancedConformerClassifier(nn.Module):
    def __init__(self, d_model=128, n_spks=600, dropout=0.1, num_layers=4, nhead=4):
        """
        增强的Conformer分类器
        
        Args:
            d_model: 特征维度 (增加到128)
            n_spks: 说话人数量
            dropout: dropout率
            num_layers: Conformer层数 (增加到4层)
            nhead: 注意力头数 (增加到4头)
        """
        super().__init__()
        
        # 增强的输入投影层
        self.prenet = nn.Sequential(
            nn.Linear(40, d_model),
            nn.LayerNorm(d_model),
            nn.Dropout(dropout)
        )
        
        # 位置编码 (相对位置编码)
        self.pos_encoding = RelativePositionalEncoding(d_model, dropout=dropout)
        
        # Conformer编码层
        self.conformer_layers = nn.ModuleList([
            EnhancedConformerBlock(
                d_model=d_model,
                nhead=nhead,
                conv_kernel_size=31,
                dropout=dropout,
                expansion_factor=4
            ) for _ in range(num_layers)
        ])
        
        # 最终层归一化
        self.final_norm = nn.LayerNorm(d_model)
        
        # 增强的分类器
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, n_spks),
        )
        
        # 初始化权重
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """初始化权重"""
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.constant_(module.bias, 0)
            torch.nn.init.constant_(module.weight, 1.0)

    def forward(self, mels):
        """
        Args:
            mels: (batch_size, seq_len, 40)
        Returns:
            out: (batch_size, n_spks)
        """
        # 输入投影
        x = self.prenet(mels)  # (batch, seq_len, d_model)
        
        # 添加位置编码
        x = self.pos_encoding(x)
        
        # Conformer编码
        for layer in self.conformer_layers:
            x = layer(x)
        
        # 最终层归一化
        x = self.final_norm(x)
        
        # 全局平均池化 + 全局最大池化 (双池化策略)
        avg_pool = x.mean(dim=1)  # (batch, d_model)
        max_pool = x.max(dim=1)[0]  # (batch, d_model)
        x = avg_pool + max_pool  # 结合两种池化方式
        
        # 分类
        out = self.classifier(x)  # (batch, n_spks)
        
        return out

class RelativePositionalEncoding(nn.Module):
    """
    相对位置编码，比绝对位置编码更适合语音任务
    """
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.d_model = d_model
        
        # 可学习的位置编码
        self.pos_embedding = nn.Parameter(torch.zeros(1, max_len, d_model))
        nn.init.trunc_normal_(self.pos_embedding, std=0.02)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        
        # 截取适当长度的位置编码
        pos_embed = self.pos_embedding[:, :seq_len, :]
        x = x + pos_embed
        
        return self.dropout(x)

def get_enhanced_conformer(d_model=128, n_spks=600, dropout=0.1, num_layers=4):
    """创建增强的Conformer模型"""
    return EnhancedConformerClassifier(
        d_model=d_model,
        n_spks=n_spks,
        dropout=dropout,
        num_layers=num_layers
    )

In [3]:
class Classifier(nn.Module):
  def __init__(self, d_model=80, n_spks=600, dropout=0.1):
    """初始化分类器模型
    
    Args:
        d_model (int): 模型的特征维度，默认为80
        n_spks (int): 说话人数量，默认为600
        dropout (float): dropout率，默认为0.1
    """
    super().__init__()
    
    # 将输入特征的维度从40投影到d_model
    # 输入: (batch_size, length, 40) -> 输出: (batch_size, length, d_model)
    self.prenet = nn.Linear(40, d_model)
    
    # TODO: 将Transformer改为Conformer
    # 参考论文: https://arxiv.org/abs/2005.08100
    # Conformer结合了CNN和Transformer的优点，在语音任务上表现更好
    
    # 当前使用Transformer编码层
    self.encoder_layer = nn.TransformerEncoderLayer(
      d_model=d_model,        # 特征维度
      dim_feedforward=256,    # 前馈网络的隐藏层维度
      nhead=2                 # 注意力头数
    )
    # 如果需要多层，可以使用TransformerEncoder
    # self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=2)

    # 预测层：将d_model维特征映射到说话人数量
    self.pred_layer = nn.Sequential(
      nn.Linear(d_model, d_model),  # 线性变换
      nn.ReLU(),                    # 激活函数
      nn.Linear(d_model, n_spks),   # 输出层，输出每个说话人的分数
    )

  def forward(self, mels):
    """
    前向传播
    
    Args:
      mels: 输入mel频谱图，形状为 (batch_size, length, 40)
      
    Return:
      out: 输出说话人分类结果，形状为 (batch_size, n_spks)
    """
    # 输入投影: (batch_size, length, 40) -> (batch_size, length, d_model)
    out = self.prenet(mels)
    
    # 调整维度以适应Transformer输入要求
    # (batch_size, length, d_model) -> (length, batch_size, d_model)
    out = out.permute(1, 0, 2)
    
    # Transformer编码层期望输入形状为 (length, batch_size, d_model)
    out = self.encoder_layer(out)
    
    # 恢复维度: (length, batch_size, d_model) -> (batch_size, length, d_model)
    out = out.transpose(0, 1)
    
    # 均值池化：沿时间维度求平均
    # (batch_size, length, d_model) -> (batch_size, d_model)
    stats = out.mean(dim=1)

    # 通过预测层得到最终分类结果
    # (batch_size, d_model) -> (batch_size, n_spks)
    out = self.pred_layer(stats)
    
    return out

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConformerBlock(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1, conv_kernel_size=31):
        super().__init__()
        self.d_model = d_model
        
        # 第一部分：前馈网络
        self.ffn1 = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        
        # 第二部分：自注意力
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        
        # 第三部分：卷积模块
        self.conv_module = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Conv1d(d_model, 2*d_model, 1),  # 点式卷积
            nn.GLU(dim=1),  # GLU激活
            nn.Conv1d(d_model, d_model, conv_kernel_size, 
                     padding=(conv_kernel_size-1)//2, groups=d_model),  # 深度可分离卷积
            nn.BatchNorm1d(d_model),
            nn.SiLU(),
            nn.Conv1d(d_model, d_model, 1),  # 点式卷积
            nn.Dropout(dropout)
        )
        self.norm3 = nn.LayerNorm(d_model)
        
        # 第四部分：第二部分前馈网络
        self.ffn2 = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )
        self.norm4 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # 第一部分：前馈网络 + 残差连接
        residual = x
        x = self.norm1(x)
        x = self.ffn1(x)
        x = self.dropout(x)
        x = x + residual
        
        # 第二部分：自注意力 + 残差连接
        residual = x
        x = self.norm2(x)
        x, _ = self.self_attn(x, x, x)
        x = self.dropout(x)
        x = x + residual
        
        # 第三部分：卷积模块 + 残差连接
        residual = x
        x = self.norm3(x)
        # 调整维度用于卷积: (batch, seq_len, d_model) -> (batch, d_model, seq_len)
        x_conv = x.transpose(1, 2)
        x_conv = self.conv_module(x_conv)
        x_conv = x_conv.transpose(1, 2)  # 恢复维度
        x = self.dropout(x_conv)
        x = x + residual
        
        # 第四部分：前馈网络 + 残差连接
        residual = x
        x = self.norm4(x)
        x = self.ffn2(x)
        x = self.dropout(x)
        x = x + residual
        
        return x

class ConformerClassifier(nn.Module):
    def __init__(self, d_model=80, n_spks=600, dropout=0.1, num_layers=4):
        super().__init__()
        
        # 输入投影层
        self.prenet = nn.Linear(40, d_model)
        
        # Conformer编码层
        self.conformer_layers = nn.ModuleList([
            ConformerBlock(
                d_model=d_model,
                nhead=8,  # 更多的注意力头
                dim_feedforward=256,
                dropout=dropout,
                conv_kernel_size=31
            ) for _ in range(num_layers)
        ])
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, n_spks),
        )

    def forward(self, mels):
        # 输入投影
        x = self.prenet(mels)  # (batch, seq_len, d_model)
        
        # Conformer编码
        for layer in self.conformer_layers:
            x = layer(x)
        
        # 全局平均池化
        x = x.mean(dim=1)  # (batch, d_model)
        
        # 分类
        out = self.classifier(x)  # (batch, n_spks)
        
        return out

In [5]:
class SpeakerClassifier(nn.Module):
    def __init__(self, d_model=80, n_spks=600, dropout=0.1):
        """说话人分类模型
        
        Args:
            d_model (int): 模型特征维度
            n_spks (int): 说话人数量
            dropout (float): dropout比率
        """
        super().__init__()
        
        # 输入投影层: 40维mel特征 -> d_model维
        self.prenet = nn.Linear(40, d_model)
        
        # Transformer编码层
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            dim_feedforward=256,
            nhead=2,
            dropout=dropout,
            batch_first=False  # Transformer需要序列维度在前
        )
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, n_spks),
        )

    def forward(self, mels):
        """前向传播
        
        Args:
            mels: mel频谱图, shape: (batch, seq_len, 40)
            
        Returns:
            说话人分类logits, shape: (batch, n_spks)
        """
        # 输入投影
        x = self.prenet(mels)  # (batch, seq_len, d_model)
        
        # 调整维度适应Transformer: (seq_len, batch, d_model)
        x = x.transpose(0, 1)
        
        # Transformer编码
        x = self.encoder_layer(x)  # (seq_len, batch, d_model)
        
        # 恢复维度: (batch, seq_len, d_model)
        x = x.transpose(0, 1)
        
        # 全局平均池化
        x = x.mean(dim=1)  # (batch, d_model)
        
        # 分类
        out = self.classifier(x)  # (batch, n_spks)
        
        return out

## 创建模型的函数

In [6]:
def get_Classifier(d_model=80, n_spks=600, dropout=0.1):
    """创建适配512×512的ResNet-18模型"""
    return Classifier(d_model=d_model, n_spks=n_spks, dropout=dropout)

# 添加工具函数

## 绘制训练和损失曲线

In [7]:
def plot_loss_curves(train_losses, val_losses, save_path=None):
    """绘制训练和验证损失曲线"""
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='训练损失', linewidth=2)
    plt.plot(val_losses, label='验证损失', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('训练和验证损失曲线')
    plt.legend()
    plt.grid(True, alpha=0.3)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"图片已保存到: {save_path}")

## 绘制训练准确率曲线

In [8]:
def plot_accuracy_curves(train_accuracies, val_accuracies, save_path=None):
    """绘制准确率曲线"""
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(train_accuracies) + 1)

    plt.plot(epochs, train_accuracies, 'b-', label='训练准确率', linewidth=2)
    plt.plot(epochs, val_accuracies, 'r-', label='验证准确率', linewidth=2)

    plt.title('训练和验证准确率', fontsize=14, fontweight='bold')
    plt.xlabel('Epochs', fontsize=12)
    plt.ylabel('Accuracy (%)', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.ylim(bottom=0)

    # 添加最佳准确率标注
    best_val_acc = max(val_accuracies)
    best_epoch = val_accuracies.index(best_val_acc) + 1
    plt.axvline(x=best_epoch, color='gray', linestyle='--', alpha=0.7)
    plt.text(best_epoch, best_val_acc / 2, f'最佳: {best_val_acc:.2f}%\nEpoch: {best_epoch}',
             ha='center', va='center', fontsize=10, 
             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

## 绘制综合曲线

In [9]:
def plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies, save_path=None):
    """绘制综合训练曲线"""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
    epochs = range(1, len(train_losses) + 1)

    # 绘制损失曲线
    ax1.plot(epochs, train_losses, 'b-', label='训练损失', linewidth=2)
    ax1.plot(epochs, val_losses, 'r-', label='验证损失', linewidth=2)
    ax1.set_title('训练和验证损失', fontsize=14, fontweight='bold')
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.legend(fontsize=12)
    ax1.grid(True, alpha=0.3)

    # 绘制准确率曲线
    ax2.plot(epochs, train_accuracies, 'b-', label='训练准确率', linewidth=2)
    ax2.plot(epochs, val_accuracies, 'r-', label='验证准确率', linewidth=2)
    ax2.set_title('训练和验证准确率', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epochs', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.legend(fontsize=12)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(bottom=0)

    # 添加最佳准确率标注
    best_val_acc = max(val_accuracies)
    best_epoch = val_accuracies.index(best_val_acc) + 1
    ax2.axvline(x=best_epoch, color='gray', linestyle='--', alpha=0.7)
    ax2.text(best_epoch, best_val_acc / 2, f'最佳: {best_val_acc:.2f}%',
             ha='center', va='center', fontsize=10,
             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

# dataset

In [10]:
class SpeakerDataset(Dataset):
    def __init__(self, data_dir, segment_len=128):
        """初始化说话人数据集
        
        Args:
            data_dir (str): 数据目录路径
            segment_len (int): 每个语音片段的帧数，默认为128
        """
        self.data_dir = data_dir
        self.segment_len = segment_len

        # 加载说话人映射关系
        mapping_path = Path(data_dir) / "mapping.json"
        mapping = json.load(mapping_path.open())
        self.speaker2id = mapping["speaker2id"]

        # 加载元数据
        metadata_path = Path(data_dir) / "metadata.json"
        metadata = json.load(open(metadata_path))["speakers"]

        # 获取说话人总数
        self.speaker_num = len(metadata.keys())
        self.data = []
        
        # 构建数据列表: [特征路径, 说话人ID]
        for speaker in metadata.keys():
            for utterances in metadata[speaker]:
                self.data.append([utterances["feature_path"], self.speaker2id[speaker]])

    def __len__(self):
        """返回数据集大小"""
        return len(self.data)

    def __getitem__(self, index):
        """获取单个样本"""
        feat_path, speaker = self.data[index]
        
        # 加载mel频谱特征
        mel = torch.load(os.path.join(self.data_dir, feat_path))

        # 随机截取固定长度的片段
        if len(mel) > self.segment_len:
            start = random.randint(0, len(mel) - self.segment_len)
            mel = torch.FloatTensor(mel[start:start+self.segment_len])
        else:
            mel = torch.FloatTensor(mel)
            
        speaker = torch.LongTensor([speaker]).squeeze()
        
        return mel, speaker

    def get_speaker_number(self):
        """返回说话人数量"""
        return self.speaker_num


# dataloader

In [11]:
import torch
from torch.utils.data import DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence


def collate_batch(batch):
  """处理一个批次的数据
  
  将同一个批次中的特征进行填充，使它们的长度相同
  
  Args:
      batch: 一个批次的数据，包含mel频谱图和说话人ID
  
  Returns:
      tuple: 填充后的mel频谱图和说话人ID张量
  """
  # 将批次数据解包为mel频谱图和说话人ID
  mel, speaker = zip(*batch)
  
  # 对同一个批次中的mel频谱图进行填充，使它们长度相同
  # 使用-20进行填充，对应log10^(-20)，这是一个非常小的值（接近0）
  mel = pad_sequence(mel, batch_first=True, padding_value=-20)
  
  # mel的形状: (批次大小, 序列长度, 40个mel频带)
  return mel, torch.FloatTensor(speaker).long()


def get_dataloader(data_dir, batch_size, n_workers):
  """生成数据加载器
  
  Args:
      data_dir (str): 数据目录路径
      batch_size (int): 批次大小
      n_workers (int): 数据加载的工作进程数
  
  Returns:
      tuple: 训练数据加载器、验证数据加载器、说话人数量
  """
  # 创建数据集实例
  dataset = SpeakerDataset(data_dir)
  
  # 获取数据集中说话人的总数
  speaker_num = dataset.get_speaker_number()
  
  # 将数据集按9:1的比例分割为训练集和验证集
  trainlen = int(0.9 * len(dataset))  # 90% 训练集
  lengths = [trainlen, len(dataset) - trainlen]  # 训练集和验证集的大小
  trainset, validset = random_split(dataset, lengths)  # 随机分割

  # 创建训练数据加载器
  train_loader = DataLoader(
    trainset,                    # 训练数据集
    batch_size=batch_size,       # 批次大小
    shuffle=True,                # 每个epoch打乱数据
    drop_last=True,              # 丢弃最后一个不完整的批次
    num_workers=n_workers,       # 数据加载的工作进程数
    pin_memory=True,             # 将数据固定在内存中，加速GPU传输
    collate_fn=collate_batch,    # 自定义批次处理函数
  )
  
  # 创建验证数据加载器
  valid_loader = DataLoader(
    validset,                    # 验证数据集
    batch_size=batch_size,       # 批次大小
    num_workers=n_workers,       # 数据加载的工作进程数
    drop_last=True,              # 丢弃最后一个不完整的批次
    pin_memory=True,             # 将数据固定在内存中，加速GPU传输
    collate_fn=collate_batch,    # 自定义批次处理函数
  )

  return train_loader, valid_loader, speaker_num

# 学习率

In [12]:
import math
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5):
    """带热身阶段的余弦退火学习率调度器
    
    Args:
        optimizer: 优化器
        num_warmup_steps: 热身步数
        num_training_steps: 总训练步数
        num_cycles: 余弦周期数
        
    Returns:
        学习率调度器
    """
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            # 线性热身
            return float(current_step) / float(max(1, num_warmup_steps))
        else:
            # 余弦退火
            progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
    
    return LambdaLR(optimizer, lr_lambda)

In [13]:
def model_fn(batch, model, criterion, device):
    """单个批次的前向传播
    
    Args:
        batch: 数据批次
        model: 模型
        criterion: 损失函数
        device: 计算设备
        
    Returns:
        tuple: (损失, 准确率)
    """
    mels, labels = batch
    mels = mels.to(device)
    labels = labels.to(device)

    # 前向传播
    outputs = model(mels)
    
    # 计算损失和准确率
    loss = criterion(outputs, labels)
    preds = outputs.argmax(dim=1)
    accuracy = (preds == labels).float().mean()
    
    return loss, accuracy

In [14]:
def validate(dataloader, model, criterion, device):
    """在验证集上评估模型
    
    Args:
        dataloader: 验证数据加载器
        model: 模型
        criterion: 损失函数
        device: 计算设备
        
    Returns:
        tuple: (平均损失, 平均准确率)
    """
    model.eval()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = len(dataloader)
    
    with torch.no_grad():
        for batch in dataloader:
            loss, accuracy = model_fn(batch, model, criterion, device)
            total_loss += loss.item()
            total_accuracy += accuracy.item()
    
    model.train()
    
    return total_loss / num_batches, total_accuracy / num_batches


In [15]:
# 配置参数
data_dir = "./Dataset"
batch_size = 32
n_workers = 4
valid_steps = 2000
warmup_steps = 1000
save_steps = 10000
total_steps = 70000
max_epochs = 100

In [16]:
# 创建保存目录
now_time = datetime.now()
time_str = datetime.strftime(now_time, '%m-%d_%H-%M')
log_dir = os.path.join("./results", time_str)
os.makedirs(log_dir, exist_ok=True)
print(f"结果保存目录: {log_dir}")

结果保存目录: ./results\10-20_19-23


In [17]:
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

使用设备: cpu


In [18]:
# 数据加载
train_loader, val_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)
print(f"完成数据加载! 说话人数量: {speaker_num}")

完成数据加载! 说话人数量: 600


In [19]:
# 替换为单头注意力的轻量级Conformer
model = get_enhanced_conformer(
    d_model=128,           # 特征维度
    n_spks=speaker_num,    # 说话人数量
    dropout=0.1,           # dropout率
    num_layers=4           # 层数
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-3)
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
print("完成模型创建!")

完成模型创建!


In [20]:
# 训练状态记录
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
learning_rates = []
max_epoch = 50

best_val_accuracy = 0.0
best_state_dict = None

In [None]:
# 改为epoch-based训练循环
model.train()

print("开始训练...")
start_time = datetime.now()

best_epoch = 0

# 计算总步数用于学习率调度
total_steps = max_epochs * len(train_loader)

for epoch in range(max_epochs):
    # 训练阶段
    model.train()
    epoch_train_loss = 0.0
    epoch_train_accuracy = 0.0
    num_batches = len(train_loader)
    
    print(f"\n--- Epoch {epoch+1}/{max_epochs} ---")
    
    for batch_idx, batch in enumerate(train_loader):
        # 前向传播和反向传播
        train_loss, train_accuracy = model_fn(batch, model, criterion, device)
        
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        scheduler.step()
        
        # 累计epoch指标
        epoch_train_loss += train_loss.item()
        epoch_train_accuracy += train_accuracy.item()
        
        # 记录每一步的训练指标
        train_losses.append(train_loss.item())
        train_accuracies.append(train_accuracy.item())
        current_lr = scheduler.get_last_lr()[0]
        learning_rates.append(current_lr)
        
        # 每100个batch打印一次进度
        if (batch_idx + 1) % 100 == 0:
            elapsed_time = (datetime.now() - start_time).total_seconds()
            print(f"Epoch {epoch+1} | Batch {batch_idx+1}/{num_batches} | "
                  f"当前损失: {train_loss.item():.4f} | "
                  f"当前准确率: {train_accuracy.item():.4f} | "
                  f"学习率: {current_lr:.2e}")
    
    # 计算epoch平均训练指标
    avg_train_loss = epoch_train_loss / num_batches
    avg_train_accuracy = epoch_train_accuracy / num_batches
    
    # 验证阶段
    val_loss, val_accuracy = validate(val_loader, model, criterion, device)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    
    # 计算已用时间
    elapsed_time = (datetime.now() - start_time).total_seconds()
    elapsed_minutes = elapsed_time / 60
    
    print(f"\nEpoch {epoch+1} 结果:")
    print(f"训练损失: {avg_train_loss:.4f} | 训练准确率: {avg_train_accuracy:.4f}")
    print(f"验证损失: {val_loss:.4f} | 验证准确率: {val_accuracy:.4f}")
    print(f"学习率: {current_lr:.2e} | 已用时间: {elapsed_minutes:.1f}分钟")
    
    # 更新最佳模型
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_epoch = epoch + 1
        best_state_dict = model.state_dict().copy()
        
        # 保存最佳模型
        model_path = os.path.join(log_dir, f"best_model_epoch{epoch+1}_acc{best_val_accuracy:.4f}.pth")
        torch.save({
            'model_state_dict': best_state_dict,
            'val_accuracy': best_val_accuracy,
            'epoch': epoch + 1
        }, model_path)
        print(f"✅ 保存最佳模型! Epoch {epoch+1}, 验证准确率: {best_val_accuracy:.4f}")
    
    # 每5个epoch保存一次检查点
    if (epoch + 1) % 5 == 0 and best_state_dict is not None:
        checkpoint_path = os.path.join(log_dir, f"checkpoint_epoch_{epoch+1}.pth")
        torch.save({
            'model_state_dict': best_state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'epoch': epoch + 1,
            'best_val_accuracy': best_val_accuracy
        }, checkpoint_path)
        print(f"💾 保存检查点: {checkpoint_path}")
    
    # 早停检查（可选）
    if epoch + 1 - best_epoch > 10:  # 如果连续10个epoch没有提升
        print(f"🚨 早停触发! 最佳准确率 {best_val_accuracy:.4f} 在 Epoch {best_epoch}")
        break

# 训练完成
total_time = (datetime.now() - start_time).total_seconds()
print(f"\n🎉 训练完成!")
print(f"总用时: {total_time/60:.1f}分钟")
print(f"总epoch数: {epoch + 1}")
print(f"最佳验证准确率: {best_val_accuracy:.4f} (Epoch {best_epoch})")

# 保存最终的最佳模型
if best_state_dict is not None:
    final_model_path = os.path.join(log_dir, "final_best_model.pth")
    torch.save(best_state_dict, final_model_path)
    print(f"最终模型已保存: {final_model_path}")

开始训练...

--- Epoch 1/100 ---


In [None]:
# 保存训练历史
training_history = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_accuracies': train_accuracies,
    'val_accuracies': val_accuracies,
    'learning_rates': learning_rates,
    'best_val_accuracy': best_val_accuracy,
    'best_epoch': len(val_accuracies) - 1
}

# 保存训练记录
log_path = os.path.join(log_dir, "training_history.pth")
torch.save(training_history, log_path)
print(f"训练记录已保存: {log_path}")

In [None]:
# 绘制训练曲线
print("开始绘制训练曲线...")
def plot_training_results(train_losses, val_losses, train_accuracies, val_accuracies, learning_rates, save_dir=None):
    """Plot complete training results charts
    
    Args:
        train_losses: List of training losses
        val_losses: List of validation losses
        train_accuracies: List of training accuracies
        val_accuracies: List of validation accuracies
        learning_rates: List of learning rates
        save_dir: Directory path to save plots
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Create subplots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Calculate x-axis data
    epochs = range(1, len(val_losses) + 1)
    steps = range(1, len(train_losses) + 1)
    
    # 1. Training and Validation Loss (by epoch)
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2, marker='o', markersize=4)
    
    # Calculate average training loss per epoch
    steps_per_epoch = len(train_losses) // len(val_losses)
    epoch_train_losses = []
    for i in range(len(val_losses)):
        start_idx = i * steps_per_epoch
        end_idx = min((i + 1) * steps_per_epoch, len(train_losses))
        epoch_avg_loss = np.mean(train_losses[start_idx:end_idx])
        epoch_train_losses.append(epoch_avg_loss)
    
    ax1.plot(epochs, epoch_train_losses, 'b-', label='Training Loss', linewidth=2, marker='s', markersize=4)
    ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Mark best validation loss
    best_val_loss_idx = np.argmin(val_losses)
    best_val_loss = val_losses[best_val_loss_idx]
    ax1.axvline(x=best_val_loss_idx + 1, color='gray', linestyle='--', alpha=0.7)
    ax1.text(best_val_loss_idx + 1, best_val_loss, f'Best\n{best_val_loss:.4f}', 
             ha='center', va='bottom', fontsize=10,
             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
    
    # 2. Training and Validation Accuracy (by epoch)
    ax2.plot(epochs, val_accuracies, 'r-', label='Validation Accuracy', linewidth=2, marker='o', markersize=4)
    
    # Calculate average training accuracy per epoch
    epoch_train_accuracies = []
    for i in range(len(val_accuracies)):
        start_idx = i * steps_per_epoch
        end_idx = min((i + 1) * steps_per_epoch, len(train_accuracies))
        epoch_avg_acc = np.mean(train_accuracies[start_idx:end_idx])
        epoch_train_accuracies.append(epoch_avg_acc)
    
    ax2.plot(epochs, epoch_train_accuracies, 'b-', label='Training Accuracy', linewidth=2, marker='s', markersize=4)
    ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Mark best validation accuracy
    best_val_acc_idx = np.argmax(val_accuracies)
    best_val_acc = val_accuracies[best_val_acc_idx]
    ax2.axvline(x=best_val_acc_idx + 1, color='gray', linestyle='--', alpha=0.7)
    ax2.text(best_val_acc_idx + 1, best_val_acc, f'Best\n{best_val_acc:.4f}', 
             ha='center', va='bottom', fontsize=10,
             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
    
    # 3. Learning Rate Schedule (by step)
    ax3.plot(steps, learning_rates, 'g-', linewidth=2)
    ax3.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Step')
    ax3.set_ylabel('Learning Rate')
    ax3.set_yscale('log')
    ax3.grid(True, alpha=0.3)
    
    # 4. Detailed Training Process (by step)
    # Apply moving average to smooth training curves
    window_size = min(100, len(train_losses) // 10)
    if window_size > 1:
        smooth_train_losses = np.convolve(train_losses, np.ones(window_size)/window_size, mode='valid')
        smooth_train_accuracies = np.convolve(train_accuracies, np.ones(window_size)/window_size, mode='valid')
        smooth_steps = steps[window_size-1:]
        
        ax4.plot(smooth_steps, smooth_train_losses, 'b-', alpha=0.7, label='Training Loss (smoothed)')
        ax4_twin = ax4.twinx()
        ax4_twin.plot(smooth_steps, smooth_train_accuracies, 'r-', alpha=0.7, label='Training Accuracy (smoothed)')
    else:
        ax4.plot(steps, train_losses, 'b-', alpha=0.7, label='Training Loss')
        ax4_twin = ax4.twinx()
        ax4_twin.plot(steps, train_accuracies, 'r-', alpha=0.7, label='Training Accuracy')
    
    ax4.set_xlabel('Step')
    ax4.set_ylabel('Loss', color='blue')
    ax4_twin.set_ylabel('Accuracy', color='red')
    ax4.set_title('Training Process Details', fontsize=14, fontweight='bold')
    
    # Combine legends
    lines1, labels1 = ax4.get_legend_handles_labels()
    lines2, labels2 = ax4_twin.get_legend_handles_labels()
    ax4.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
    
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    if save_dir:
        plot_path = os.path.join(save_dir, "training_analysis.png")
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        print(f"Training analysis plot saved: {plot_path}")
    
    plt.show()
    
    return fig


In [None]:
print("Plotting training curves...")
import numpy as np


# Load training history (if previously saved)
log_path = os.path.join(log_dir, "training_history.pth")
if os.path.exists(log_path):
    training_history = torch.load(log_path)
    train_losses = training_history['train_losses']
    val_losses = training_history['val_losses']
    train_accuracies = training_history['train_accuracies']
    val_accuracies = training_history['val_accuracies']
    learning_rates = training_history['learning_rates']

# Plot comprehensive training analysis
plot_training_results(
    train_losses=train_losses,
    val_losses=val_losses,
    train_accuracies=train_accuracies,
    val_accuracies=val_accuracies,
    learning_rates=learning_rates,
    save_dir=log_dir
)

# Also plot the simplified version
steps_per_epoch = len(train_losses) // len(val_losses)
plot_training_curves(
    train_losses=[np.mean(train_losses[i*steps_per_epoch:(i+1)*steps_per_epoch]) 
                  for i in range(len(val_losses))],
    val_losses=val_losses,
    train_accuracies=[np.mean(train_accuracies[i*steps_per_epoch:(i+1)*steps_per_epoch]) 
                      for i in range(len(val_accuracies))],
    val_accuracies=val_accuracies,
    save_path=os.path.join(log_dir, "training_curves.png")
)

print("All plots generated successfully!")