In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, ConcatDataset, DataLoader
import json
import os

# Load and process data
def load_utterance(feature_path):
    return torch.load(os.path.join('Dataset/uttr', feature_path))

class SpeakerDataset(Dataset):
    def __init__(self, data_dir=''):
        mapping_path = 'Dataset/mapping.json'
        metadata_path = 'Dataset/metadata.json'
        
        # 加载metadata
        with open(metadata_path, 'r') as f:
            self.metadata = json.load(f)
        
        # 加载speaker映射
        with open(mapping_path, 'r') as f:
            self.mapping = json.load(f)
        
        # 构建样本列表
        self.samples = []
        self.ignore_num = 0
        for speaker_id, utterances in self.metadata['speakers'].items():
            speaker_idx = self.mapping['speaker2id'][speaker_id]  # 获取说话者的数字ID
            for utterance in utterances:
                mel_len = utterance['mel_len']
                if mel_len > 200:
                    self.samples.append({
                        'feature_path': utterance['feature_path'],
                        'speaker_id': speaker_idx,
                        'mel_len': utterance['mel_len']
                    })
                else:
                    self.ignore_num += 1
    
        self.data_dir = data_dir
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        # 加载特征文件
        features = load_utterance(sample['feature_path'])
        return sample['speaker_id'], sample['mel_len'], features[:200]

train_dataset = SpeakerDataset()
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Self-attention + classification model
class SpeakerClassifier(nn.Module):
    def __init__(self, input_dim=40, hidden_dim=256, num_heads=4, num_classes=600):
        super().__init__()
        
        # Self-attention layer
        self.attention = nn.MultiheadAttention(input_dim, num_heads)
        
        # Feed-forward layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, input_dim) -> (seq_len, batch_size, input_dim)
        x = x.transpose(0, 1)  # 添加这行来调整维度顺序
        
        # Self-attention
        attn_output, _ = self.attention(x, x, x)
        
        # Global average pooling over sequence length
        pooled = torch.mean(attn_output, dim=0)
        
        # Feed-forward layers
        x = F.relu(self.fc1(pooled))
        x = self.fc2(x)
        
        # Softmax classification
        return x



device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model
model = SpeakerClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-5)

def train(model, train_loader, criterion, optimizer, epoch, num_epochs):
    model.train()
    running_loss = 0.0
    for i,  (speaker_idx, mel_len, features) in enumerate(train_loader):
        features = features.to(device)
        speaker_idx = speaker_idx.to(device)
        
        optimizer.zero_grad()

        # features.shape [128, 200, 40]
        outputs = model(features)
        loss = criterion(outputs, speaker_idx)
        loss.backward()

        optimizer.step()
        
        running_loss += loss.item()
        if (i + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')
            running_loss = 0.0




Using device: mps


In [33]:
# Training loop
num_epochs = 100
min_accuracy = 0

for epoch in range(num_epochs):
    train(model, train_loader, criterion, optimizer, epoch, num_epochs)
    # accuracy, loss = validate(model, val_loader, criterion)
    # if accuracy > min_accuracy:
    #     min_accuracy = accuracy
    #     torch.save(model.state_dict(), 'speaker_classifier.pth')