In [2]:
pip install Levenshtein

Collecting Levenshtein
  Downloading levenshtein-0.26.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.2 kB)
Collecting rapidfuzz<4.0.0,>=3.9.0 (from Levenshtein)
  Downloading rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading levenshtein-0.26.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.6/162.6 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m45.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, Levenshtein
Successfully installed Levenshtein-0.26.1 rapidfuzz-3.10.1


In [3]:

import Levenshtein as Lev
import torch


class Decoder(object):
    """
    Basic decoder class from which all other decoders inherit. Implements several
    helper functions. Subclasses should implement the decode() method.

    Arguments:
        labels (list): mapping from integers to characters
        blank_index (int, optional): index for the blank '_' character. Defaults to 0.
        spece_index (int, optional): index for the space ' ' character. Defaults to 28.
    """

    def __init__(self, labels, blank_index=0):
        self.labels = labels
        self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
        self.blank_index = blank_index
        space_index = len(labels) # To prevent errors in decode, we add an out of bounds index for the space
        if ' ' in labels:
            space_index = labels.index(' ')
        self.space_index = space_index

    def wer(self, s1, s2):


    # Tokenize the sentences into words
         words1 = s1.split()
         words2 = s2.split()

    # Calculate the Levenshtein distance directly on word sequences
         return Lev.distance(' '.join(words1), ' '.join(words2)) / float(len(words1) or 1)


    def cer(self, s1, s2):
        """
        Computes the character Error Rate, defined as the edit distace.

        Arguments:
            s1 (string): space-separated sentence
            s2 (string): space-separated sentence
        """
        s1, s2 = s1.replace(' ', ''), s2.replace(' ', '')
        if len(s1) == 0 and len(s2) == 0:
            return 0.0  # No error if both strings are empty
        return Lev.distance(s1, s2)

    def decode(self, probs, sizes=None):
        """
        Given a matrix of character probabilities, returns the decoder's
        best guess of the transcription
        Arguments:
            probs: Tensor of character probabilities, where probs[c,t]
                            is the probability of character c at time t
            sizes(optional): Size of each sequence in the mini-batch
        Returns:
            string: sequence of the model's best guess for the transcription
        """
        raise NotImplementedError

class BeamCTCDecoder(Decoder):
    def __init__(self, labels, lm_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100, num_processes=4, blank_index=0):
        super(BeamCTCDecoder, self).__init__(labels)
        try:
            from ctcdecode import CTCBeamDecoder
        except ImportError:
            raise ImportError("BeamCTCDecoder requires paddledecoder package.")
        labels = list(labels)  # Ensure labels are a list before passing to decoder
        self._decoder = CTCBeamDecoder(labels, lm_path, alpha, beta, cutoff_top_n,  cutoff_prob, beam_width,
                                       num_processes, blank_index)

    def convert_to_strings(self, out, seq_len):
        results = []
        for b, batch in enumerate(out):
            utterances = []
            for p, utt in enumerate(batch):
                size = seq_len[b][p]
                if size > 0:
                    transcript = ''.join(map(lambda x: self.int_to_char[x.item()], utt[0:size]))
                else:
                    transcript = ''
                utterances.append(transcript)
            results.append(utterances)
        return results

    def convert_tensor(self, offsets, sizes):
        results = []
        for b, batch in enumerate(offsets):
            utterances = []
            for p, utt in enumerate(offsets):
                size = sizes[b][p]
                if sizes[b][p] > 0:
                    utterances.append(utt[0:size])
                else:
                    utterances.append(torch.tensor([], dtype=torch.int))
            results.append(utterances)
        return results

    def decode(self, probs, sizes=None):
        """
        Decodes probability output using ctcdecode package.
        Arguments:
            probs: Tensor of character probabilities, where probs[c,t]
                            is the probability of character c at time t
            sizes: Size of each sequence in the mini-batch
        Returns:
            string: sequences of the model's best guess for the transcription
        """
        probs = probs.cpu()
        out, scores, offsets, seq_lens = self._decoder.decode(probs, sizes)

        strings = self.convert_to_strings(out, seq_lens)
        offsets = self.convert_tensor(offsets, seq_lens)
        return strings, offsets


class GreedyDecoder(Decoder):
    def __init__(self, labels, blank_index=0):
        super(GreedyDecoder, self).__init__(labels, blank_index)

    def convert_to_strings(self, sequences, sizes=None, remove_repetitions=False, return_offsets=False):
        """Given a list of numeric sequences, returns the corresponding strings"""
        strings = []
        offsets = [] if return_offsets else None
        for x in range(len(sequences)):
            seq_len = sizes[x] if sizes is not None else len(sequences[x])
            string, string_offsets = self.process_string(sequences[x], seq_len, remove_repetitions)
            strings.append([string])  # We only return one path
            if return_offsets:
                offsets.append([string_offsets])
        if return_offsets:
            return strings, offsets
        else:
            return strings

    def process_string(self, sequence, size, remove_repetitions=False):
        string = ''
        offsets = []
        for i in range(size):
            char = self.int_to_char[sequence[i].item()]
            if char != self.int_to_char[self.blank_index]:
                # if this char is a repetition and remove_repetitions=true, then skip
                if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]:
                    pass
                elif char == self.labels[self.space_index]:
                    string += ' '
                    offsets.append(i)
                else:
                    string = string + char
                    offsets.append(i)
        return string, torch.tensor(offsets, dtype=torch.int)

    def decode(self, probs, sizes=None):
        """
        Returns the argmax decoding given the probability matrix. Removes
        repeated elements in the sequence, as well as blanks.
        Arguments:
            probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim
            sizes(optional): Size of each sequence in the mini-batch
        Returns:
            strings: sequences of the model's best guess for the transcription on inputs
            offsets: time step per character predicted
        """
        _, max_probs = torch.max(probs, 2)
        strings, offsets = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), sizes,
                                                   remove_repetitions=True, return_offsets=True)
        return strings, offsets


