In [None]:
!pip install websockets
!pip install openai-whisper
!pip install nest_asyncio


In [None]:

import asyncio
import websockets
import json
import base64
import numpy as np
import whisper
import torch  # Importing the torch module
from IPython.display import display, HTML, clear_output
import nest_asyncio
from datetime import datetime
import time
import wave
import io

# Enable nested event loops in Colab or Jupyter
nest_asyncio.apply()

def get_server_ip():
    """Prompts the user to enter the server's public IP."""
    while True:
        server_ip = input("Enter the server's public IP (example: 192.168.1.1): ").strip()
        if server_ip:
            return server_ip
        else:
            print("Invalid IP. Please try again.")

def get_use_cuda():
    """Asks the user whether to use CUDA (GPU) or CPU."""
    while True:
        choice = input("Do you want to use CUDA (GPU)? (y/n): ").strip().lower()
        if choice in ['y', 'yes']:
            return True
        elif choice in ['n', 'no']:
            return False
        else:
            print("Invalid input. Please respond with 'y' or 'n'.")

def get_language():
    """Asks the user whether to use automatic language detection or specify a language."""
    while True:
        choice = input("Do you want the language to be detected automatically? (y/n): ").strip().lower()
        if choice in ['y', 'yes']:
            return "automatic"
        elif choice in ['n', 'no']:
            lang = input("Specify the language code (example: 'pt' for Portuguese): ").strip().lower()
            if lang:
                return lang
            else:
                print("Invalid language code. Please try again.")
        else:
            print("Invalid input. Please respond with 'y' or 'n'.")

# Collect user configurations
SERVER_IP = get_server_ip()
SERVER_PORT = 9024  # You can also make this configurable if desired

USE_CUDA = get_use_cuda()

LANGUAGE = get_language()

# Determine the device based on USE_CUDA and CUDA availability
if USE_CUDA and torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
    if USE_CUDA:
        print("CUDA is not available. Using CPU instead.")

print("Loading Whisper model...")
model = whisper.load_model("medium", device=device)
print(f"Model loaded on device: {device}")
print(f"Language set to: {LANGUAGE if LANGUAGE != 'automatic' else 'Automatic'}")

