In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import string
from tqdm import tqdm
%pip install Levenshtein
import Levenshtein

%config Completer.use_jedi = False


In [None]:

class TextPreProcess:
    """Maps characters to integers and vice versa."""

    def __init__(self):
        self.vocab = ["''", "<SPACE>"] + list(string.ascii_lowercase)
        self.char_map = {ch: i for i, ch in enumerate(self.vocab)}
        self.index_map = {i: ch for ch, i in self.char_map.items()}
        self.index_map[self.char_map["<SPACE>"]] = " "

    def text_to_int(self, text: str):
        """Convert text -> integer sequence."""
        return [self.char_map.get(ch, self.char_map["<SPACE>"]) for ch in text.lower()]

    def int_to_text(self, labels: list[int]):
        """Convert integer sequence -> text."""
        return "".join(self.index_map[i] for i in labels if i in self.index_map)

    def int_to_text_remove_pad(self, labels: list[int]):
        """Remove trailing pad zeros, then convert."""
        while labels and labels[-1] == 0:
            labels.pop()
        return self.int_to_text(labels)

text_transform = TextPreProcess()


In [None]:

def levenshtein_distance(ref, hyp):
    distance = Levenshtein.distance(ref, hyp)
    return distance

def normalize_text(text, ignore_case = False, remove_space = False):
    if ignore_case:
        text = text.lower()
    if remove_space:
        text = ''.join(text.split())
    return text

def calculate_errors(reference, hypothesis, ignore_case=False, remove_space=False, delimiter=None):
    reference = normalize_text(reference, ignore_case, remove_space)
    hypothesis = normalize_text(hypothesis, ignore_case, remove_space)

    if delimiter:
        reference = reference.split(delimiter)
        hypothesis = hypothesis.split(delimiter)

    edit_distance = levenshtein_distance(reference, hypothesis)
    ref_len = len(reference)

    return float(edit_distance), ref_len

def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
    edit_distance, ref_len = calculate_errors(reference, hypothesis, ignore_case, False, delimiter)
    return edit_distance / ref_len

def cer(reference, hypothesis, ignore_case=False, remove_space=False):
    edit_distance, ref_len = calculate_errors(reference, hypothesis, ignore_case, remove_space)
    return edit_distance / ref_len


In [None]:

train_audio_transforms = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(sample_rate = 16000, n_mels = 128),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
    torchaudio.transforms.TimeMasking(time_mask_param=100)
)

valid_audio_transforms = torchaudio.transforms.MelSpectrogram()


pipeline_params = {
    'batch_size': 10,
    'epochs': 1,
    'learning_rate': 5e-4,
    'n_cnn_layers': 3, 
    'n_rnn_layers': 5,
    'rnn_dim': 512,
    'n_class': 29,
    'n_feats': 128,
    'stride': 2,
    'dropout': 0.1,
    'n_heads': 8,
    'n_transformer_layers': 2,
    'transformer_dim': 512
}

kwargs = {'num_workers': 1, 'pin_memory': True}


In [None]:

def data_processing(data, data_type = 'train'):
    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []
    
    audio_transforms = train_audio_transforms if data_type == 'train' else valid_audio_transforms
    
    for (waveform, _, utterance, _, _, _) in data:
        spec = audio_transforms(waveform).squeeze(0).transpose(0, 1) # (T, 128)
        spectrograms.append(spec)
        label = torch.Tensor(text_transform.text_to_int(utterance.lower()))
        labels.append(label)
        input_lengths.append(spec.shape[0] // 2)
        label_lengths.append(len(label))
        
    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first = True).unsqueeze(1).transpose(2,3)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first = True)
    return spectrograms, labels, input_lengths, label_lengths


In [None]:

class CNNLayerNorm(nn.Module):
    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layernorm = nn.LayerNorm(n_feats)

    def forward(self, x):
        # BATCH * CHANNEL * FEATURE * TIME
        x = x.transpose(2, 3).contiguous()
        x = self.layernorm(x)
        return x.transpose(2, 3).contiguous()

class ResidualCNN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
        super(ResidualCNN, self).__init__()
        self.cnn1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel,
            stride=stride,
            padding=kernel//2,
        )
        self.cnn2 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel,
            stride=stride,
            padding=kernel//2,
        )
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x

