# W2V-BERT TorchScript Demo

Run `scripts/export_w2vbert_torchscript.py` first to generate the TorchScript artifact.

In [3]:
from pathlib import Path

NOTEBOOK_DIR = Path.cwd()

def find_repo_root(start: Path) -> Path:
    for candidate in [start, *start.parents]:
        if (candidate / "recipes").exists() and (candidate / "deeplab").exists():
            return candidate
    raise RuntimeError("Unable to locate the repository root.")

REPO_ROOT = find_repo_root(NOTEBOOK_DIR)
print(f"Notebook directory: {NOTEBOOK_DIR}")
print(f"Repository root: {REPO_ROOT}")

Notebook directory: /Users/zb/NWG/w2v-BERT-2.0_SV/recipes/DeepASV/notebooks
Repository root: /Users/zb/NWG/w2v-BERT-2.0_SV


Prepare a clean environment with the minimal dependencies:

```bash
pip install torch soundfile librosa numpy
```

In [4]:
import torch
from typing import Iterable

def pick_first(paths: Iterable[Path]) -> Path:
    for candidate in paths:
        if candidate.exists():
            return candidate.resolve()
    raise FileNotFoundError("None of the provided paths exist:\n" + "\n".join(str(p) for p in paths))

# ARTIFACT_WAVEFORM = pick_first([
#     REPO_ROOT / "packages/w2vbert_speaker/artifacts/w2vbert_speaker_script.pt",
# ])

ARTIFACT_PREPROCESSED = pick_first([
    REPO_ROOT / "packages/w2vbert_speaker/artifacts/w2vbert_speaker_script_preprocessed.pt",
])

CHECKPOINT_PATH = pick_first([
    REPO_ROOT.parent / "pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0/model_lmft_0.14.pth",
    REPO_ROOT / "deeplab/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0/model_lmft_0.14.pth",
])

MODEL_DIR = pick_first([
    REPO_ROOT.parent / "pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0",
    REPO_ROOT / "deeplab/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0",
])

# print(f"Waveform artifact: {ARTIFACT_WAVEFORM}")
print(f"Preprocessed artifact: {ARTIFACT_PREPROCESSED}")
print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"Model directory: {MODEL_DIR}")

metadata_waveform = {"sample_rate": "", "embedding_dim": ""}
# scripted_waveform = torch.jit.load(str(ARTIFACT_WAVEFORM), map_location="cpu", _extra_files=metadata_waveform)

metadata_preprocessed = {"sample_rate": "", "embedding_dim": "", "preprocessed": ""}
scripted_preprocessed = torch.jit.load(str(ARTIFACT_PREPROCESSED), map_location="cpu", _extra_files=metadata_preprocessed)

sample_rate = int(metadata_waveform["sample_rate"] or metadata_preprocessed["sample_rate"] or 16000)
embedding_dim = int(metadata_waveform["embedding_dim"] or metadata_preprocessed["embedding_dim"] or -1)

print(f"Loaded waveform TorchScript (embedding_dim={embedding_dim}, sample_rate={sample_rate})")
print(f"Loaded preprocessed TorchScript (preprocessed flag={metadata_preprocessed['preprocessed']!r})")

Preprocessed artifact: /Users/zb/NWG/w2v-BERT-2.0_SV/packages/w2vbert_speaker/artifacts/w2vbert_speaker_script_preprocessed.pt
Checkpoint: /Users/zb/NWG/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0/model_lmft_0.14.pth
Model directory: /Users/zb/NWG/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0
Loaded waveform TorchScript (embedding_dim=256, sample_rate=16000)
Loaded preprocessed TorchScript (preprocessed flag=b'true')


