In [2]:
!pip install jiwer

Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading rapidfuzz-3.14.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m51.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-4.0.0 rapidfuzz-3.14.1


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os
from jiwer import wer, cer
import kagglehub

def padcollate(batch):
    spectrograms, transcripts, inputlengths, targetlengths = [], [], [], []
    for waveform, transcript, inputlength, targetlength in batch:
        spectrograms.append(waveform.squeeze(0).transpose(0, 1))
        transcripts.append(transcript)
        inputlengths.append(inputlength)
        targetlengths.append(targetlength)
    paddedspecs = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
    paddedtranscripts = nn.utils.rnn.pad_sequence(transcripts, batch_first=True)
    return paddedspecs, paddedtranscripts, torch.tensor(inputlengths), torch.tensor(targetlengths)

class SpeechDataset(Dataset):
    def __init__(self, manifestfile, audiopath, charmap):
        self.audiolist = pd.read_csv(manifestfile, sep='|', header=None, quoting=3)
        self.audiopath = audiopath
        self.charmap = charmap

        melkwargs = {
            'n_fft': 1024,
            'win_length': 1024,
            'hop_length': 256,
            'n_mels': 80
        }
        self.featuremaker = torchaudio.transforms.MFCC(
            sample_rate=22050,
            n_mfcc=80,
            melkwargs=melkwargs
        )

    def __len__(self):
        return len(self.audiolist)
    def __getitem__(self, index):
        audioname = self.audiolist.iloc[index, 0]
        transcripttext = self.audiolist.iloc[index, 2]
        audiopath = os.path.join(self.audiopath, f"{audioname}.wav")
        waveform, samplerate = torchaudio.load(audiopath)
        features = self.featuremaker(waveform)
        encodedtext = [self.charmap[char] for char in transcripttext.lower() if char in self.charmap]
        inputlength = features.shape[1]
        targetlength = len(encodedtext)
        return features, torch.tensor(encodedtext), inputlength, targetlength

class SpeechModel(nn.Module):
    def __init__(self, numfeatures, numclasses):
        super(SpeechModel, self).__init__()
        self.lstm = nn.LSTM(input_size=numfeatures, hidden_size=512, num_layers=3, bidirectional=True, batch_first=True)
        self.classifier = nn.Linear(1024, numclasses)
    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.classifier(x)
        return x

def greedydecoder(output, indexmap, blanklabel='-'):
    arg_maxes = torch.argmax(output, dim=2)
    decodedtexts = []
    for i, args in enumerate(arg_maxes):
        decoded = []
        for j, index in enumerate(args):
            if index != charmap[blanklabel]:
                if j == 0 or index != args[j-1]:
                    decoded.append(index.item())
        decodedtexts.append("".join([indexmap[c] for c in decoded if c in indexmap]))
    return decodedtexts

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"--- Using device: {device} ---")

    print("Downloading the LJ Speech dataset...")
    path = kagglehub.dataset_download("mathurinache/the-lj-speech-dataset")
    print("Dataset downloaded to:", path)

    basepath = os.path.join(path, "LJSpeech-1.1")
    metadatapath = os.path.join(basepath, "metadata.csv")
    audiopath = os.path.join(basepath, "wavs")

    characters = "'-abcdefghijklmnopqrstuvwxyz "
    charmap = {char: i for i, char in enumerate(characters)}
    indexmap = {i: char for i, char in enumerate(characters)}

    dataset = SpeechDataset(metadatapath, audiopath, charmap)
    traindata = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=padcollate, pin_memory=True, num_workers=2)

    model = SpeechModel(numfeatures=80, numclasses=len(characters)).to(device)
    lossfunction = nn.CTCLoss(blank=charmap['-'], zero_infinity=True).to(device) # Added zero_infinity
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    epochs = 1
    print("\n--- Starting model training (QUICK TEST)... ---")
    for epoch in range(epochs):
        model.train()
        for i, data in enumerate(traindata):
            if i > 200:
                print("...Stopping training early for this quick test...")
                break

            spectrograms, transcripts, inputlengths, targetlengths = data
            spectrograms, transcripts = spectrograms.to(device), transcripts.to(device)
            optimizer.zero_grad()
            output = model(spectrograms)
            output = nn.functional.log_softmax(output, dim=2).transpose(0, 1)
            loss = lossfunction(output, transcripts, inputlengths, targetlengths)
            loss.backward()


            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)

            optimizer.step()
            if i % 50 == 0:
                print(f"  Epoch {epoch+1}, Batch {i}, Current Loss: {loss.item():.4f}")

    print("\n--- Starting evaluation... ---")
    model.eval()
    allpredictions, alltargets = [], []
    with torch.no_grad():
        for i, data in enumerate(traindata):
            if i > 20:
                break
            spectrograms, transcripts, _, _ = data
            spectrograms = spectrograms.to(device)
            output = model(spectrograms)
            predictions = greedydecoder(output.cpu(), indexmap)
            allpredictions.extend(predictions)
            targets = ["".join([indexmap[c.item()] for c in text]) for text in transcripts]
            alltargets.extend(targets)

    worderror = wer(alltargets, allpredictions)
    charerror = cer(alltargets, allpredictions)
    print("\n--- Quick Test Evaluation Complete ---")
    print(f"Sample Prediction: '{allpredictions[0]}'")
    print(f"Sample Target:     '{alltargets[0]}'")
    print(f"\nWord Error Rate (WER): {worderror:.4f}")
    print(f"Character Error Rate (CER): {charerror:.4f}")

--- Using device: cuda ---
Downloading the LJ Speech dataset...
Using Colab cache for faster access to the 'the-lj-speech-dataset' dataset.
Dataset downloaded to: /kaggle/input/the-lj-speech-dataset

--- Starting model training (QUICK TEST)... ---


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


  Epoch 1, Batch 0, Current Loss: 0.4750
  Epoch 1, Batch 50, Current Loss: 1.1134
  Epoch 1, Batch 100, Current Loss: 1.1374
  Epoch 1, Batch 150, Current Loss: 1.3784
  Epoch 1, Batch 200, Current Loss: 0.5944
...Stopping training early for this quick test...

--- Starting evaluation... ---


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)
  s = torchaudio.io.StreamReader(src, format, None, buffer_size)



--- Quick Test Evaluation Complete ---
Sample Prediction: ''
Sample Target:     'it was particularly recommended by the committee on jails in eighteen fourteen''''''''''''''''''''''''''''''''''''''''''''''''''''

Word Error Rate (WER): 1.0000
Character Error Rate (CER): 1.0000
