In [1]:
import torch
import yaml
import wespeaker.models.resnet  # Import the ResNet model from wespeaker.models module

# Step 3: Load pre-trained model configuration
with open("voxceleb_resnet152_LM/voxceleb_resnet152_LM.yaml", "r") as f:
    config = yaml.safe_load(f)

# Step 4: Load pre-trained model
model = wespeaker.models.resnet.ResNet152(**config["model_args"])

# Load only compatible keys from the pre-trained model
pretrained_dict = torch.load("voxceleb_resnet152_LM/voxceleb_resnet152_LM.pt")
model_dict = model.state_dict()

# Filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# Update model weights
model.load_state_dict(pretrained_dict, strict=False)




<All keys matched successfully>

In [2]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

class AudioDataset(Dataset):
    def __init__(self, root_dir, target_length=32000, transform=None):
        self.files = []
        self.labels = []
        self.label_to_index = {}
        self.index_to_label = {}
        self.target_length = target_length
        self.transform = transform
        subfolders = ['custom_noised', 'gauss', 'multiplied_amplitude', 'subsampled']
        
        for folder in subfolders:
            folder_path = os.path.join(root_dir, folder)
            for filename in os.listdir(folder_path):
                if filename.endswith('.wav'):
                    full_path = os.path.join(folder_path, filename)
                    label = self.extract_label_from_filename(filename)
                    if label not in self.label_to_index:
                        self.label_to_index[label] = len(self.label_to_index)
                        self.index_to_label[self.label_to_index[label]] = label
                    self.files.append(full_path)
                    self.labels.append(self.label_to_index[label])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        waveform, sample_rate = torchaudio.load(self.files[idx])
        label = self.labels[idx]

        # Ensure waveform is mono
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)
        
        # Pad or truncate waveform to the target length
        if waveform.shape[1] > self.target_length:
            waveform = waveform[:, :self.target_length]
        elif waveform.shape[1] < self.target_length:
            padding_size = self.target_length - waveform.shape[1]
            padding = torch.zeros((1, padding_size))
            waveform = torch.cat((waveform, padding), 1)

        if self.transform:
            waveform = self.transform(waveform, sample_rate)
        
        return waveform, label
    
    @staticmethod
    def extract_label_from_filename(filename):
        parts = filename.split('_')
        return ' '.join(parts[:2])

In [3]:
def audio_transform(waveform, sample_rate, target_sample_rate=16000):
    if sample_rate != target_sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    return waveform

In [4]:
def collate_fn(batch):
    waveforms, labels = zip(*batch)
    max_length = max([waveform.shape[1] for waveform in waveforms])
    waveforms_padded = torch.stack([
        torch.cat([waveform, torch.zeros(1, max_length - waveform.shape[1])], dim=1) 
        if waveform.shape[1] < max_length else waveform for waveform in waveforms
    ])
    labels = torch.tensor(labels)
    return waveforms_padded, labels

In [5]:
dataset = AudioDataset(root_dir='data_noise_train', transform=audio_transform)
data_loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [8]:
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self, original_model, output_features):
        super(CustomModel, self).__init__()
        self.original_model = original_model
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, output_features))  # Adjust the target output size
        self.classifier = nn.Linear(output_features, len(dataset.label_to_index))  # Adjust based on your number of classes

    def forward(self, x):
        x = self.adaptive_pool(x)  # Resize feature to match the expected input size of the classifier
        x = x.view(x.size(0), -1)  # Flatten the output
        x = self.classifier(x)
        return x

model = CustomModel(model, 100) # '100' should be adjusted based on your needs


In [9]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=100, patience=50):
    model.train()
    best_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0
        
        for waveforms, labels in train_loader:
            optimizer.zero_grad()
            # Ensure data has correct shape, [batch_size, 1, length]
            waveforms = waveforms.unsqueeze(1)
            outputs = model(waveforms)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1

        average_loss = total_loss / num_batches
        print(f'Epoch {epoch+1}: Training Loss: {average_loss:.4f}')

        # Early stopping logic based on training loss
        if average_loss < best_loss:
            best_loss = average_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_model.pth')  # Save the best model
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print("Early stopping triggered.")
                break

    # Load the best model once training is finished
    model.load_state_dict(torch.load('fine_tuned_model.pth'))

train_model(model, data_loader, criterion, optimizer)


KeyboardInterrupt: 