In [1]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121


In [2]:
import torch
print(torch.cuda.is_available())

True


In [3]:
import librosa
import numpy as np
import random
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from audiomentations import Compose, TimeStretch, PitchShift, Shift
import torchaudio
import torchaudio.functional as TF
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import os

# --- Global Settings ---
SAMPLE_RATE = 16000
AUDIO_WINDOW = 1.0  # seconds
AUDIO_LENGTH = int(AUDIO_WINDOW * SAMPLE_RATE)
LOG_MEL_MEAN = 1.4
LOG_MEL_STD = 1.184

# --- Utility Functions ---
def randomCrop(x: np.array, length=AUDIO_LENGTH) -> np.array:
    assert(x.shape[0] > length)
    frontBits = random.randint(0, x.shape[0] - length)
    return x[frontBits:frontBits + length]

def addPadding(x: np.array, length=AUDIO_LENGTH) -> np.array:
    assert(x.shape[0] < length)
    bitCountToBeAdded = length - x.shape[0]
    frontBits = random.randint(0, bitCountToBeAdded)
    new_x = np.append(np.zeros(frontBits), x)
    new_x = np.append(new_x, np.zeros(bitCountToBeAdded - frontBits))
    assert new_x.shape[0] == length, f"Error: Padded audio shape is {new_x.shape}, expected {length}"
    return new_x

def removeExistingPadding(x: np.array) -> np.array:
    lastZeroBitBeforeAudio = 0
    firstZeroBitAfterAudio = len(x)
    for i in range(len(x)):
        if x[i] == 0:
            lastZeroBitBeforeAudio = i
        else:
            break
    for i in range(len(x) - 1, 1, -1):
        if x[i] == 0:
            firstZeroBitAfterAudio = i
        else:
            break
    return x[lastZeroBitBeforeAudio:firstZeroBitAfterAudio]

def fixPaddingIssues(x: np.array, length=AUDIO_LENGTH) -> np.array:
    x = removeExistingPadding(x)
    if x.shape[0] > length:
        return randomCrop(x, length=length)
    elif x.shape[0] < length:
        return addPadding(x, length=length)
    else:
        return x

def addNoise(x: np.array, noise: np.array, noise_factor=0.4) -> np.array:
    assert(x.shape[0] == noise.shape[0])
    out = (1 - noise_factor) * x / x.max() + noise_factor * (noise / noise.max())
    return out / out.max()

def splitNoiseFileToChunks(filename: str, target_folder: str, count=100, sr=16000):
    noiseAudio, _ = librosa.load(filename, sr=sr)
    if len(noiseAudio) <= AUDIO_LENGTH:
        print(f"Warning: Audio file {filename} is shorter than {AUDIO_LENGTH / SAMPLE_RATE} seconds. Skipping this file.")
        return

    for i in range(count):
        noiseAudioCrop = randomCrop(noiseAudio)
        outFilePath = target_folder + "/" + (f"{'.'.join(filename.split('.')[:-1])}_{i}.wav").split("/")[-1]
        sf.write(outFilePath, noiseAudioCrop, sr, 'PCM_24')

# --- Audio Augmentation ---
augmentation_pipeline = Compose([
    TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
    PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
    Shift(min_shift=-0.5, max_shift=0.5, p=0.5),
])

# --- Mel Spectrogram Calculation (Corrected) ---
def get_mel_spectrogram(waveform):
    """Calculates log-Mel spectrogram using torch.stft."""
    waveform = waveform.unsqueeze(1)

    print("Waveform shape:", waveform.shape)
    print("Waveform device:", waveform.device)

    # STFT parameters
    n_fft = 512
    hop_length = 160
    win_length = 512
    window = torch.hann_window(win_length)

    print("Window shape:", window.shape)
    print("Window device:", window.device)

    # Calculate STFT 
    stft_out = torch.stft(waveform,
                          n_fft=n_fft,
                          hop_length=hop_length,
                          win_length=win_length,
                          window=window.to(waveform.device),
                          center=False,
                          onesided=True,
                          return_complex=True)

    print("STFT output shape:", stft_out.shape)  # <--- Print the shape 
    print("STFT output device:", stft_out.device) # <--- Print the device

    # Magnitude spectrogram
    magnitude_spec = torch.abs(stft_out)


    # Magnitude spectrogram
    magnitude_spec = torch.abs(stft_out)

    # Mel filter bank parameters
    n_mels = 64
    f_min = 50
    f_max = 8000

    # Create Mel filter bank
    mel_filter_bank = TF.create_fb_matrix(
        n_freqs=n_fft // 2 + 1,
        f_min=f_min,
        f_max=f_max,
        n_mels=n_mels,
        sample_rate=SAMPLE_RATE
    )

    # Apply filter bank
    melspec = torch.matmul(magnitude_spec.transpose(1, 2), mel_filter_bank.to(magnitude_spec.device)).transpose(1, 2)

    # Log-scaling and normalization
    log_melspec = (torch.log(melspec + 1e-9) - LOG_MEL_MEAN) / LOG_MEL_STD

    # Replicate channels for ResNet compatibility
    log_melspec = log_melspec.repeat(1, 3, 1, 1) 

    return log_melspec