class BidirectionalGRU(nn.Module):
    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()
        self.BiGRU = nn.GRU(
            input_size=rnn_dim,
            hidden_size=hidden_size,
            batch_first=batch_first,
            bidirectional=True
        )
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfAttentionHead(nn.Module):
  """
  This is the self attention head
  We want to be able to concat a bunch of these in the same dimension (unlike stack) to produce MultiHeadAttention
  """
  def __init__(self, head_size, n_embd, dropout):
    # technically head_size = n_embd / n_heads
    super().__init__()
    self.head_size = head_size
    self.n_embd = n_embd

    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)

    self.dropout = nn.Dropout(dropout)
    
  def forward(self, param):
    # param = batch_size * tokens * embedding length [B * T * C]
    # C = count of how many numbers represent each token. ["Hello", "World"] -> [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]
    B, T, C = param.shape
    key = self.key(param)
    query = self.query(param)
    value = self.value(param)

    # now time for the maths formula
    tmp = query @ key.transpose(-2, -1) * key.shape[-1] ** -0.5
    
    # Create causal mask dynamically based on actual sequence length
    tril = torch.tril(torch.ones(T, T, device=param.device))
    tmp = tmp.masked_fill(tril == 0, float('-inf'))
    tmp = F.softmax(tmp, dim=-1)
    tmp = self.dropout(tmp)

    output = tmp @ value
    return output

In [None]:
import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
  """
  Multihead attention block, needs params num_heads to define how many
  SelfAttentionHead we actually need.
  Concatenating all of their values
  Finallu projecting them linearly back into the n_embd size so that they can be used to propagate further
  """
  def __init__(self, num_heads, head_size, n_embd, dropout):
    super().__init__()
    self.num_heads = num_heads
    self.head_size = head_size
    self.n_embd = n_embd
    
    self.heads = nn.ModuleList([SelfAttentionHead(head_size, n_embd, dropout) for _ in range(num_heads)])
    self.proj = nn.Linear(head_size * num_heads, n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self, param):
    output = torch.cat([h(param) for h in self.heads], dim=-1)
    return self.dropout(self.proj(output))

In [None]:
# Transformer-Enhanced Speech Recognition Model
class SpeechRecognitionModel(nn.Module):
    def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, 
                 stride=2, dropout=0.1, n_heads=8, n_transformer_layers=2, transformer_dim=512):
        super(SpeechRecognitionModel, self).__init__()
        n_feats = n_feats // 2
        
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3//2)
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats)
            for _ in range(n_cnn_layers)
        ])

        self.fully_connected = nn.Linear(n_feats*32, rnn_dim)
        
        self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
            for i in range(n_rnn_layers)
        ])

        self.transformer_projection = nn.Linear(rnn_dim*2, transformer_dim)
        
        head_size = transformer_dim // n_heads
        self.multi_head_attention_block = MultiHeadAttention(
            num_heads=n_heads, 
            head_size=head_size, 
            n_embd=transformer_dim, 
            dropout=dropout
        )

        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=transformer_dim,
                nhead=n_heads,
                dim_feedforward=transformer_dim*4,
                dropout=dropout,
                batch_first=True
            )
            for _ in range(n_transformer_layers)
        ])

        self.classifier = nn.Sequential(
            nn.Linear(transformer_dim, rnn_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        
        B, C, F, T = x.size()
        x = x.view(B, C*F, T)  # B*F*T
        x = x.transpose(1, 2)  # B*T*F
        x = self.fully_connected(x)
        
        x = self.birnn_layers(x)  # Shape: (B, T, rnn_dim*2)
        
        x = self.transformer_projection(x)  # Shape: (B, T, transformer_dim)
        
        x = self.multi_head_attention_block(x)  # Shape: (B, T, transformer_dim)
        
        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x)  # Shape: (B, T, transformer_dim)
        
        x = self.classifier(x)  # Shape: (B, T, n_class)
        
        return x


In [None]:

def decode(outputs, blank=28):
    """Decode CTC output to text"""
    _, max_indices = torch.max(outputs, dim=2)
    output = []
    for idx, indexes in enumerate(max_indices.transpose(0, 1)):
        prev = -1
        res = []
        for index in indexes:
            if index not in [prev, blank]:
                res.append(index.item())
            prev = index
        output.append(res)
    return output


def validate(model, validation_loader, criterion, device):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        all_predicted_texts = []
        all_true_texts = []
        for batch_idx, _data in tqdm(enumerate(validation_loader)):
            spectrograms, labels, input_lengths, label_lengths = _data
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            
            output = model(spectrograms)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1)  # Needed for CTCLoss
            loss = criterion(output, labels, input_lengths, label_lengths)
            total_loss += loss.item()
            
            decoded_outputs = decode(output)
            predicted_texts = [text_transform.int_to_text(seq) for seq in decoded_outputs]
            true_texts = [text_transform.int_to_text(label.tolist()) for label in labels]
            
            all_predicted_texts.extend(predicted_texts)
            all_true_texts.extend(true_texts)
            
        avg_loss = total_loss / len(validation_loader)
        avg_wer = np.mean([wer(ref, hyp) for ref, hyp in zip(all_true_texts, all_predicted_texts)])
        avg_cer = np.mean([cer(ref, hyp) for ref, hyp in zip(all_true_texts, all_predicted_texts)])
        
        print(f"Validation Loss: {avg_loss}")
        print(f"Average WER: {avg_wer:.4f}")
        print(f"Average CER: {avg_cer:.4f}")
        
        return avg_loss, avg_wer, avg_cer


In [None]:

