<a href="https://colab.research.google.com/github/martingkc/voice_extraction/blob/main/voice_extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install  pyannote.audio
!pip install git+https://github.com/openai/whisper.git

Collecting pyannote.audio
  Downloading pyannote.audio-3.3.2-py2.py3-none-any.whl.metadata (11 kB)
Collecting asteroid-filterbanks>=0.4 (from pyannote.audio)
  Downloading asteroid_filterbanks-0.4.0-py3-none-any.whl.metadata (3.3 kB)
Collecting lightning>=2.0.1 (from pyannote.audio)
  Downloading lightning-2.5.2-py3-none-any.whl.metadata (38 kB)
Collecting pyannote.core>=5.0.0 (from pyannote.audio)
  Downloading pyannote.core-5.0.0-py3-none-any.whl.metadata (1.4 kB)
Collecting pyannote.database>=5.0.1 (from pyannote.audio)
  Downloading pyannote.database-5.1.3-py3-none-any.whl.metadata (1.1 kB)
Collecting pyannote.metrics>=3.2 (from pyannote.audio)
  Downloading pyannote.metrics-3.2.1-py3-none-any.whl.metadata (1.3 kB)
Collecting pyannote.pipeline>=3.0.1 (from pyannote.audio)
  Downloading pyannote.pipeline-3.0.1-py3-none-any.whl.metadata (897 bytes)
Collecting pytorch-metric-learning>=2.1.0 (from pyannote.audio)
  Downloading pytorch_metric_learning-2.8.1-py3-none-any.whl.metadata (18

In [11]:
import os
import tempfile
from pathlib import Path
from transformers import WhisperProcessor, WhisperModel, Wav2Vec2Model, Wav2Vec2Processor
import torch
import whisper
from pyannote.audio import Pipeline
from pydub import AudioSegment
from tqdm import tqdm
from numpy import dot
from numpy.linalg import norm
import soundfile as sf
import random
from collections import defaultdict
from google.colab import userdata
import librosa




In [12]:
AUDIO_FILE = "/content/output2.wav"      # input audio
SPEAKER_AUDIO_FILE = "/content/fatih_solo_final_0133_3088p679s_to_3154p627s_score_0.069.wav"
EMBEDDING_MODEL = "facebook/wav2vec2-base-960h"
DIAR_MODEL = "pyannote/speaker-diarization-3.1"  # pretrained pipeline
WHISPER_MODEL = "medium"
LANGUAGE = "tr"
OUTPUT_DIR = Path("speaker_transcripts")
SAMPLE_RATE = 16_000
OUTPUT_DIR.mkdir(exist_ok=True)


In [4]:
print("Loading diarization pipeline…")
diar_pipeline = Pipeline.from_pretrained(DIAR_MODEL, use_auth_token = userdata.get('HF_TOKEN'))

print("Loading Whisper model…")
device = "cuda" if torch.cuda.is_available() else "cpu"
whisper_model = whisper.load_model(WHISPER_MODEL, device=device)
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")




Loading diarization pipeline…


config.yaml:   0%|          | 0.00/469 [00:00<?, ?B/s]

DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for _speechbrain_save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for _speechbrain_load
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for load
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint save hook for _save
DEBUG:speechbrain.utils.checkpoints:Registered checkpoint load hook for _recover


pytorch_model.bin:   0%|          | 0.00/5.91M [00:00<?, ?B/s]

config.yaml:   0%|          | 0.00/399 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/26.6M [00:00<?, ?B/s]

config.yaml:   0%|          | 0.00/221 [00:00<?, ?B/s]

Loading Whisper model…


100%|█████████████████████████████████████| 1.42G/1.42G [00:25<00:00, 59.5MiB/s]


preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

normalizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

In [13]:
wav2vec_model = Wav2Vec2Model.from_pretrained(EMBEDDING_MODEL)
w2v2_processor = Wav2Vec2Processor.from_pretrained(EMBEDDING_MODEL)

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

In [22]:
def get_embeddings(path):
  audio, sr = librosa.load(path, sr=SAMPLE_RATE)
  input = w2v2_processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)

  with torch.no_grad():
      features = wav2vec_model(**input).last_hidden_state
      emb = features.mean(dim=1).squeeze(0)
  return emb

