In [None]:
!pip install -qU audiomentations

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import librosa
import numpy as np
import random
import os
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import math
from tqdm import tqdm

In [None]:
# Constants
SAMPLE_RATE = 16000
AUDIO_WINDOW = 1.0  # seconds
AUDIO_LENGTH = int(AUDIO_WINDOW * SAMPLE_RATE)
LOG_MEL_MEAN = 1.4   
LOG_MEL_STD = 1.184  

In [None]:
# Utility functions
def random_crop(x: np.array, length=AUDIO_LENGTH) -> np.array:
    assert x.shape[0] > length
    front_bits = random.randint(0, x.shape[0] - length)
    return x[front_bits:front_bits + length]

def add_padding(x: np.array, length=AUDIO_LENGTH) -> np.array:
    assert x.shape[0] < length
    bit_count_to_be_added = length - x.shape[0]
    front_bits = random.randint(0, bit_count_to_be_added)
    new_x = np.pad(x, (front_bits, bit_count_to_be_added - front_bits))
    assert new_x.shape[0] == length, f"Error: Padded audio shape is {new_x.shape}, expected {length}"
    return new_x

def remove_existing_padding(x: np.array) -> np.array:
    non_zero = np.nonzero(x)[0]
    return x[non_zero[0]:non_zero[-1] + 1] if len(non_zero) > 0 else x

def fix_padding_issues(x: np.array, length=AUDIO_LENGTH) -> np.array:
    x = remove_existing_padding(x)
    if x.shape[0] > length:
        return random_crop(x, length=length)
    elif x.shape[0] < length:
        return add_padding(x, length=length)
    else:
        return x

def add_noise(x: np.array, noise: np.array, noise_factor=0.4) -> np.array:
    assert x.shape[0] == noise.shape[0]
    out = (1 - noise_factor) * x / x.max() + noise_factor * (noise / noise.max())
    return out / out.max()

In [None]:
# Audio augmentation
from audiomentations import Compose, TimeStretch, PitchShift, Shift

augmentation_pipeline = Compose([
    TimeStretch(min_rate=0.8, max_rate=1.25, p=0.4),
    PitchShift(min_semitones=-4, max_semitones=4, p=0.4),
    Shift(min_shift=-0.5, max_shift=0.5,p=0.4),
])

In [None]:
# Mel spectrogram calculation
def build_mel_spectrogram(waveform):
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=500,
        win_length=400,
        hop_length=160,
        n_mels=64,
        f_min=50,
        f_max=8000
    )
    log_mel_spectrogram = (torch.log(mel_spectrogram(waveform) + 1e-9) - LOG_MEL_MEAN) / LOG_MEL_STD
    return log_mel_spectrogram.unsqueeze(0)  # channel dim

