# Introduction to digital signal processing

Credits to https://github.com/markovka17/apdl/blob/master/week01/

In [None]:
# !pip install torchaudio==0.12.1

In [None]:
from matplotlib import pyplot as plt
from IPython import display

import torch
import torchaudio

import numpy as np

In [None]:
torchaudio.__version__

# Time domain $\rightarrow$  frequency domain

In [None]:
wav, sr = torchaudio.load('c.wav')

In [None]:
wav

In [None]:
display.Audio(wav, rate=sr)

In [None]:
def visualize_audio(wav: torch.Tensor, sr: int = 22050):
    # Average all channels
    if wav.dim() == 2:
        # Any to mono audio convertion
        wav = wav.mean(dim=0)
    
    plt.figure(figsize=(20, 5))
    plt.plot(wav, alpha=.7, c='green')
    plt.grid()
    plt.xlabel('Time', size=20)
    plt.ylabel('Amplitude', size=20)
    plt.show()
    
    display.display(display.Audio(wav, rate=sr))


In [None]:
visualize_audio(wav, sr)

In [None]:
n_fft = 1024
ft = torch.fft.fft(wav.mean(dim=0), n=n_fft)

In [None]:
ft.dtype

In [None]:
magnitude = ft.abs().pow(2)
frequency = np.linspace(0, sr, len(magnitude))

In [None]:
frequency[:5000].shape

In [None]:
# plot spectrum
plt.figure(figsize=(18, 8))
plt.plot(frequency, magnitude) # magnitude spectrum
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitude")
plt.show()

In [None]:
# plot spectrum
plt.figure(figsize=(18, 8))
plt.plot(frequency[:90], magnitude[:90]) # magnitude spectrum
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitude")
plt.show()

In [None]:
sr

In [None]:
1 / 523 * 44100

In [None]:
# plot wave
plt.figure(figsize=(18, 8))
plt.plot(wav.mean(0)[100:600])
plt.show()

# Build mel spectrogram

In [None]:
wav, sr = torchaudio.load('example.wav')

In [None]:
n_fft = 1024

In [None]:
spectrum = torch.fft.rfft(wav, n=n_fft)

In [None]:
spectrum.dtype

In [None]:
spectrum = torch.fft.rfft(wav.mean(dim=0), n=n_fft)

In [None]:
spectrogram = spectrum.abs().pow(2)

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(spectrogram.squeeze(), c='green')
plt.grid()
plt.xlabel('Frequency (Hz)', size=20)
plt.ylabel('Magnitude$^2$', size=20)
plt.show()


In [None]:
window_size = n_fft
window = torch.hann_window(window_size)

plt.figure(figsize=(20, 5))
plt.plot(window, c='green')
plt.grid()
plt.show()


In [None]:
clipped_wav = wav[:, :window_size]
windowed_clipped_wav = window * clipped_wav

fig, axes = plt.subplots(1, 2, figsize=(20, 5))

axes[0].plot(clipped_wav.squeeze(), c='green')
axes[0].set_title('Raw Audio', size=20)

axes[1].plot(windowed_clipped_wav.squeeze(), c='green')
axes[1].set_title('Windowed Audio', size=20)

for i in range(2):
    axes[i].grid()
    axes[i].set_xlabel('Time', size=20)
    axes[i].set_ylabel('Amplitude', size=20)

plt.show()


In [None]:
spectrogram = torch.fft.rfft(clipped_wav).abs().pow(2)
windowed_spectrogram = torch.fft.rfft(windowed_clipped_wav).abs().pow(2)

fig, axes = plt.subplots(1, 2, figsize=(20, 5))

axes[0].plot(spectrogram.squeeze(), c='green')
axes[0].set_title('Spectrogram of Raw Audio', size=20)

axes[1].plot(windowed_spectrogram.squeeze(), c='green')
axes[1].set_title('Spectrogram of Windowed Audio', size=20)

for i in range(2):
    axes[i].grid()
    axes[i].set_xlabel('Frequency (Hz)', size=20)

plt.show()


In [None]:
spectrum = torch.stft(
    wav,
    n_fft=1024,
    hop_length=256,
    win_length=1024,
    window=torch.hann_window(1024),
    
    # We don't want to pad input signal
    center=False,
    
    # Take first (n_fft // 2 + 1) frequencies
    onesided=True,
    
    # Apply torch.view_as_real on each window
    return_complex=False, 
)


In [None]:
spectrum.shape


In [None]:
spectrogram = spectrum.norm(dim=-1).pow(2)
spectrogram.shape