root_path = '/kaggle/input/librispeech-clean'
print("Note: Make sure to have the dataset in the input section of Kaggle")


train_dataset = torchaudio.datasets.LIBRISPEECH(root_path, url="train-clean-100", download=False)
test_dataset = torchaudio.datasets.LIBRISPEECH(root_path, url="test-clean", download=False)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")


waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id = train_dataset[0]
plt.figure(figsize=(15, 5))
plt.plot(waveform.t().numpy())
ln = len(utterance)
plt.title(f"This is the waveform graph for the text:\n {utterance[:ln//2-1]}\n{utterance[ln//2-1:]}")
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.show()


In [None]:
# Create data loaders
train_loader = data.DataLoader(dataset=train_dataset,
                              batch_size=pipeline_params['batch_size'],
                              shuffle=True,
                              collate_fn=lambda x: data_processing(x, 'train'),
                              **kwargs)

test_loader = data.DataLoader(dataset=test_dataset,
                             batch_size=pipeline_params['batch_size'],
                             shuffle=True,
                             collate_fn=lambda x: data_processing(x, 'valid'),
                             **kwargs)

print(f"Train loader batches: {len(train_loader)}")
print(f"Test loader batches: {len(test_loader)}")


In [None]:

optimizer = optim.AdamW(model.parameters(), pipeline_params["learning_rate"])
criterion = nn.CTCLoss(blank=28).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=pipeline_params["learning_rate"],
                                        steps_per_epoch=int(len(train_loader)),
                                        epochs=pipeline_params["epochs"],
                                        anneal_strategy="linear")

print("Training setup complete!")
print(f"Optimizer: {type(optimizer).__name__}")
print(f"Criterion: {type(criterion).__name__}")
print(f"Scheduler: {type(scheduler).__name__}")


In [None]:

data_len = len(train_loader.dataset)
logging_idx = 0
logging_freq = 100

print("Starting training...")
for epoch in range(pipeline_params["epochs"]):
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), 
                       desc=f"Epoch {epoch+1}/{pipeline_params['epochs']}", unit="batches")
    
    for batch_idx, _data in progress_bar:
        spectrograms, labels, input_lengths, label_lengths = _data
        spectrograms, labels = spectrograms.to(device), labels.to(device)

        optimizer.zero_grad()

        output = model(spectrograms)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1)

        loss = criterion(output, labels, input_lengths, label_lengths)
        loss.backward()

        optimizer.step()
        scheduler.step()
        
        if logging_idx % logging_freq == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                epoch, 
                batch_idx * len(spectrograms), 
                data_len, 100. * batch_idx / len(train_loader), 
                loss.item()
            ))
        logging_idx += 1
        progress_bar.set_postfix({'loss': loss.item()})

print("Training completed!")


In [None]:

model_path = '/kaggle/working/transformer_speech_recognition_model.pth'
torch.save(model.state_dict(), model_path)
print(f"Model saved to: {model_path}")

# Also save as .pt format for compatibility
model_path_pt = '/kaggle/working/transformer_speech_recognition_model.pt'
torch.save(model.state_dict(), model_path_pt)
print(f"Model also saved to: {model_path_pt}")


In [None]:

print("Running validation on test set...")
validate(model, test_loader, criterion, device)


In [None]:
# Model loading and inference functions
def load_model(model_path, device):
    """Load the transformer-enhanced model from saved state dict"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    loaded_model = SpeechRecognitionModel(
        pipeline_params['n_cnn_layers'],
        pipeline_params['n_rnn_layers'],
        pipeline_params['rnn_dim'],
        pipeline_params['n_class'],
        pipeline_params['n_feats'],
        pipeline_params['stride'],
        pipeline_params['dropout'],
        pipeline_params['n_heads'],
        pipeline_params['n_transformer_layers'],
        pipeline_params['transformer_dim']
    ).to(device)
    
    loaded_model.load_state_dict(torch.load(model_path, map_location=device))
    loaded_model.eval()
    
    return loaded_model

def speech_to_text(audio_path, model, device, text_transform, valid_audio_transform):
    """Convert speech audio file to text using the transformer-enhanced model"""
    waveform, _ = torchaudio.load(audio_path)
    waveform = valid_audio_transform(waveform)
    waveform = waveform.unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(waveform)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1)
        decoded_output = decode(output)
        predicted_text = text_transform.int_to_text(decoded_output[0])
    
    return predicted_text

print("Model loading and inference functions defined!")


In [None]:
model.load_state_dict(torch.load('/kaggle/working/transformer_speech_recognition_model.pt'))
text_transform = TextPreProcess()
valid_audio_transforms = torchaudio.transforms.MelSpectrogram()

audio_path = '/kaggle/input/librispeech/LibriSpeech/dev-clean/174/50561/174-50561-0002.wav'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
predicted_text = speech_to_text(audio_path, model, device, text_transform, valid_audio_transforms)
print(f"Predicted Text: {predicted_text}")