<a href="https://colab.research.google.com/github/beruscoder/4-7audioseperation/blob/main/huggingfacewavlm_GMM_fixed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:

!pip install torch torchvision torchaudio --upgrade
!pip install transformers soundfile librosa matplotlib scikit-learn tqdm pesq scipy sentencepiece

# 2. Speaker separation pipeline

import os
import torch
import numpy as np
import soundfile as sf
import librosa
import matplotlib.pyplot as plt
from scipy.signal import stft, istft
from sklearn.cluster import SpectralClustering
from sklearn.mixture import GaussianMixture
from tqdm import tqdm
import json
import warnings

from transformers import WavLMModel, Wav2Vec2FeatureExtractor

warnings.filterwarnings('ignore')


# Configuration

class Config:
    SAMPLE_RATE = 16000
    N_FFT = 1024
    HOP_LENGTH = 256
    WIN_LENGTH = 1024
    MAX_SPEAKERS = 7
    EMBEDDING_DIM = 1024
    CHUNK_SIZE = 4.0
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# HuggingFace WavLM Embedder

class WavLMEmbedder:
    def __init__(self, model_name="microsoft/wavlm-large"):
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
        self.model = WavLMModel.from_pretrained(model_name)
        self.model.to(Config.DEVICE)
        self.model.eval()
        self.config = Config

    def extract_embeddings(self, audio):
        with torch.no_grad():
            if isinstance(audio, np.ndarray):
                audio = torch.from_numpy(audio).float()
            if audio.dim() == 1:
                audio = audio.unsqueeze(0)
            # HuggingFace expects mono float32, [batch, time]
            inputs = self.feature_extractor(audio, sampling_rate=Config.SAMPLE_RATE, return_tensors="pt", padding=True)
            input_values = inputs.input_values.to(Config.DEVICE)
            attention_mask = inputs.attention_mask.to(Config.DEVICE) if "attention_mask" in inputs else None

            outputs = self.model(input_values, attention_mask=attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
            embeddings = hidden_states.mean(dim=1)
            return embeddings.cpu().numpy().squeeze()


# Audio Processing Utilities

class AdvancedAudioProcessor:
    def __init__(self):
        self.config = Config()

    def load_audio(self, path):
        try:
            audio, sr = sf.read(path)
            if len(audio.shape) > 1:
                audio = audio[:, 0]
            if sr != self.config.SAMPLE_RATE:
                audio = librosa.resample(audio, orig_sr=sr, target_sr=self.config.SAMPLE_RATE)
            audio = audio / (np.max(np.abs(audio)) + 1e-7)
            return audio.astype(np.float32)
        except Exception as e:
            print(f"Error loading {path}: {e}")
            return None

    def compute_spectrogram(self, audio):
        f, t, stft_matrix = stft(
            audio,
            fs=self.config.SAMPLE_RATE,
            nperseg=self.config.WIN_LENGTH,
            noverlap=self.config.WIN_LENGTH - self.config.HOP_LENGTH,
            nfft=self.config.N_FFT
        )
        magnitude = np.abs(stft_matrix)
        phase = np.angle(stft_matrix)
        return magnitude, phase, f, t

    def reconstruct_audio(self, magnitude, phase):
        stft_matrix = magnitude * np.exp(1j * phase)
        _, audio = istft(
            stft_matrix,
            fs=self.config.SAMPLE_RATE,
            nperseg=self.config.WIN_LENGTH,
            noverlap=self.config.WIN_LENGTH - self.config.HOP_LENGTH,
            nfft=self.config.N_FFT
        )
        return audio


# Speaker Clustering

class SpeakerClustering:
    def __init__(self, embedder):
        self.embedder = embedder

    def estimate_speakers(self, audio, max_speakers=7):
        chunk_length = int(self.embedder.config.SAMPLE_RATE * 2)  # 2 sec
        embeddings = []
        for i in range(0, len(audio) - chunk_length, chunk_length // 2):
            chunk = audio[i:i + chunk_length]
            emb = self.embedder.extract_embeddings(chunk)
            embeddings.append(emb.flatten())
        embeddings = np.array(embeddings)
        best_score = -np.inf
        best_n_speakers = 2
        for n in range(2, max_speakers + 1):
            try:
                gmm = GaussianMixture(n_components=n, random_state=42)
                labels = gmm.fit_predict(embeddings)
                score = gmm.score(embeddings)
                if score > best_score:
                    best_score = score
                    best_n_speakers = n
            except:
                continue
        return best_n_speakers

    def cluster_frames(self, spectrogram, n_speakers):
        features = []
        for t in range(spectrogram.shape[1]):
            frame = spectrogram[:, t]
            mel_features = librosa.feature.melspectrogram(
                S=frame.reshape(-1, 1),
                sr=self.embedder.config.SAMPLE_RATE,
                n_mels=40
            ).flatten()
            features.append(mel_features)
        features = np.array(features)
        clustering = SpectralClustering(
            n_clusters=n_speakers,
            random_state=42,
            affinity='rbf'
        )
        labels = clustering.fit_predict(features)
        return labels


# Mask Generator

class MaskGenerator:
    def __init__(self):
        self.config = Config()

    def generate_soft_masks(self, spectrogram, labels, n_speakers):
        freq_bins, time_frames = spectrogram.shape
        masks = np.zeros((n_speakers, freq_bins, time_frames))
        for t in range(time_frames):
            speaker_idx = labels[t]
            masks[speaker_idx, :, t] = 1.0
        from scipy.ndimage import gaussian_filter
        for i in range(n_speakers):
            masks[i] = gaussian_filter(masks[i], sigma=1.0)
        mask_sum = np.sum(masks, axis=0, keepdims=True)
        masks = masks / (mask_sum + 1e-7)
        return masks

# Main Separation Pipeline

class ModernSpeakerSeparator:
    def __init__(self):
        self.audio_processor = AdvancedAudioProcessor()
        self.embedder = WavLMEmbedder()
        self.clustering = SpeakerClustering(self.embedder)
        self.mask_generator = MaskGenerator()

    def separate_speakers(self, audio_path, known_speakers=None):
        print(f"Processing: {audio_path}")
        mixture = self.audio_processor.load_audio(audio_path)
        if mixture is None:
            return None, None, 0
        magnitude, phase, f, t = self.audio_processor.compute_spectrogram(mixture)
        if known_speakers is None:
            n_speakers = self.clustering.estimate_speakers(mixture)
        else:
            n_speakers = known_speakers
        print(f"Estimated {n_speakers} speakers")
        labels = self.clustering.cluster_frames(magnitude, n_speakers)
        masks = self.mask_generator.generate_soft_masks(magnitude, labels, n_speakers)
        separated_audio = []
        for i in range(n_speakers):
            masked_magnitude = magnitude * masks[i]
            separated = self.audio_processor.reconstruct_audio(masked_magnitude, phase)
            separated_audio.append(separated)
        return mixture, separated_audio, n_speakers

# Evaluation

class ImprovedEvaluator:
    @staticmethod
    def calculate_si_sdr(reference, estimate):
        reference = reference - np.mean(reference)
        estimate = estimate - np.mean(estimate)
        alpha = np.dot(estimate, reference) / np.dot(reference, reference)
        scaled_reference = alpha * reference
        noise = estimate - scaled_reference
        si_sdr = 10 * np.log10(np.sum(scaled_reference**2) / (np.sum(noise**2) + 1e-7))
        return si_sdr

    @staticmethod
    def calculate_pesq(reference, estimate, sr=16000):
        try:
            from pesq import pesq
            return pesq(sr, reference, estimate, 'wb')
        except ImportError:
            print("PESQ library not available. Install with: pip install pesq")
            return None

    @staticmethod
    def evaluate_separation(separated_audio, ground_truth_paths):
        results = {
            'si_sdr_scores': [],
            'pesq_scores': [],
            'assignments': []
        }
        gt_sources = []
        for path in ground_truth_paths:
            audio, _ = sf.read(path)
            gt_sources.append(audio)
        n_sources = len(gt_sources)
        n_separated = len(separated_audio)
        si_sdr_matrix = np.zeros((n_separated, n_sources))
        for i, sep_audio in enumerate(separated_audio):
            for j, gt_audio in enumerate(gt_sources):
                min_len = min(len(sep_audio), len(gt_audio))
                si_sdr = ImprovedEvaluator.calculate_si_sdr(
                    gt_audio[:min_len],
                    sep_audio[:min_len]
                )
                si_sdr_matrix[i, j] = si_sdr
        from scipy.optimize import linear_sum_assignment
        row_ind, col_ind = linear_sum_assignment(-si_sdr_matrix)
        for i, j in zip(row_ind, col_ind):
            si_sdr = si_sdr_matrix[i, j]
            results['si_sdr_scores'].append(si_sdr)
            results['assignments'].append((i, j))
            min_len = min(len(separated_audio[i]), len(gt_sources[j]))
            pesq_score = ImprovedEvaluator.calculate_pesq(
                gt_sources[j][:min_len],
                separated_audio[i][:min_len]
            )
            if pesq_score is not None:
                results['pesq_scores'].append(pesq_score)
        return results

# ====================
# Visualization
# ====================
def create_visualization(mixture, separated, save_dir, mix_id):
    n_sources = len(separated)
    plt.figure(figsize=(15, 3 * (n_sources + 1)))
    plt.subplot(n_sources + 1, 1, 1)
    plt.plot(mixture)
    plt.title(f"Original Mixture - Mix {mix_id}")
    plt.ylabel("Amplitude")
    for i, source in enumerate(separated):
        plt.subplot(n_sources + 1, 1, i + 2)
        plt.plot(source)
        plt.title(f"Separated Speaker {i}")
        plt.ylabel("Amplitude")
    plt.xlabel("Time (samples)")
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "separation_results.png"), dpi=150)
    plt.close()

