In [1]:
!pip3 install Levenshtein

Collecting Levenshtein
  Downloading levenshtein-0.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)
Collecting rapidfuzz<4.0.0,>=3.9.0 (from Levenshtein)
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading levenshtein-0.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (161 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m161.7/161.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m43.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: rapidfuzz, Levenshtein
Successfully installed Levenshtein-0.27.1 rapidfuzz-3.13.0


In [2]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torchaudio
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import Levenshtein
import math

In [3]:
!ls /kaggle/input/

librispeech-datasets


In [4]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:

dataset_path = "/kaggle/input/librispeech-datasets/"
def download(dataset):
    audioset = torchaudio.datasets.LIBRISPEECH(dataset_path+'/'+dataset,url=dataset,download=False)
    return audioset

In [6]:
train_clean = download('train-clean-100')
dev_clean = download('dev-clean')
test_clean = download('test-clean')

In [7]:
class Unigram():
    def __init__(self,word_model_type):
        self.word_model_type = 'unigram'
        self.blank_id = 28
        self.n_class = 29

        self.SPACE = "[space]"
        self.characters = "' " + self.SPACE + " " +" ".join("abcdefghijklmnopqrstuvwxyz")
        self.tokens = self.characters.split(' ')

        self.char_to_id = {char: idx for idx, char in enumerate(self.tokens)}
        self.id_to_char = {idx: char for idx, char in enumerate(self.tokens)}

    def text_to_int(self, sentence: str):
        idx_sequence = []
        for ch in sentence:
            idx = self.char_to_id[self.SPACE] if ch == " " else self.char_to_id[ch]
            idx_sequence.append(idx)
        return idx_sequence

    def int_to_text(self, indices):
        sentence = []
        for i in indices:
            ch = self.id_to_char[i]
            sentence.append(ch)
        return "".join(sentence).replace(self.SPACE, " ")


word_encoding_model = Unigram('unigram')
original = "my name is olan"
encoded = word_encoding_model.text_to_int(original)
reconstructed = word_encoding_model.int_to_text(encoded)
print(original)
print(encoded)
print(reconstructed)

my name is olan
[14, 26, 1, 15, 2, 14, 6, 1, 10, 20, 1, 16, 13, 2, 15]
my name is olan


In [8]:
def preprocess(audioset,split,stride,word_model):

        if split[:5] == "train":
            train_pipe = nn.Sequential(
                torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
                torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
                torchaudio.transforms.TimeMasking(time_mask_param=35))
            augment_fn = train_pipe
        else:
            test_pipe = torchaudio.transforms.MelSpectrogram()
            augment_fn = test_pipe
      
        spectrograms = []
        indices = []
        len_spectrograms = []
        len_indices = []
        
        for waveform, _, transcript, _, _, _ in audioset:
            # Augment audio data
            spec = augment_fn(waveform).squeeze(0).transpose(0,1)
            spectrograms.append(spec)

            # Convert text transcript to sequence of ids
            ids = torch.Tensor(word_model.text_to_int(transcript.lower()))
            indices.append(ids)

            # Append audio and text length
            if stride == 2:
                len_spec = spec.shape[0]//stride
            else:
                len_spec = spec.shape[0]//stride - 2
            
            len_spectrograms.append(len_spec)
            len_indices.append(len(ids))
        
        # Zero pad
        spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
        indices = nn.utils.rnn.pad_sequence(indices, batch_first=True)

        return spectrograms, indices, len_spectrograms, len_indices

In [9]:
def reading_data_sample(loader):
    print("Data length : ",len(loader.dataset))
    for batch_sample in loader:
        print("Spectrogram shape:", list(batch_sample[0].shape))
        print("Label shape:", list(batch_sample[1].shape))
        print("Mel length (length of each spectrogram):", batch_sample[2][:6], "...")
        print("Idx length (length of each label):", batch_sample[3][:6], "...")
        break

In [10]:
# General Hyper-params
batch_size = 8
epochs =10

n_features = 128 
stride = 2      

lr = 0.0005

In [11]:
train_loader  = DataLoader(dataset=train_clean,
                               batch_size=batch_size,
                               shuffle=False,
                               collate_fn=lambda x: preprocess(x, "train-clean-100", stride, word_encoding_model))


dev_clean_loader = DataLoader(dataset=dev_clean,
                               batch_size=batch_size,
                               shuffle=False,
                               collate_fn=lambda x: preprocess(x, "dev-clean", stride, word_encoding_model))


test_clean_loader = DataLoader(dataset=test_clean,
                               batch_size=batch_size,
                               shuffle=False,
                               collate_fn=lambda x: preprocess(x, "test-clean", stride, word_encoding_model))

In [12]:
print(f"Train : {len(train_loader.dataset)} samples ")
print(f"Dev Clean : {len(dev_clean_loader.dataset)} samples ")
print(f"Test Clean : {len(test_clean_loader.dataset)} samples ")
print()

reading_data_sample(train_loader)

Train : 28539 samples 
Dev Clean : 2703 samples 
Test Clean : 2620 samples 

Data length :  28539
Spectrogram shape: [8, 1, 128, 1276]
Label shape: [8, 283]
Mel length (length of each spectrogram): [563, 638, 558, 588, 501, 607] ...
Idx length (length of each label): [201, 283, 250, 268, 227, 263] ...




In [13]:
class ConformerMHSA(nn.Module):
    def __init__(self, num_features, device, num_heads, max_rel_pos=800, drop_rate=0.1):

        super(ConformerMHSA, self).__init__()

        self.emb_dim = num_features
        self.num_heads = num_heads
        self.max_rel_pos = max_rel_pos

        self.norm = nn.LayerNorm(num_features)
        self.attention = nn.MultiheadAttention(num_features, num_heads, batch_first=True)
        self.dropout = nn.Dropout(p=drop_rate)
        
        self.device = device
        self.pos_matrix = self.get_positional_matrix().to(self.device)


    def get_positional_matrix(self):
        """
        Create positional matrix of shape (2*self.max_rel_pos + 1, emb_dim)
        Only (:seq_len, emb_dim) will be summed to input tensors
        """
        matrix = torch.zeros(2*self.max_rel_pos + 1, self.emb_dim)

        pos = torch.arange(0, 2*self.max_rel_pos + 1).unsqueeze(1).float()
        divisor = torch.exp(torch.arange(0, self.emb_dim, 2).float() * -math.log(10000) / self.emb_dim)
        
        matrix[:, 0::2] = torch.sin(pos*divisor)
        matrix[:, 1::2] = torch.cos(pos*divisor)
        final_matrix = matrix.unsqueeze(0)

        return final_matrix

    def forward(self, x):
        # Input shape: [batch, seq_len, num_features]
        batch_size, seq_len, _ = x.size()

        skip = x

        x = self.norm(x)
        pos_emb = self.pos_matrix[:, :seq_len, :].expand(batch_size, seq_len, self.emb_dim)
        x += pos_emb
        x, _ = self.attention(x, x, x)
        x = self.dropout(x)

        x += skip

        return x

class ConformerConv(nn.Module):
    def __init__(self, num_features, kernel_size, exp_factor=2, drop_rate=0.1):

        super(ConformerConv, self).__init__()

        self.layer_norm = nn.LayerNorm(num_features)
        self.point_conv_1 = nn.Conv1d(in_channels=num_features,
                                      out_channels=num_features*exp_factor,
                                      kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depth_conv = nn.Conv1d(in_channels=num_features,
                                    out_channels=num_features,
                                    kernel_size=kernel_size,
                                    padding=(kernel_size-1)//2,
                                    groups=num_features)
        
        self.batch_norm = nn.BatchNorm1d(num_features)
        self.swish = nn.SiLU()
        self.point_conv_2 = nn.Conv1d(in_channels=num_features,
                                      out_channels=num_features,
                                      kernel_size=1)
        self.dropout = nn.Dropout(p=drop_rate)

    def forward(self, x):

        skip = x # [batch_size, seq_len, num_features] 
        x = self.layer_norm(x)

        x = x.transpose(1, 2).contiguous() # [batch_size, num_features, seq_len] 

        x = self.point_conv_1(x)
        x = self.glu(x)
        x = self.depth_conv(x)
        x = self.batch_norm(x)
        x = self.swish(x)
        x = self.point_conv_2(x)
        x = self.dropout(x)

        x = x.transpose(1, 2).contiguous() # [batch_size, seq_len, num_features]
        x += skip

        return x

class ConformerFFN(nn.Module):

    def __init__(self, num_features, exp_factor=4, drop_rate=0.1):

        super(ConformerFFN, self).__init__()

        self.norm = nn.LayerNorm(num_features)

        self.linear_1 = nn.Linear(num_features, num_features*exp_factor)
        self.swish = nn.SiLU()
        self.dropout_1 = nn.Dropout(p=drop_rate)

        self.linear_2 = nn.Linear(num_features*exp_factor, num_features)
        self.dropout_2 = nn.Dropout(p=drop_rate)
    
    def forward(self, x):
        
        skip = x # [batch_size, seq_len, num_features]

        x = self.norm(x)
        
        x = self.linear_1(x) # [batch_size, seq_len, num_features * exp_factor]
        x = self.swish(x)
        x = self.dropout_1(x)

        x = self.linear_2(x) # [batch_size, seq_len, num_features]
        x = self.dropout_2(x)

        x = skip + 1/2 * x

        return x

class PostProcess(nn.Module):

    def __init__(self, encoder_dim, hidden_size, n_class):

        super(PostProcess, self).__init__()

        self.lstm = nn.LSTM(input_size=encoder_dim,
                            hidden_size=hidden_size,
                            num_layers=1,
                            batch_first=True)
        self.scoring = nn.Linear(in_features=hidden_size, out_features=n_class)
        
    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.scoring(x)
        return x


class ConvSubsampling(nn.Module):

    def __init__(self, out_channels):
        
        super(ConvSubsampling, self).__init__()
    
        self.sub_stack = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=out_channels,
                      kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels,
                      out_channels=out_channels,
                      kernel_size=3, stride=1, padding='same'),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.sub_stack(x)

        batch_size, channels, seq_len, num_features = x.size()
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, seq_len, channels*num_features)

        return x

