In [39]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import torchaudio.functional as F
import matplotlib.pyplot as plt
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base")

import warnings
warnings.filterwarnings("ignore")

In [113]:
class LibriSpeechDataset(Dataset):

    """
    LibriSpeechDataset downloaded from OpenSLR: https://www.openslr.org/12

    There are 5 splits downloaded, 3 which are for training and 3 for testing:

        Training: ["train-clean-100", "train-clean-360", "train-other-500"]
        Validation: ["dev-clean", "test-clean"]

    """
    def __init__(self, 
                 path_to_data_root, 
                 include_splits=["train-clean-100", "train-clean-360", "train-other-500"],
                 sampling_rate=16000,
                 num_audio_channels=1):
        
        if isinstance(include_splits, str):
            include_splits = [include_splits]

        self.sampling_rate = sampling_rate
        self.num_audio_channels = num_audio_channels

        ### GET PATH TO ALL AUDIO/TEXT FILES ###
        self.librispeech_data = []
        for split in include_splits:
            path_to_split = os.path.join(path_to_data_root, split)
            
            for speaker in os.listdir(path_to_split):
                path_to_speaker = os.path.join(path_to_split, speaker)

                for section in os.listdir(path_to_speaker):
                    path_to_section = os.path.join(path_to_speaker, section)

                    ### Grab Files and Split FLAC Audios and Text Transcripts ###
                    files = os.listdir(path_to_section)
                    transcript_file = [path for path in files if ".txt" in path][0]

                    ### Load Transcripts ###
                    with open(os.path.join(path_to_section, transcript_file), "r") as f:
                        transcripts = f.readlines()

                    ### Split Transcripts by Audio Filename and Transcript ###
                    for line in transcripts:
                        split_line = line.split()
                        audio_root = split_line[0]
                        audio_file = audio_root + ".flac"
                        full_path_to_audio_file = os.path.join(path_to_section, audio_file)
                        transcript = " ".join(split_line[1:]).strip()

                        self.librispeech_data.append((full_path_to_audio_file, transcript))
   
        self.audio2mels =  T.MelSpectrogram(
            sample_rate=sampling_rate,
            n_mels=80
        )

        self.amp2db = T.AmplitudeToDB(
            top_db=80.0
        )
        
    def __len__(self):
        return len(self.librispeech_data)
    
    def __getitem__(self, idx):
        
        ### Grab Path to Audio and Transcript ###
        path_to_audio, transcript = self.librispeech_data[idx]

        ### Load Audio ###
        audio, orig_sr = torchaudio.load(path_to_audio, normalize=True)

        if orig_sr != self.sampling_rate:
            audio = torchaudio.functional.resample(audio, orig_freq=orig_sr, new_freq=self.sampling_rate)
        
        ### Create Mel Spectrogram ###
        mel = self.audio2mels(audio)

        ### Convert to Decibels ###
        mel = self.amp2db(mel)

        ### Normalize Spectrogram ###
        mel = (mel - mel.mean()) / (mel.std() + 1e-6)

        ### Tokenize Text ###
        tokenized_transcript = torch.tensor(tokenizer.encode(transcript))

        batch = {"input_values": mel[0].T, 
                 "labels": tokenized_transcript}
        
        return batch


def collate_fn(batch):

    """
    This collate function is basically the heart of our implementation! It includes everything we need for training
    such as attention masks, sub_attention_masks, span_masks and our sampled negatives!
    """

    ### Sort Batch from Longest to Shortest (for future packed padding) ###
    batch = sorted(batch, key=lambda x: x["input_values"].shape[0], reverse=True)
    
    ### Grab Audios from our Batch Dictionary ###
    batch_mels = [sample["input_values"] for sample in batch]
    batch_transcripts = [sample["labels"] for sample in batch]

    ### Get Length of Audios ###
    seq_lens = torch.tensor([b.shape[0] for b in batch_mels], dtype=torch.long)

    ### Pad and Stack Spectrograms ###
    spectrograms = torch.nn.utils.rnn.pad_sequence(batch_mels, batch_first=True, padding_value=0)

    ### Convert to Shape Convolution Is Happy With (B x C x H x W) ###
    spectrograms = spectrograms.unsqueeze(1).transpose(-1,-2)

    ### Get Target Lengths ###
    target_lengths = torch.tensor([len(t) for t in batch_transcripts], dtype=torch.long)

    ### Pack Transcripts (CTC Loss Can Take Packed Targets) ###
    packed_transcripts = torch.cat(batch_transcripts)

    ### Create Batch ###
    batch = {"input_values": spectrograms, 
             "seq_lens": seq_lens, 
             "labels": packed_transcripts, 
             "target_lengths": target_lengths}

    return batch
    