In [None]:
spectrogram.max(), spectrogram.min(), spectrogram.mean()

In [None]:
plt.figure(figsize=(20, 5))
plt.imshow(spectrogram.squeeze().data.numpy()[::-1,])
plt.xlabel('Time', size=20)
plt.ylabel('Frequency (Hz)', size=20)
plt.colorbar()

plt.show()


In [None]:
plt.figure(figsize=(20, 5))
plt.imshow(spectrogram.squeeze().log().data.numpy()[::-1,])
plt.xlabel('Time', size=20)
plt.ylabel('Frequency (Hz)', size=20)
plt.colorbar()

plt.show()


In [None]:
mel_scaler = torchaudio.transforms.MelScale(
    n_mels=80,
    sample_rate=22_050,
    n_stft=n_fft // 2 + 1
)

In [None]:
mel_scaler.fb.shape

In [None]:
plt.figure(figsize=(20, 5))
plt.imshow(mel_scaler.fb.T)
plt.xlabel('Hertz Scale', size=20)
plt.ylabel('Mels Scale', size=20)
plt.gca().invert_yaxis()
plt.show()


In [None]:
mel_spectrogram = mel_scaler(spectrogram)


In [None]:
mel_spectrogram.shape


In [None]:
plt.figure(figsize=(20, 5))
plt.imshow(mel_spectrogram.squeeze().log())
plt.xlabel('Time', size=20)
plt.ylabel('Mels', size=20)
plt.show()


# Audio mnist classification

![](https://i.imgur.com/OX1ADxu.png)



Uncomment to download data

In [1]:
# !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1ouSOru91p-ZJCyI6E8cGh7N0r3vffi06' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1ouSOru91p-ZJCyI6E8cGh7N0r3vffi06" -O AudioMNIST.zip && rm -rf /tmp/cookies.txt

Uncomment to unzip archive

In [None]:
# !unzip -q AudioMNIST.zip

In [None]:
from typing import List, Tuple

import pathlib
from tqdm import tqdm
from itertools import islice
from collections import defaultdict

import torch
import torch.nn.functional as F
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, Subset


Определим класс AudioMnistDataset для загрузки данных. При загрузке извлечем из названия файла label аудио - произносимую цифру - первый символ в аудио до знака "_"



In [None]:
class AudioMnistDataset(Dataset):
    SR = 16_000
    
    """
    Each wavfile has the following format: digit_speackerid_wavid.wav
        For example, 6_01_47.wav:
            6 -- the number 6 is spoken
            01 -- the number is spoken by 1 speaker
            47 -- id of wavfile        
    """
    
    def __init__(self, path_to_data: str):
        self.path_to_data = pathlib.Path(path_to_data)
        self.paths = list(self.path_to_data.rglob('?_*_*.wav'))
    
    def __getitem__(self, index: int):
        path_to_wav = self.paths[index].as_posix()
        
        # Load wav
        wav, sr = torchaudio.load(path_to_wav)
        
        label = int(path_to_wav.split('/')[-1].split('_')[0])
        
        return wav, label
    
    def __len__(self):
        return len(self.paths)


In [None]:
dataset = AudioMnistDataset('AudioMNIST')


In [None]:
wav, label = dataset[123]
label

In [None]:
visualize_audio(wav, sr=dataset.SR)

Разобъем датасет на датасеты для трейна и валидации.



In [None]:
train_ratio = 0.9
train_size = int(len(dataset) * train_ratio)
validation_size = len(dataset) - train_size

indexes = torch.randperm(len(dataset))
train_indexes = indexes[:train_size]
validation_indexes = indexes[train_size:]

train_dataset = Subset(dataset, train_indexes)
validation_dataset = Subset(dataset, validation_indexes)

In [None]:
assert not set(train_indexes.tolist()).intersection(set(validation_indexes.tolist()))

Класс Collator объединяет аудиодорожки в один батч. Так как в нашей задаче все аудио разной длины, для того чтобы собрать в батч, заполним недостающее нулями (сделаем паддинг). Для этого создадим тензор из нулей размера [batch_size, max_wav_len] и заполним его элементами батча.



In [None]:
class Collator:
    
    def __call__(self, batch: List[Tuple[torch.Tensor, int]]):
        lengths = []
        wavs, labels = zip(*batch)
        
        for wav in wavs:
            lengths.append(wav.size(-1))
        
        max_len = max(lengths)
        # your code

        labels = torch.tensor(labels).long()
        lengths = torch.tensor(lengths).long()
        
        return {
            'wav': batch_wavs,
            'label': labels,
            'length': lengths,
        }


In [None]:
train_dataloader = DataLoader(
    train_dataset, batch_size=32,
    shuffle=True, collate_fn=Collator(),
    num_workers=2, pin_memory=True
)

validation_dataloader = DataLoader(
    validation_dataset, batch_size=32,
    collate_fn=Collator(),
    num_workers=2, pin_memory=True
)

Класс Featurizer делает необходимый в нашей задаче препроцессинг - считаем мел спектрограмму, логарифмирует ее и считает длину спектрограммы.



In [2]:
NUM_MELS = 80
HOP_LEN = 256
N_FFT = 1024
WIN_LEN = N_FFT
SAMPLE_RATE = 16000

In [None]:
class Featurizer(nn.Module):
    
    def __init__(self):
        super(Featurizer, self).__init__()
        
        self.featurizer = torchaudio.transforms.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_fft=N_FFT,
            win_length=WIN_LEN,
            hop_length=HOP_LEN,
            n_mels=NUM_MELS,
        )
        
    def forward(self, wav, length=None):
        mel_spectrogram = self.featurizer(wav)
        mel_spectrogram = mel_spectrogram.clamp(min=1e-5).log()
        
        if length is not None:
            length = (length - self.featurizer.win_length) // self.featurizer.hop_length
            # We add `4` because in MelSpectrogram center==True
            length += 4
            
            return mel_spectrogram, length
        
        return mel_spectrogram


