## **USER ACTION REQUIRED**

- Upload ```silero_vad.onnx```
- Upload ```audio file```

## **Define ASR Model**

In [1]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0


In [2]:
!pip uninstall -y torch torchvision torchaudio triton nvidia-cublas-cu12 nvidia-cuda-runtime-cu12 nvidia-cudnn-cu12 nvidia-cufft-cu12 nvidia-curand-cu12 nvidia-cusolver-cu12 nvidia-cusparse-cu12 nvidia-nccl-cu12

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Found existing installation: triton 3.2.0
Uninstalling triton-3.2.0:
  Successfully uninstalled triton-3.2.0
Found existing installation: nvidia-cublas-cu12 12.5.3.2
Uninstalling nvidia-cublas-cu12-12.5.3.2:
  Successfully uninstalled nvidia-cublas-cu12-12.5.3.2
Found existing installation: nvidia-cuda-runtime-cu12 12.5.82
Uninstalling nvidia-cuda-runtime-cu12-12.5.82:
  Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82
Found existing installation: nvidia-cudnn-cu12 9.3.0.75
Uninstalling nvidia-cudnn-cu12-9.3.0.75:
  Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75
Found exi

In [3]:
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu125

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu125
Collecting torch
  Downloading torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading torchaudio-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6

In [None]:
!pip install -q datasets bitsandbytes accelerate onnxruntime evaluate jiwer

In [None]:
!pip install pydub

In [4]:
import torch
import triton
import torch.backends.cudnn as cudnn

print("Torch Version:", torch.__version__)
print("Torch CUDA Version:", torch.version.cuda)
print("CUDA Available:", torch.cuda.is_available())
print("Triton Version:", triton.__version__)
print("cuDNN Version:", cudnn.version())

Torch Version: 2.6.0+cu124
Torch CUDA Version: 12.4
CUDA Available: True
Triton Version: 3.2.0
cuDNN Version: 90100


In [None]:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from google.colab import drive
import time
import warnings
import numpy as np
import onnxruntime
from datasets import load_dataset
from itertools import islice
from torch.utils.data import IterableDataset
import json
import os
import wave
import threading
import textwrap
from IPython.display import clear_output
import evaluate
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
import collections
import pickle
import random
from pydub import AudioSegment

<br/>
<br/>
<br/>

**User Action Required**
- Define model to use

In [None]:
# Google Drive Env Setup
whisper_ver = 'whisper-base'
checkpoint_num = '2600'
drive.mount('/content/drive')
checkpoint_path = f'/content/drive/My Drive/{whisper_ver}-noiseaugmented-minieval-checkpoints/checkpoint-{checkpoint_num}'

# Model setup code for fine-tuned whisper
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_path).to(device)
model.config.use_cache = True
processor = WhisperProcessor.from_pretrained(f"openai/{whisper_ver}", language="en", task="transcribe")

Mounted at /content/drive


preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

## **Define VAD model-related dependencies**