# Main Dataset Processing Loop

def process_dataset(mixtures_dir, sources_dir, results_dir, start_id=14999, end_id=14990):
    separator = ModernSpeakerSeparator()
    os.makedirs(results_dir, exist_ok=True)
    results = []
    for mix_id in tqdm(range(start_id, end_id - 1, -1), desc="Processing mixtures"):
        try:
            mix_path = os.path.join(mixtures_dir, f"mix_{mix_id}.wav")
            if not os.path.exists(mix_path):
                continue
            gt_paths = []
            spk_idx = 0
            while True:
                source_path = os.path.join(sources_dir, f"{mix_id}_spk{spk_idx}.wav")
                if os.path.exists(source_path):
                    gt_paths.append(source_path)
                    spk_idx += 1
                else:
                    break
            if len(gt_paths) == 0:
                continue
            mixture, separated, n_speakers = separator.separate_speakers(
                mix_path, known_speakers=len(gt_paths)
            )
            if mixture is None:
                continue
            eval_results = ImprovedEvaluator.evaluate_separation(separated, gt_paths)
            mix_results = {
                'mix_id': mix_id,
                'n_speakers': n_speakers,
                'si_sdr_scores': eval_results['si_sdr_scores'],
                'avg_si_sdr': np.mean(eval_results['si_sdr_scores']),
                'pesq_scores': eval_results['pesq_scores'],
                'avg_pesq': np.mean(eval_results['pesq_scores']) if eval_results['pesq_scores'] else None
            }
            results.append(mix_results)
            mix_dir = os.path.join(results_dir, str(mix_id))
            os.makedirs(mix_dir, exist_ok=True)
            sf.write(os.path.join(mix_dir, "mixture.wav"), mixture, Config.SAMPLE_RATE)
            for i, source in enumerate(separated):
                sf.write(os.path.join(mix_dir, f"separated_{i}.wav"), source, Config.SAMPLE_RATE)
            create_visualization(mixture, separated, mix_dir, mix_id)
            print(f"Mix {mix_id}: {n_speakers} speakers, SI-SDR = {mix_results['avg_si_sdr']:.2f} dB")
        except Exception as e:
            print(f"Error processing mix {mix_id}: {str(e)}")
    with open(os.path.join(results_dir, "results.json"), "w") as f:
        json.dump(results, f, indent=2)
    if results:
        si_sdr_values = [r['avg_si_sdr'] for r in results]
        print(f"\nFinal Results:")
        print(f"Processed mixtures: {len(results)}")
        print(f"Average SI-SDR: {np.mean(si_sdr_values):.2f} dB")
        print(f"Median SI-SDR: {np.median(si_sdr_values):.2f} dB")
        print(f"Max SI-SDR: {np.max(si_sdr_values):.2f} dB")
        print(f"Min SI-SDR: {np.min(si_sdr_values):.2f} dB")
    return results