In [4]:
!apt-get update
!apt-get install -y swig libboost-all-dev


0% [Working]            Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
0% [Connecting to archive.ubuntu.com (185.125.190.81)] [Connecting to security.ubuntu.com (91.189.910% [Connecting to archive.ubuntu.com (185.125.190.81)] [Connecting to security.ubuntu.com (91.189.91                                                                                                    Get:2 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:4 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Hit:5 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:6 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [1,194 kB]
Get:7 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:8 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [2,458 kB]
Hit:9 https://ppa.launch

In [5]:
!rm -rf ctcdecode
!git clone --recursive https://github.com/parlance/ctcdecode.git
!cd ctcdecode && pip install .


Cloning into 'ctcdecode'...
remote: Enumerating objects: 1102, done.[K
remote: Counting objects: 100% (39/39), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 1102 (delta 16), reused 32 (delta 14), pack-reused 1063 (from 1)[K
Receiving objects: 100% (1102/1102), 782.27 KiB | 10.57 MiB/s, done.
Resolving deltas: 100% (529/529), done.
Submodule 'third_party/ThreadPool' (https://github.com/progschj/ThreadPool.git) registered for path 'third_party/ThreadPool'
Submodule 'third_party/kenlm' (https://github.com/kpu/kenlm.git) registered for path 'third_party/kenlm'
Cloning into '/content/ctcdecode/third_party/ThreadPool'...
remote: Enumerating objects: 82, done.        
remote: Counting objects: 100% (26/26), done.        
remote: Compressing objects: 100% (9/9), done.        
remote: Total 82 (delta 19), reused 17 (delta 17), pack-reused 56 (from 1)        
Receiving objects: 100% (82/82), 13.34 KiB | 2.67 MiB/s, done.
Resolving deltas: 100% (36/36), done.
Clonin

In [6]:
!pip install ctcdecode-tensorflow

[31mERROR: Could not find a version that satisfies the requirement ctcdecode-tensorflow (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for ctcdecode-tensorflow[0m[31m
[0m

In [7]:
pip install git+https://github.com/parlance/ctcdecode.git@master


Collecting git+https://github.com/parlance/ctcdecode.git@master
  Cloning https://github.com/parlance/ctcdecode.git (to revision master) to /tmp/pip-req-build-0ccrsi0_
  Running command git clone --filter=blob:none --quiet https://github.com/parlance/ctcdecode.git /tmp/pip-req-build-0ccrsi0_
  Resolved https://github.com/parlance/ctcdecode.git to commit c90ad94a0b19554f80804fb7812f2a1447a34a70
  Running command git submodule update --init --recursive -q
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: ctcdecode
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for ctcdecode (setup.py) ... [?25lerror
[31m  ERROR: Failed building wheel for ctcdecode[0m[31m
[0m[?25h 

In [8]:
!apt-get install -y sox

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libopencore-amrnb0 libopencore-amrwb0 libsox-fmt-alsa libsox-fmt-base libsox3 libwavpack1
Suggested packages:
  libsox-fmt-all
The following NEW packages will be installed:
  libopencore-amrnb0 libopencore-amrwb0 libsox-fmt-alsa libsox-fmt-base libsox3 libwavpack1 sox
0 upgraded, 7 newly installed, 0 to remove and 58 not upgraded.
Need to get 617 kB of archives.
After this operation, 1,764 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libopencore-amrnb0 amd64 0.1.5-1 [94.8 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libopencore-amrwb0 amd64 0.1.5-1 [49.1 kB]
Get:3 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 libsox3 amd64 14.4.2+git20190427-2+deb11u2ubuntu0.22.04.1 [240 kB]
Get:4 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 libs

In [9]:
import math
import os
import librosa
import numpy as np
import scipy.signal
import soundfile as sf
import torch

from torch.utils.data import Dataset, Sampler, DistributedSampler, DataLoader


LABELS = [
  "_",
  "'",
  "A",
  "B",
  "C",
  "D",
  "E",
  "F",
  "G",
  "H",
  "I",
  "J",
  "K",
  "L",
  "M",
  "N",
  "O",
  "P",
  "Q",
  "R",
  "S",
  "T",
  "U",
  "V",
  "W",
  "X",
  "Y",
  "Z",
  " "
]


windows = {
    'hamming': lambda N: scipy.signal.get_window('hamming', N),  # Use get_window function
    'hann': lambda N: scipy.signal.get_window('hann', N),
    'blackman': lambda N: scipy.signal.get_window('blackman', N),
    'bartlett': lambda N: scipy.signal.get_window('bartlett', N)
}


def load_audio(path):
    sound, sample_rate = sf.read(path, dtype='int16')
    # TODO this should be 32768.0 to get twos-complement range.
    # TODO the difference is negligible but should be fixed for new models.
    sound = sound.astype('float32') / 32767  # normalize audio
    if len(sound.shape) > 1:
        if sound.shape[1] == 1:
            sound = sound.squeeze()
        else:
            sound = sound.mean(axis=1)  # multiple channels, average
    return sound

class SpeechDataset:
    def __init__(self, args, df):

        self.args = args
        self.audio_path = df.audio_path.values.tolist()
        self.transcript_path = df.txt_path.values.tolist()
        self.labels_map = dict([(LABELS[i], i) for i in range(len(LABELS))])

    def __getitem__(self, item):
        audio_path = self.audio_path[item]
        transcript_path = self.transcript_path[item]
        try:
            spect = self.parse_audio(audio_path)
            transcript = self.parse_transcript(transcript_path)
        except Exception as e:
            print(f"Error loading item {item}: {audio_path}, {transcript_path}")
            raise e

        return spect, transcript


    def __len__(self):
        return len(self.audio_path)

    def parse_audio(self, audio_path):

        y = load_audio(audio_path)

        n_fft = int(self.args.sample_rate * self.args.window_size)
        win_length = n_fft
        hop_length = int(self.args.sample_rate * self.args.window_stride)
        # STFT
        D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length,
                         win_length=win_length, window=self.args.window)
        spect, phase = librosa.magphase(D)
        # S = log(S+1)
        spect = np.log1p(spect)
        spect = torch.FloatTensor(spect)
        if self.args.normalize:
            mean = spect.mean()
            std = spect.std()
            spect.add_(-mean)
            spect.div_(std)

        return spect

    def parse_transcript(self, transcript_path):
        with open(transcript_path, 'r', encoding='utf8') as transcript_file:
            transcript = transcript_file.read().replace('\n', '')
        transcript = list(filter(None, [self.labels_map.get(x) for x in list(transcript)]))
        return transcript



def _collate_fn(batch):
    def func(p):
        return p[0].size(1)

    batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
    longest_sample = max(batch, key=func)[0]
    freq_size = longest_sample.size(0)
    minibatch_size = len(batch)
    max_seqlength = longest_sample.size(1)
    inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength)
    input_percentages = torch.FloatTensor(minibatch_size)
    target_sizes = torch.IntTensor(minibatch_size)
    targets = []
    for x in range(minibatch_size):
        sample = batch[0]
        tensor = sample[0]
        target = sample[1]
        seq_length = tensor.size(1)
        inputs[x][0].narrow(1, 0, seq_length).copy_(tensor)
        input_percentages[x] = seq_length / float(max_seqlength)
        target_sizes[x] = len(target)
        targets.extend(target)
    targets = torch.IntTensor(targets)
    return inputs, targets, input_percentages, target_sizes

class AudioDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        """
        Creates a data loader for AudioDatasets.
        """
        super(AudioDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = _collate_fn



In [10]:
class args:

    # training
    seed = 42
    epochs = 100
    early_stopping_patience = 5  # Decreased patience for early stopping

    # data
    train_path = r"/content/drive/MyDrive/datasets/train_manifest.xlsx"
    test_path = r"/content/drive/MyDrive/datasets/test_manifest.xlsx"
    sample_rate = 16000
    batch_size = 32
    num_workers = 12
    window_size = .002
    window_stride = .01
    window = "hamming"
    normalize = True
    pin_memory=True

    # model
    rnn_type = "lstm"
    hidden_size = 1024
    hidden_layers = 5
    bidirectional = True
    dropout_rate = 0.5 # Added dropout to RNN layers

    # optimizer
    learning_rate = 0.001
    weight_decay = 1e-4
    momentum = 0.9
    eps = 1e-8
    betas = (0.9, 0.999)
    max_norm = 100  # Gradient clipping threshold
    learning_anneal = 1.05


    lr_scheduler = "ReduceLROnPlateau"  # Changed to a scheduler that reduces lr based on plateau

    # Early Stopping
    early_stopping_patience = 5  # Decreased patience for early stopping

    # Gradient Accumulation
    accumulation_steps = 4  # Used to simulate a larger batch size

    # Regularization
    max_grad_norm = 1  # Gradient clipping value


In [11]:
import math
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

supported_rnns = {
    'lstm': nn.LSTM,
    'rnn': nn.RNN,
    'gru': nn.GRU
}
supported_rnns_inv = dict((v, k) for k, v in supported_rnns.items())


class SequenceWise(nn.Module):
    def __init__(self, module):
        """
        Collapses input of dim T*N*H to (T*N)*H, and applies to a module.
        Allows handling of variable sequence lengths and minibatch sizes.
        :param module: Module to apply input to.
        """
        super(SequenceWise, self).__init__()
        self.module = module

    def forward(self, x):
        t, n = x.size(0), x.size(1)
        x = x.view(t * n, -1)
        x = self.module(x)
        x = x.view(t, n, -1)
        return x

    def __repr__(self):
        tmpstr = self.__class__.__name__ + ' (\n'
        tmpstr += self.module.__repr__()
        tmpstr += ')'
        return tmpstr


class MaskConv(nn.Module):
    def __init__(self, seq_module):
        """
        Adds padding to the output of the module based on the given lengths. This is to ensure that the
        results of the model do not change when batch sizes change during inference.
        Input needs to be in the shape of (BxCxDxT)
        :param seq_module: The sequential module containing the conv stack.
        """
        super(MaskConv, self).__init__()
        self.seq_module = seq_module

    def forward(self, x, lengths):
        """
        :param x: The input of size BxCxDxT
        :param lengths: The actual length of each sequence in the batch
        :return: Masked output from the module
        """
        for module in self.seq_module:
            x = module(x)
            mask = torch.BoolTensor(x.size()).fill_(0)
            if x.is_cuda:
                mask = mask.cuda()
            for i, length in enumerate(lengths):
                length = length.item()
                if (mask[i].size(2) - length) > 0:
                    mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1)
            x = x.masked_fill(mask, 0)
        return x, lengths


class InferenceBatchSoftmax(nn.Module):
    def forward(self, input_):
        if not self.training:
            return F.softmax(input_, dim=-1)
        else:
            return input_


class BatchRNN(nn.Module):
    def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True):
        super(BatchRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional
        self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None
        self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size,
                            bidirectional=bidirectional, bias=True)
        self.num_directions = 2 if bidirectional else 1

    def flatten_parameters(self):
        self.rnn.flatten_parameters()

    def forward(self, x, output_lengths):
        if self.batch_norm is not None:
            x = self.batch_norm(x)
        x = nn.utils.rnn.pack_padded_sequence(x, output_lengths)
        x, h = self.rnn(x)
        x, _ = nn.utils.rnn.pad_packed_sequence(x)
        if self.bidirectional:
            x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1)  # (TxNxH*2) -> (TxNxH) by sum
        return x


class Lookahead(nn.Module):
    # Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks
    # input shape - sequence, batch, feature - TxNxH
    # output shape - same as input
    def __init__(self, n_features, context):
        super(Lookahead, self).__init__()
        assert context > 0
        self.context = context
        self.n_features = n_features
        self.pad = (0, self.context - 1)
        self.conv = nn.Conv1d(self.n_features, self.n_features, kernel_size=self.context, stride=1,
                              groups=self.n_features, padding=0, bias=None)

    def forward(self, x):
        x = x.transpose(0, 1).transpose(1, 2)
        x = F.pad(x, pad=self.pad, value=0)
        x = self.conv(x)
        x = x.transpose(1, 2).transpose(0, 1).contiguous()
        return x

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'n_features=' + str(self.n_features) \
               + ', context=' + str(self.context) + ')'


class DeepSpeech(nn.Module):
    def __init__(self, rnn_type, labels, rnn_hidden_size, nb_layers, audio_conf,
                 bidirectional, context=20):
        super(DeepSpeech, self).__init__()

        self.hidden_size = rnn_hidden_size
        self.hidden_layers = nb_layers
        self.rnn_type = rnn_type
        self.audio_conf = audio_conf
        self.labels = labels
        self.bidirectional = bidirectional

        sample_rate = self.audio_conf["sample_rate"]
        window_size = self.audio_conf["window_size"]
        num_classes = len(self.labels)

        self.conv = MaskConv(nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)),
            nn.BatchNorm2d(32),
            nn.Hardtanh(0, 20, inplace=True),
            nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)),
            nn.BatchNorm2d(32),
            nn.Hardtanh(0, 20, inplace=True)
        ))
        # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
        rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
        rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)
        rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)
        rnn_input_size *= 32

        rnns = []
        rnn = BatchRNN(input_size=rnn_input_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type,
                       bidirectional=bidirectional, batch_norm=False)
        rnns.append(('0', rnn))
        for x in range(nb_layers - 1):
            rnn = BatchRNN(input_size=rnn_hidden_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type,
                           bidirectional=bidirectional)
            rnns.append(('%d' % (x + 1), rnn))
        self.rnns = nn.Sequential(OrderedDict(rnns))
        self.lookahead = nn.Sequential(
            # consider adding batch norm?
            Lookahead(rnn_hidden_size, context=context),
            nn.Hardtanh(0, 20, inplace=True)
        ) if not bidirectional else None

        fully_connected = nn.Sequential(
            nn.BatchNorm1d(rnn_hidden_size),
            nn.Linear(rnn_hidden_size, num_classes, bias=False)
        )
        self.fc = nn.Sequential(
            SequenceWise(fully_connected),
        )
        self.inference_softmax = InferenceBatchSoftmax()

    def forward(self, x, lengths):
        lengths = lengths.cpu().int()
        output_lengths = self.get_seq_lens(lengths)
        x, _ = self.conv(x, output_lengths)

        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # Collapse feature dimension
        x = x.transpose(1, 2).transpose(0, 1).contiguous()  # TxNxH

        for rnn in self.rnns:
            x = rnn(x, output_lengths)

        if not self.bidirectional:  # no need for lookahead layer in bidirectional
            x = self.lookahead(x)

        x = self.fc(x)
        x = x.transpose(0, 1)
        # identity in training mode, softmax in eval mode
        x = self.inference_softmax(x)
        return x, output_lengths

    def get_seq_lens(self, input_length):
        """
        Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable
        containing the size sequences that will be output by the network.
        :param input_length: 1D Tensor
        :return: 1D Tensor scaled by model
        """
        seq_len = input_length
        for m in self.conv.modules():
            if type(m) == nn.modules.conv.Conv2d:
                seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) / m.stride[1] + 1)
        return seq_len.int()

    @classmethod
    def load_model(cls, path):
        package = torch.load(path, map_location=lambda storage, loc: storage)
        model = DeepSpeech.load_model_package(package)
        return model

    @classmethod
    def load_model_package(cls, package):
        model = cls(rnn_hidden_size=package['hidden_size'],
                    nb_layers=package['hidden_layers'],
                    labels=package['labels'],
                    audio_conf=package['audio_conf'],
                    rnn_type=supported_rnns[package['rnn_type']],
                    bidirectional=package.get('bidirectional', True))
        model.load_state_dict(package['state_dict'])
        return model

    def serialize_state(self):
        return {
            'hidden_size': self.hidden_size,
            'hidden_layers': self.hidden_layers,
            'rnn_type': supported_rnns_inv.get(self.rnn_type, self.rnn_type.__name__.lower()),
            'audio_conf': self.audio_conf,
            'labels': self.labels,
            'state_dict': self.state_dict(),
            'bidirectional': self.bidirectional,
        }

    @staticmethod
    def get_param_size(model):
        params = 0
        for p in model.parameters():
            tmp = 1
            for x in p.size():
                tmp *= x
            params += tmp
        return params



In [None]:
import torch
import numpy as np
import random
from torch.nn import CTCLoss
from tqdm import tqdm  # for progress bar
import pandas as pd  # for reading CSV files
# Ensure you import your custom modules as well


import json
import os
import random
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.nn import CTCLoss


class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def train_fn(args, train_loader, model, optimizer, criterion, epoch):
    model.train()
    losses = AverageMeter()

    t = tqdm(train_loader)
    for i, data in enumerate(t):

        inputs, targets, input_percentages, target_sizes = data
        input_sizes = input_percentages.mul_(int(inputs.size(3))).int()

        inputs = inputs.to(args.device)
        target = targets.to(args.device)
        target_sizes = target_sizes.to(args.device)

        out, output_sizes = model(inputs, input_sizes)
        out = out.transpose(0, 1) # TxNxH


        float_out = out.float() # ensure float32 for loss

        output_sizes = output_sizes.to(args.device)
        loss = criterion(float_out, targets, output_sizes, target_sizes).to(args.device)
        loss = loss / inputs.size(0) # average the loss by minibatch
        loss_value = loss.item()
        loss_value = loss.item()



        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        current_lr = optimizer.param_groups[0]['lr']


        losses.update(loss_value, inputs.size(0))

    return losses.avg


def test_fn(args, test_loader, model, decoder, target_decoder):

    model.eval()

    total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
    output_data = []

    with torch.no_grad():
        t = tqdm(test_loader)
        for i, data in enumerate(t):

            inputs, targets, input_percentages, target_sizes = data
            input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
            inputs = inputs.to(args.device)


            inputs = inputs.to(args.device)
            targets = targets.to(args.device)
            target_sizes = target_sizes.to(args.device)

            # unflatten targets
            split_targets = []
            offset = 0
            for size in target_sizes:
                split_targets.append(targets[offset:offset + size])
                offset += size

            out, output_sizes = model(inputs, input_sizes)

            decoded_output, _ = decoder.decode(out, output_sizes)
            target_strings = target_decoder.convert_to_strings(split_targets)

            # add output to data array, and continue
            output_data.append((out.cpu(), output_sizes, target_strings))

            for x in range(len(target_strings)):
                transcript, reference = decoded_output[x][0], target_strings[x][0]

                wer_inst = decoder.wer(transcript, reference)
                cer_inst = decoder.cer(transcript, reference)
                total_wer += wer_inst
                total_cer += cer_inst
                num_tokens += len(reference.split())
                num_chars += len(reference.replace(' ',''))

        wer = float(total_wer) / num_tokens
        cer = float(total_cer) / num_chars
        return wer * 100, cer * 100, output_data


def main(args):


    # Set seeds for determinism
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", args.device)



    labels = LABELS

    audio_conf = dict(sample_rate=args.sample_rate,
                      window_size=args.window_size,
                      window_stride=args.window_stride,
                      window=args.window)

    rnn_type = args.rnn_type.lower()
    assert rnn_type in supported_rnns, "rnn_type should be either lstm, rnn or gru"
    model = DeepSpeech(rnn_hidden_size=args.hidden_size,
                       nb_layers=args.hidden_layers,
                       labels=labels,
                       rnn_type=supported_rnns[rnn_type],
                       audio_conf=audio_conf,
                       bidirectional=args.bidirectional)


    # Data setup
    evaluation_decoder = GreedyDecoder(model.labels) # Decoder used for validation

    train_df = pd.read_csv(args.train_path, names=['audio_path', 'txt_path'])
    train_dataset = SpeechDataset(args=args, df=train_df)

    test_df = pd.read_csv(args.test_path, names=['audio_path', 'txt_path'])
    test_dataset = SpeechDataset(args=args, df=test_df)



    train_loader = AudioDataLoader(dataset=train_dataset,
                                   num_workers=args.num_workers,
                                   batch_size=args.batch_size)

    test_loader = AudioDataLoader(dataset=test_dataset,
                                  num_workers=args.num_workers,
                                  batch_size=args.batch_size)

    model = model.to(args.device)
    parameters = model.parameters()

    optimizer = torch.optim.AdamW(parameters,
                                  lr=args.learning_rate,
                                  betas=args.betas,
                                  eps=args.eps,
                                  weight_decay=args.weight_decay)

    criterion = CTCLoss()

    best_score = 99999

    for epoch in range(args.epochs):
        train_loss = train_fn(args, train_loader, model, optimizer, criterion, epoch)
        wer, cer, output_data = test_fn(args=args,
                                        test_loader=test_loader,
                                        model=model,
                                        decoder=evaluation_decoder,
                                        target_decoder=evaluation_decoder)

        print('Validation Summary Epoch: [{0}]\t'
              'Average WER {wer:.3f}\t'
              'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer))

        if (wer+cer)/2 < best_score:
            print("**** Model Improved !!!! Saving Model")
            torch.save(model.state_dict(), f"best_model.bin")
            best_score = (wer+cer)/2

if __name__ == '__main__':
    Args = args()
    main(args=args)

!nvidia-smi

if os.path.exists("best_model.bin"):
    model.load_state_dict(torch.load("best_model.bin"))
    print("Loaded the saved model.")

from google.colab import drive
drive.mount('/content/drive')
torch.save(model.state_dict(), "/content/drive/MyDrive/advanced machine learning/best_model.bin")

scaler = torch.cuda.amp.GradScaler()
for data in train_loader:
    optimizer.zero_grad()
    with torch.cuda.amp.autocast():
        output = model(data)
        loss = criterion(output, target)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()


with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
    on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),
    record_shapes=True
) as prof:
    train_fn(...)
print(prof.key_averages().table(sort_by="cuda_time_total"))





Using device: cuda


100%|██████████| 557/557 [34:52<00:00,  3.76s/it]
100%|██████████| 336/336 [25:27<00:00,  4.55s/it]


Validation Summary Epoch: [1]	Average WER 137.660	Average CER 83.014	
**** Model Improved !!!! Saving Model


100%|██████████| 557/557 [03:09<00:00,  2.93it/s]
100%|██████████| 336/336 [04:56<00:00,  1.13it/s]


Validation Summary Epoch: [2]	Average WER 145.767	Average CER 81.608	


100%|██████████| 557/557 [03:10<00:00,  2.92it/s]
100%|██████████| 336/336 [04:54<00:00,  1.14it/s]


Validation Summary Epoch: [3]	Average WER 50.576	Average CER 78.953	
**** Model Improved !!!! Saving Model


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [04:50<00:00,  1.16it/s]


Validation Summary Epoch: [4]	Average WER 30.366	Average CER 81.012	
**** Model Improved !!!! Saving Model


100%|██████████| 557/557 [03:12<00:00,  2.90it/s]
100%|██████████| 336/336 [04:35<00:00,  1.22it/s]


Validation Summary Epoch: [5]	Average WER 30.579	Average CER 85.693	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [04:21<00:00,  1.29it/s]


Validation Summary Epoch: [6]	Average WER 52.925	Average CER 89.073	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [04:14<00:00,  1.32it/s]


Validation Summary Epoch: [7]	Average WER 78.711	Average CER 89.347	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [04:14<00:00,  1.32it/s]


Validation Summary Epoch: [8]	Average WER 54.356	Average CER 88.836	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [04:02<00:00,  1.38it/s]


Validation Summary Epoch: [9]	Average WER 47.705	Average CER 89.166	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:56<00:00,  1.42it/s]


Validation Summary Epoch: [10]	Average WER 64.831	Average CER 92.206	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:51<00:00,  1.45it/s]


Validation Summary Epoch: [11]	Average WER 76.366	Average CER 94.128	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:49<00:00,  1.47it/s]


Validation Summary Epoch: [12]	Average WER 128.690	Average CER 96.511	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:44<00:00,  1.50it/s]


Validation Summary Epoch: [13]	Average WER 180.569	Average CER 97.465	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:38<00:00,  1.54it/s]


Validation Summary Epoch: [14]	Average WER 258.373	Average CER 98.121	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:31<00:00,  1.59it/s]


