# Speech command recognition - Day 3 Part 1
## Model improvement - Data augmentation
******
Author: Duowei Tang \
Reference: This exercise is adopted from a Pytorch tutorial: https://pytorch.org/tutorials/intermediate/speech_command_classification_with_torchaudio_tutorial.html \
Pytorch ignite: https://pytorch-ignite.ai/ \
The Aachen RIR dataset: https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/aachen-impulse-response-database/

In [None]:
import torch
import torch.nn.functional as F
import torchaudio

import matplotlib.pyplot as plt

from torchaudio.datasets import SPEECHCOMMANDS
import os
import glob


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

### Class lables
*******
In the code block below, it listed the selected command classes and all class labels that are commented out are considered as "unknown".

In [None]:
labels = ['forward', 'backward', 'up', 'down',
          'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'zero',
          'left', 'right', 'go', 'stop', 'yes', 'no', 'on', 'off', 'unknown']
# The following dataset labels are considered unkonwn
# unknown = ['bed', 'bird', 'cat', 'dog', 'follow', 'happy', 'house', 'learn', 'marvin',
#            'sheila', 'visual', 'wow', 'tree']

### Prepare the datasets
*******
Unlike the previous day where we first store the extracted MFCC features to the disk and load the pre-computed features during training. In this exercise, we will extract the features on-the-fly during training.

In [None]:
SPEECH_DATA_ROOT = "/Users/invincibleo/Leo/Projects/Datasets/SpeechCommands"
# Load the speech command dataset from pytorch dataset
class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__(os.path.dirname(SPEECH_DATA_ROOT), download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as fileobj:
                return [os.path.normpath(os.path.join(self._path, line.strip())) for line in fileobj]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]

# Create the paritions and features will be extracted during training
train_set = SubsetSC("training")
valid_set = SubsetSC("validation")
test_set = SubsetSC("testing")

In [None]:
def label_to_index(word):
    if word in labels:
        return torch.tensor(labels.index(word))
    else:
        return torch.tensor(labels.index("unknown"))

def index_to_label(index):
    # Return the word corresponding to the index in labels
    # This is the inverse of label_to_index
    return labels[index]

### Data augmentation examples
*******
In this section, we will try out several speech data augmentation techniques.

In [None]:
sample_rate = 16000
output_root = "dataAugOut/"

def add_exist_noise(waveform):
    # Apply existing noise
    noise_files = glob.glob(os.path.join(SPEECH_DATA_ROOT, "speech_commands_v0.02", "_background_noise_", "*.wav"))
    noise_file = noise_files[torch.randint(0, len(noise_files), (1,)).item()]
    noise_waveform, _ = torchaudio.load(noise_file)
    # Make waveform and noise the same length by padding with zeros
    if noise_waveform.size(-1) < waveform.size(-1):
        noise_waveform = F.pad(noise_waveform, (0, waveform.size(-1) - noise_waveform.size(-1)))
    elif noise_waveform.size(-1) >= waveform.size(-1):
        # randomly crop noise
        max_offset = noise_waveform.size(-1) - waveform.size(-1)
        offset = torch.randint(0, max_offset, (1,))
        noise_waveform = noise_waveform[..., offset:offset+waveform.size(-1)]
    waveform = torchaudio.transforms.AddNoise()(waveform, noise_waveform, snr=torch.randint(-5, 10, (1,)))
    return waveform

def time_shift(waveform):
    # Apply time shift
    shift_amount = int(sample_rate*0.3*torch.randint(-1, 1, (1,)).item())
    # Apply random time shift to waveform and zero pad at the beginning or at the end
    if shift_amount > 0:
        waveform = waveform[..., :-shift_amount]
        waveform = F.pad(waveform, (shift_amount, 0))
    else:
        waveform = waveform[..., -shift_amount:]
        waveform = F.pad(waveform, (0, -shift_amount))
    return waveform

