## **Define ASR Model**

In [None]:
!pip install -q datasets bitsandbytes accelerate

In [None]:
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from google.colab import drive
import time

# Google Drive Env Setup
whisper_ver = 'whisper-base'
checkpoint_num = '2100'
google_drive_path = f'/content/drive/My Drive/{whisper_ver}-checkpoints'
drive.mount('/content/drive')
checkpoint_path = f'/content/drive/My Drive/{whisper_ver}-checkpoints/checkpoint-{checkpoint_num}'

# Model setup code for fine-tuned whisper
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_path).to(device)
model.config.use_cache = True
processor = WhisperProcessor.from_pretrained(f"openai/{whisper_ver}", language="en", task="transcribe")

## **Define VAD model-related dependencies**

In [None]:
!pip install onnxruntime

In [None]:
import torch
import warnings
import numpy as np
import onnxruntime


class OnnxWrapper():
    '''
    Code taken from: https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/utils_vad.py

    See https://github.com/collabora/WhisperLive/blob/main/whisper_live/vad.py
    '''
    def __init__(self, path, force_onnx_cpu=False):
        opts = onnxruntime.SessionOptions()
        opts.inter_op_num_threads = 1
        opts.intra_op_num_threads = 1

        if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
            self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
        else:
            self.session = onnxruntime.InferenceSession(path, sess_options=opts)

        self.reset_states()
        if '16k' in path:
            warnings.warn('This model support only 16000 sampling rate!')
            self.sample_rates = [16000]
        else:
            self.sample_rates = [8000, 16000]

    def _validate_input(self, x, sr: int):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        if x.dim() > 2:
            raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")

        if sr != 16000 and (sr % 16000 == 0):
            step = sr // 16000
            x = x[:,::step]
            sr = 16000

        if sr not in self.sample_rates:
            raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
        if sr / x.shape[1] > 31.25:
            raise ValueError("Input audio chunk is too short")

        return x, sr

    def reset_states(self, batch_size=1):
        self._state = torch.zeros((2, batch_size, 128)).float()
        self._context = torch.zeros(0)
        self._last_sr = 0
        self._last_batch_size = 0

    def __call__(self, x, sr: int):

        x, sr = self._validate_input(x, sr)
        num_samples = 512 if sr == 16000 else 256

        if x.shape[-1] != num_samples:
            raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")

        batch_size = x.shape[0]
        context_size = 64 if sr == 16000 else 32

        if not self._last_batch_size:
            self.reset_states(batch_size)
        if (self._last_sr) and (self._last_sr != sr):
            self.reset_states(batch_size)
        if (self._last_batch_size) and (self._last_batch_size != batch_size):
            self.reset_states(batch_size)

        if not len(self._context):
            self._context = torch.zeros(batch_size, context_size)

        x = torch.cat([self._context, x], dim=1)
        if sr in [8000, 16000]:
            ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
            ort_outs = self.session.run(None, ort_inputs)
            out, state = ort_outs
            self._state = torch.from_numpy(state)
        else:
            raise ValueError()

        self._context = x[..., -context_size:]
        self._last_sr = sr
        self._last_batch_size = batch_size

        out = torch.from_numpy(out)
        return out

    def audio_forward(self, x, sr: int):
        outs = []
        x, sr = self._validate_input(x, sr)
        self.reset_states()
        num_samples = 512 if sr == 16000 else 256

        if x.shape[1] % num_samples:
            pad_num = num_samples - (x.shape[1] % num_samples)
            x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)

        for i in range(0, x.shape[1], num_samples):
            wavs_batch = x[:, i:i+num_samples]
            out_chunk = self.__call__(wavs_batch, sr)
            outs.append(out_chunk)

        stacked = torch.cat(outs, dim=1)
        return stacked.cpu()