class PreProcess(nn.Module):
    def __init__(self, in_features, encoder_dim, drop_rate=0.1):

        super(PreProcess, self).__init__()

        self.out_features = self.get_out_features(in_features)

        self.conv_sub = ConvSubsampling(out_channels=in_features)
        self.linear = nn.Linear(in_features=self.out_features, out_features=encoder_dim)
        self.dropout = nn.Dropout(p=drop_rate)
    
    def get_out_features(self, in_features):
        ans = in_features * in_features // 2
        return ans

    def forward(self, x):
        x = self.conv_sub(x)
        x = self.linear(x)
        x = self.dropout(x)
        return x

In [14]:
class ConformerBlock(nn.Module):
    def __init__(self, encoder_dim, num_heads, kernel_size, device):
        super(ConformerBlock, self).__init__()
        self.feed_forward_1 = ConformerFFN(num_features=encoder_dim)
        self.attention = ConformerMHSA(num_features=encoder_dim, device=device, num_heads=num_heads)
        self.convolution = ConformerConv(num_features=encoder_dim, kernel_size=kernel_size)
        self.feed_forward_2 = ConformerFFN(num_features=encoder_dim)
        self.norm = nn.LayerNorm(normalized_shape=encoder_dim)
    
    def forward(self, x):
        x = self.feed_forward_1(x)
        x = self.attention(x)
        x = self.convolution(x)
        x = self.feed_forward_2(x)
        x = self.norm(x)

        return x