In [None]:
class OnnxWrapper():
    '''
    Code taken from: https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/utils_vad.py

    See https://github.com/collabora/WhisperLive/blob/main/whisper_live/vad.py
    '''
    def __init__(self, path, force_onnx_cpu=False):
        opts = onnxruntime.SessionOptions()
        opts.inter_op_num_threads = 1
        opts.intra_op_num_threads = 1

        if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
            self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
        else:
            self.session = onnxruntime.InferenceSession(path, sess_options=opts)

        self.reset_states()
        if '16k' in path:
            warnings.warn('This model support only 16000 sampling rate!')
            self.sample_rates = [16000]
        else:
            self.sample_rates = [8000, 16000]

    def _validate_input(self, x, sr: int):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        if x.dim() > 2:
            raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")

        if sr != 16000 and (sr % 16000 == 0):
            step = sr // 16000
            x = x[:,::step]
            sr = 16000

        if sr not in self.sample_rates:
            raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
        if sr / x.shape[1] > 31.25:
            raise ValueError("Input audio chunk is too short")

        return x, sr

    def reset_states(self, batch_size=1):
        self._state = torch.zeros((2, batch_size, 128)).float()
        self._context = torch.zeros(0)
        self._last_sr = 0
        self._last_batch_size = 0

    def __call__(self, x, sr: int):

        x, sr = self._validate_input(x, sr)
        num_samples = 512 if sr == 16000 else 256

        if x.shape[-1] != num_samples:
            raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")

        batch_size = x.shape[0]
        context_size = 64 if sr == 16000 else 32

        if not self._last_batch_size:
            self.reset_states(batch_size)
        if (self._last_sr) and (self._last_sr != sr):
            self.reset_states(batch_size)
        if (self._last_batch_size) and (self._last_batch_size != batch_size):
            self.reset_states(batch_size)

        if not len(self._context):
            self._context = torch.zeros(batch_size, context_size)

        x = torch.cat([self._context, x], dim=1)
        if sr in [8000, 16000]:
            ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
            ort_outs = self.session.run(None, ort_inputs)
            out, state = ort_outs
            self._state = torch.from_numpy(state)
        else:
            raise ValueError()

        self._context = x[..., -context_size:]
        self._last_sr = sr
        self._last_batch_size = batch_size

        out = torch.from_numpy(out)
        return out

    def audio_forward(self, x, sr: int):
        outs = []
        x, sr = self._validate_input(x, sr)
        self.reset_states()
        num_samples = 512 if sr == 16000 else 256

        if x.shape[1] % num_samples:
            pad_num = num_samples - (x.shape[1] % num_samples)
            x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)

        for i in range(0, x.shape[1], num_samples):
            wavs_batch = x[:, i:i+num_samples]
            out_chunk = self.__call__(wavs_batch, sr)
            outs.append(out_chunk)

        stacked = torch.cat(outs, dim=1)
        return stacked.cpu()




class VoiceActivityDetector:
    '''
    See https://github.com/collabora/WhisperLive/blob/main/whisper_live/vad.py

    Onxx model:
    '''
    def __init__(self, path, threshold=0.5, frame_rate=16000):
        """
        Initializes the VoiceActivityDetector with a voice activity detection model and a threshold.

        Args:
            threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5.
        """
        self.model = OnnxWrapper(path=path,force_onnx_cpu=True)
        self.threshold = threshold
        self.frame_rate = frame_rate

    def __call__(self, audio_frame):
        """
        Determines if the given audio frame contains speech by comparing the detected speech probability against
        the threshold.

        Args:
            audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a
                                      NumPy array of audio samples.

        Returns:
            bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
                  False otherwise.
        """
        speech_probs = self.model.audio_forward(torch.from_numpy(audio_frame.copy()), self.frame_rate)[0]
        return torch.any(speech_probs > self.threshold).item()

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

## **Loop to simulate live transcription**

Re-run this block for each new file

## **Define fnames**

In [None]:
audio_fname = 'ryan_jone_2.wav'

In [None]:
reference_fname = 'ryan_jone_2.txt'

## **Define Server-Related Classes**

In [None]:
# Helper function Provided by Alexander Veysov
# See
# - https://github.com/snakers4/silero-vad/blob/master/examples/pyaudio-streaming/pyaudio-streaming-examples.ipynb
# Or (different implementations)
# - https://github.com/snakers4/silero-vad/blob/master/examples/microphone_and_webRTC_integration/microphone_and_webRTC_integration.py
def int2float(sound):
    abs_max = np.abs(sound).max()
    sound = sound.astype('float32')
    if abs_max > 0:
        sound *= 1/32768
    sound = sound.squeeze()  # depends on the use case
    return sound