def cosine_similarity(a, b):
    return dot(a, b) / (norm(a) * norm(b))



def select_speaker(path_speaker, segments, audio_file):
    ref_emb = get_embeddings(path_speaker).detach().cpu().numpy()

    audio = AudioSegment.from_file(audio_file)

    max_sims = {}
    for speaker_id, seg_list in segments.items():
        sims = []
        segment = random.choice(seg_list)
        t0, t1 = segment
        clip = audio[int(t0*1000):int(t1*1000)]

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
            clip.export(tmp.name, format="wav")
            seg_emb = get_embeddings(tmp.name).detach().cpu().numpy()


        sim = cosine_similarity(ref_emb, seg_emb)
        max_sims[speaker_id] = sim

    best_speaker = max(max_sims, key=max_sims.get)
    return best_speaker



In [6]:
print("Running diarization…")
diarization = diar_pipeline(AUDIO_FILE)

Running diarization…


  std = sequences.std(dim=-1, correction=1)


In [7]:
segments_dict = defaultdict(list)
for turn, _, speaker in diarization.itertracks(yield_label=True):
    segments_dict[int(speaker.split("_")[-1])].append((turn.start, turn.end))

print(len(segments_dict))


2


In [24]:
target_speaker = select_speaker(SPEAKER_AUDIO_FILE, segments_dict, AUDIO_FILE)

segments =list(segments_dict[target_speaker])
print(f"Target speaker: {target_speaker}")

Target speaker: 1


In [25]:

from datasets import Dataset, Features, Audio, Value
audio = AudioSegment.from_file(AUDIO_FILE)
ds = {"id":[], "text":[], "audio":[], "speaker":[]}


In [27]:
features = Features(
    {
        "id": Value("string"),
        "audio": Audio(sampling_rate=SAMPLE_RATE),
        "text": Value("string"),
        "speaker": Value("string"),
    }
)

In [26]:
for idx, (t0, t1) in enumerate(tqdm(segments, desc="Transcribing")):
    # pydub works in ms
    seg = audio[int(t0 * 1000) : int(t1 * 1000)]
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        seg.export(tmp.name, format="wav")
        result = whisper_model.transcribe(tmp.name, language=LANGUAGE, verbose=False)
        text = result["text"].strip()
        ds["text"].append(text)
        ds["audio"].append(tmp.name)
        ds["speaker"].append("1")




Transcribing:   0%|          | 0/82 [00:00<?, ?it/s]
  0%|          | 0/521 [00:00<?, ?frames/s][A
100%|██████████| 521/521 [00:01<00:00, 293.30frames/s]
Transcribing:   1%|          | 1/82 [00:02<02:49,  2.09s/it]
  0%|          | 0/756 [00:00<?, ?frames/s][A
100%|██████████| 756/756 [00:02<00:00, 318.26frames/s]
Transcribing:   2%|▏         | 2/82 [00:04<03:07,  2.34s/it]
  0%|          | 0/3 [00:00<?, ?frames/s][A
100%|██████████| 3/3 [00:00<00:00,  7.77frames/s]
Transcribing:   4%|▎         | 3/82 [00:05<01:58,  1.50s/it]
  0%|          | 0/374 [00:00<?, ?frames/s][A
100%|██████████| 374/374 [00:00<00:00, 594.06frames/s]
Transcribing:   5%|▍         | 4/82 [00:05<01:34,  1.21s/it]
  0%|          | 0/302 [00:00<?, ?frames/s][A
100%|██████████| 302/302 [00:00<00:00, 340.60frames/s]
Transcribing:   6%|▌         | 5/82 [00:06<01:27,  1.14s/it]
  0%|          | 0/33 [00:00<?, ?frames/s][A
100%|██████████| 33/33 [00:00<00:00, 42.73frames/s]
Transcribing:   7%|▋         | 6/82 [00:0

KeyError: 'text'

In [28]:
ds["id"] = list(range(len(ds["text"])))

dataset = Dataset.from_dict(ds, features=features)
dataset.push_to_hub("Martingkc/processed_audio_tr", private=True)

Map:   0%|          | 0/82 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]