## **Define ASR Model**

In [1]:
!pip install -q datasets bitsandbytes accelerate

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m35.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 MB[0m [31m32.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m108.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m91.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperProcessor
from google.colab import drive
import time

# Google Drive Env Setup
whisper_ver = 'whisper-base'
checkpoint_num = '2100'
google_drive_path = f'/content/drive/My Drive/{whisper_ver}-checkpoints'
drive.mount('/content/drive')
checkpoint_path = f'/content/drive/My Drive/{whisper_ver}-checkpoints/checkpoint-{checkpoint_num}'

# Model setup code for fine-tuned whisper
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_path).to(device)
model.config.use_cache = True
processor = WhisperProcessor.from_pretrained(f"openai/{whisper_ver}", language="en", task="transcribe")

Mounted at /content/drive


preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

## **Define Dataset**

In [1]:
from datasets import load_dataset

# Define dataset for testing
dataset_repo = "johnlohjy/imda_nsc_p3_same_closemic_test"
dataset = load_dataset(dataset_repo, split='test', streaming=True, trust_remote_code=True)
dataset_iter = iter(dataset)
sample = next(dataset_iter)
sample = sample["audio"]

imda_nsc_p3_same_closemic_test.py:   0%|          | 0.00/3.96k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [3]:
sample

{'path': 'waves/3000-1_11.wav',
 'array': array([-9.15527344e-05, -1.22070312e-04, -9.15527344e-05, ...,
         3.05175781e-04,  3.05175781e-04,  2.44140625e-04]),
 'sampling_rate': 16000}

In [4]:
'''
Convert to sample to wav file

import wave
import scipy.io.wavfile as wav
import numpy as np
wav.write('sample.wav', sample["sampling_rate"], (sample["array"] * 32767).astype(np.int16))
'''

In [6]:
'''
Transcribe sample

start_time = time.time()
input_features = processor.feature_extractor(sample["array"], sampling_rate=16000).input_features[0]
input_features = torch.tensor(input_features).unsqueeze(0).to(device)
generated_ids = model.generate(input_features)
predicted_transcription = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
stop_time = time.time()
print(predicted_transcription)
print(f'\n The time taken is {stop_time-start_time}')
'''

you know ya she was from france so that's she shared with me about her culture lah so you know uh she would show me around when she would when i would visit lah when i travel to france looking forward to that it's really not interesting generation gap what's the biggest similarity or difference between your generation and your parent's generation okay

 The time taken is 0.49919724464416504


## **Define VAD model-related dependencies**

In [None]:
import numpy as np
from vad import VoiceActivityDetector

# Provided by Alexander Veysov
def int2float(sound):
    abs_max = np.abs(sound).max()
    sound = sound.astype('float32')
    if abs_max > 0:
        sound *= 1/32768
    sound = sound.squeeze()  # depends on the use case
    return sound

# Load the Silero VAD model
vad_model_path = './silero_vad.onnx'
vad_detector = VoiceActivityDetector(vad_model_path)

## **Define Server-Related Classes, Adapted for Colab**

In [None]:
import json
import numpy as np
import os
import wave
import threading
import textwrap

class Client:
    def __init__(self, transcriber, send_last_n_segments=3):
        '''
        In charge of adding audio chunks, managing the audio buffer, adjusting the amount of audio to transcribe,
        transcribing the audio, printing the latest segment + send_last_n_segments, 

        when a prolonged silence is encountered
        - save the latest (finalized) transcription
        - update timestamp_offset to go past the transcription
        '''
        self.frames_np = None # Store frames buffer as a numpy array
        self.frames_offset = 0.0 # Track frames offset from the very start/Duration of audio discarded
        self.timestamp_offset = 0.0 # Track transcription offset from the very start
        self.send_last_n_segments = send_last_n_segments  # Number of last transcribed segments that will be 'sent' to the client
        self.eos # End-Of-Speech Flag
        self.transcriber = transcriber # Initialize Whisper ASR model
        self.transcript = [] # Store fully transcribed segments
        self.lock = threading.Lock() # for shared resources: frames_np. https://realpython.com/python-thread-lock/#threadinglock-for-primitive-locking

        # SINGLE_MODEL = None # only necessary for multiple clients?
        # SINGLE_MODEL_LOCK = threading.Lock() # only necessary for multiple clients?

        # Thread to run speech-to-text function
        self.trans_thread = threading.Thread(target=self.speech_to_text)
        self.trans_thread.start()


    def add_frames(self, frame_np):
        '''
        Manage the ongoing buffer's size - update frames_offset and timestamp_offset as necessary
        Add new audio chunks to the client's frames buffer
        '''
        # Manage the ongoing buffer
        # Lock required as frames_np, frames_offset, timestamp_offset is a shared resource with the speech_to_text thread code
        # lock the critical section of code below (lock shared resources, no 2 threads can modify)
        self.lock.acquire()
        # If the buffer is more than 45s
        if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
            # Increase frames_offset by 30s
            # Discard oldest 30s of audio from buffer
            self.frames_offset += 30.0
            self.frames_np = self.frames_np[int(30*self.RATE):]
            # Update timestamp_offset
            if self.timestamp_offset < self.frames_offset:
                self.timestamp_offset = self.frames_offset

        # If the frame buffer is empty, initialise it with the new audio frames
        if self.frames_np is None:
            self.frames_np = frame_np.copy()
        # Else, append the new audio chunk to the existing buffer
        else:
            self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
        # unlock the critical section of code above
        self.lock.release()


    def speech_to_text(self):
        '''
        Process audio buffer in an infinite loop, continuously transcribing speech
        - Adjust timestamp_offset s.t there is an appropriate amount of audio to transcribe
        - Get the audio to be transcribed 
        - Transcribe the audio to be transcribed
        - Print the latest segment along with the send_last_n_segments
        
        If the client's end-of-speech flag is True
        (prolonged silence, means that speech segment is finalized),
        - Save the latest speech segment
        - Update the timestamp_offset as this portion of the buffer has been finalized
        (doesn't need to be transcribed anymore)
        '''
        while True:
            # Wait for some chunks to arrive
            if self.frames_np is None:
                time.sleep(0.02)
                continue

            # Adjust the timestamp_offset
            self.clip_audio_if_no_valid_segment()

            # Get the audio to be transcribed using the timestamp_offset
            input_bytes, duration = self.get_audio_chunk_for_processing()
            if duration < 0.4:
                continue

            # Transcribe the audio and
            input_sample = input_bytes.copy()
            self.transcribe_audio(input_sample)


    def clip_audio_if_no_valid_segment(self):
        '''
        If there is > 25s of audio to transcribe in the buffer,
        adjust timestamp_offset such that its only 5s behind the total audio added so far

        If there is <=25s of audio to transcribe, its okay
        '''
        # Lock is required because frames_np, frames_offset, timestamp_offset is a shared resource
        # Written as a context manager -> Auto acquires and release lock
        with self.lock:
            # If there is more than 25s of audio to transcribe
            if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE:
                # Adjust timestamp_offset s.t it is only 5s behind the total audio added so far
                duration = self.frames_np.shape[0] / self.RATE
                self.timestamp_offset = self.frames_offset + duration - 5


    def get_audio_chunk_for_processing(self):
        '''
        Get the audio to be transcribed from the buffer calculated using timestamp_offset
        '''
        # Use timestamp_offset to help subset the buffer to get the audio to be transcribed
        # Lock is required because frames_np, frames_offset, timestamp_offset is a shared resource
        # Written as a context manager -> Auto acquires and release lock
        with self.lock:
            samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
            input_bytes = self.frames_np[int(samples_take):].copy()
        duration = input_bytes.shape[0] / self.RATE
        return input_bytes, duration


    def transcribe_audio(self, input_bytes, duration):
        '''
        Transcribe the audio to be transcribed
        '''
        # Transcribe the audio to be transcribed
        last_segment = self.transcriber.transcribe(input_bytes)
        # Print the transcriptions to be printed and
        # if there is prolonged silence,
        # - Save the latest speech segment
        # - Update the timestamp_offset
        self.handle_transcription_output(last_segment, duration)


    def handle_transcription_output(self, last_segment, duration):
        '''
        Print the transcriptions to be printed
        If the client's end-of-speech flag is True
        (prolonged silence, means that speech segment is finalized),
        - Save the latest speech segment
        - Update the timestamp_offset
        '''
        segments = self.prepare_segments({"text": last_segment})
        self.send_transcription_to_client(segments)
        if self.eos:
            self.update_timestamp_offset(last_segment, duration)


    def prepare_segments(self, last_segment=None):
        '''
        Prepare the segments to be printed
        '''
        segments = []
        # If the length of self.transcript is more than or equal send_last_n_segments,
        # set segments to be the last send_last_n_segments number of elements of self.transcript
        if len(self.transcript) >= self.send_last_n_segments:
            segments = self.transcript[-self.send_last_n_segments:].copy()
        # If not, just set segments to be self.transcript
        else:
            segments = self.transcript.copy()
        # If the segment that was just transcribed is not None, add it to segments
        if last_segment is not None:
            segments = segments + [last_segment]
        return segments


    def send_transcription_to_client(self, text):
        '''
        Print the transcription for testing on Colab
        '''
        wrapper = textwrap.TextWrapper(width=60)
        for line in wrapper.wrap(text="".join(text)):
            print(line)


    def update_timestamp_offset(self, last_segment, duration):
        # If self.transcript is empty, add the last segment received
        if not len(self.transcript):
            self.transcript.append({"text": last_segment + " "})
        # If the last element of self.transcript is != to the last segment, add the last segment
        elif self.transcript[-1]["text"].strip() != last_segment:
            self.transcript.append({"text": last_segment + " "})
        # Lock is required because frames_np, frames_offset, timestamp_offset is a shared resource
        # Written as a context manager -> Auto acquires and release lock
        # Update the timestamp_offset as this portion of the buffer has been finalized
        # (doesn't need to be transcribed anymore)
        with self.lock:
            self.timestamp_offset += duration




class ClientManager:
    '''
    Custom client manager class to handle clients connected over the
    WebSocket server
    '''
    def __init__(self):
        self.clients = {}

    def add_client(self, websocket, client):
        '''
        Add a WebSocket server connection info and its associated client
        '''
        self.clients[websocket] = client

    def get_client(self, websocket):
        '''
        Retrieve a client associated with the WebSocket server connection info provided
        '''
        if websocket in self.clients:
            return self.clients[websocket]
        return False




class Server:
    def __init__(self, vad_detector):
        '''
        In charge of receiving, checking VA and passing audio to client
        '''
        self.client_manager = None
        self.vad_detector = vad_detector # VAD detector
        self.no_voice_activity_chunks = 0 # Help to track prolonged silence
        self.no_voice_activity_threshold = 3


    def recv_audio(self,websocket):
        """
        Handle the new connection
        Continously process audio frames

        todo: when writing the loop, pass "websocket" as the websocket argument
        """
        # Handle the new connection
        if not self.handle_new_connection(websocket):
            return

        # Continously process audio frames
        while True:
            if not self.process_audio_frames(websocket):
                break


    def handle_new_connection(self,websocket):
        '''
        Initialise the client manager
        Set the VAD
        Initialise the new client and add it to the client manager
        '''

        # Initialise the client manager if not done
        if self.client_manager is None:
            self.client_manager = ClientManager()

        # Initialise the new client and add it to the client manager
        # todo: Change websocket key
        self.initialize_client(websocket)

        return True


    def initialize_client(self, websocket):
        '''
        Initialize the new client and add it to the client manager
        '''
        # Initialize the new client
        client = Client()

        # Add the client to the client manager
        self.client_manager.add_client(websocket, client)


    def process_audio_frames(self, websocket):
        '''
        Get the audio chunk from the WebSocket
        If it has voice activity
        - Reset the no voice activity settings
        - Add the audio chunk to the client's buffer
        - return True
        If it has no voice activity
        - return True
        '''
        # Get the audio chunk from the WebSocket as a numpy array
        # todo: Change websocket arg to receive from audio file in a loop to be used in colab
        frame_np, audio_float32 = self.get_audio_from_websocket(websocket)

        # Get the client using its associated WebSocket
        # todo: Change key used in colab
        client = self.client_manager.get_client(websocket)

        # Check for voice activity in the audio chunk
        # - if there is no voice activity return False
        # - if there is prolonged silence (accumulated), set the eos flag of the client to True
        # - if there is voice activity return True
        voice_active = self.voice_activity(websocket, audio_float32, self.no_voice_activity_threshold)

        # If there is voice activity, reset the
        # - no_voice_activity_chunks
        # - eos flag
        # - add audio chunk to the client's buffer
        if voice_active:
            self.no_voice_activity_chunks = 0
            client.set_eos(False)
            client.add_frames(frame_np)
        return True


    def get_audio_from_websocket(self, websocket):
        '''
        Receive audio chunks from the WebSocket and create a numpy array out of it
        '''
        # Subsequently, receive audio data (message) over the WebSocket server connection
        # todo: change the way audio chunk is received for use in colab
        frame_data = websocket.recv()
        # Creates numpy array without copying it (more efficient)
        return ( np.frombuffer(frame_data, dtype=np.float32), int2float(np.frombuffer(frame_data, np.int16)) )


    def voice_activity(self, websocket, frame_np, no_voice_activity_threshold=3):
        '''
        Whenever no voice activity is detected, increment no_voice_activity_chunks

        If the counter is > 3 i.e. prolonged silence:
        - set the end-of-speech flag of the client to True
        - wait for .1 seconds. todo: is this needed

        return False for no voice activity

        return True for voice activity
        '''
        if not self.vad_detector(frame_np):
            self.no_voice_activity_chunks += 1
            if self.no_voice_activity_chunks > no_voice_activity_threshold:
                client = self.client_manager.get_client(websocket)
                if not client.eos:
                    client.set_eos(True)
                time.sleep(0.1)    # Sleep 100m; wait some voice activity
            return False
        return True

## **Loop to simulate sending server audio data (from file) and Server sending client transcription (print)**