class Client:
    def __init__(self, transcriber, transcriber_processor, RATE=16000, idle_time=0):
        self.frames_np = None # Store frames buffer as a numpy array
        self.frames_offset = 0.0 # Track frames offset from the very start/Duration of audio discarded
        self.timestamp_offset = 0.0 # Track transcription offset from the very start. Transcription progress tracker
        self.eos = False # End-Of-Speech Flag
        self.transcriber = transcriber # Initialize Whisper ASR model
        self.transcriber_processor = transcriber_processor # Initialize Whisper ASR processor
        self.transcript = [] # Store fully transcribed segments
        self.lock = threading.Lock() # for shared resources: frames_np. https://realpython.com/python-thread-lock/#threadinglock-for-primitive-locking
        self.RATE = RATE

        # Thread to run speech-to-text function
        '''
        Main "Entry point" for client
        '''
        self.running = True
        self.trans_thread = threading.Thread(target=self.speech_to_text)
        self.trans_thread.start()

        self.idle_time=idle_time


    def add_frames(self, frame_np):
        '''
        Manage the ongoing buffer's size - update frames_offset and timestamp_offset as necessary
        Add new audio chunks to the client's frames buffer
        '''
        # Manage the ongoing buffer
        # Lock required as frames_np, frames_offset, timestamp_offset is a shared resource with the speech_to_text thread code
        # lock the critical section of code below (lock shared resources, no 2 threads can modify)
        self.lock.acquire()
        # If the buffer is more than 45s
        if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
            # Increase frames_offset by 30s
            # Discard oldest 30s of audio from buffer
            self.frames_offset += 30.0
            self.frames_np = self.frames_np[int(30*self.RATE):]
            # Update timestamp_offset
            if self.timestamp_offset < self.frames_offset:
                self.timestamp_offset = self.frames_offset

        # If the frame buffer is empty, initialise it with the new audio frames
        if self.frames_np is None:
            self.frames_np = frame_np.copy()
        # Else, append the new audio chunk to the existing buffer
        else:
            self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
        # unlock the critical section of code above
        self.lock.release()


    def speech_to_text(self):
        while self.running:
            # Wait for some chunks to arrive
            if self.frames_np is None:
                time.sleep(0.02)
                continue

            # Adjust the timestamp_offset/transcription tracker
            self.clip_audio_if_no_valid_segment()

            # Get the audio to be transcribed using the timestamp_offset
            input_bytes, duration = self.get_audio_chunk_for_processing()
            # Prevents hallucination on too short audio
            if duration < 0.4:
                continue

            # Transcribe the audio, print the transcriptions to be printed,
            # Update the timestamp_offset if prolonged silence was encountered
            # - it means the latest segment is finalized
            input_sample = input_bytes.copy()
            self.transcribe_audio(input_sample, duration)


    def clip_audio_if_no_valid_segment(self):
        '''
        If there is > 25s of audio to transcribe in the buffer,
        adjust timestamp_offset such that its only 5s behind the total audio added so far

        If there is <=25s of audio to transcribe, its okay
        '''
        # Lock is required because frames_np, frames_offset, timestamp_offset is a shared resource
        # Written as a context manager -> Auto acquires and release lock
        with self.lock:
            # If there is more than 25s of audio to transcribe
            if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE:
                # Adjust timestamp_offset s.t it is only 5s behind the total audio added so far
                duration = self.frames_np.shape[0] / self.RATE
                self.timestamp_offset = self.frames_offset + duration - 5


    def get_audio_chunk_for_processing(self):
        '''
        Get the audio to be transcribed from the buffer calculated using timestamp_offset
        '''
        # Use timestamp_offset to help subset the buffer to get the audio to be transcribed
        # Lock is required because frames_np, frames_offset, timestamp_offset is a shared resource
        # Written as a context manager -> Auto acquires and release lock
        with self.lock:
            samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
            input_bytes = self.frames_np[int(samples_take):].copy()
        duration = input_bytes.shape[0] / self.RATE
        return input_bytes, duration


    def transcribe_audio(self, input_bytes, duration):
        '''
        Transcribe the audio to be transcribed
        '''
        # Transcribe the audio to be transcribed
        # last_segment = self.transcriber.transcribe(input_bytes)
        input_features = self.transcriber_processor.feature_extractor(input_bytes, sampling_rate=16000).input_features[0]
        input_features = torch.tensor(input_features).unsqueeze(0).to(device)
        generated_ids = self.transcriber.generate(input_features)
        last_segment = self.transcriber_processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)

        # Print the transcriptions to be printed and
        # if there is prolonged silence,
        # - Save the latest speech segment
        # - Update the timestamp_offset s.t the portion is completed/does not need to be transcribed anymore
        self.handle_transcription_output(last_segment, duration)


    def handle_transcription_output(self, last_segment, duration):
        '''
        Print the transcriptions to be printed
        If the client's end-of-speech flag is True
        (prolonged silence, means that speech segment is finalized),
        - Save the latest speech segment
        - Update the timestamp_offset
        '''
        self.send_transcription_to_client({"text": last_segment})
        if self.get_eos():
            self.update_timestamp_offset(last_segment, duration)
            # Add the finalized segment (finalized with full context) to the global current predictions variable
            current_prediction.append(last_segment)


    def send_transcription_to_client(self, segments):
        '''
        Print the transcription
        '''
        """Processes transcript segments."""
        text = [segments["text"]]

        '''
        For cmd line like outputs
        os.system("cls" if os.name == "nt" else "clear")
        '''
        clear_output(wait=True) # Delays clearing the output until new output is available -> reduce flickering

        """Prints formatted transcript text."""
        wrapper = textwrap.TextWrapper(width=60)
        for line in wrapper.wrap(text="".join(text)):
            print(line)

        # Add idle time to prevent throttling
        time.sleep(self.idle_time)


    def update_timestamp_offset(self, last_segment, duration):
        '''
        # If self.transcript is empty, add the last segment received
        if not len(self.transcript):
            self.transcript.append({"text": last_segment + " "})
        # If the last element of self.transcript is != to the last segment, add the last segment
        elif self.transcript[-1]["text"].strip() != last_segment:
            self.transcript.append({"text": last_segment + " "}) #store all finalized transcription
        '''
        # Lock is required because frames_np, frames_offset, timestamp_offset is a shared resource
        # Written as a context manager -> Auto acquires and release lock
        # Update the timestamp_offset as this portion of the buffer has been finalized
        # (doesn't need to be transcribed anymore)
        with self.lock:
            self.timestamp_offset += duration


    def get_eos(self):
        """
        Sets the End of Speech (EOS) flag.

        Args:
            eos (bool): The value to set for the EOS flag.
        """
        self.lock.acquire()
        eos_val = self.eos
        self.lock.release()
        return eos_val

    def set_eos(self, eos):
        """
        Sets the End of Speech (EOS) flag.

        Args:
            eos (bool): The value to set for the EOS flag.
        """
        self.lock.acquire()
        self.eos = eos
        self.lock.release()


    def stop(self):
        '''
        Signals the thread to stop gracefully.
        '''
        print("Stopping transcription thread...")
        self.running = False
        if self.trans_thread.is_alive():
            self.trans_thread.join()  # Wait for the thread to finish





