In [33]:
import librosa
import datasets
ds = datasets.load_from_disk('../data/hf')
ds = ds['validation']

In [34]:
class ASRAlphabet:
    def __init__(self, alphabet_string):
        self.alphabet = list(alphabet_string)
        self.alphabet.append('<blank>')  # Adding blank token
        self.char_to_index = {char: index for index, char in enumerate(self.alphabet)}
        self.index_to_char = {index: char for index, char in enumerate(self.alphabet)}

    def text_to_array(self, text):
        return [self.char_to_index[char] for char in str.lower(text) if char in self.char_to_index]

    def array_to_text(self, array):
        return ''.join(self.index_to_char.get(int(index), '<UNK>') for index in array)
    
alphabet = ASRAlphabet(alphabet_string='abcdefghijklmnopqrstuvwxyzåäö ')

In [35]:
import torchaudio
import torch

mel_spectrogram_converter = torchaudio.transforms.MelSpectrogram(
        sample_rate=16000,
        n_fft=400,
        n_mels=40
)

# from audio_utils
def preprocess_audio(batch):

    max_input_length = 0
    max_label_length = 0

    audio = []
    label = []
    audio_length = []
    label_length = []

    for item in batch:
        t1 = torch.tensor(item['audio']['array']).float()

        sample_rate = item['audio']['sampling_rate']
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            t1 = resampler(t1)
        
        t1 = torchaudio.functional.preemphasis(t1)

        mel_spectrogram = mel_spectrogram_converter(t1)

        # Apply log to the mel spectrogram
        mel_spectrogram = torch.log(mel_spectrogram + 1e-9)

        # Normalize the spectrogram
        mel_spectrogram = (mel_spectrogram - mel_spectrogram.mean()) / mel_spectrogram.std()

        # Transpose the mel spectrogram to correct dimension (time, mels)
        mel_spectrogram = mel_spectrogram.transpose(0, 1)

        # eka bugi löyty (väärä shape (1))
        max_input_length = max(max_input_length, mel_spectrogram.shape[0])

        sentence = torch.tensor(alphabet.text_to_array(str.lower(item['sentence'])))

        max_label_length = max(max_label_length, len(sentence))

        audio.append(mel_spectrogram)
        label.append(sentence)
        audio_length.append(mel_spectrogram.shape[0])
        label_length.append(len(sentence))

    audio_padded = map(lambda x: torch.nn.functional.pad(x, (0, 0, 0, max_input_length - x.size(0))), audio)
    blank_index = len(alphabet.alphabet) - 1  # Index of the blank token
    labels_padded = map(lambda x: torch.nn.functional.pad(x, (0, max_label_length - len(x)), value=blank_index), label)

    return (list(audio_padded), list(labels_padded), audio_length, label_length)

In [36]:
audio, label, audio_len, label_len = preprocess_audio([ds[0]])


In [37]:
import torch.nn as nn
import torch.nn.functional as F

class LSTMCTC(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, num_layers, dropout_rate):
        super(LSTMCTC, self).__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=True,
            dropout=dropout_rate,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size * 2, num_classes)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # x shape: (batch_size, seq_len, input_size)
        x, _ = self.lstm(x)
        x = self.dropout(x)  
        return self.fc(x)

In [38]:
model = LSTMCTC(
        input_size=40,  # mels
        hidden_size=320,
        num_layers=4,
        num_classes=len(alphabet.alphabet),
        dropout_rate=0.2
)
model.load_state_dict(torch.load('../gd.pth', weights_only=True))
model.eval()