class Conformer(nn.Module):
    def __init__(self,
                 in_features,
                 encoder_dim,
                 num_heads,
                 kernel_size,
                 hidden_size,
                 n_class,
                 n_blocks,
                 device):

        super(Conformer, self).__init__()

        self.pre_process = PreProcess(in_features=in_features, encoder_dim=encoder_dim)
        self.conformer_stack = nn.Sequential(
            *[ConformerBlock(encoder_dim=encoder_dim,
                             num_heads=num_heads,
                             kernel_size=kernel_size,
                             device=device)
              for _ in range(n_blocks)])
        self.post_process = PostProcess(encoder_dim=encoder_dim, hidden_size=hidden_size, n_class=n_class)

    def forward(self, x):
        """
        Input:  [batch_size, 1, seq_len, num_features]
        Output: [batch_size, seq_len, n_classes]
        """
        x = x.transpose(2, 3).contiguous()

        x = self.pre_process(x)
        x = self.conformer_stack(x)
        x = self.post_process(x)

        return x

In [15]:
cm3 = Conformer(in_features=128,
                      encoder_dim=256,
                      num_heads=4,
                      kernel_size=31,
                      hidden_size=320,
                      n_class=29,
                      n_blocks=16,
                      device=device).to(device)

tot_params = sum([p.numel() for p in cm3.parameters()])
print(f"Number of parameters: {tot_params}")

Number of parameters: 27362525


