<a href="https://colab.research.google.com/github/hadil-sgh/-Machine-Learning-Projects/blob/main/Speaker_Diarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import sounddevice as sd
import numpy as np
import torch
from faster_whisper import WhisperModel
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import pairwise_distances
from pynput import keyboard
import time
import signal
import threading
import warnings
from pathlib import Path
import tempfile
import wave

# Suppress warnings
warnings.filterwarnings("ignore")
os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1"

# Audio settings
SAMPLE_RATE = 16000
CHANNELS = 1
DEVICE_INDEX = 1
MIN_RECORD_TIME = 1.0   # seconds
SEGMENT_DURATION = 2.0  # seconds
STEP_DURATION = 1.0     # seconds

# Model settings
WHISPER_MODEL_SIZE = "tiny"
DEVICE = "cpu"

# Globals
audio_buffer = []
RECORDING = False
recording_start_time = None
processing_thread = None
exit_flag = False

print("Loading models...")

# 1. Initialize Whisper for transcription
whisper_model = WhisperModel(WHISPER_MODEL_SIZE, device=DEVICE)
print("✅ Whisper model loaded")

# 2. Initialize PyAnnote for speaker diarization
try:
    from pyannote.audio import Pipeline
    from pyannote.audio.pipelines.utils.hook import ProgressHook

    # Use a temporary directory to avoid permission issues
    temp_dir = tempfile.mkdtemp()
    os.environ["HF_HOME"] = temp_dir

    # Initialize the pipeline
    diarization_pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization-3.1",
        use_auth_token=False  # Set to your HF token if needed
    )

    # Move to CPU
    diarization_pipeline.to(DEVICE)
    print("✅ PyAnnote diarization model loaded")
    USE_PYANNOTE = True
except Exception as e:
    print(f"⚠️ Could not load PyAnnote: {e}")
    print("⚠️ Falling back to basic diarization")
    USE_PYANNOTE = False

    # Try to load SpeechBrain as fallback
    try:
        from speechbrain.inference import EncoderClassifier

        # Try to load the model without symlinks
        speaker_model = EncoderClassifier.from_hparams(
            source="speechbrain/spkrec-ecapa-voxceleb",
            run_opts={"device": DEVICE}
        )
        print("✅ SpeechBrain model loaded as fallback")
        USE_SPEECHBRAIN = True
    except Exception as e:
        print(f"⚠️ Could not load SpeechBrain: {e}")
        USE_SPEECHBRAIN = False

# Check audio device
try:
    device_info = sd.query_devices(DEVICE_INDEX)
    if device_info['max_input_channels'] < CHANNELS:
        raise ValueError("Selected device does not support enough input channels.")
except Exception as e:
    print(f"⚠️ Error: {e}")
    print("🔍 Searching for a valid input device...")
    for i, d in enumerate(sd.query_devices()):
        if d['max_input_channels'] >= CHANNELS:
            DEVICE_INDEX = i
            print(f"✅ Auto-selected input device #{i}: {d['name']}")
            break
    else:
        raise RuntimeError("❌ No suitable input device found.")

# Audio callback
def callback(indata, frames, time, status):
    global RECORDING, audio_buffer
    if status:
        print(f"⚠️ Audio callback status: {status}")
    if RECORDING:
        audio_buffer.append(indata.copy())

# Save audio to WAV file
def save_audio_to_file(audio_data, filename="temp_audio.wav"):
    """Save audio data to a WAV file"""
    with wave.open(filename, 'wb') as wf:
        wf.setnchannels(CHANNELS)
        wf.setsampwidth(2)  # 16-bit
        wf.setframerate(SAMPLE_RATE)
        # Convert float32 to int16
        audio_int = (audio_data * 32767).astype(np.int16)
        wf.writeframes(audio_int.tobytes())
    return filename