Validation Summary Epoch: [15]	Average WER 349.652	Average CER 98.522	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:32<00:00,  1.58it/s]


Validation Summary Epoch: [16]	Average WER 288.469	Average CER 97.907	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:26<00:00,  1.62it/s]


Validation Summary Epoch: [17]	Average WER 301.258	Average CER 97.881	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:25<00:00,  1.64it/s]


Validation Summary Epoch: [18]	Average WER 202.314	Average CER 96.865	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:16<00:00,  1.71it/s]


Validation Summary Epoch: [19]	Average WER 538.166	Average CER 99.937	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:17<00:00,  1.70it/s]


Validation Summary Epoch: [20]	Average WER 522.500	Average CER 99.764	


100%|██████████| 557/557 [03:12<00:00,  2.90it/s]
100%|██████████| 336/336 [03:17<00:00,  1.70it/s]


Validation Summary Epoch: [21]	Average WER 536.845	Average CER 99.903	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:19<00:00,  1.69it/s]


Validation Summary Epoch: [22]	Average WER 458.818	Average CER 99.439	


100%|██████████| 557/557 [03:12<00:00,  2.90it/s]
100%|██████████| 336/336 [03:21<00:00,  1.67it/s]


Validation Summary Epoch: [23]	Average WER 108.548	Average CER 94.070	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:23<00:00,  1.65it/s]