class ClientManager:
    '''
    Custom client manager class to handle clients
    "connected over the WebSocket server" (not necessarily)
    '''
    def __init__(self):
        self.clients = {}

    def add_client(self, websocket, client):
        '''
        Add a WebSocket server connection info (or just a key) and its associated client
        '''
        self.clients[websocket] = client

    def get_client(self, websocket):
        '''
        Retrieve a client associated with the WebSocket server connection info (or just a key) provided
        '''
        if websocket in self.clients:
            return self.clients[websocket]
        return False




class Server:
    def __init__(self, vad_detector, no_voice_activity_threshold=3):
        '''
        In charge of receiving, checking VA and passing audio to client
        '''
        self.client_manager = None
        self.vad_detector = vad_detector # VAD detector
        self.no_voice_activity_chunks = 0 # Help to track prolonged silence
        self.no_voice_activity_threshold = no_voice_activity_threshold


    def recv_audio(self, websocket, audio_array):
        """
        DEPRECATED entry point for transcription
        - kept in case of conversion to websockets
        Handle the new connection
        Continously process audio frames
        """
        # Handle the new connection
        if not self.handle_new_connection(websocket):
            return

        # Continously process audio frames
        while True:
            if not self.process_audio_frames(websocket, audio_array):
                break


    def handle_new_connection(self, websocket, transcriber, transcriber_processor, RATE=16000, idle_time=0):
        '''
        Initialise the client manager
        Initialise the new client and add it to the client manager
        '''

        # Initialise the client manager if not done
        if self.client_manager is None:
            self.client_manager = ClientManager()

        # Initialise the new client and add it to the client manager.
        # The websocket acts as the key, or just use any other key
        self.initialize_client(websocket, transcriber, transcriber_processor, RATE, idle_time)

        return True


    def initialize_client(self, websocket, transcriber, transcriber_processor, RATE=16000, idle_time=0):
        '''
        Initialize the new client and add it to the client manager
        '''
        # Initialize the new client
        client = Client(transcriber, transcriber_processor, RATE, idle_time)

        # Add the client to the client manager
        self.client_manager.add_client(websocket, client)


    def process_audio_frames(self, websocket, audio_array):
        '''
        Responsible for
        - checking voice activity
        - adding audio chunks to the retrieved client's buffer

        Get the audio chunk
        If it has voice activity
        - Reset the no voice activity settings
        - Add the audio chunk to the client's buffer
        - return True
        If it has no voice activity
        - return True
        '''
        # Get the audio chunk as a numpy array
        frame_np, audio_float32 = self.get_audio_from_websocket(websocket, audio_array)

        # Get the client using its associated key
        client = self.client_manager.get_client(websocket)

        # Check for voice activity in the audio chunk
        # - if there is no voice activity return False
        # - if there is prolonged silence (accumulated over multiple chunks), set the eos flag of the client to True
        # - if there is voice activity return True
        voice_active = self.voice_activity(websocket, audio_float32, self.no_voice_activity_threshold)

        # If there is voice activity, reset the
        # - no_voice_activity_chunks
        # - eos flag
        # - add audio chunk to the client's buffer
        if voice_active:
            self.no_voice_activity_chunks = 0
            client.set_eos(False)
            client.add_frames(frame_np)
        return True


    def get_audio_from_websocket(self, websocket, frame_data):
        '''
        Receive audio chunks and create a numpy array out of it
        '''
        # Creates numpy array without copying it (more efficient)
        # MAKE SURE DATA FORMATS ARE CORRECT
        # VAD:
        # - https://github.com/snakers4/silero-vad/blob/master/examples/pyaudio-streaming/pyaudio-streaming-examples.ipynb
        # - https://github.com/snakers4/silero-vad/blob/master/examples/microphone_and_webRTC_integration/microphone_and_webRTC_integration.py

        return ( frame_data, frame_data )


    def voice_activity(self, websocket, frame_np, no_voice_activity_threshold=3):
        '''
        Whenever no voice activity is detected, increment no_voice_activity_chunks

        If the counter is > 3 i.e. prolonged silence:
        - set the end-of-speech flag of the client to True
        - wait for .1 seconds.

        return False for no voice activity

        return True for voice activity
        '''
        if not self.vad_detector(frame_np):
            self.no_voice_activity_chunks += 1
            if self.no_voice_activity_chunks > no_voice_activity_threshold:
                client = self.client_manager.get_client(websocket)
                client.set_eos(True)
                time.sleep(0.1)    # Sleep 100m; wait some voice activity
            return False
        return True

    def shutdown(self):
        '''
        Stops all client threads before shutting down the server.
        '''
        print("Stopping all clients")
        for websocket, client in self.client_manager.clients.items():
            client.stop()
        print("Server shutdown complete.")