# Extract embeddings using SpeechBrain
def extract_embeddings(audio_segments):
    """Extract speaker embeddings using SpeechBrain"""
    if not USE_SPEECHBRAIN:
        return None

    embeddings = []
    for _, _, segment_audio in audio_segments:
        tensor = torch.from_numpy(segment_audio).float().unsqueeze(0)
        emb = speaker_model.encode_batch(tensor).squeeze().detach().cpu().numpy()
        embeddings.append(emb)

    return np.vstack(embeddings) if embeddings else None

# Process audio in a separate thread
def process_audio_thread(audio_data):
    print("🎙️ Processing audio...")

    try:
        # Normalize audio
        audio_data = audio_data / np.max(np.abs(audio_data)) if np.max(np.abs(audio_data)) > 0 else audio_data
        print(f"🔍 Audio data length: {len(audio_data)}")

        # 1. SEGMENT THE AUDIO
        num_samples = len(audio_data)
        segment_samples = int(SAMPLE_RATE * SEGMENT_DURATION)
        step_samples = int(SAMPLE_RATE * STEP_DURATION)

        segments = []
        for start in range(0, num_samples - segment_samples + 1, step_samples):
            end = start + segment_samples
            segments.append((start, end, audio_data[start:end]))

        print(f"📝 {len(segments)} audio segments generated")

        # 2. SPEAKER DIARIZATION
        if USE_PYANNOTE and len(audio_data) > SAMPLE_RATE:  # At least 1 second of audio
            # Save audio to file for PyAnnote
            audio_file = save_audio_to_file(audio_data)
            print(f"Audio saved to {audio_file}")

            # Run diarization
            print("Running PyAnnote diarization...")
            with ProgressHook() as hook:
                diarization = diarization_pipeline(audio_file, hook=hook)

            # Extract speaker turns
            speaker_turns = []
            for turn, _, speaker in diarization.itertracks(yield_label=True):
                speaker_turns.append({
                    'start': turn.start,
                    'end': turn.end,
                    'speaker': speaker
                })

            print(f"🔖 {len(set(turn['speaker'] for turn in speaker_turns))} speakers detected")

            # 3. TRANSCRIBE AND ASSIGN SPEAKERS
            print("Transcribing with speaker labels...")

            # Process each speaker turn
            for turn in speaker_turns:
                # Convert time to samples
                start_sample = int(turn['start'] * SAMPLE_RATE)
                end_sample = min(int(turn['end'] * SAMPLE_RATE), len(audio_data))

                # Extract audio segment
                if end_sample > start_sample:
                    segment_audio = audio_data[start_sample:end_sample]

                    # Transcribe segment
                    segment_audio = segment_audio.astype(np.float32)
                    segments_result, _ = whisper_model.transcribe(segment_audio)
                    transcript = " ".join(s.text for s in segments_result).strip()

                    if transcript:
                        print(f"[{turn['start']:.1f}s-{turn['end']:.1f}s] {turn['speaker']}: {transcript}")

            # Clean up temp file
            try:
                os.remove(audio_file)
            except:
                pass

        elif USE_SPEECHBRAIN and len(segments) > 1:
            # Extract embeddings
            print("Extracting speaker embeddings...")
            embeddings = extract_embeddings(segments)

            if embeddings is not None and len(embeddings) > 1:
                # Cluster embeddings - FIXED VERSION
                num_speakers = min(2, len(embeddings))

                # Calculate distance matrix with cosine distance
                distance_matrix = pairwise_distances(embeddings, metric='cosine')

                # Use AgglomerativeClustering without the 'affinity' parameter
                clustering = AgglomerativeClustering(
                    n_clusters=num_speakers,
                    linkage="average",
                    # Remove the 'affinity' parameter
                    # Use precomputed distances instead
                    affinity='precomputed'
                )
                labels = clustering.fit_predict(distance_matrix)
                print(f"🔖 {len(set(labels))} speakers detected")

                # Transcribe segments with speaker labels
                print("Transcribing with speaker labels...")
                for i, (start, end, segment_audio) in enumerate(segments):
                    segment_audio = segment_audio.astype(np.float32)
                    segments_result, _ = whisper_model.transcribe(segment_audio)
                    transcript = " ".join(s.text for s in segments_result).strip()

                    if transcript:
                        speaker_label = f"Speaker {labels[i] + 1}"
                        start_time = start / SAMPLE_RATE
                        end_time = end / SAMPLE_RATE
                        print(f"[{start_time:.1f}s-{end_time:.1f}s] {speaker_label}: {transcript}")
            else:
                print("⚠️ Not enough segments for diarization")
                # Fall back to regular transcription
                audio_data = audio_data.astype(np.float32)
                segments_result, _ = whisper_model.transcribe(audio_data)
                transcript = " ".join(s.text for s in segments_result).strip()
                if transcript:
                    print(f"📝 Transcript: {transcript}")
                else:
                    print("⚠️ No speech detected")
        else:
            # Just do regular transcription without diarization
            print("Starting transcription (without speaker diarization)...")
            audio_data = audio_data.astype(np.float32)
            segments_result, _ = whisper_model.transcribe(audio_data)
            transcript = " ".join(s.text for s in segments_result).strip()

            if transcript:
                print(f"📝 Transcript: {transcript}")
            else:
                print("⚠️ No speech detected")

    except Exception as e:
        import traceback
        print(f"⚠️ Error in process_audio: {str(e)}")
        print(traceback.format_exc())

    print("✅ Audio processing complete")