In [4]:
from torch.utils.data import Dataset , DataLoader

class WakeWordDataset(Dataset):
    def __init__(self, chunked_noise_path, dataset_path, transform=None,
                 training=True, max_noise_factor=0.2, min_noise_factor=0.05,
                 sampling_rate=16000, spectrogram=True, print_words=False):
        super(WakeWordDataset, self).__init__()
        self.chunked_noise_path = chunked_noise_path
        self.dataset_path = dataset_path
        self.transform = transform 
        self.training = training
        self.max_noise_factor = max_noise_factor
        self.min_noise_factor = min_noise_factor
        self.sampling_rate = sampling_rate
        self.spectrogram = spectrogram
        self.print_words = print_words

        self.types_of_noise = self.get_file_list(self.chunked_noise_path)
        self.words_in_dataset = self.get_file_list(self.dataset_path)

        if self.training:
            self.words_in_dataset = self.words_in_dataset[:int(0.9 * len(self.words_in_dataset))]
            print("Train size: ", len(self.words_in_dataset))
        else:
            self.words_in_dataset = self.words_in_dataset[int(0.9 * len(self.words_in_dataset)):]
            print("Test size: ", len(self.words_in_dataset))

        if self.spectrogram:
            self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
                sample_rate=sampling_rate,
                n_fft=512,
                win_length=512,
                hop_length=160,
                n_mels=32
            )

    def get_file_list(self, path):
        return [f for f in os.listdir(path) if not f.startswith('.')]

    def give_joined_audio(self, word1, word2):
        if self.print_words:
            print(word1, word2)

        sample1_path, sample2_path = self.get_audio_paths(word1, word2)
        voice_vector1, _ = librosa.load(sample1_path, sr=self.sampling_rate)
        voice_vector2, _ = librosa.load(sample2_path, sr=self.sampling_rate)

        voice_vector1 = fixPaddingIssues(voice_vector1) 
        voice_vector2 = fixPaddingIssues(voice_vector2)

        noise_vector1, noise_vector2 = self.load_noise()

        random_noise_factor1 = random.uniform(self.min_noise_factor, self.max_noise_factor)
        random_noise_factor2 = random.uniform(self.min_noise_factor, self.max_noise_factor)

        voice_with_noise1 = addNoise(voice_vector1, noise_vector1, random_noise_factor1)
        voice_with_noise2 = addNoise(voice_vector2, noise_vector2, random_noise_factor2)

        if self.spectrogram:
            voice_with_noise1 = torch.tensor(voice_with_noise1, dtype=torch.float32).unsqueeze(0)
            voice_with_noise2 = torch.tensor(voice_with_noise2, dtype=torch.float32).unsqueeze(0)

            voice_with_noise_spectrogram1 = self.mel_spectrogram(voice_with_noise1)
            voice_with_noise_spectrogram2 = self.mel_spectrogram(voice_with_noise2)
            
            # Convert to decibel scale
            voice_with_noise_spectrogram1 = torchaudio.transforms.AmplitudeToDB()(voice_with_noise_spectrogram1)
            voice_with_noise_spectrogram2 = torchaudio.transforms.AmplitudeToDB()(voice_with_noise_spectrogram2)
            
            return voice_with_noise_spectrogram1.squeeze(0), voice_with_noise_spectrogram2.squeeze(0)
        else:
            return voice_with_noise1, voice_with_noise2

    def get_audio_paths(self, word1, word2):
        if word1 == word2:
            sample1, sample2 = random.sample(self.get_file_list(os.path.join(self.dataset_path, word1)), 2)
            return (os.path.join(self.dataset_path, word1, sample1),
                    os.path.join(self.dataset_path, word2, sample2))
        else:
            sample1 = random.choice(self.get_file_list(os.path.join(self.dataset_path, word1)))
            sample2 = random.choice(self.get_file_list(os.path.join(self.dataset_path, word2)))
            return (os.path.join(self.dataset_path, word1, sample1),
                    os.path.join(self.dataset_path, word2, sample2))
    def load_noise(self):
        noise_type1, noise_type2 = random.sample(self.types_of_noise, 2)
        noise_file1 = random.choice(self.get_file_list(os.path.join(self.chunked_noise_path, noise_type1)))
        noise_file2 = random.choice(self.get_file_list(os.path.join(self.chunked_noise_path, noise_type2)))

        noise_vector1, _ = librosa.load(os.path.join(self.chunked_noise_path, noise_type1, noise_file1),
                                        sr=self.sampling_rate)
        noise_vector2, _ = librosa.load(os.path.join(self.chunked_noise_path, noise_type2, noise_file2),
                                        sr=self.sampling_rate)
        return noise_vector1, noise_vector2

    def __getitem__(self, index):
        word1 = self.words_in_dataset[index // 2] 
        word2 = self.words_in_dataset[(index // 2 + index % 2) % (len(self.words_in_dataset) // 2)]

        x1, x2 = self.give_joined_audio(word1, word2)
        
        y = 1.0 if index % 2 == 0 else 0.0  # 1 for match, 0 for mismatch
        return (x1, x2), torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return 2 * len(self.words_in_dataset) # Double for positive & negative pairs

In [5]:
class SiameseResNet(nn.Module):
    def __init__(self):
        super(SiameseResNet, self).__init__()
        self.resnet = torchvision.models.resnet50(pretrained=True)
        # No need to modify conv1 

        # Freeze early layers
        for param in self.resnet.parameters():
            param.requires_grad = False
        for param in self.resnet.layer4.parameters():
            param.requires_grad = True

        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(2048, 128)

    def forward_one(self, x):
        x = self.resnet(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.normalize(x, p=2, dim=1)  # L2 normalization
        return x

    def forward(self, input1, input2):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        return output1, output2

In [6]:
def triplet_loss(output1, output2, y_true, margin=1.0):
    euclidean_distance = F.pairwise_distance(output1, output2, p=2)
    loss_contrastive = torch.mean((1 - y_true) * torch.pow(euclidean_distance, 2) +
                                    (y_true) * torch.pow(torch.clamp(margin - euclidean_distance, min=0.0), 2))
    return loss_contrastive

def accuracy(output1, output2, y_true, threshold=0.2):
    euclidean_distance = F.pairwise_distance(output1, output2, p=2)
    predictions = (euclidean_distance < threshold).float()
    correct = (predictions == y_true).float().sum()
    return (correct / len(y_true)).item()

In [7]:
if __name__ == "__main__": 
    BATCH_SIZE = 32
    EPOCHS = 10
    LEARNING_RATE = 1e-4

    chunkedNoisePath = r"C:\Users\salos\OneDrive\Desktop\EfficientWord-Net\Efficient_word_net\NoiseChunked" 
    datasetPath = r"C:\Users\salos\OneDrive\Desktop\EfficientWord-Net\Efficient_word_net\test"

    train_dataset = WakeWordDataset(chunkedNoisePath, datasetPath, spectrogram=True, training=True)
    test_dataset = WakeWordDataset(chunkedNoisePath, datasetPath, spectrogram=True, training=False)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = SiameseResNet().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    writer = SummaryWriter() 

    for epoch in range(EPOCHS):
        for batch_idx, ((data1, data2), y) in enumerate(train_loader):
            data1, data2, y = data1.to(device), data2.to(device), y.to(device)

            optimizer.zero_grad()
            output1, output2 = model(data1, data2) 
            
            loss = triplet_loss(output1, output2, y)
            loss.backward()
            optimizer.step()

Train size:  21157
Test size:  2351




RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 32, 32, 101] to have 3 channels, but got 32 channels instead