In [16]:
model_to_train = cm3
adamW = optim.AdamW(model_to_train.parameters(), lr)
ctc_loss = nn.CTCLoss(blank=word_encoding_model.blank_id).to(device)
one_cycle_lr = optim.lr_scheduler.OneCycleLR(adamW,
                                             max_lr=lr,
                                             steps_per_epoch=int(len(train_loader)),
                                             epochs=epochs,
                                             anneal_strategy="linear")

In [17]:
def train(epoch, dataset_loader, model, optimizer, scheduler, fn_loss):
    print(f"Traininig... (e={epoch})")
    
    # Train mode ON
    model.train()
    total_train_loss = 0  #tracking loss
    n_samples = int(len(dataset_loader.dataset))

    for idx, audio_data in enumerate(dataset_loader):
        
        # Get audio data with shape [batch, 1, n_features, seq_len]
        spectrograms, indices, len_spectrograms, len_indices = audio_data
        spectrograms, indices = spectrograms.to(device), indices.to(device)

        optimizer.zero_grad()

        # Forward pass
        out = model(spectrograms)
        out = F.log_softmax(out, dim=2)
        out = out.transpose(0, 1)
        
        # Backward pass
        loss = fn_loss(out, indices, len_spectrograms, len_indices)
        loss.backward()

        # Step
        optimizer.step()
        scheduler.step()

        total_train_loss += loss.item()
        

        # Log
        if idx % 20 == 0 or idx == n_samples:
            print("Epoch: {}, [{}/{}], Loss: {:.6f}".format(
                epoch, 
                idx*len(spectrograms), 
                n_samples,
                loss.item()))

    avg_train_loss = total_train_loss/n_samples
    return avg_train_loss
    

In [18]:
# Computes Word Error Rate
def compute_wer(hypothesis, reference):
    hypothesis_words = hypothesis.split()
    reference_words = reference.split()
    
    wer = Levenshtein.distance(hypothesis_words, reference_words) / len(reference_words)
    return wer

# Computes Character Error Rate
def compute_cer(hypothesis, reference):

    cer = Levenshtein.distance(hypothesis, reference) / len(reference)
    return cer

#     Calculates CER for each hyp-ref pair, and returns the average
def avg_cer(batch_hyp, batch_ref):
    batch_size = len(batch_ref)
    out = []
    for i in range(batch_size):
        out.append(compute_cer(batch_hyp[i], batch_ref[i]))
    
    return sum(out) / batch_size

#     Calculates WER for each hyp-ref pair, and returns the average
def avg_wer(batch_hyp, batch_ref):
    batch_size = len(batch_ref)
    out = []
    for i in range(batch_size):
        out.append(compute_wer(batch_hyp[i], batch_ref[i]))
    
    return sum(out) / batch_size

In [19]:
def decode_prob(prob, word_encoding_model):
    """
    Decodes (a batch of) log-probabilities
    into characters.
    prob -> shape (e.g.) [16, 650, 29]
    """
    prob = torch.transpose(prob, 0, 1)
    arg_maxes = torch.argmax(prob, dim=-1) # [16, 650]
    decodes = []

    for i, args in enumerate(arg_maxes):
        decode = []
        for j, index in enumerate(args):
             # ignore blank id
            if index == word_encoding_model.blank_id:
                continue
            # avoid repetitions
            if j != 0 and index == args[j-1]:
                continue
            decode.append(index.item())
        decodes.append(word_encoding_model.int_to_text(decode))
    return decodes


def decode_labels(indices, len_indices, word_encoding_model):
    """
    Decodes (a batch of) ids into characters.
    indices -> shape: [32, 300]
    len_indices -> shape: [32]
    word_model -> tool to convert idx into chars
    """
    out = []
    for i, ids in enumerate(indices):
        len_ids = len_indices[i]
        unpad_ids = ids[:len_ids]
        out.append(word_encoding_model.int_to_text(unpad_ids.tolist()))
    return out