In [None]:
##################SETTINGS########################
chunk = 4096 # number of samples to take per read
RATE = 16000 # number of samples taken per second
'''
chunk duration = chunk_size/sampling_rate
number of chunks in a single second of data = sampling_rate/chunk_size
4096/16000 = 0.256s
'''
idle_time=0

vad_model_path = './silero_vad.onnx'
vad_threshold = 0.5
vad_detector = VoiceActivityDetector(vad_model_path,vad_threshold,RATE) # Load the Silero VAD model
no_voice_activity_threshold = 3 # how many consecutive audio chunks constitute a pause/prolonged silence to denote that a segment of speech if fully contextualised/finalized
##################SETTINGS########################




##################INITIALIZATIONS########################
# Initialise Server
server = Server(vad_detector,no_voice_activity_threshold)
server.handle_new_connection("websocket",model,processor,RATE,idle_time) # manual call. Just use "websocket" as a key

# Simulation delay
simulation_delay = chunk/RATE

# Initialize list to store current transcriptions. Global variable
current_prediction = []
##################INITIALIZATIONS########################



##################LOAD AUDIO FILE########################
# Load audio file
audio = AudioSegment.from_file(audio_fname)
audio = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2)  # ensure 16kHz, mono, 16-bit PCM

# Get the length of the audio file
audio_duration_seconds = audio.duration_seconds

