# W2V-BERT Embedding Demo

Run `scripts/setup_w2vbert_notebook_env.py` before executing this notebook to provision the dedicated virtual environment and Jupyter kernel.

In [None]:
from pathlib import Path
import sys

def find_repo_root(start: Path) -> Path:
    for candidate in [start, *start.parents]:
        if (candidate / "recipes").exists() and (candidate / "deeplab").exists():
            return candidate
    raise RuntimeError("Unable to locate the repository root.")

NOTEBOOK_DIR = Path.cwd()
REPO_ROOT = find_repo_root(NOTEBOOK_DIR)

SRC_PATHS = [
    REPO_ROOT,
    REPO_ROOT / "recipes/DeepASV",
    REPO_ROOT / "deeplab/pretrained/audio2vector/module/transformers/src",
]

for candidate in SRC_PATHS:
    resolved = str(candidate)
    if candidate.exists() and resolved not in sys.path:
        sys.path.append(resolved)

print(f"Repository root: {REPO_ROOT}")

In [None]:
import torch

from recipes.DeepASV.utils.inference import W2VBERT_SPK_Module

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

embedding_model = W2VBERT_SPK_Module(device=device).load_model()

In [None]:
import librosa

import soundfile as sf

import torch



target_audio = REPO_ROOT.parent / "datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav"

if not target_audio.exists():

    alt_path = REPO_ROOT / "datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav"

    if alt_path.exists():

        target_audio = alt_path

    else:

        raise FileNotFoundError(f"Audio file not found at {target_audio} or {alt_path}")



signal, sr = sf.read(str(target_audio), dtype="float32")

if signal.ndim > 1:

    signal = signal.mean(axis=1)



target_sr = embedding_model.hparams.get("sample_rate", 16000)

if sr != target_sr:

    signal = librosa.resample(signal, orig_sr=sr, target_sr=target_sr)

    sr = target_sr



waveform = torch.from_numpy(signal).unsqueeze(0).to(torch.float32)

print(f"Loaded waveform shape: {waveform.shape}, sample rate: {sr}")


In [None]:
embeddings = embedding_model(waveform)
embedding_vector = embeddings.squeeze(0).detach().cpu().numpy()

print(f"Embedding shape: {embedding_vector.shape}")
embedding_vector[:8]