Validation Summary Epoch: [24]	Average WER 49.113	Average CER 89.782	


100%|██████████| 557/557 [03:12<00:00,  2.90it/s]
100%|██████████| 336/336 [03:26<00:00,  1.63it/s]


Validation Summary Epoch: [25]	Average WER 28.574	Average CER 86.185	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:28<00:00,  1.61it/s]


Validation Summary Epoch: [26]	Average WER 16.093	Average CER 83.971	
**** Model Improved !!!! Saving Model


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:29<00:00,  1.60it/s]


Validation Summary Epoch: [27]	Average WER 13.076	Average CER 84.028	
**** Model Improved !!!! Saving Model


100%|██████████| 557/557 [03:12<00:00,  2.89it/s]
100%|██████████| 336/336 [03:29<00:00,  1.61it/s]


Validation Summary Epoch: [28]	Average WER 13.777	Average CER 83.848	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:32<00:00,  1.58it/s]


Validation Summary Epoch: [29]	Average WER 9.410	Average CER 82.342	
**** Model Improved !!!! Saving Model


100%|██████████| 557/557 [03:12<00:00,  2.90it/s]
100%|██████████| 336/336 [03:29<00:00,  1.60it/s]


Validation Summary Epoch: [30]	Average WER 11.163	Average CER 82.898	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:28<00:00,  1.61it/s]


