In [47]:
import torchaudio
import torch
import IPython
import matplotlib
import matplotlib.pyplot as plt
import tensorflow as tf
import os
import numpy as np
import random

In [48]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H


print('Sample Rate: ', bundle.sample_rate)
print('Labels: ', bundle.get_labels())

Sample Rate:  16000
Labels:  ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')


In [39]:
model = bundle.get_model()

In [40]:
DATASET_PATH = './speech_commands_v0.01/'

train_ds, val_ds = tf.keras.utils.audio_dataset_from_directory(
    directory = DATASET_PATH,
    batch_size = 64,
    output_sequence_length = 16000,
    validation_split = 0.2,
    subset = "both",
    seed = 0
)

test_ds = val_ds.shard(num_shards = 2, index = 0)
val_ds = val_ds.shard(num_shards = 2, index = 1)

Found 64727 files belonging to 31 classes.
Using 51782 files for training.
Using 12945 files for validation.


In [41]:
label_names = np.array(train_ds.class_names)
print(label_names)

['_background_noise_' 'bed' 'bird' 'cat' 'dog' 'down' 'eight' 'five'
 'four' 'go' 'happy' 'house' 'left' 'marvin' 'nine' 'no' 'off' 'on' 'one'
 'right' 'seven' 'sheila' 'six' 'stop' 'three' 'tree' 'two' 'up' 'wow'
 'yes' 'zero']


In [42]:
def squeeze(audio, label):
    audio = tf.squeeze(audio, axis=-1)
    return audio, label

train_ds = train_ds.map(squeeze, tf.data.AUTOTUNE)
val_ds = val_ds.map(squeeze, tf.data.AUTOTUNE)
test_ds = test_ds.map(squeeze, tf.data.AUTOTUNE)

In [43]:
def get_spectrogram(audio):
    spectrogram = tf.signal.stft(audio, frame_length=255, frame_step=128)
    spectrogram = tf.abs(spectrogram)
    spectrogram = spectrogram[..., tf.newaxis]
    
    return spectrogram


def make_spectrogram_ds(ds):
    return ds.map(map_func = lambda audio, label: (get_spectrogram(audio), label), num_parallel_calls=tf.data.AUTOTUNE)


train_spectrogram_ds = make_spectrogram_ds(train_ds)
test_spectrogram_ds = make_spectrogram_ds(test_ds)
val_spectrogram_ds = make_spectrogram_ds(val_ds)

train_spectrogram_ds= train_spectrogram_ds.cache().shuffle(2030).prefetch(tf.data.AUTOTUNE)
test_spectrogram_ds = test_spectrogram_ds.cache().prefetch(tf.data.AUTOTUNE)
val_spectrogram_ds = val_spectrogram_ds.cache().prefetch(tf.data.AUTOTUNE)

In [44]:
for sample_ds, sample_ds_labels in test_ds:
    break
def convert_to_label(labels):
    # return [label_names[label.numpy()] for label in labels]
    return [label_names[label] for label in labels]

In [45]:
with torch.inference_mode():
    emission, _ = model(torch.from_numpy(np.asarray(sample_ds)))

In [30]:
class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, labels):
        super().__init__()
        self.labels = labels
        self.blank = 0
        
    def forward(self, emission: torch.Tensor) -> str:
        indices = torch.argmax(emission, dim = -1)
        indices = torch.unique_consecutive(indices, dim = -1)
        indices = [i for i in indices if i != self.blank]
        kalimat = ''.join([self.labels[i] for i in indices])
        kalimat = kalimat.replace('|', ' ')
        return kalimat.lower()

In [31]:
decoder = GreedyCTCDecoder(labels = bundle.get_labels())
text = decoder(emission[0])

In [32]:
num = len(sample_ds_labels)

In [51]:
import jiwer
count = 0
total = 0
for sample_ds, sample_ds_labels in test_ds:
    with torch.inference_mode():
        emission, _ = model(torch.from_numpy(np.asarray(sample_ds)))
        decoder = GreedyCTCDecoder(labels = bundle.get_labels())
        predict_label = decoder(emission[0])
    
    
    # for _, label in test_ds:
    #     true_label.append(label)
    
    # true_label = tf.concat(label, axis = 0)
    # true_label = convert_to_label(true_label)    

    true_label = convert_to_label(sample_ds_labels)
    true_label = np.array(true_label)
    # for i in range(len(sample_ds_labels)):
    print("True Label :" , true_label[0] ,"Predicted Label :", predict_label)
    count+=1
    wer = jiwer.wer(true_label[0], predict_label)
    total+=wer

True Label : stop Predicted Label : stop 
True Label : three Predicted Label : three 
True Label : on Predicted Label : on 
True Label : seven Predicted Label : seven 
True Label : stop Predicted Label : stop 
True Label : up Predicted Label : up 
True Label : down Predicted Label : down 
True Label : five Predicted Label : five 
True Label : two Predicted Label : two 
True Label : four Predicted Label : four 
True Label : five Predicted Label : five 
True Label : stop Predicted Label : stop 
True Label : zero Predicted Label : zero 
True Label : marvin Predicted Label : marvin 
True Label : no Predicted Label : no 
True Label : four Predicted Label : four 
True Label : marvin Predicted Label : marvin 
True Label : down Predicted Label : don 
True Label : zero Predicted Label : zero 
True Label : seven Predicted Label : seven 
True Label : sheila Predicted Label : shila 
True Label : one Predicted Label : 
True Label : stop Predicted Label : stop 
True Label : three Predicted Label : t

In [52]:
print(total/count)

0.2647058823529412


In [53]:
print(f"Word Error Rate: {(total/count)*100} %")

Word Error Rate: 26.47058823529412 %