class VoiceActivityDetector:
    '''
    See https://github.com/collabora/WhisperLive/blob/main/whisper_live/vad.py

    Onxx model:
    '''
    def __init__(self, path, threshold=0.5, frame_rate=16000):
        """
        Initializes the VoiceActivityDetector with a voice activity detection model and a threshold.

        Args:
            threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5.
        """
        self.model = OnnxWrapper(path=path,force_onnx_cpu=True)
        self.threshold = threshold
        self.frame_rate = frame_rate

    def __call__(self, audio_frame):
        """
        Determines if the given audio frame contains speech by comparing the detected speech probability against
        the threshold.

        Args:
            audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a
                                      NumPy array of audio samples.

        Returns:
            bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
                  False otherwise.
        """
        speech_probs = self.model.audio_forward(torch.from_numpy(audio_frame.copy()), self.frame_rate)[0]
        return torch.any(speech_probs > self.threshold).item()

## **Define Server-Related Classes**

In [None]:
import json
import numpy as np
import os
import wave
import threading
import textwrap
from IPython.display import clear_output

# Helper function Provided by Alexander Veysov
# See
# - https://github.com/snakers4/silero-vad/blob/master/examples/pyaudio-streaming/pyaudio-streaming-examples.ipynb
# Or (different implementations)
# - https://github.com/snakers4/silero-vad/blob/master/examples/microphone_and_webRTC_integration/microphone_and_webRTC_integration.py 
def int2float(sound):
    abs_max = np.abs(sound).max()
    sound = sound.astype('float32')
    if abs_max > 0:
        sound *= 1/32768
    sound = sound.squeeze()  # depends on the use case
    return sound