В модель приходит батч размера [batch_size, num_mels, seq_len].

Последний слой - слой для классификации, возвращающий тензор размера [batch_size, NUM_CLASSES].


In [None]:
class Model(nn.Module):
    
    def __init__(self, input_dim, num_channels):
        super().__init__()
        self.input_dim = input_dim
        self.num_channels = num_channels
        # your code

    def forward(self, inputs, length=None):
        # inputs of shape [batch_size, num_mels, seq_len]
        # your code
        

Зададим параметры модели и оптимизатора и функцию потерь



In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = Model(input_dim=NUM_MELS, num_channels=32).to(device)
featurizer = Featurizer().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()


Создадим класс для подсчета метрики



In [None]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    """
    
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [None]:
storage = defaultdict(list)
num_epoch = 10

In [None]:
for epoch in range(num_epoch):
    train_loss_meter = AverageMeter()
    print(f"Epoch {epoch} out of {num_epoch}")

    model.train()
    for i, batch in enumerate(tqdm(train_dataloader)):
        # Move batch to device if device != 'cpu'
        wav = batch['wav'].to(device)
        length = batch['length'].to(device)
        label = batch['label'].to(device)

        # Compute mel spectrogram
        mel, mel_length = featurizer(wav, length)

        # feed model
        output = model(mel, mel_length)

        # compute loss
        loss = criterion(output, label)

        # zero out previously computed gradients
        optimizer.zero_grad()

        # compute gradients
        loss.backward()
    
        # update weights
        optimizer.step()

        # update metrics
        train_loss_meter.update(loss.item())
        
    storage['train_loss'].append(train_loss_meter.avg)
    
    validation_loss_meter = AverageMeter()
    validation_accuracy_meter = AverageMeter()

    model.eval()
    for i, batch in islice(enumerate(tqdm(validation_dataloader)), 1):
        # Move batch to device if device != 'cpu'
        wav = batch['wav'].to(device)
        length = batch['length'].to(device)
        label = batch['label'].to(device)

        # in inference mode we don't need to compute gradients
        # so we use `no_grad()` context manager to speed up inference
        with torch.no_grad():

            mel, mel_length = featurizer(wav, length)
            output = model(mel, mel_length)

            loss = criterion(output, label)
        
        # compute accuracy
        matches = (output.argmax(dim=-1) == label).float().mean()

        validation_loss_meter.update(loss.item())
        validation_accuracy_meter.update(matches.item())
    
    storage['validation_loss'].append(validation_loss_meter.avg)
    storage['validation_accuracy'].append(validation_accuracy_meter.avg)
    
    display.clear_output()
    
    fig, axes = plt.subplots(1, 3, figsize=(20, 5))

    axes[0].plot(storage['train_loss'], label='train_loss')
    axes[1].plot(storage['validation_loss'], label='validation_loss')

    axes[2].plot(storage['validation_accuracy'], label='validation_accuracy')

    for i in range(3):
        axes[i].grid()
        
        axes[i].legend()

    plt.show()


In [None]:
def inference(dataloader, take_n=10):
    """
    Display wav and results of NN
    """
    batch = next(iter(dataloader))
    # your code


In [None]:
inference(validation_dataloader)