# Convert to raw samples
audio_samples = np.array(audio.get_array_of_samples()).astype(np.float32) / 32768

# Add no_voice_activity_threshold chunks of silence to the concatenated audio samples so the lastest segment of audio will be saved
silence = np.zeros(chunk*(no_voice_activity_threshold+5), dtype=np.float32)
audio_samples = np.concatenate((audio_samples, silence))


# Get the number of samples of the concatenated/continous audio
num_audio_samples = len(audio_samples)
# Initialize samples read to be 0
samples_read = 0
##################LOAD AUDIO FILE########################




#################PROCESS##########################
while samples_read < num_audio_samples:

  if samples_read + chunk > num_audio_samples:
    data = audio_samples[samples_read:]
    # FOR TESTING ONLY
    # if the remaining audio array is < 512 samples (required by SileroVAD)
    if len(data) < 512:
      data = np.concatenate((data, np.zeros(512-len(data), dtype=np.float32)))
  else:
    data = audio_samples[samples_read:samples_read+chunk]

  # Update the number of samples read
  samples_read += len(data)
  # Simulate the time the chunk takes to be 'spoken'
  time.sleep(simulation_delay)
  # Directly call process_audio_frames with the new audio chunk in bytes
  server.process_audio_frames("websocket", data)
#################PROCESS##########################

server.shutdown()
del server

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Prediction: you know ya she was from france she shared to me about our culture so uh she would she want to join us she would when i would visit lah when i travel the house looking forward to that it's for them what interesting generation gap what's the biggest similarity or difference between your generation and your parent's generation okay
Reference : you know yeah she was from france so that's she shared with me about her culture lah so you know err she would show me around when she would when i would visit lah when i travel to france looking forward to that it's really not interesting generation gap what's the biggest similarity or difference between your generation and your parent's generation okay


Prediction: would you say it's more positive or negative it will throw lah good for the country in the in having as in globalization it's good for the country okay okay mm never ya okay next  tekam tekam how do you choose four day or total numbers mm okay
Reference : would you say it

<br/>
<br/>
<br/>
<br/>
<br/>

## **Calculate WER**

In [None]:
metric = evaluate.load("wer")
normalizer = BasicTextNormalizer()

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

In [None]:
current_prediction

In [None]:
prediction = ' '.join(current_prediction)

In [None]:
prediction

29.186155285313376

In [None]:
with open(reference_fname, "r") as file:
    reference = file.read().replace('\n', '')

In [None]:
prediction_normalized = normalizer(prediction)
reference_normalized = normalizer(reference)

In [None]:
wer = 100 * metric.compute(predictions=[prediction_normalized], references=[reference_normalized])
print(f'The WER is {wer}%')

7564.655187499998

In [None]:
print(f'The length of the audio is {audio_duration_seconds}')

2.101293107638888

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>