class Client:
    def __init__(self, transcriber, transcriber_processor, send_last_n_segments=3, RATE=16000):
        '''
        In charge of adding audio chunks, managing the audio buffer, adjusting the amount of audio to transcribe,
        transcribing the audio, printing the latest segment + send_last_n_segments,

        when a prolonged silence is encountered
        - save the latest (finalized) transcription
        - update timestamp_offset to go past the transcription
        '''
        self.frames_np = None # Store frames buffer as a numpy array
        self.frames_offset = 0.0 # Track frames offset from the very start/Duration of audio discarded
        self.timestamp_offset = 0.0 # Track transcription offset from the very start. Transcription progress tracker
        self.send_last_n_segments = send_last_n_segments  # Number of last transcribed segments that will be 'sent' to the client
        self.eos = False # End-Of-Speech Flag
        self.transcriber = transcriber # Initialize Whisper ASR model
        self.transcriber_processor = transcriber_processor # Initialize Whisper ASR processor
        self.transcript = [] # Store fully transcribed segments
        self.lock = threading.Lock() # for shared resources: frames_np. https://realpython.com/python-thread-lock/#threadinglock-for-primitive-locking
        self.RATE = RATE

        # Thread to run speech-to-text function
        '''
        Main "Entry point" for client
        '''
        self.running = True
        self.trans_thread = threading.Thread(target=self.speech_to_text)
        self.trans_thread.start()


    def add_frames(self, frame_np):
        '''
        Manage the ongoing buffer's size - update frames_offset and timestamp_offset as necessary
        Add new audio chunks to the client's frames buffer
        '''
        # Manage the ongoing buffer
        # Lock required as frames_np, frames_offset, timestamp_offset is a shared resource with the speech_to_text thread code
        # lock the critical section of code below (lock shared resources, no 2 threads can modify)
        self.lock.acquire()
        # If the buffer is more than 45s
        if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
            # Increase frames_offset by 30s
            # Discard oldest 30s of audio from buffer
            self.frames_offset += 30.0
            self.frames_np = self.frames_np[int(30*self.RATE):]
            # Update timestamp_offset
            if self.timestamp_offset < self.frames_offset:
                self.timestamp_offset = self.frames_offset

        # If the frame buffer is empty, initialise it with the new audio frames
        if self.frames_np is None:
            self.frames_np = frame_np.copy()
        # Else, append the new audio chunk to the existing buffer
        else:
            self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
        # unlock the critical section of code above
        self.lock.release()


    def speech_to_text(self):
        '''
        Process audio buffer in an infinite loop, continuously transcribing speech
        - Adjust timestamp_offset s.t there is an appropriate amount of audio to transcribe
        - Get the audio to be transcribed
        - Transcribe the audio to be transcribed
        - Print the latest segment along with the send_last_n_segments

        If the client's end-of-speech flag is True
        (prolonged silence, means that speech segment is finalized),
        - Save the latest speech segment
        - Update the timestamp_offset as this portion of the buffer has been finalized
        (doesn't need to be transcribed anymore)
        '''
        while self.running:
            # Wait for some chunks to arrive
            if self.frames_np is None:
                time.sleep(0.02)
                continue

            # Adjust the timestamp_offset/transcription tracker
            self.clip_audio_if_no_valid_segment()

            # Get the audio to be transcribed using the timestamp_offset
            input_bytes, duration = self.get_audio_chunk_for_processing()
            if duration < 0.4:
                continue

            # Transcribe the audio, print the transcriptions to be printed,
            # Update the timestamp_offset if prolonged silence was encountered
            # - it means the latest segment is finalized
            input_sample = input_bytes.copy()
            self.transcribe_audio(input_sample, duration)


    def clip_audio_if_no_valid_segment(self):
        '''
        If there is > 25s of audio to transcribe in the buffer,
        adjust timestamp_offset such that its only 5s behind the total audio added so far

        If there is <=25s of audio to transcribe, its okay
        '''
        # Lock is required because frames_np, frames_offset, timestamp_offset is a shared resource
        # Written as a context manager -> Auto acquires and release lock
        with self.lock:
            # If there is more than 25s of audio to transcribe
            if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE:
                # Adjust timestamp_offset s.t it is only 5s behind the total audio added so far
                duration = self.frames_np.shape[0] / self.RATE
                self.timestamp_offset = self.frames_offset + duration - 5


    def get_audio_chunk_for_processing(self):
        '''
        Get the audio to be transcribed from the buffer calculated using timestamp_offset
        '''
        # Use timestamp_offset to help subset the buffer to get the audio to be transcribed
        # Lock is required because frames_np, frames_offset, timestamp_offset is a shared resource
        # Written as a context manager -> Auto acquires and release lock
        with self.lock:
            samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
            input_bytes = self.frames_np[int(samples_take):].copy()
        duration = input_bytes.shape[0] / self.RATE
        return input_bytes, duration


    def transcribe_audio(self, input_bytes, duration):
        '''
        Transcribe the audio to be transcribed
        '''
        # Transcribe the audio to be transcribed
        # last_segment = self.transcriber.transcribe(input_bytes)
        input_features = self.transcriber_processor.feature_extractor(input_bytes, sampling_rate=16000).input_features[0]
        input_features = torch.tensor(input_features).unsqueeze(0).to(device)
        generated_ids = self.transcriber.generate(input_features)
        last_segment = self.transcriber_processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)

        # Print the transcriptions to be printed and
        # if there is prolonged silence,
        # - Save the latest speech segment
        # - Update the timestamp_offset s.t the portion is completed/does not need to be transcribed anymore
        self.handle_transcription_output(last_segment, duration)


    def handle_transcription_output(self, last_segment, duration):
        '''
        Print the transcriptions to be printed
        If the client's end-of-speech flag is True
        (prolonged silence, means that speech segment is finalized),
        - Save the latest speech segment
        - Update the timestamp_offset
        '''
        segments = self.prepare_segments({"text": last_segment})
        self.send_transcription_to_client(segments)
        if self.eos:
            self.update_timestamp_offset(last_segment, duration)


    def prepare_segments(self, last_segment=None):
        '''
        Prepare the segments to be printed
        '''
        segments = []
        # If the length of self.transcript is more than or equal send_last_n_segments,
        # set segments to be the last send_last_n_segments number of elements of self.transcript
        if len(self.transcript) >= self.send_last_n_segments:
            segments = self.transcript[-self.send_last_n_segments:].copy()
        # If not, just set segments to be self.transcript
        else:
            segments = self.transcript.copy()
        # If the segment that was just transcribed is not None, add it to segments
        if last_segment is not None:
            segments = segments + [last_segment]
        return segments


    def send_transcription_to_client(self, segments):
        '''
        Print the transcription
        '''
        """Processes transcript segments."""
        text = []
        for i, seg in enumerate(segments):
            # If the text list is empty or the latest element in the text list is != to the current text segment being processed
            if not text or text[-1] != seg["text"]:
                # Append the current text segment to the text list
                text.append(seg["text"])

        '''
        For cmd line like outputs

        os.system("cls" if os.name == "nt" else "clear")
        '''
        clear_output(wait=True) # Delays clearing the output until new output is available -> reduce flickering

        """Prints formatted transcript text."""
        wrapper = textwrap.TextWrapper(width=60)
        for line in wrapper.wrap(text="".join(text)):
            print(line)
        



    def update_timestamp_offset(self, last_segment, duration):
        # If self.transcript is empty, add the last segment received
        if not len(self.transcript):
            self.transcript.append({"text": last_segment + " "})
        # If the last element of self.transcript is != to the last segment, add the last segment
        elif self.transcript[-1]["text"].strip() != last_segment:
            self.transcript.append({"text": last_segment + " "}) #store all finalized transcription
        # Lock is required because frames_np, frames_offset, timestamp_offset is a shared resource
        # Written as a context manager -> Auto acquires and release lock
        # Update the timestamp_offset as this portion of the buffer has been finalized
        # (doesn't need to be transcribed anymore)
        with self.lock:
            self.timestamp_offset += duration


    def set_eos(self, eos):
        """
        Sets the End of Speech (EOS) flag.

        Args:
            eos (bool): The value to set for the EOS flag.
        """
        self.lock.acquire()
        self.eos = eos
        self.lock.release()


    def stop(self):
        '''
        Signals the thread to stop gracefully.
        '''
        print("Stopping transcription thread...")
        self.running = False
        if self.trans_thread.is_alive():
            self.trans_thread.join()  # Wait for the thread to finish




