## **Define ASR Model and dataset**

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import time

# Model setup code for distil-whisper small
model_id = "distil-whisper/distil-small.en"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    torch_dtype=torch_dtype,
    device=device,
)

# Define dataset for testing
dataset_repo = "johnlohjy/imda_nsc_p3_same_closemic_train"
dataset = load_dataset(dataset_repo, split='train', streaming=True, trust_remote_code=True)
dataset_iter = iter(dataset)
sample = next(dataset_iter)
sample = sample["audio"]

## **Define Server-Related Classes, Adapted for Colab**

In [None]:
import json
import numpy as np
import os
import wave

class Client:
    def __init__(self):
        self.frames_np = None # To store frames buffer as a numpy array
        self.timestamp_offset = 0.0 # Track transcription offset from the very start
        self.frames_offset = 0.0 # Track frames offset from the very start/Duration of audio discarded
        self.send_last_n_segments = 10  # Number of transcribed segments that will be 'sent' to the client
        


    def add_frames(self, frame_np):
        '''
        Add new audio chunks to frames buffer

        Check if need lock in the future
        '''
        if self.frames_np is None:
            # If the frames buffer is empty, initialise it with the new audio frames
            self.frames_np = frame_np.copy()
        else:
            # Append the new audio chunk to the existing frames buffer
            self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)

    def save_frames(self):
        '''
        Sample code to save the audio when client disconnects
        '''
        fp = os.path.join(os.getcwd(), "test_frames.wav")
        with wave.open(fp, "wb") as wavfile:
            wavfile: wave.Wave_write
            wavfile.setnchannels(1)
            wavfile.setsampwidth(2)
            wavfile.setframerate(16000)
            wavfile.writeframes(self.frames_np)





class ClientManager:
    '''
    Custom client manager class to handle clients connected over the 
    WebSocket server
    '''
    def __init__(self):
        self.clients = {}

    def add_client(self, websocket, client):
        '''
        Add a WebSocket server connection info and its associated client
        '''
        self.clients[websocket] = client

    def get_client(self, websocket):
        '''
        Retrieve a client associated with the WebSocket server connection info provided
        '''
        if websocket in self.clients:
            return self.clients[websocket]
        return False 




class Server:
    '''
    Server class handles
    - New client connections
    - Receiving and processing audio from client
    '''
    def __init__(self):
        self.client_manager = None




    def initialize_client(self, websocket):
        '''
        Initialize the new client and add it to the client manager

        EXPAND ON CLIENT CLASS (see ServeClientTensorRT, ServeClientBase)
        '''
        # Initialize the client
        client = Client()

        # Add the client to the client manager
        self.client_manager.add_client(websocket, client)




    def handle_new_connection(self,websocket):
        '''
        Initialise the client manager
        Initialise the new client and add it to the client manager
        '''

        # Initialise the client manager if not done
        if self.client_manager is None:
            self.client_manager = ClientManager()

        # Initialise the new client and add it to the client manager
        self.initialize_client(websocket)

        return True




    def get_audio_from_websocket(self, websocket):
        '''
        Receive audio chunks from the WebSocket and create a numpy array out of it
        '''
        # Subsequently, receive audio data (message) over the WebSocket server connection
        # https://websockets.readthedocs.io/en/stable/reference/sync/server.html#websockets.sync.server.ServerConnection.recv
        frame_data = websocket.recv()

        # Creates numpy array without copying it (more efficient)
        return np.frombuffer(frame_data, dtype=np.float32)




    def process_audio_frames(self, websocket):
        '''
        Get the audio chunk from the WebSocket as a numpy array

        Send a dummy transcription back to the client first
        '''
        # Get the audio chunk from the WebSocket as a numpy array
        frame_np = self.get_audio_from_websocket(websocket)
        print('"Received" audio chunk')
        # Get the client using its associated WebSocket
        client = self.client_manager.get_client(websocket)

        client.add_frames(frame_np)

        # Send a dummy transcription
        # https://websockets.readthedocs.io/en/stable/reference/sync/server.html#websockets.sync.server.ServerConnection.send
        # ServerConnection provides recv() and send() methods for receiving and sending messages.
        websocket.send(
            json.dumps({
                "test": "test response",
            })
        )

        return True




    def recv_audio(self,websocket):
        """
        First handle the new connection

        Continously process audio frames
        """

        # Try to handle the new connection
        if not self.handle_new_connection(websocket):
            return
        
        # Continously process audio frames
        while True: 
            if not self.process_audio_frames(websocket):
                break

## **Loop to simulate sending server audio data (from file) and Server sending client transcription (print)**