# W2V-BERT Embedding Demo (Packaged)

This variant uses the published `w2vbert_speaker` package inside the dedicated virtual environment.

In [3]:
from pathlib import Path

def find_repo_root(start: Path) -> Path:
    for candidate in [start, *start.parents]:
        if (candidate / "recipes").exists() and (candidate / "deeplab").exists():
            return candidate
    raise RuntimeError(f"Unable to locate the repository root from {start}.")

NOTEBOOK_DIR = Path.cwd()
REPO_ROOT = find_repo_root(NOTEBOOK_DIR)

print(f"Notebook directory: {NOTEBOOK_DIR}")
print(f"Resolved repository root: {REPO_ROOT}")

Notebook directory: /Users/zb/NWG/w2v-BERT-2.0_SV/recipes/DeepASV/notebooks
Resolved repository root: /Users/zb/NWG/w2v-BERT-2.0_SV


In [4]:
import torch
from pathlib import Path
from w2vbert_speaker import W2VBERT_SPK_Module

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def resolve_checkpoint(repo_root: Path) -> Path:
    candidate_relatives = [
        "deeplab/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0/model_lmft_0.14.pth",
        "../pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0/model_lmft_0.14.pth",
        "model_lmft_0.14.pth",
    ]
    for relative in candidate_relatives:
        candidate = (repo_root / relative).resolve()
        if candidate.exists():
            return candidate
    raise FileNotFoundError("Checkpoint model_lmft_0.14.pth was not found in known locations.")

def resolve_model_dir(repo_root: Path) -> Path:
    candidate_relatives = [
        "deeplab/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0",
        "../pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0",
    ]
    for relative in candidate_relatives:
        candidate = (repo_root / relative).resolve()
        if candidate.is_dir() and (candidate / "model.safetensors").exists():
            return candidate
    raise FileNotFoundError("Local Transformer weights not found. Expected model.safetensors under a known directory.")

checkpoint_path = resolve_checkpoint(REPO_ROOT)
model_dir = resolve_model_dir(REPO_ROOT)

print(f"Using checkpoint: {checkpoint_path}")
print(f"Using local encoder weights from: {model_dir}")

embedding_model = W2VBERT_SPK_Module(device=DEVICE, model_path=str(model_dir)).load_model(checkpoint_path)
spk_model = embedding_model.modules_dict["spk_model"]
target_sr = spk_model.front.feature_extractor.sampling_rate
print(f"Using device: {DEVICE}, target sample rate: {target_sr}")

Using checkpoint: /Users/zb/NWG/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0/model_lmft_0.14.pth
Using local encoder weights from: /Users/zb/NWG/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0
Using device: cpu, target sample rate: 16000
Using device: cpu, target sample rate: 16000


In [5]:
import soundfile as sf
import torchaudio

def resolve_audio_path(repo_root: Path) -> Path:
    candidate_relatives = [
        "../datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav",
        "datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav",
    ]
    for relative in candidate_relatives:
        candidate = (repo_root / relative).resolve()
        if candidate.exists():
            return candidate
    raise FileNotFoundError("Audio file not found. Update resolve_audio_path with a valid sample.")

def load_waveform(path: Path, target_sr: int) -> tuple[torch.Tensor, int]:
    signal, sr = sf.read(str(path), dtype="float32")
    if signal.ndim > 1:
        signal = signal.mean(axis=1)
    waveform = torch.from_numpy(signal)
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
        waveform = resampler(waveform.unsqueeze(0)).squeeze(0)
        sr = target_sr
    waveform = waveform.unsqueeze(0).to(torch.float32)
    return waveform, sr

AUDIO_PATH = resolve_audio_path(REPO_ROOT)
waveform, sr = load_waveform(AUDIO_PATH, target_sr)
print(f"Using audio file: {AUDIO_PATH}")
print(f"Loaded waveform shape: {waveform.shape}, sample rate: {sr}")

Using audio file: /Users/zb/NWG/datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav
Loaded waveform shape: torch.Size([1, 133761]), sample rate: 16000


In [6]:
with torch.inference_mode():
    embeddings = embedding_model(waveform.to(DEVICE))
embedding_vector = embeddings.squeeze(0).detach().cpu().numpy()

print(f"Embedding shape: {embedding_vector.shape}")
embedding_vector[:8]

Embedding shape: (256,)


array([-0.16015907, -0.8298738 ,  0.70560724, -0.14280552,  0.35778913,
        0.47867072, -0.12521656, -0.90179497], dtype=float32)