LSTMCTC(
  (lstm): LSTM(40, 320, num_layers=4, batch_first=True, dropout=0.2, bidirectional=True)
  (fc): Linear(in_features=640, out_features=31, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [39]:
import torch
from typing import List

class SimpleDecoder:
    def __init__(self, alphabet):
        self.alphabet = alphabet
        self.blank_index = len(alphabet)  # Assuming blank is the last index

    def decode(self, log_probs: torch.Tensor) -> List[str]:
        """
        Decode log probabilities to text using greedy decoding.
        
        Args:
        log_probs (torch.Tensor): Log probabilities from the model 
                                  Shape: (batch_size, sequence_length, num_classes)
        
        Returns:
        List[str]: Decoded texts for each item in the batch
        """
        # Get the most likely class at each step
        predictions = torch.argmax(log_probs, dim=-1)  # Shape: (batch_size, sequence_length)
        
        batch_texts = []
        for batch_item in predictions:
            text = self._decode_prediction(batch_item)
            batch_texts.append(text)
        
        return batch_texts

    def _decode_prediction(self, prediction: torch.Tensor) -> str:
        """
        Decode a single prediction sequence to text.
        
        Args:
        prediction (torch.Tensor): Prediction sequence for a single item
                                   Shape: (sequence_length,)
        
        Returns:
        str: Decoded text
        """
        decoded = []
        previous = None
        for p in prediction:
            p = p.item()
            if p != previous and p != self.blank_index:
                if p < len(self.alphabet):
                    decoded.append(self.alphabet[p])
            previous = p
        
        return ''.join(decoded)


In [40]:
audio = torch.stack(audio, dim=0)
output = model(audio)
print(output.shape)

torch.Size([1, 505, 31])


In [41]:
dec = SimpleDecoder(alphabet=alphabet.alphabet)
log_probs = F.log_softmax(output, dim=-1)
#log_probs = F.log_softmax(output, dim=2).permute(1, 0, 2)
print(log_probs.shape, str(dec.decode(log_probs)).replace('<blank>', ''), ds[0]['sentence'])

torch.Size([1, 505, 31]) ['hirvityshyökkäsiitä lähinnä olevan tankin kimppuun'] Hirvitys hyökkäsi sitä lähinnä olevan tankin kimppuun.


In [42]:
audio, label, audio_len, label_len = preprocess_audio([ds[0], ds[1], ds[2]])

In [43]:
audio = torch.stack(audio, dim=0)
print(audio.shape)

torch.Size([3, 505, 40])


In [44]:
labels = torch.stack(label, dim=0)
label_len = torch.tensor(label_len)
print(labels.shape, label_len)

torch.Size([3, 54]) tensor([53, 54, 20])


In [45]:
print(labels)

tensor([[ 7,  8, 17, 21,  8, 19, 24, 18, 29,  7, 24, 28, 10, 10, 27, 18,  8, 29,
         18,  8, 19, 27, 29, 11, 27,  7,  8, 13, 13, 27, 29, 14, 11,  4, 21,  0,
         13, 29, 19,  0, 13, 10,  8, 13, 29, 10,  8, 12, 15, 15, 20, 20, 13, 30],
        [27, 27, 13,  4, 18, 19, 27, 12, 12,  4, 29, 18,  8,  8, 18, 29, 19, 27,
         12, 27, 13, 18, 20, 20, 13, 19,  0,  8, 18, 19,  4, 13, 29, 19,  0, 17,
         10,  8, 18, 19, 20, 18, 19,  4, 13, 29, 15, 20, 14, 11,  4, 18, 19,  0],
        [10, 20,  8, 13, 29, 10,  8, 12, 15, 15, 20, 20, 13, 29,  0, 12, 15, 20,
          4, 13, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
         30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30]])


In [46]:
print(f"Labels shape: {labels.shape}")
print(f"Label_len shape: {label_len.shape}")
ref = []
for i in range(labels.shape[0]):  # This should iterate 3 times
    single_label = labels[i]
    length = label_len[i]
    
    print(f"\nProcessing sample {i}")
    clipped_label = single_label[:length]
    print(f"Clipped label: {clipped_label}")
   
    text = alphabet.array_to_text(clipped_label)
    print(f"Converted text: '{text}'")
    print(f"Text length: {len(text)}")
   
    ref.append(text)

print(f"\nFinal ref: {ref}")

Labels shape: torch.Size([3, 54])
Label_len shape: torch.Size([3])

Processing sample 0
Clipped label: tensor([ 7,  8, 17, 21,  8, 19, 24, 18, 29,  7, 24, 28, 10, 10, 27, 18,  8, 29,
        18,  8, 19, 27, 29, 11, 27,  7,  8, 13, 13, 27, 29, 14, 11,  4, 21,  0,
        13, 29, 19,  0, 13, 10,  8, 13, 29, 10,  8, 12, 15, 15, 20, 20, 13])
Converted text: 'hirvitys hyökkäsi sitä lähinnä olevan tankin kimppuun'
Text length: 53

Processing sample 1
Clipped label: tensor([27, 27, 13,  4, 18, 19, 27, 12, 12,  4, 29, 18,  8,  8, 18, 29, 19, 27,
        12, 27, 13, 18, 20, 20, 13, 19,  0,  8, 18, 19,  4, 13, 29, 19,  0, 17,
        10,  8, 18, 19, 20, 18, 19,  4, 13, 29, 15, 20, 14, 11,  4, 18, 19,  0])
Converted text: 'äänestämme siis tämänsuuntaisten tarkistusten puolesta'
Text length: 54

Processing sample 2
Clipped label: tensor([10, 20,  8, 13, 29, 10,  8, 12, 15, 15, 20, 20, 13, 29,  0, 12, 15, 20,
         4, 13])
Converted text: 'kuin kimppuun ampuen'
Text length: 20

Final ref: ['hirv

In [47]:
output = model(audio)
print(output.shape)

torch.Size([3, 505, 31])


In [48]:
dec = SimpleDecoder(alphabet=alphabet.alphabet)
log_probs = F.log_softmax(output, dim=-1)
#log_probs = F.log_softmax(output, dim=2).permute(1, 0, 2)
print(log_probs.shape)
decoded_batch = dec.decode(log_probs)
cleaned_decoded = [str(text).replace('<blank>', '') for text in decoded_batch]
print(cleaned_decoded)
print(ref)

torch.Size([3, 505, 31])
['hirvityshyökkäsiitä lähinnä olevan tankin kimppuun', 'äänestämme siis tämäin suntaisten tarkistusten puolestan', 'kuin kimppuun ampuen']
['hirvitys hyökkäsi sitä lähinnä olevan tankin kimppuun', 'äänestämme siis tämänsuuntaisten tarkistusten puolesta', 'kuin kimppuun ampuen']


In [49]:
import torch
from typing import List

class Alphabet:
    def __init__(self, alphabet_string = 'abcdefghijklmnopqrstuvwxyzåäö '):
        self.alphabet = list(alphabet_string)
        self.alphabet.append('<blank>')  # Adding blank token
        self.char_to_index = {char: index for index, char in enumerate(self.alphabet)}
        self.index_to_char = {index: char for index, char in enumerate(self.alphabet)}
        self.blank_index = len(self.alphabet) - 1  # Blank is the last index

    def text_to_array(self, text):
        return [self.char_to_index[char] for char in str.lower(text) if char in self.char_to_index]

    def array_to_text(self, array):
        return ''.join(self.index_to_char.get(int(index), '<UNK>') for index in array)

    def decode(self, log_probs: torch.Tensor, remove_blanks: bool = False) -> List[str]:
        """
        Decode log probabilities to text using simple greedy decoding.
        
        Args:
        log_probs (torch.Tensor): Log probabilities from the model
                                  Shape: (batch_size, sequence_length, num_classes)
        remove_blanks (bool): If True, remove all blank tokens from the output
        
        Returns:
        List[str]: Decoded texts for each item in the batch
        """
        # Get the most likely class at each step
        predictions = torch.argmax(log_probs, dim=-1)  # Shape: (batch_size, sequence_length)
        
        batch_texts = []
        for batch_item in predictions:
            text = self._decode_prediction(batch_item, remove_blanks)
            batch_texts.append(text)
        
        return batch_texts

    def _decode_prediction(self, prediction: torch.Tensor, remove_blanks: bool) -> str:
        """
        Decode a single prediction sequence to text.
        
        Args:
        prediction (torch.Tensor): Prediction sequence for a single item
                                   Shape: (sequence_length,)
        remove_blanks (bool): If True, remove all blank tokens from the output
        
        Returns:
        str: Decoded text
        """
        decoded = []
        previous = None
        for p in prediction:
            p = p.item()
            if remove_blanks:
                if p != self.blank_index and p != previous:
                    if p < len(self.alphabet) - 1:  # Exclude blank token
                        decoded.append(self.alphabet[p])
            else:
                if p != previous:
                    if p < len(self.alphabet):  # Include blank token
                        decoded.append(self.alphabet[p])
            previous = p
        
        return ''.join(decoded)
    
ap = Alphabet()

In [50]:
log_probs = F.log_softmax(output, dim=-1)
#log_probs = F.log_softmax(output, dim=2).permute(1, 0, 2)
print(log_probs.shape)
decoded_batch = ap.decode(log_probs)
cleaned_decoded = [str(text).replace('<blank>', '') for text in decoded_batch]
print(cleaned_decoded)
print(ref)

torch.Size([3, 505, 31])
['hirvityshyökkäsiitä lähinnä olevan tankin kimppuun', 'äänestämme siis tämäin suntaisten tarkistusten puolestan', 'kuin kimppuun ampuen']
['hirvitys hyökkäsi sitä lähinnä olevan tankin kimppuun', 'äänestämme siis tämänsuuntaisten tarkistusten puolesta', 'kuin kimppuun ampuen']


In [51]:
log_probs = F.log_softmax(output, dim=-1)
#log_probs = F.log_softmax(output, dim=2).permute(1, 0, 2)
print(log_probs.shape)
decoded_batch = ap.decode(log_probs, remove_blanks=True)
print(decoded_batch)
print(ref)

torch.Size([3, 505, 31])
['hirvityshyökkäsiitä lähinnä olevan tankin kimppuun', 'äänestämme siis tämäin suntaisten tarkistusten puolestan', 'kuin kimppuun ampuen']
['hirvitys hyökkäsi sitä lähinnä olevan tankin kimppuun', 'äänestämme siis tämänsuuntaisten tarkistusten puolesta', 'kuin kimppuun ampuen']


In [53]:
batch = preprocess_audio([ds[0], ds[1], ds[2]])

In [57]:

audio_padded, labels_padded, audio_lengths, label_lengths = batch

audio = torch.stack(audio_padded, dim=0)
# Shape: (batch_size, max_label_length)
labels = torch.stack(labels_padded, dim=0)
input_lengths = torch.tensor(audio_lengths)
target_lengths = torch.tensor(label_lengths)
# compute loss
output = model(audio)
log_probs = F.log_softmax(output, dim=2).permute(1, 0, 2)
# Calculate log_probs for decoding (without permute)
log_probs_decode = F.log_softmax(output, dim=-1)
# Decode using SimpleDecoder
decoded_text = ap.decode(log_probs_decode, remove_blanks=True)
#compute wer & cer
ref = []
for label, length in zip(labels, label_lengths):
    clipped_label = label[:length]
    ref.append(alphabet.array_to_text(clipped_label.cpu().numpy()))
print(f"""
ref0 {ref[0]}
dec0 {decoded_text[0]}
logprobs decoder shape {log_probs_decode.shape}
""")


ref0 hirvitys hyökkäsi sitä lähinnä olevan tankin kimppuun
dec0 hirvityshyökkäsiitä lähinnä olevan tankin kimppuun
logprobs decoder shape torch.Size([3, 505, 31])