class AudioClient:
    def __init__(self, server_url):
        self.server_url = server_url
        self.files_processed = 0
        self.running = True
        self.connected = False
        self.reconnect_delay = 1
        self.max_reconnect_delay = 30
        self.last_activity = time.time()
        self.connection_timeout = 60
        self.audio_buffer = []  # Buffer for audio chunks
        self.websocket = None  # Adding WebSocket reference
        self._setup_display()

    def _setup_display(self):
        """Sets up the initial display"""
        display(HTML(f"""
        <div style="padding: 20px; background-color: #f0f0f0; border-radius: 5px;">
            <h3>Audio Client</h3>
            <p>Initializing connection to server at {self.server_url}...</p>
            <div id="status"></div>
        </div>
        """))

    def update_status(self, message, status_type='info'):
        """Updates the status in the Jupyter interface"""
        timestamp = datetime.now().strftime("%H:%M:%S")
        colors = {
            'info': '#e6f3ff',
            'success': '#d4edda',
            'warning': '#fff3cd',
            'error': '#f8d7da'
        }

        status_html = f"""
        <div style="padding: 10px; background-color: {colors.get(status_type, '#e6f3ff')};
                    border-radius: 5px; margin: 10px 0;">
            [{timestamp}] {message}
        </div>
        """
        display(HTML(status_html))

    def validate_audio_format(self, wf):
        """Validates the audio format"""
        if wf.getnchannels() != 1 or wf.getframerate() != 16000:
            self.update_status(
                f"Invalid audio format - Channels: {wf.getnchannels()}, " +
                f"Sample Rate: {wf.getframerate()} Hz",
                'warning'
            )
            return False
        return True

    async def transcribe_audio(self, audio_float):
        """Performs audio transcription"""
        try:
            result = model.transcribe(
                audio=audio_float,
                language=None if LANGUAGE == "automatic" else LANGUAGE,
                task="transcribe",
                fp16=(device == "cuda")  # Use fp16 if on GPU
            )

            transcription = result["text"].strip()
            if transcription:
                self.files_processed += 1
                self.update_status(
                    f"Transcription {self.files_processed}: {transcription}",
                    'success'
                )
                # Send the transcription back to the server
                await self.send_transcription(transcription)
                return transcription
        except Exception as e:
            self.update_status(f"Error during transcription: {str(e)}", 'error')
            return None

    async def send_transcription(self, transcription):
        """Sends the transcription back to the server via WebSocket"""
        if self.websocket and self.connected:
            message = {
                "type": "transcription",
                "timestamp": datetime.now().strftime("%H:%M:%S"),
                "text": transcription
            }
            try:
                await self.websocket.send(json.dumps(message))
                self.update_status(f"Transcription sent to server: {transcription}", 'info')
            except Exception as e:
                self.update_status(f"Error sending transcription: {str(e)}", 'error')

    async def process_audio_data(self, audio_data_b64, message_info):
        """Processes received audio data"""
        try:
            # Decode the audio
            audio_data = base64.b64decode(audio_data_b64)
            message_type = message_info.get('type', 'chunk')

            # If it's an ongoing audio chunk
            if message_type == "chunk":
                # Optional log for reception
                self.update_status(
                    f"Receiving audio... ({len(audio_data)/1024:.1f}KB)",
                    'info'
                )
                return

            # If it's complete audio for transcription
            elif message_type == "complete_audio":
                wav_io = io.BytesIO(audio_data)
                duration = message_info.get('duration', 0)

                with wave.open(wav_io, 'rb') as wf:
                    # Validate format
                    if not self.validate_audio_format(wf):
                        return

                    # Convert to numpy array
                    audio_data = np.frombuffer(wf.readframes(wf.getnframes()), dtype=np.int16)
                    audio_float = audio_data.astype(np.float32) / 32768.0

                    # Check for invalid data
                    if not np.all(np.isfinite(audio_float)):
                        self.update_status("Invalid audio data detected", 'error')
                        return

                    # Log the start of transcription
                    file_size = len(audio_data) / 1024
                    self.update_status(
                        f"Starting transcription - Size: {file_size:.1f}KB, " +
                        f"Duration: {duration:.1f}s",
                        'info'
                    )

                    # Perform the transcription
                    await self.transcribe_audio(audio_float)

        except Exception as e:
            self.update_status(f"Error processing audio: {str(e)}", 'error')

    async def connect_to_server(self):
        """Manages the connection to the server"""
        while self.running:
            try:
                async with websockets.connect(
                    self.server_url,
                    ping_interval=None,
                    max_size=20 * 1024 * 1024,
                    close_timeout=5
                ) as websocket:
                    self.websocket = websocket  # Store the WebSocket reference
                    self.connected = True
                    self.update_status("Connected to the server!", 'success')

                    while True:
                        try:
                            # Wait for a message with timeout
                            message = await asyncio.wait_for(
                                websocket.recv(),
                                timeout=self.connection_timeout
                            )

                            # Reset reconnect delay after success
                            self.reconnect_delay = 1
                            self.last_activity = time.time()

                            # Process the message
                            try:
                                data = json.loads(message)
                                await self.process_audio_data(
                                    data['audio_data'],
                                    {
                                        'type': data.get('type', 'chunk'),
                                        'timestamp': data.get('timestamp'),
                                        'duration': data.get('duration', 0)
                                    }
                                )
                            except json.JSONDecodeError as e:
                                self.update_status(f"Error decoding message: {str(e)}", 'error')
                                continue

                        except asyncio.TimeoutError:
                            if time.time() - self.last_activity > self.connection_timeout:
                                self.update_status("Connection timeout. Reconnecting...", 'warning')
                                break
                        except websockets.exceptions.ConnectionClosed:
                            self.update_status("Connection lost. Reconnecting...", 'warning')
                            break
                        except Exception as e:
                            self.update_status(f"Error: {str(e)}", 'error')
                            continue

            except Exception as e:
                self.connected = False
                self.update_status(
                    f"Connection error: {str(e)}\nAttempting to reconnect in {self.reconnect_delay} seconds...",
                    'error'
                )
                await asyncio.sleep(self.reconnect_delay)
                self.reconnect_delay = min(self.reconnect_delay * 2, self.max_reconnect_delay)

    async def handle_incoming_messages(self):
        """Optional: Handles incoming messages from the server if necessary"""
        # If you need to handle specific messages from the server, implement here
        pass

async def main():
    """Main function"""
    client = AudioClient(f"ws://{SERVER_IP}:{SERVER_PORT}")
    try:
        await client.connect_to_server()
    except KeyboardInterrupt:
        print("\nClosing connection...")
    finally:
        client.running = False

# Run the client
loop = asyncio.get_event_loop()
try:
    loop.run_until_complete(main())
except KeyboardInterrupt:
    print("\nClosing connection...")
finally:
    # Clean up all pending tasks
    pending = asyncio.all_tasks(loop)
    for task in pending:
        task.cancel()
    loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
    loop.close()