Validation Summary Epoch: [31]	Average WER 10.962	Average CER 82.503	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:29<00:00,  1.60it/s]


Validation Summary Epoch: [32]	Average WER 13.042	Average CER 83.927	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:30<00:00,  1.60it/s]


Validation Summary Epoch: [33]	Average WER 9.848	Average CER 82.193	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:48<00:00,  1.47it/s]


Validation Summary Epoch: [34]	Average WER 13.204	Average CER 79.805	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:31<00:00,  1.59it/s]


Validation Summary Epoch: [35]	Average WER 9.405	Average CER 89.063	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:30<00:00,  1.59it/s]


Validation Summary Epoch: [36]	Average WER 10.176	Average CER 87.170	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:29<00:00,  1.61it/s]


Validation Summary Epoch: [37]	Average WER 11.734	Average CER 90.409	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:40<00:00,  1.52it/s]


Validation Summary Epoch: [38]	Average WER 68.559	Average CER 93.065	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:39<00:00,  1.53it/s]


Validation Summary Epoch: [39]	Average WER 17.750	Average CER 79.539	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:40<00:00,  1.52it/s]


Validation Summary Epoch: [40]	Average WER 16.292	Average CER 81.069	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:41<00:00,  1.52it/s]


Validation Summary Epoch: [41]	Average WER 16.726	Average CER 81.569	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:40<00:00,  1.53it/s]