class ClientManager:
    '''
    Custom client manager class to handle clients 
    "connected over the WebSocket server" (not necessarily)
    '''
    def __init__(self):
        self.clients = {}

    def add_client(self, websocket, client):
        '''
        Add a WebSocket server connection info (or just a key) and its associated client
        '''
        self.clients[websocket] = client

    def get_client(self, websocket):
        '''
        Retrieve a client associated with the WebSocket server connection info (or just a key) provided
        '''
        if websocket in self.clients:
            return self.clients[websocket]
        return False




class Server:
    def __init__(self, vad_detector, no_voice_activity_threshold=3):
        '''
        In charge of receiving, checking VA and passing audio to client
        '''
        self.client_manager = None
        self.vad_detector = vad_detector # VAD detector
        self.no_voice_activity_chunks = 0 # Help to track prolonged silence
        self.no_voice_activity_threshold = no_voice_activity_threshold


    def recv_audio(self, websocket, audio_array):
        """
        DEPRECATED entry point for transcription
        - kept in case of conversion to websockets
        Handle the new connection
        Continously process audio frames
        """
        # Handle the new connection
        if not self.handle_new_connection(websocket):
            return

        # Continously process audio frames
        while True:
            if not self.process_audio_frames(websocket, audio_array):
                break


    def handle_new_connection(self, websocket, transcriber, transcriber_processor, send_last_n_segments=3, RATE=16000):
        '''
        Initialise the client manager
        Initialise the new client and add it to the client manager
        '''

        # Initialise the client manager if not done
        if self.client_manager is None:
            self.client_manager = ClientManager()

        # Initialise the new client and add it to the client manager.
        # The websocket acts as the key, or just use any other key
        self.initialize_client(websocket, transcriber, transcriber_processor, send_last_n_segments, RATE)

        return True


    def initialize_client(self, websocket, transcriber, transcriber_processor, send_last_n_segments=3, RATE=16000):
        '''
        Initialize the new client and add it to the client manager
        '''
        # Initialize the new client
        client = Client(transcriber, transcriber_processor, send_last_n_segments, RATE)

        # Add the client to the client manager
        self.client_manager.add_client(websocket, client)


    def process_audio_frames(self, websocket, audio_array):
        '''
        Responsible for
        - checking voice activity
        - adding audio chunks to the retrieved client's buffer

        Get the audio chunk
        If it has voice activity
        - Reset the no voice activity settings
        - Add the audio chunk to the client's buffer
        - return True
        If it has no voice activity
        - return True
        '''
        # Get the audio chunk as a numpy array
        frame_np, audio_float32 = self.get_audio_from_websocket(websocket, audio_array)

        # Get the client using its associated key
        client = self.client_manager.get_client(websocket)

        # Check for voice activity in the audio chunk
        # - if there is no voice activity return False
        # - if there is prolonged silence (accumulated over multiple chunks), set the eos flag of the client to True
        # - if there is voice activity return True
        voice_active = self.voice_activity(websocket, audio_float32, self.no_voice_activity_threshold)

        # If there is voice activity, reset the
        # - no_voice_activity_chunks
        # - eos flag
        # - add audio chunk to the client's buffer
        if voice_active:
            self.no_voice_activity_chunks = 0
            client.set_eos(False)
            client.add_frames(frame_np)
        return True


    def get_audio_from_websocket(self, websocket, frame_data):
        '''
        Receive audio chunks and create a numpy array out of it
        '''
        # Creates numpy array without copying it (more efficient)
        # MAKE SURE DATA FORMATS ARE CORRECT
        # VAD:
        # - https://github.com/snakers4/silero-vad/blob/master/examples/pyaudio-streaming/pyaudio-streaming-examples.ipynb
        # - https://github.com/snakers4/silero-vad/blob/master/examples/microphone_and_webRTC_integration/microphone_and_webRTC_integration.py
        return ( np.frombuffer(int2float(np.frombuffer(frame_data, dtype=np.int16)), dtype=np.float32), int2float(np.frombuffer(frame_data, dtype=np.int16)) )


    def voice_activity(self, websocket, frame_np, no_voice_activity_threshold=3):
        '''
        Whenever no voice activity is detected, increment no_voice_activity_chunks

        If the counter is > 3 i.e. prolonged silence:
        - set the end-of-speech flag of the client to True
        - wait for .1 seconds.

        return False for no voice activity

        return True for voice activity
        '''
        if not self.vad_detector(frame_np):
            self.no_voice_activity_chunks += 1
            if self.no_voice_activity_chunks > no_voice_activity_threshold:
                client = self.client_manager.get_client(websocket)
                if not client.eos:
                    client.set_eos(True)
                time.sleep(0.1)    # Sleep 100m; wait some voice activity
            return False
        return True

    def shutdown(self):
        '''
        Stops all client threads before shutting down the server.
        '''
        print("Stopping all clients")
        for websocket, client in self.client_manager.clients.items():
            client.stop()
        print("Server shutdown complete.")