dataset = LibriSpeechDataset(path_to_data_root="/mnt/datadrive/data/LibriSpeech", include_splits=["train-clean-100"])
loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
for batch in loader:
    print(batch)
    break



{'input_values': tensor([[[[-0.2179,  0.7754,  1.0092,  ...,  1.0028,  0.8619,  0.7192],
          [ 0.2044,  1.1977,  1.4315,  ...,  1.4250,  1.2841,  1.1415],
          [ 1.1554,  1.1019,  1.2418,  ...,  1.2349,  1.1558,  1.1624],
          ...,
          [-0.5171, -0.7380, -0.8145,  ..., -1.1671, -1.2113, -1.4759],
          [-1.0369, -0.8433, -0.8466,  ..., -1.1060, -1.2173, -1.4674],
          [-0.8981, -0.9524, -0.7986,  ..., -0.9846, -0.9079, -1.1360]]],


        [[[ 0.5506,  0.8957,  0.7771,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.9604,  1.3054,  1.1869,  ...,  0.0000,  0.0000,  0.0000],
          [ 1.1985,  1.1731,  1.0924,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [-1.0183, -0.5068, -0.7741,  ...,  0.0000,  0.0000,  0.0000],
          [-0.9513, -1.2879, -1.1026,  ...,  0.0000,  0.0000,  0.0000],
          [-1.0110, -1.2289, -1.0970,  ...,  0.0000,  0.0000,  0.0000]]],


        [[[ 0.7977,  1.0251,  1.0864,  ...,  0.0000,  0.0000,  0.0000],
         

In [183]:
class MaskedConv2d(nn.Conv2d):

    """
    Our spectrograms are padded, so different spectrograms will have 
    a different length. We need to make sure we dont include any padding information
    in our convolution, and update padding masks for the next convolution in the stack!

    Args:
        
    """

    def __init__(self, 
                 in_channels,
                 out_channels, 
                 kernel_size,
                 stride, 
                 padding=0,
                 bias=True,
                 **kwargs):
        
        super(MaskedConv2d, self).__init__(in_channels=in_channels, 
                                           out_channels=out_channels, 
                                           kernel_size=kernel_size, 
                                           stride=stride, 
                                           padding=padding, 
                                           bias=bias, 
                                           **kwargs)

    def forward(self, x, seq_lens):

        """
        Updates convolution forward to zero out padding regions after convolution 
        """

        batch_size, channels, height, width = x.shape
        
        ### Compute Output Seq Lengths of Each Sample After Convolution ###
        output_seq_lens = self._compute_output_seq_len(seq_lens)

        ### Pass Data Through Convolution ###
        conv_out = super().forward(x)

        ### Zero Out Any Values In The Padding Region (After Convolution) So they Dont Contribute ###
        mask = torch.zeros(batch_size, output_seq_lens.max(), device=x.device)
        for i, length in enumerate(output_seq_lens):
            mask[i, :length] = 1

        ### Unsqueeze mask to match image shape ###
        mask = mask.unsqueeze(1).unsqueeze(1)

        ### Apply Mask ###
        conv_out = conv_out * mask

        return conv_out, output_seq_lens

    def _compute_output_seq_len(self, seq_lens):

        """
        To perform masking AFTER the encoding 2D Convolutions, we need to 
        compute what the shape of the output tensor is after each successive convolutions
        is applied.
    
        Convolution formula can be found in PyTorch Docs: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
            
        """
        
        return torch.floor((seq_lens + (2 * self.padding[1]) - (self.kernel_size[1] - 1) - 1) // self.stride[1]) + 1

class ConvolutionFeatureExtractor(nn.Module):

    def __init__(self, 
                 in_channels=1, 
                 out_channels=32):

        super(ConvolutionFeatureExtractor, self).__init__()

        self.in_channels = in_channels, 
        self.out_channels = out_channels

        self.conv1 = MaskedConv2d(in_channels, out_channels, kernel_size=(11, 41), stride=(2,2), padding=(5,20), bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = MaskedConv2d(out_channels, out_channels, kernel_size=(11, 21), stride=(2,1), padding=(5,10), bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        ### I just preset this, I know after I apply my two convolutions above with a kernel size of 
        ### 11, that my output feature vectors will go from 80 down to 20! I could also dynamically
        ### compute this but its probably fine, I wont be changing the data after all!
        self.output_feature_dim = 20

        ### Compute Final Output Features ###
        self.conv_output_features = self.output_feature_dim * self.out_channels
        
    def forward(self, x, seq_lens):

        x, seq_lens = self.conv1(x, seq_lens)
        x = self.bn1(x)
        x = torch.nn.functional.hardtanh(x)
        
        x, seq_lens = self.conv2(x, seq_lens)
        x = self.bn2(x)
        x = torch.nn.functional.hardtanh(x)
 
        x = x.permute(0,3,1,2).flatten(2)

        return x, seq_lens 

class RNNLayer(nn.Module):

    def __init__(self, 
                 input_size,
                 hidden_size = 512, 
                 dropout_p=0.1):

        super(RNNLayer, self).__init__()

        self.hidden_dim = hidden_size
        self.input_size = input_size

        self.rnn = nn.GRU(
            input_size=input_size, 
            hidden_size=hidden_size, 
            batch_first=True, 
            bidirectional=True
        )

        self.layernorm = nn.LayerNorm(2 * hidden_size)

    def forward(self, x, seq_lens):

        batch, seq_len, embed_dim = x.shape 
        
        ### Pack Sequence (For efficient computation that ignores padding) ###
        packed_x = nn.utils.rnn.pack_padded_sequence(x, seq_lens, batch_first=True)

        ### Pass Packed Sequence through RNN ###
        out, _ = self.rnn(packed_x)

        ### Unpack (and repad) sequence ###
        x, _ = nn.utils.rnn.pad_packed_sequence(out, total_length=seq_len, batch_first=True)

        ### Normalize ###
        x = self.layernorm(x)

        return x
        
class DeepSpeech2(nn.Module):

    def __init__(self,
                 conv_in_channels=1, 
                 conv_out_channels=32, 
                 rnn_hidden_size=512,
                 rnn_dropout_p=0.1):

        super(DeepSpeech2, self).__init__()

        self.feature_extractor = ConvolutionFeatureExtractor(
            conv_in_channels, conv_out_channels
        )

        self.output_hidden_features = self.feature_extractor.conv_output_features

        ### Stack Together RNN Layers ###
        ### First Layer has 640 inputs, everything after has 2 * 512 inputs ###
        self.rnns = nn.ModuleList(
            [
                RNNLayer(self.output_hidden_features if i==0 else 2 * rnn_hidden_size,
                         hidden_size=rnn_hidden_size,
                         dropout_p=rnn_dropout_p)
                for i in range(6)
            ]
        )

        ### Classification Head ###
        self.head = nn.Sequential(
            nn.Linear(2 * rnn_hidden_size, rnn_hidden_size), 
            nn.Hardtanh(), 
            nn.Linear(rnn_hidden_size, tokenizer.vocab_size)
        )

    def forward(self, x, seq_lens):

        ### Extract Features ###
        x, final_seq_lens = self.feature_extractor(x, seq_lens)

        ### Pass To RNN Layers ###
        for rnn in self.rnns:
            x = rnn(x, final_seq_lens)

        ### Classification Head ###
        x = self.head(x)

        return x, final_seq_lens

    

model = DeepSpeech2()
model(batch["input_values"], batch["seq_lens"])
# rnnlayer = RNNLayer(640)
# rnnlayer(x, seq_lens)
        
    


(tensor([[[-0.3361, -0.2026,  0.3064,  ...,  0.3774, -0.2509, -0.3006],
          [-0.2287, -0.0897,  0.2544,  ...,  0.3286, -0.2201, -0.1948],
          [-0.1910,  0.0636,  0.2897,  ...,  0.2103, -0.1194, -0.1016],
          ...,
          [-0.0449,  0.6207,  0.2410,  ...,  0.3058,  0.5405, -0.2013],
          [ 0.0043,  0.6845,  0.1449,  ...,  0.3928,  0.5194, -0.1913],
          [ 0.0208,  0.6529,  0.0241,  ...,  0.3828,  0.3839, -0.1984]],
 
         [[ 0.0476, -0.6017,  0.0141,  ...,  0.3220,  0.2874, -0.1674],
          [-0.0563, -0.5016, -0.1267,  ...,  0.3000,  0.4080, -0.0895],
          [-0.0313, -0.3985, -0.2119,  ...,  0.2242,  0.4900,  0.0284],
          ...,
          [-0.0128,  0.0306, -0.0057,  ..., -0.0099,  0.0274,  0.0149],
          [-0.0128,  0.0306, -0.0057,  ..., -0.0099,  0.0274,  0.0149],
          [-0.0128,  0.0306, -0.0057,  ..., -0.0099,  0.0274,  0.0149]],
 
         [[-0.3252,  0.1422, -0.1062,  ...,  0.3323,  0.1096, -0.1799],
          [-0.2652,  0.1368,

In [172]:
tokenizer.vocab_size

32

torch.Size([4, 32, 20, 638])