Validation Summary Epoch: [42]	Average WER 16.864	Average CER 80.647	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:41<00:00,  1.52it/s]


Validation Summary Epoch: [43]	Average WER 14.327	Average CER 78.853	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:42<00:00,  1.51it/s]


Validation Summary Epoch: [44]	Average WER 13.447	Average CER 78.191	
**** Model Improved !!!! Saving Model


100%|██████████| 557/557 [03:12<00:00,  2.89it/s]
100%|██████████| 336/336 [03:39<00:00,  1.53it/s]


Validation Summary Epoch: [45]	Average WER 14.073	Average CER 81.008	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:42<00:00,  1.51it/s]


Validation Summary Epoch: [46]	Average WER 13.192	Average CER 81.138	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:40<00:00,  1.52it/s]


Validation Summary Epoch: [47]	Average WER 14.670	Average CER 83.862	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:41<00:00,  1.52it/s]


Validation Summary Epoch: [48]	Average WER 15.849	Average CER 81.394	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:40<00:00,  1.52it/s]


Validation Summary Epoch: [49]	Average WER 13.852	Average CER 82.591	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:40<00:00,  1.52it/s]


Validation Summary Epoch: [50]	Average WER 14.655	Average CER 81.049	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:40<00:00,  1.53it/s]


Validation Summary Epoch: [51]	Average WER 15.104	Average CER 79.004	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:39<00:00,  1.53it/s]


