# W2V-BERT Parity Demo: Packaged Eager vs Scripted Preprocessed

This notebook demonstrates how to install the packaged eager model and the lightweight scripted runtime, how to load them from the notebook virtual environment, and how to verify numeric parity between the eager embedding and the preprocessed TorchScript artifact.
Notes:
- The notebook will attempt to install the local packages into the kernel's Python using `sys.executable -m pip install -e ...` if they are not already installed.
- The scripted runtime is expected to use the saved Hugging Face feature_extractor to compute deterministic `input_features` that match the exporter.
- This demo assumes the repository root contains `packages/w2vbert_speaker` and `packages/w2vbert_speaker_scripted`.

In [1]:
# Cell 1: locate repository root and set up common helpers
from pathlib import Path
import sys

def find_repo_root(start: Path) -> Path:
    for candidate in [start, *start.parents]:
        if (candidate / 'recipes').exists() and (candidate / 'deeplab').exists():
            return candidate
    raise RuntimeError(f'Unable to locate the repository root from {start}.')

NOTEBOOK_DIR = Path.cwd()
REPO_ROOT = find_repo_root(NOTEBOOK_DIR)
print(f'Notebook directory: {NOTEBOOK_DIR}')
print(f'Resolved repository root: {REPO_ROOT}')


Notebook directory: /Users/zb/NWG/w2v-BERT-2.0_SV/recipes/DeepASV/notebooks
Resolved repository root: /Users/zb/NWG/w2v-BERT-2.0_SV


In [2]:
# Cell 3: imports and device selection (minimal)
import torch
from importlib import reload

# fresh-import useful during iterative development in notebook
import w2vbert_speaker as eager_pkg
import w2vbert_speaker_scripted as scripted_pkg
reload(eager_pkg)
reload(scripted_pkg)

from w2vbert_speaker import W2VBERT_SPK_Module
from w2vbert_speaker_scripted.runtime import W2VBERT_SPK_Scripted

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


In [3]:
# Cell 4: resolve model/checkpoint and artifact locations (adjust paths if needed)
from pathlib import Path

def resolve_checkpoint(repo_root: Path) -> Path:
    candidate_relatives = [
        'deeplab/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0/model_lmft_0.14.pth',
        '../pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0/model_lmft_0.14.pth',
        'model_lmft_0.14.pth',
    ]
    for relative in candidate_relatives:
        candidate = (repo_root / relative).resolve()
        if candidate.exists():
            return candidate
    raise FileNotFoundError('Checkpoint model_lmft_0.14.pth was not found in known locations.')


def resolve_model_dir(repo_root: Path) -> Path:
    candidate_relatives = [
        'deeplab/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0',
        '../pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0',
    ]
    for relative in candidate_relatives:
        candidate = (repo_root / relative).resolve()
        if candidate.is_dir() and (candidate / 'model.safetensors').exists():
            return candidate
    raise FileNotFoundError('Local Transformer weights not found. Expected model.safetensors under a known directory.')

checkpoint_path = resolve_checkpoint(REPO_ROOT)
model_dir = resolve_model_dir(REPO_ROOT)
print(f'Using checkpoint: {checkpoint_path}')
print(f'Using local encoder weights from: {model_dir}')

# Paths to exported TorchScript artifacts and HF extractor saved by the exporter
artifact_preprocessed = (REPO_ROOT / 'packages' / 'w2vbert_speaker' / 'artifacts' / 'w2vbert_speaker_script_preprocessed.pt').resolve()
artifact_feature_extractor = (REPO_ROOT / 'packages' / 'w2vbert_speaker' / 'artifacts' / 'feature_extractor').resolve()
print('Preprocessed scripted artifact:', artifact_preprocessed)
print('Saved HF feature_extractor dir:', artifact_feature_extractor)


Using checkpoint: /Users/zb/NWG/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0/model_lmft_0.14.pth
Using local encoder weights from: /Users/zb/NWG/pretrained/audio2vector/ckpts/facebook/w2v-bert-2.0
Preprocessed scripted artifact: /Users/zb/NWG/w2v-BERT-2.0_SV/packages/w2vbert_speaker/artifacts/w2vbert_speaker_script_preprocessed.pt
Saved HF feature_extractor dir: /Users/zb/NWG/w2v-BERT-2.0_SV/packages/w2vbert_speaker/artifacts/feature_extractor