In [None]:
# Custom Dataset
class CustomDataset(Dataset):
    def __init__(self, chunked_noise_path, dataset_path, training=True, max_noise_factor=0.3, min_noise_factor=0.1):
        self.chunked_noise_path = chunked_noise_path         #0.2 ,0.05
        self.dataset_path = dataset_path
        self.max_noise_factor = max_noise_factor
        self.min_noise_factor = min_noise_factor
        
        self.words = [d for d in os.listdir(dataset_path) if not d.startswith('.')]
        self.noise_types = [d for d in os.listdir(chunked_noise_path) if not d.startswith('.')]
        
        if training:
            self.words = self.words[:int(0.9 * len(self.words))]
        else:
            self.words = self.words[int(0.9 * len(self.words)):]
        
        self.n = 2 * len(self.words)  # Positive and negative pairs
        
    def __len__(self):
        return self.n
    
    def __getitem__(self, index):
        word1 = self.words[index // 2]
        word2 = self.words[(index // 2 + index % 2) % (self.n // 2)]
        
        sample1 = random.choice([f for f in os.listdir(os.path.join(self.dataset_path, word1)) if not f.startswith('.')])
        sample2 = random.choice([f for f in os.listdir(os.path.join(self.dataset_path, word2)) if not f.startswith('.')])
        
        voice_vector1, _ = librosa.load(os.path.join(self.dataset_path, word1, sample1), sr=SAMPLE_RATE)
        voice_vector2, _ = librosa.load(os.path.join(self.dataset_path, word2, sample2), sr=SAMPLE_RATE)
        
        # Fix padding issues
        voice_vector1 = fix_padding_issues(voice_vector1)
        voice_vector2 = fix_padding_issues(voice_vector2)
        
        # Apply audio augmentation
        voice_vector1 = augmentation_pipeline(samples=voice_vector1, sample_rate=SAMPLE_RATE)
        voice_vector2 = augmentation_pipeline(samples=voice_vector2, sample_rate=SAMPLE_RATE)
        
        # Select random noise types
        noise_type1, noise_type2 = random.sample(self.noise_types, 2)
        noise1 = random.choice([f for f in os.listdir(os.path.join(self.chunked_noise_path, noise_type1)) if not f.startswith('.')])
        noise2 = random.choice([f for f in os.listdir(os.path.join(self.chunked_noise_path, noise_type2)) if not f.startswith('.')])
        
        noise_vector1, _ = librosa.load(os.path.join(self.chunked_noise_path, noise_type1, noise1), sr=SAMPLE_RATE)
        noise_vector2, _ = librosa.load(os.path.join(self.chunked_noise_path, noise_type2, noise2), sr=SAMPLE_RATE)
        
        # Apply noise
        noise_factor1 = random.uniform(self.min_noise_factor, self.max_noise_factor)
        noise_factor2 = random.uniform(self.min_noise_factor, self.max_noise_factor)
        voice_with_noise1 = add_noise(voice_vector1, noise_vector1, noise_factor1)
        voice_with_noise2 = add_noise(voice_vector2, noise_vector2, noise_factor2)
        
        # Generate mel spectrograms
        spectrogram1 = build_mel_spectrogram(torch.tensor(voice_with_noise1))
        spectrogram2 = build_mel_spectrogram(torch.tensor(voice_with_noise2))
        
        # Label: 1.0 for positive pairs, 0.0 for negative pairs
        label = 1.0 if index % 2 == 0 else 0.0
        
        return spectrogram1, spectrogram2, torch.tensor(label, dtype=torch.float32)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
torch.backends.cudnn.benchmark = True

# SwiGLU Activation
class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class SwiGLU(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SwiGLU, self).__init__()
        self.linear1 = nn.Linear(input_dim, output_dim * 2)  

    def forward(self, x):
        x = self.linear1(x)
        x1, x2 = x.chunk(2, dim=-1)
        return x1 * F.silu(x2)

# Siamese Network - ResNet-101 
class SiameseNetworkResNet101(nn.Module):
    def __init__(self):
        super(SiameseNetworkResNet101, self).__init__()
        
        self.base_network = models.resnet101(pretrained=True)
        self.base_network.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        modules = list(self.base_network.children())[:-1]
        self.base_network = nn.Sequential(*modules)
        
        self.fc1 = nn.Linear(2048, 2048)
        self.ln1 = nn.LayerNorm(2048)
        self.dropout1 = nn.Dropout(p=0.3)     
        
        
        self.swiglu = SwiGLU(2048, 1024)      
        self.ln_swiglu = nn.LayerNorm(1024)
        
        self.fc2 = nn.Linear(1024, 512)
        self.ln2 = nn.LayerNorm(512)
        self.dropout2 = nn.Dropout(p=0.3)     
        
        self.fc3 = nn.Linear(512, 256)
        self.ln3 = nn.LayerNorm(256)
        self.dropout3 = nn.Dropout(p=0.3)     
        self.l2_norm = nn.functional.normalize

    def forward_one(self, x):
        x = self.base_network(x)
        x = x.view(x.size(0), -1)  
        
        x = F.relu(self.fc1(x))
        x = self.ln1(x)
        x = self.dropout1(x)
        x = self.swiglu(x)
        x = self.ln_swiglu(x)   
           
        x = F.relu(self.fc2(x))
        x = self.ln2(x)
        x = self.dropout2(x)
        
        x = F.relu(self.fc3(x))
        x = self.ln3(x)
        x = self.dropout3(x)
        
        x = self.l2_norm(x, p=2, dim=1)
        return x
    
    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        return torch.pairwise_distance(output1, output2)


In [None]:
# Loss function
def triplet_loss(y_true, y_pred):
    match_loss = y_true * -2.0 * torch.log(1 - y_pred/2)
    mismatch_loss = torch.clamp((1 - y_true) * (-torch.log(y_pred/0.2)), min=0)
    return torch.mean(match_loss + mismatch_loss)

# Accuracy metric
def accuracy(y_true, y_pred):
    threshold_check = (y_pred <= 0.2).float()
    return (y_true == threshold_check).float().mean()

In [None]:
# Training loop
def train(model, train_loader, val_loader, num_epochs, device):
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=1, factor=0.1, min_lr=1e-5)
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_acc = 0
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        
        with tqdm(total=len(train_loader), desc="Training") as progress_bar:
            for batch_idx, (data1, data2, target) in enumerate(train_loader):
                data1, data2, target = data1.to(device), data2.to(device), target.to(device)
                
                optimizer.zero_grad()
                output = model(data1, data2)
                loss = triplet_loss(target, output)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                train_acc += accuracy(target, output)
                
                progress_bar.update(1) 
        
        train_loss /= len(train_loader)
        train_acc /= len(train_loader)
        
        val_loss, val_acc = validate(model, val_loader, device)
        
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
        
        # Save model every 5 epochs 
        if (epoch + 1) % 5 == 0:
            model_filename = f"siamese_model_epoch_{epoch+1}_val_acc_{val_acc:.4f}.pth"
            torch.save(model.state_dict(), model_filename)
            print(f"Model saved as {model_filename}")
        
        scheduler.step(val_loss)

def validate(model, val_loader, device):
    model.eval()
    val_loss = 0
    val_acc = 0
    
    with tqdm(total=len(val_loader), desc="Validation") as progress_bar:  
        with torch.no_grad():
            for data1, data2, target in val_loader:
                data1, data2, target = data1.to(device), data2.to(device), target.to(device)
                
                output = model(data1, data2)
                loss = triplet_loss(target, output)
                
                val_loss += loss.item()
                val_acc += accuracy(target, output)
                
                progress_bar.update(1)  
    
    val_loss /= len(val_loader)
    val_acc /= len(val_loader)
    
    return val_loss, val_acc

In [None]:
# Main 
if __name__ == "__main__":
    chunked_noise_path = r"" # path to noise (chunked)
    dataset_path = r""       # path to dataset
    
    train_dataset = CustomDataset(chunked_noise_path, dataset_path, training=True)
    val_dataset = CustomDataset(chunked_noise_path, dataset_path, training=False)
    
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SiameseNetworkResNet101().to(device)
    
    num_epochs = 50
    train(model, train_loader, val_loader, num_epochs, device)
    
    torch.save(model.state_dict(), "siamese_model.pth")