MIXTURES_DIR = "/content/drive/MyDrive/mixed_data_4to7/mixtures"
SOURCES_DIR = "/content/drive/MyDrive/mixed_data_4to7/sources"
RESULTS_DIR = "/content/results"


results = process_dataset(
    mixtures_dir=MIXTURES_DIR,
    sources_dir=SOURCES_DIR,
    results_dir=RESULTS_DIR,
    start_id=14999,
    end_id=14990
)
print("Processing complete!")

Collecting torch
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading torchaudio-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch)
  Downloading nvidia_cu

preprocessor_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.22k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.26G [00:00<?, ?B/s]


Processing mixtures:   0%|          | 0/10 [00:00<?, ?it/s][A
Processing mixtures:  10%|█         | 1/10 [03:06<27:56, 186.25s/it][A
Processing mixtures:  20%|██        | 2/10 [05:06<19:38, 147.30s/it][A

Processing: /content/drive/MyDrive/mixed_data_4to7/mixtures/mix_14997.wav
Estimated 6 speakers



Processing mixtures:  30%|███       | 3/10 [06:25<13:32, 116.01s/it][A

Mix 14997: 6 speakers, SI-SDR = -30.59 dB
Processing: /content/drive/MyDrive/mixed_data_4to7/mixtures/mix_14996.wav
Estimated 6 speakers



Processing mixtures:  40%|████      | 4/10 [06:38<07:32, 75.41s/it] [A

Mix 14996: 6 speakers, SI-SDR = -40.04 dB
Processing: /content/drive/MyDrive/mixed_data_4to7/mixtures/mix_14995.wav
Estimated 7 speakers



Processing mixtures:  50%|█████     | 5/10 [06:51<04:24, 52.90s/it][A

Mix 14995: 7 speakers, SI-SDR = -30.69 dB
Processing: /content/drive/MyDrive/mixed_data_4to7/mixtures/mix_14994.wav
Estimated 7 speakers



Processing mixtures:  60%|██████    | 6/10 [07:03<02:36, 39.16s/it][A

Mix 14994: 7 speakers, SI-SDR = -31.69 dB
Processing: /content/drive/MyDrive/mixed_data_4to7/mixtures/mix_14993.wav
Estimated 7 speakers



Processing mixtures:  70%|███████   | 7/10 [07:16<01:31, 30.42s/it][A

Mix 14993: 7 speakers, SI-SDR = -36.02 dB
Processing: /content/drive/MyDrive/mixed_data_4to7/mixtures/mix_14992.wav
Estimated 5 speakers



Processing mixtures:  80%|████████  | 8/10 [07:24<00:46, 23.30s/it][A

Mix 14992: 5 speakers, SI-SDR = -33.50 dB
Processing: /content/drive/MyDrive/mixed_data_4to7/mixtures/mix_14991.wav
Estimated 7 speakers



Processing mixtures:  90%|█████████ | 9/10 [07:37<00:20, 20.22s/it][A

Mix 14991: 7 speakers, SI-SDR = -29.24 dB
Processing: /content/drive/MyDrive/mixed_data_4to7/mixtures/mix_14990.wav
Estimated 4 speakers



Processing mixtures: 100%|██████████| 10/10 [07:44<00:00, 46.47s/it]

Mix 14990: 4 speakers, SI-SDR = -33.98 dB

Final Results:
Processed mixtures: 8
Average SI-SDR: -33.22 dB
Median SI-SDR: -32.59 dB
Max SI-SDR: -29.24 dB
Min SI-SDR: -40.04 dB
Processing complete!