# Keyboard handling
def on_press(key):
    global RECORDING, recording_start_time, exit_flag
    if key == keyboard.Key.esc:
        exit_flag = True
        return False  # Stop listener
    elif key == keyboard.Key.space and not RECORDING:
        RECORDING = True
        recording_start_time = time.time()
        audio_buffer.clear()
        print("🎤 Recording started...")

def on_release(key):
    global RECORDING, recording_start_time, audio_buffer, processing_thread
    if key == keyboard.Key.space and RECORDING:
        elapsed = time.time() - recording_start_time
        if elapsed < MIN_RECORD_TIME:
            print(f"⏳ Hold space for at least {MIN_RECORD_TIME} sec to record.")
            return

        RECORDING = False
        print("🛑 Recording stopped.")

        # Process the recorded audio
        if audio_buffer:
            # Concatenate the audio buffer
            audio_data = np.concatenate(audio_buffer, axis=0).flatten()
            audio_buffer = []

            # Process in a separate thread
            if processing_thread and processing_thread.is_alive():
                print("⚠️ Previous processing still running, please wait...")
                processing_thread.join()

            processing_thread = threading.Thread(target=process_audio_thread, args=(audio_data,))
            processing_thread.daemon = True
            processing_thread.start()
        else:
            print("⚠️ No audio recorded.")

# Cleanup and signal handling
def cleanup():
    print("\n🛑 Exiting... Cleaning up.")
    try:
        if 'stream' in globals() and stream.active:
            stream.close()
        if processing_thread and processing_thread.is_alive():
            processing_thread.join(timeout=1.0)  # Wait for processing to finish with timeout
    except Exception as e:
        print(f"⚠️ Cleanup error: {e}")
    print("✅ Goodbye!")

def signal_handler(sig, frame):
    global exit_flag
    exit_flag = True
    cleanup()
    exit(0)

signal.signal(signal.SIGINT, signal_handler)

# Main
if __name__ == "__main__":
    # Set up audio stream
    stream = sd.InputStream(
        samplerate=SAMPLE_RATE,
        channels=CHANNELS,
        callback=callback,
        dtype=np.float32,
        device=DEVICE_INDEX,
        blocksize=8000  # Smaller blocksize for more frequent callbacks
    )

    print("🎙️ Hold [SPACE] to record, release to process. Press [ESC] to exit.")
    print("🔴 [CTRL+C] exits safely.")

    try:
        stream.start()

        # Use a simple approach with a listener that blocks
        with keyboard.Listener(on_press=on_press, on_release=on_release) as listener:
            # Keep checking if we should exit
            while not exit_flag and listener.running:
                time.sleep(0.1)

        cleanup()
    except Exception as e:
        print(f"⚠️ Error in main: {e}")
        cleanup()



ModuleNotFoundError: No module named 'sounddevice'