## **Loop to simulate live transcription (from a .wav file)**

In [None]:
### SETTINGS ###

chunk = 4096 # number of samples to take per read 
RATE = 16000 # number of samples taken per second
'''
chunk duration = chunk_size/sampling_rate
number of chunks in a single second of data = sampling_rate/chunk_size
4096/16000 = 0.256s
'''

send_last_n_segments = 3 # number of previous segments to display along with the current segment

vad_model_path = './silero_vad.onnx'
vad_threshold = 0.5 
vad_detector = VoiceActivityDetector(vad_model_path,vad_threshold,RATE) # Load the Silero VAD model

no_voice_activity_threshold = 3 # how many consecutive audio chunks constitute a pause/prolonged silence to denote that a segment of speech if fully contextualised/finalized

### SETTINGS ###

# Initialise Server
server = Server(vad_detector,no_voice_activity_threshold)
server.handle_new_connection("websocket",model,processor,send_last_n_segments,RATE) # manual call. Just use "websocket" as a key

# Simulation delay
simulation_delay = chunk/RATE

filename = "./sample.wav"
with wave.open(filename, "rb") as wavfile:
    try:
        while True:
            data = wavfile.readframes(chunk)
            time.sleep(simulation_delay) # to simulate the time the chunk takes to be 'spoken' 
            if data == b"":
                break
            # Directly call process_audio_frames with the new audio chunk in bytes
            server.process_audio_frames("websocket", data) 
        server.shutdown()
        wavfile.close()
    except KeyboardInterrupt:
        server.shutdown()
        wavfile.close()
        print("KeyboardInterrupt")