## **Define ASR Model and dataset**

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import time

# Model setup code for distil-whisper small
model_id = "distil-whisper/distil-small.en"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    torch_dtype=torch_dtype,
    device=device,
)

# Define dataset for testing
dataset_repo = "johnlohjy/imda_nsc_p3_same_closemic_train"
dataset = load_dataset(dataset_repo, split='train', streaming=True, trust_remote_code=True)
dataset_iter = iter(dataset)
sample = next(dataset_iter)
sample = sample["audio"]

## **Define Server-Related Classes, Adapted for Colab**

In [None]:
import json
import numpy as np
import os
import wave
import threading

class Client:
    def __init__(self):
        self.frames_np = None # To store frames buffer as a numpy array
        self.timestamp_offset = 0.0 # Track transcription offset from the very start
        self.frames_offset = 0.0 # Track frames offset from the very start/Duration of audio discarded
        self.send_last_n_segments = 10  # Number of transcribed segments that will be 'sent' to the client
        # self.lock = threading.Lock() # for shared resources        
        # SINGLE_MODEL_LOCK = threading.Lock() # Check whats this 
        self.eos # End-Of-Speech Flag
        self.transcriber = None # Initialize with Whisper ASR model 

        # Thread to run speech-to-text function
        self.trans_thread = threading.Thread(target=self.speech_to_text)
        self.trans_thread.start()

    def add_frames(self, frame_np):
        '''
        Add new audio chunks to frames buffer

        Check if need lock to implement lock
        '''
        # Manage the ongoing buffer
        # If the buffer is more than 45s
        if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE:
            # Increase frames_offset by 30s
            # Discard oldest 30s of audio from buffer
            self.frames_offset += 30.0
            self.frames_np = self.frames_np[int(30*self.RATE):]
            
            # Update timestamp_offset
            if self.timestamp_offset < self.frames_offset:
                self.timestamp_offset = self.frames_offset

        # If the frame buffer is empty, initialise it with the new audio frames
        if self.frames_np is None:
            self.frames_np = frame_np.copy()
        # Else, append the new audio chunk to the existing buffer
        else:
            self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)

    def save_frames(self):
        '''
        Sample code to save the audio when client disconnects
        '''
        fp = os.path.join(os.getcwd(), "test_frames.wav")
        with wave.open(fp, "wb") as wavfile:
            wavfile: wave.Wave_write
            wavfile.setnchannels(1)
            wavfile.setsampwidth(2)
            wavfile.setframerate(16000)
            wavfile.writeframes(self.frames_np)
    
    def clip_audio_if_no_valid_segment(self):
        '''
        If there is > 25s of audio to transcribe in the buffer,
        adjust timestamp_offset such that its only 5s behind the total audio added so far
        '''
        pass

    def get_audio_chunk_for_processing(self):
        pass 

    def transcribe_audio(self):
        pass

    def speech_to_text(self):
        pass




class ClientManager:
    '''
    Custom client manager class to handle clients connected over the 
    WebSocket server
    '''
    def __init__(self):
        self.clients = {}

    def add_client(self, websocket, client):
        '''
        Add a WebSocket server connection info and its associated client
        '''
        self.clients[websocket] = client

    def get_client(self, websocket):
        '''
        Retrieve a client associated with the WebSocket server connection info provided
        '''
        if websocket in self.clients:
            return self.clients[websocket]
        return False 




class Server:
    '''
    Server class handles
    - New client connections
    - Receiving and processing audio from client
    '''
    def __init__(self):
        self.client_manager = None
        self.vad_detector = None #todo: VAD detector
        self.no_voice_activity_chunks = 0 # Help to track prolonged silence

    def initialize_client(self, websocket):
        '''
        Initialize the new client and add it to the client manager

        EXPAND ON CLIENT CLASS (see ServeClientTensorRT, ServeClientBase)
        '''
        # Initialize the client
        client = Client()

        # Add the client to the client manager
        self.client_manager.add_client(websocket, client)

    def handle_new_connection(self,websocket):
        '''
        Initialise the client manager
        Initialise the new client and add it to the client manager
        '''

        # Initialise the client manager if not done
        if self.client_manager is None:
            self.client_manager = ClientManager()

        # todo: Use a Voice Activity Detector
        # self.vad_detector = VoiceActivityDetector()

        # Initialise the new client and add it to the client manager
        self.initialize_client(websocket)

        return True

    def get_audio_from_websocket(self, websocket):
        '''
        Receive audio chunks from the WebSocket and create a numpy array out of it
        '''
        # Subsequently, receive audio data (message) over the WebSocket server connection
        # https://websockets.readthedocs.io/en/stable/reference/sync/server.html#websockets.sync.server.ServerConnection.recv
        frame_data = websocket.recv()

        # Creates numpy array without copying it (more efficient)
        return np.frombuffer(frame_data, dtype=np.float32)

    def voice_activity(self, websocket, frame_np):
        '''
        todo: Add in websocket argument
        threshold config for no voice activity chunks
        '''

        '''
        Whenever no voice activity is detected, increment the counter no_voice_activity_chunks
        
        If the counter is > 3 i.e. prolonged silence: 
        - set the end of speech flag of the client to True 
        - wait for .1 seconds -> ??? is this needed
        - return False for no voice activity
        return True for voice activity
        '''
        if not self.vad_detector(frame_np):
            self.no_voice_activity_chunks += 1
            if self.no_voice_activity_chunks > 3:
                client = self.client_manager.get_client(websocket)
                if not client.eos:
                    client.set_eos(True)
                time.sleep(0.1)    # Sleep 100m; wait some voice activity. todo-> check if this needed
            return False
        return True

    def process_audio_frames(self, websocket):
        '''
        Get the audio chunk from the WebSocket as a numpy array

        Send a dummy transcription back to the client first

        TOCHECK: VAD
        '''
        # Get the audio chunk from the WebSocket as a numpy array
        # todo: Change to receive from audio file in a loop to be used in colab
        frame_np = self.get_audio_from_websocket(websocket)
        
        # Get the client using its associated WebSocket
        # todo: Change key to be used in colab
        client = self.client_manager.get_client(websocket)
        
        # Check for voice activity in the audio chunk
        # If there is no voice activity return False
        # If there is prolonged silence, set the eos flag of the client to False
        # If there is voice activity return True
        voice_active = self.voice_activity(websocket, frame_np)

        # If there is voice activity, reset the
        # - no_voice_activity_chunks
        # - eos flag
        if voice_active:
            self.no_voice_activity_chunks = 0
            client.set_eos(False)
        # If there is no voice activity, return True
        else:
            return True
        
        # If there is voice activity,
        # add frames to the client's buffer and return True
        client.add_frames(frame_np)
        return True

    def recv_audio(self,websocket):
        """
        First handle the new connection

        Continously process audio frames
        """

        # Try to handle the new connection
        if not self.handle_new_connection(websocket):
            return
        
        # Continously process audio frames
        while True: 
            if not self.process_audio_frames(websocket):
                break

## **Loop to simulate sending server audio data (from file) and Server sending client transcription (print)**