Validation Summary Epoch: [52]	Average WER 12.970	Average CER 80.820	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:39<00:00,  1.53it/s]


Validation Summary Epoch: [53]	Average WER 12.851	Average CER 84.479	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:41<00:00,  1.52it/s]


Validation Summary Epoch: [54]	Average WER 15.792	Average CER 82.502	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:38<00:00,  1.54it/s]


Validation Summary Epoch: [55]	Average WER 16.599	Average CER 80.916	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:39<00:00,  1.53it/s]


Validation Summary Epoch: [56]	Average WER 15.509	Average CER 82.511	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:47<00:00,  1.48it/s]


Validation Summary Epoch: [57]	Average WER 17.061	Average CER 80.769	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:27<00:00,  1.62it/s]


Validation Summary Epoch: [58]	Average WER 24.650	Average CER 86.823	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:30<00:00,  1.59it/s]


Validation Summary Epoch: [59]	Average WER 44.962	Average CER 80.881	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:34<00:00,  1.57it/s]


Validation Summary Epoch: [60]	Average WER 24.138	Average CER 81.280	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [05:04<00:00,  1.10it/s]


Validation Summary Epoch: [61]	Average WER 510.460	Average CER 93.734	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:20<00:00,  1.67it/s]


Validation Summary Epoch: [62]	Average WER 261.933	Average CER 92.121	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:22<00:00,  1.66it/s]


Validation Summary Epoch: [63]	Average WER 233.796	Average CER 92.586	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:23<00:00,  1.65it/s]


Validation Summary Epoch: [64]	Average WER 222.607	Average CER 94.277	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:25<00:00,  1.64it/s]


Validation Summary Epoch: [65]	Average WER 179.023	Average CER 88.065	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:27<00:00,  1.62it/s]


Validation Summary Epoch: [66]	Average WER 105.740	Average CER 87.659	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:23<00:00,  1.65it/s]