In [20]:
def test(epoch, dataset_name, dataset_loader, model, optimizer, fn_loss, debug=False):
    print(f"Testing on {dataset_name} (epoch={epoch})")
    model.eval()

    total_loss = 0
    wer_list = []
    cer_list = []

    n_batch = int(len(dataset_loader))

    with torch.no_grad():
        for idx, audio_data in enumerate(dataset_loader):
        
            # Get audio data
            spectrograms, indices, len_spectrograms, len_indices = audio_data
            spectrograms, indices = spectrograms.to(device), indices.to(device)

            optimizer.zero_grad()
            
            # Forward pass
            out = model(spectrograms)
            out = F.log_softmax(out, dim=2)
            out = out.transpose(0, 1)

            # Compute loss
            loss = fn_loss(out, indices, len_spectrograms, len_indices)
            total_loss += loss.item() / n_batch

            # Metrics
            decode_hypothesis = decode_prob(out, word_encoding_model)
            decode_reference = decode_labels(indices, len_indices, word_encoding_model)

            wer_list.append(avg_wer(decode_hypothesis, decode_reference))
            cer_list.append(avg_cer(decode_hypothesis, decode_reference))
            
            
    print(f"Loss: {total_loss:.6f}")
    print(f"WER: {sum(wer_list)/len(wer_list):.4f}")
    print(f"CER: {sum(cer_list)/len(cer_list):.4f}")

    avg_test_wer = sum(wer_list) / len(wer_list)
    avg_test_cer = sum(cer_list) / len(cer_list)

    return total_loss, avg_test_wer, avg_test_cer

In [None]:
import time

train_loss = []
dev_loss = []
dev_wer = []
dev_cer = []
total_train_time=0
for epoch in range(1, epochs+1):
    #calculate train time
    training_start = time.time()
    
    train_epoch_loss = train(epoch, train_loader, model_to_train, adamW, one_cycle_lr, ctc_loss)

    training_end = time.time()

    training_time = (training_end - training_start)/60
    total_train_time+=training_time
    print(f"Epoch {epoch} training time: {training_time:.2f} minutes")

    train_loss.append(train_epoch_loss)

    #dev
    dev_epoch_loss, dev_epoch_wer, dev_epoch_cer = test(epoch, "dev-clean", dev_clean_loader, model_to_train, adamW, ctc_loss)
    dev_loss.append(dev_epoch_loss)
    dev_wer.append(dev_epoch_wer)
    dev_cer.append(dev_epoch_cer)
print("Total training time : ",total_train_time," minutes")
print("Average time taken per epoch : ",total_train_time/epochs," minutes")

Traininig... (e=1)




Epoch: 1, [0/28539], Loss: 5.770763
Epoch: 1, [160/28539], Loss: 4.938419
Epoch: 1, [320/28539], Loss: 3.548324
Epoch: 1, [480/28539], Loss: 3.191332
Epoch: 1, [640/28539], Loss: 3.118243
Epoch: 1, [800/28539], Loss: 2.987504
Epoch: 1, [960/28539], Loss: 3.000678
Epoch: 1, [1120/28539], Loss: 2.957337
Epoch: 1, [1280/28539], Loss: 2.912342
Epoch: 1, [1440/28539], Loss: 2.912927
Epoch: 1, [1600/28539], Loss: 2.861511
Epoch: 1, [1760/28539], Loss: 3.064484
Epoch: 1, [1920/28539], Loss: 3.000016
Epoch: 1, [2080/28539], Loss: 2.889499
Epoch: 1, [2240/28539], Loss: 2.921304
Epoch: 1, [2400/28539], Loss: 2.914713
Epoch: 1, [2560/28539], Loss: 2.902354
Epoch: 1, [2720/28539], Loss: 2.897294
Epoch: 1, [2880/28539], Loss: 2.906665
Epoch: 1, [3040/28539], Loss: 2.913656
Epoch: 1, [3200/28539], Loss: 2.869197
Epoch: 1, [3360/28539], Loss: 2.862186
Epoch: 1, [3520/28539], Loss: 2.870985
Epoch: 1, [3680/28539], Loss: 2.848508
Epoch: 1, [3840/28539], Loss: 2.857217
Epoch: 1, [4000/28539], Loss: 2.86

In [None]:
import matplotlib.pyplot as plt
plt.plot(range(1, epochs + 1), train_loss, label='Training Loss')
plt.plot(range(1, epochs + 1), dev_loss, label='Dev Loss', linestyle='--')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss curves for Training and Validation sets')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
#WER curve
plt.plot(range(1, epochs + 1), dev_wer, label='WER')
plt.xlabel('Epochs')
plt.ylabel('WER')
plt.title('WER curve for Validation set')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
#CER curve
plt.plot(range(1, epochs + 1), dev_cer, label='CER')
plt.xlabel('Epochs')
plt.ylabel('CER')
plt.title('CER curve for Validation set')
plt.legend()
plt.grid(True)
plt.show()