In [5]:
from w2vbert_speaker import W2VBERT_SPK_Module

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
eager_model = W2VBERT_SPK_Module(device=DEVICE, model_path=str(MODEL_DIR)).load_model(CHECKPOINT_PATH)
spk_module = eager_model.modules_dict["spk_model"]
feature_extractor = spk_module.front.feature_extractor
target_sr = int(feature_extractor.sampling_rate)
print(f"Using device: {DEVICE}, target sample rate: {target_sr}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu, target sample rate: 16000


In [6]:
import librosa
import soundfile as sf

audio_candidates = [
    REPO_ROOT.parent / "datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav",
    REPO_ROOT / "datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav",
]

for candidate in audio_candidates:
    if candidate.exists():
        target_audio = candidate.resolve()
        break
else:
    raise FileNotFoundError("Audio file not found. Update audio_candidates with a valid sample path.")

signal, sr = sf.read(str(target_audio), dtype="float32")
if signal.ndim > 1:
    signal = signal.mean(axis=1)
if sr != target_sr:
    signal = librosa.resample(signal, orig_sr=sr, target_sr=target_sr)
    sr = target_sr

waveform = torch.from_numpy(signal).unsqueeze(0).to(torch.float32)
print(f"Loaded {target_audio} at sample rate {sr} with waveform shape {waveform.shape}")

Loaded /Users/zb/NWG/datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav at sample rate 16000 with waveform shape torch.Size([1, 133761])


In [8]:
import torch.nn.functional as F
import numpy as np

with torch.inference_mode():
    eager_embedding = eager_model(waveform.to(DEVICE)).squeeze(0).cpu()
# with torch.inference_mode():
#     scripted_embedding = scripted_waveform(waveform).squeeze(0).cpu()

features = feature_extractor(
    waveform.squeeze(0).cpu().numpy(),
    sampling_rate=target_sr,
    return_tensors="pt",
    padding=False,
    truncation=False,
    return_attention_mask=False,
 )
input_features = features["input_features"].to(torch.float32)

with torch.inference_mode():
    preprocessed_embedding = scripted_preprocessed(input_features).squeeze(0).cpu()

# cos_waveform = F.cosine_similarity(eager_embedding.unsqueeze(0), scripted_embedding.unsqueeze(0)).item()
# l2_waveform = torch.norm(eager_embedding - scripted_embedding).item()
cos_preprocessed = F.cosine_similarity(eager_embedding.unsqueeze(0), preprocessed_embedding.unsqueeze(0)).item()
l2_preprocessed = torch.norm(eager_embedding - preprocessed_embedding).item()

def preview(vec: torch.Tensor, n: int = 8) -> str:
    return " ".join(f"{value:.4f}" for value in vec[:n].tolist())

print("Eager embedding preview:", preview(eager_embedding))
# print("Waveform TorchScript preview:", preview(scripted_embedding))
print("Preprocessed TorchScript preview:", preview(preprocessed_embedding))
print()
# print(f"cosine(eager, waveform_script) = {cos_waveform:.6f}")
# print(f"    L2(eager, waveform_script) = {l2_waveform:.6e}")
print(f"cosine(eager, preprocessed_script) = {cos_preprocessed:.6f}")
print(f"    L2(eager, preprocessed_script) = {l2_preprocessed:.6e}")

Eager embedding preview: -0.1602 -0.8299 0.7056 -0.1428 0.3578 0.4787 -0.1252 -0.9018
Preprocessed TorchScript preview: -0.1602 -0.8299 0.7056 -0.1428 0.3578 0.4787 -0.1252 -0.9018

cosine(eager, preprocessed_script) = 1.000000
    L2(eager, preprocessed_script) = 0.000000e+00


In [None]:
tolerance = 1e-5
if l2_preprocessed > tolerance:
    raise AssertionError(f"Preprocessed TorchScript deviates from eager more than {tolerance}: L2={l2_preprocessed}")
# if l2_waveform > 1e-3:
#     print(f"Warning: waveform TorchScript deviates from eager (L2={l2_waveform:.6e}). Use the preprocessed artifact for parity.")
# else:
#     print("Waveform TorchScript within tolerance of eager model.")