def perturbate_speed(waveform):
    waveform_aug = torchaudio.transforms.SpeedPerturbation(orig_freq=sample_rate, factors=[0.9, 1.1, 1.0, 1.0])(waveform)[0]
    if waveform_aug.size(-1) < waveform.size(-1):
        waveform_aug = F.pad(waveform_aug, (0, waveform.size(-1) - waveform.size(-1)))
    else:
        length_diff = waveform_aug.size(-1) - waveform.size(-1)
        waveform_aug = waveform_aug[..., length_diff//2:length_diff//2+waveform.size(-1)]
    return waveform_aug

def adjust_volume(waveform):
    # Apply volume adjustment
    waveform = torchaudio.transforms.Vol(gain=torch.randint(low=-10, high=10, size=(1,)), gain_type='db')(waveform)
    return waveform

def add_white_noise(waveform):
    # Apply white noise
    waveform = torchaudio.transforms.AddNoise()(waveform, torch.randn_like(waveform), snr=torch.randint(low=-5, high=10, size=(1,)))
    return waveform

def time_mask(waveform):
    # Apply time masking
    waveform = torchaudio.transforms.TimeMasking(time_mask_param=int(0.1*sample_rate), p=1.0)(waveform)
    return waveform

def augment_rir(waveform):
    # Apply room impulse response
    rir_files = glob.glob(os.path.join("AIR_wav_files", "*.wav"))
    rir_file = rir_files[torch.randint(0, len(rir_files), (1,)).item()]
    rir_waveform, _ = torchaudio.load(rir_file)
    waveform_aug = torchaudio.functional.fftconvolve(waveform, rir_waveform)
    waveform_aug = waveform_aug[..., :sample_rate]
    return waveform_aug       

In [None]:
# Check the data augmentation functions - Add existing noise
for i, (waveform, sample_rate, label, speaker_id, utterance_number) in enumerate(train_set):
    if i == 10:
        break
    waveform_aug = add_exist_noise(waveform)
    # save the augmented waveform and original waveform
    os.makedirs(os.path.join(output_root, "addExistNoise"), exist_ok=True)
    torchaudio.save(os.path.join(output_root, "addExistNoise", f"{speaker_id}_{utterance_number}_{label}_aug.wav"), waveform_aug, sample_rate)
    torchaudio.save(os.path.join(output_root, "addExistNoise", f"{speaker_id}_{utterance_number}_{label}_org.wav"), waveform, sample_rate)

    if i == 1:
        plt.figure()
        plt.plot(waveform.t().numpy())
        plt.title("Original waveform")
        plt.show()

        plt.figure()
        plt.plot(waveform_aug.t().numpy())
        plt.title("Augmented waveform")
        plt.show()

In [None]:
# Check the data augmentation functions - Time shift
for i, (waveform, sample_rate, label, speaker_id, utterance_number) in enumerate(train_set):
    if i == 10:
        break
    waveform_aug = time_shift(waveform)
    # save the augmented waveform and original waveform
    os.makedirs(os.path.join(output_root, "timeShift"), exist_ok=True)
    torchaudio.save(os.path.join(output_root, "timeShift", f"{speaker_id}_{utterance_number}_{label}_aug.wav"), waveform_aug, sample_rate)
    torchaudio.save(os.path.join(output_root, "timeShift", f"{speaker_id}_{utterance_number}_{label}_org.wav"), waveform, sample_rate)

    if i == 1:
        plt.figure()
        plt.plot(waveform.t().numpy())
        plt.title("Original waveform")
        plt.show()

        plt.figure()
        plt.plot(waveform_aug.t().numpy())
        plt.title("Augmented waveform")
        plt.show()

In [None]:
# Check the data augmentation functions - Perturbate speed
for i, (waveform, sample_rate, label, speaker_id, utterance_number) in enumerate(train_set):
    if i == 10:
        break
    waveform_aug = perturbate_speed(waveform)
    # save the augmented waveform and original waveform
    os.makedirs(os.path.join(output_root, "perturbateSpeed"), exist_ok=True)
    torchaudio.save(os.path.join(output_root, "perturbateSpeed", f"{speaker_id}_{utterance_number}_{label}_aug.wav"), waveform_aug, sample_rate)
    torchaudio.save(os.path.join(output_root, "perturbateSpeed", f"{speaker_id}_{utterance_number}_{label}_org.wav"), waveform, sample_rate)

    if i == 1:
        plt.figure()
        plt.plot(waveform.t().numpy())
        plt.title("Original waveform")
        plt.show()

        plt.figure()
        plt.plot(waveform_aug.t().numpy())
        plt.title("Augmented waveform")
        plt.show()

In [None]:
# Check the data augmentation functions - Adjust volume
for i, (waveform, sample_rate, label, speaker_id, utterance_number) in enumerate(train_set):
    if i == 10:
        break
    waveform_aug = adjust_volume(waveform)
    # save the augmented waveform and original waveform
    os.makedirs(os.path.join(output_root, "adjustVolume"), exist_ok=True)
    torchaudio.save(os.path.join(output_root, "adjustVolume", f"{speaker_id}_{utterance_number}_{label}_aug.wav"), waveform_aug, sample_rate)
    torchaudio.save(os.path.join(output_root, "adjustVolume", f"{speaker_id}_{utterance_number}_{label}_org.wav"), waveform, sample_rate)

    if i == 1:
        plt.figure()
        plt.plot(waveform.t().numpy())
        plt.title("Original waveform")
        plt.show()

        plt.figure()
        plt.plot(waveform_aug.t().numpy())
        plt.title("Augmented waveform")
        plt.show()

In [None]:
# Check the data augmentation functions - Add white noise
for i, (waveform, sample_rate, label, speaker_id, utterance_number) in enumerate(train_set):
    if i == 10:
        break
    waveform_aug = add_white_noise(waveform)
    # save the augmented waveform and original waveform
    os.makedirs(os.path.join(output_root, "addWhiteNoise"), exist_ok=True)
    torchaudio.save(os.path.join(output_root, "addWhiteNoise", f"{speaker_id}_{utterance_number}_{label}_aug.wav"), waveform_aug, sample_rate)
    torchaudio.save(os.path.join(output_root, "addWhiteNoise", f"{speaker_id}_{utterance_number}_{label}_org.wav"), waveform, sample_rate)

    if i == 1:
        plt.figure()
        plt.plot(waveform.t().numpy())
        plt.title("Original waveform")
        plt.show()

        plt.figure()
        plt.plot(waveform_aug.t().numpy())
        plt.title("Augmented waveform")
        plt.show()

In [None]:
# Check the data augmentation functions - Time mask
for i, (waveform, sample_rate, label, speaker_id, utterance_number) in enumerate(train_set):
    if i == 10:
        break
    waveform_aug = time_mask(waveform)
    # save the augmented waveform and original waveform
    os.makedirs(os.path.join(output_root, "timeMask"), exist_ok=True)
    torchaudio.save(os.path.join(output_root, "timeMask", f"{speaker_id}_{utterance_number}_{label}_aug.wav"), waveform_aug, sample_rate)
    torchaudio.save(os.path.join(output_root, "timeMask", f"{speaker_id}_{utterance_number}_{label}_org.wav"), waveform, sample_rate)

    if i == 1:
        plt.figure()
        plt.plot(waveform.t().numpy())
        plt.title("Original waveform")
        plt.show()

        plt.figure()
        plt.plot(waveform_aug.t().numpy())
        plt.title("Augmented waveform")
        plt.show()

In [None]:
# Check the data augmentation functions - Augment RIR
for i, (waveform, sample_rate, label, speaker_id, utterance_number) in enumerate(train_set):
    if i == 10:
        break
    waveform_aug = augment_rir(waveform)
    # save the augmented waveform and original waveform
    os.makedirs(os.path.join(output_root, "augmentRIR"), exist_ok=True)
    torchaudio.save(os.path.join(output_root, "augmentRIR", f"{speaker_id}_{utterance_number}_{label}_aug.wav"), waveform_aug, sample_rate)
    torchaudio.save(os.path.join(output_root, "augmentRIR", f"{speaker_id}_{utterance_number}_{label}_org.wav"), waveform, sample_rate)

    if i == 1:
        plt.figure()
        plt.plot(waveform.t().numpy())
        plt.title("Original waveform")
        plt.show()

        plt.figure()
        plt.plot(waveform_aug.t().numpy())
        plt.title("Augmented waveform")
        plt.show()