In [4]:
# Cell 5: load eager packaged model and determine target sample rate
embedding_model = W2VBERT_SPK_Module(device=DEVICE, model_path=str(model_dir)).load_model(checkpoint_path)
spk_model = embedding_model.modules_dict['spk_model']
target_sr = spk_model.front.feature_extractor.sampling_rate
print(f'Using device: {DEVICE}, target sample rate: {target_sr}')


Using device: cpu, target sample rate: 16000


In [5]:
# Cell 6: load a small audio sample (adjust path to a local file if necessary)
import soundfile as sf
import torchaudio

from pathlib import Path

def resolve_audio_path(repo_root: Path) -> Path:
    candidate_relatives = [
        '../datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav',
        'datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav',
    ]
    for relative in candidate_relatives:
        candidate = (repo_root / relative).resolve()
        if candidate.exists():
            return candidate
    raise FileNotFoundError('Audio file not found. Update resolve_audio_path with a valid sample.')


def load_waveform(path: Path, target_sr: int) -> tuple[torch.Tensor, int]:
    signal, sr = sf.read(str(path), dtype='float32')
    if signal.ndim > 1:
        signal = signal.mean(axis=1)
    waveform = torch.from_numpy(signal)
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
        waveform = resampler(waveform.unsqueeze(0)).squeeze(0)
        sr = target_sr
    waveform = waveform.unsqueeze(0).to(torch.float32)
    return waveform, sr

AUDIO_PATH = resolve_audio_path(REPO_ROOT)
waveform, sr = load_waveform(AUDIO_PATH, target_sr)
print(f'Using audio file: {AUDIO_PATH}')
print(f'Loaded waveform shape: {waveform.shape}, sample rate: {sr}')


Using audio file: /Users/zb/NWG/datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav
Loaded waveform shape: torch.Size([1, 133761]), sample rate: 16000


In [6]:
# Cell 7: run the eager packaged model to get an embedding (eager)
import torch.nn.functional as F
with torch.inference_mode():
    embeddings = embedding_model(waveform.to(DEVICE))
embedding_vector = embeddings.squeeze(0).detach().cpu().numpy()
print(f'Eager embedding shape: {embedding_vector.shape}')


Eager embedding shape: (256,)


In [7]:
# Cell 8: load the scripted preprocessed artifact via the scripted wrapper and run it (safe)
# Use the wrapper (not the underlying scripted module) so shapes are handled correctly.
scripted_wrapper = W2VBERT_SPK_Scripted(scripted_path=str(artifact_preprocessed), feature_extractor_dir=str(artifact_feature_extractor), device=DEVICE)

with torch.inference_mode():
    # wrapper accepts waveform and will compute features, ensure we pass waveform with batch dim
    emb_scripted = scripted_wrapper(waveform.cpu())
    emb_scripted = emb_scripted.squeeze(0).detach().cpu()

print(f'Scripted (preprocessed) embedding shape: {tuple(emb_scripted.shape)}')


Scripted (preprocessed) embedding shape: (256,)


In [8]:
# Cell 9: compute parity metrics between eager embedding and scripted embedding
import numpy as np
from numpy.linalg import norm

# Convert eager to torch for cosine similarity
import torch.nn.functional as F

eager = torch.from_numpy(embedding_vector)
scrip = emb_scripted
cos_sim = F.cosine_similarity(eager.unsqueeze(0), scrip.unsqueeze(0)).item()
l2 = float(norm(eager.numpy() - scrip.numpy()))
print(f'cosine(eager, scripted_preprocessed) = {cos_sim:.6f}')
print(f'    L2(eager, scripted_preprocessed) = {l2:.6e}')

# Assert within a tight tolerance for the preprocessed artifact (expected near-exact parity)
tolerance = 1e-5
if l2 > tolerance:
    raise AssertionError(f'Preprocessed TorchScript deviates from eager more than {tolerance}: L2={l2}')
else:
    print('Preprocessed TorchScript within tolerance of eager model.')


cosine(eager, scripted_preprocessed) = 1.000000
    L2(eager, scripted_preprocessed) = 0.000000e+00
Preprocessed TorchScript within tolerance of eager model.