Validation Summary Epoch: [67]	Average WER 139.171	Average CER 90.745	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:24<00:00,  1.64it/s]


Validation Summary Epoch: [68]	Average WER 223.674	Average CER 90.724	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:30<00:00,  1.59it/s]


Validation Summary Epoch: [69]	Average WER 34.202	Average CER 82.101	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:36<00:00,  1.55it/s]


Validation Summary Epoch: [70]	Average WER 40.899	Average CER 78.418	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:32<00:00,  1.58it/s]


Validation Summary Epoch: [71]	Average WER 82.394	Average CER 87.007	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:26<00:00,  1.62it/s]


Validation Summary Epoch: [72]	Average WER 257.998	Average CER 92.366	


100%|██████████| 557/557 [03:11<00:00,  2.90it/s]
100%|██████████| 336/336 [03:28<00:00,  1.61it/s]


Validation Summary Epoch: [73]	Average WER 226.408	Average CER 87.679	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:31<00:00,  1.59it/s]


Validation Summary Epoch: [74]	Average WER 235.464	Average CER 90.718	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:25<00:00,  1.63it/s]


Validation Summary Epoch: [75]	Average WER 277.493	Average CER 89.713	


100%|██████████| 557/557 [03:11<00:00,  2.91it/s]
100%|██████████| 336/336 [03:26<00:00,  1.62it/s]


Validation Summary Epoch: [76]	Average WER 238.720	Average CER 88.147	


 73%|███████▎  | 409/557 [02:10<00:47,  3.13it/s]


KeyboardInterrupt: 

In [None]:
import os
print("Number of CPU cores available:", os.cpu_count())


Number of CPU cores available: 12


In [12]:
pip install gradio

Collecting gradio
  Downloading gradio-5.8.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.6-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.4.0-py3-none-any.whl.metadata (2.9 kB)
Collecting gradio-client==1.5.1 (from gradio)
  Downloading gradio_client-1.5.1-py3-none-any.whl.metadata (7.1 kB)
Collecting markupsafe~=2.0 (from gradio)
  Downloading MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.19-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.2.2 (from gradio)
  Downloading ruff-0.8.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metad

In [15]:
# Install required libraries
!pip install gradio torchaudio

import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import MelSpectrogram, Resample
import numpy as np


class DeepSpeech(torch.nn.Module):
    @classmethod
    def load_model_package(cls, package):
        try:
            model = cls(
                hidden_size=package['hidden_size'],  # Key might need verification
                hidden_layers=package['hidden_layers'],
                labels=package['labels'],
                audio_conf=package['audio_conf'],
                rnn_type=supported_rnns[package['rnn_type']],
                bidirectional=package.get('bidirectional', True)
            )
            model.load_state_dict(package['state_dict'])
            return model
        except KeyError as e:
            raise ValueError(f"KeyError: {e}. Please verify the structure of your checkpoint.")


# Function to load the model
def load_deepspeech_model(model_path):
    package = torch.load(model_path, map_location=torch.device('cpu'))
    model = DeepSpeech.load_model_package(package)
    labels = package['labels']
    audio_conf = package['audio_conf']
    return model, labels, audio_conf


# Function for preprocessing and inference
def transcribe(audio, model, labels, audio_conf):
    try:
        # Load the waveform and sample rate
        waveform, sample_rate = torchaudio.load(audio.name)

        # Resample if necessary
        target_sample_rate = audio_conf["sample_rate"]
        if sample_rate != target_sample_rate:
            resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
            waveform = resampler(waveform)

        # Convert to MelSpectrogram
        mel_spectrogram = MelSpectrogram(
            sample_rate=target_sample_rate,
            n_fft=int(target_sample_rate * audio_conf["window_size"]),
            win_length=int(target_sample_rate * audio_conf["window_size"]),
            hop_length=int(target_sample_rate * audio_conf["window_stride"]),
            n_mels=80
        )
        spectrogram = mel_spectrogram(waveform)
        spectrogram = torch.log1p(spectrogram)

        # Normalize
        spectrogram -= spectrogram.mean()
        spectrogram /= spectrogram.std()

        # Add batch and channel dimensions
        spectrogram = spectrogram.unsqueeze(0).unsqueeze(0)

        # Lengths for RNN
        lengths = torch.tensor([spectrogram.size(3)], dtype=torch.int32)

        # Run inference
        model.eval()
        with torch.no_grad():
            output, _ = model(spectrogram, lengths)
            output = output[0].cpu().numpy()
            output = np.argmax(output, axis=-1)

        # Decode the transcription
        transcription = "".join([labels[i] for i in output if i != len(labels) - 1])
        return transcription

    except Exception as e:
        return f"Error during transcription: {str(e)}"


# Load the model
model_path = "/content/drive/MyDrive/advanced-machine-learning/best_model.bin"
model, labels, audio_conf = load_deepspeech_model(model_path)

# Create the Gradio interface
def gradio_transcribe(audio):
    return transcribe(audio, model, labels, audio_conf)

interface = gr.Interface(
    fn=gradio_transcribe,
    inputs=gr.Audio(source="upload", type="file"),
    outputs="text",
    title="Speech Recognition",
    description="Upload an audio file and get its transcription."
)

# Launch the interface
interface.launch()




  package = torch.load(model_path, map_location=torch.device('cpu'))


ValueError: KeyError: 'hidden_size'. Please verify the structure of your checkpoint.