# W2V-BERT TorchScript Demo

Run `scripts/export_w2vbert_torchscript.py` first to generate the TorchScript artifact.

Prepare a clean environment with the minimal dependencies:

```bash
pip install torch soundfile librosa numpy
```

In [2]:
from pathlib import Path
import torch

NOTEBOOK_DIR = Path.cwd()
ARTIFACT_PATH = (NOTEBOOK_DIR / "../../../packages/w2vbert_speaker/artifacts/w2vbert_speaker_script.pt").resolve()

metadata = {"sample_rate": "", "embedding_dim": ""}
scripted_model = torch.jit.load(str(ARTIFACT_PATH), map_location="cpu", _extra_files=metadata)
sample_rate = int(metadata["sample_rate"]) if metadata["sample_rate"] else 16000
embedding_dim = int(metadata["embedding_dim"]) if metadata["embedding_dim"] else -1

if embedding_dim < 0:
    with torch.inference_mode():
        probe = scripted_model(torch.zeros(1, sample_rate, dtype=torch.float32))
    embedding_dim = int(probe.shape[-1])

print(f"Loaded TorchScript module from {ARTIFACT_PATH}")
print(f"Metadata: sample_rate={sample_rate}, embedding_dim={embedding_dim}")

Loaded TorchScript module from /Users/zb/NWG/w2v-BERT-2.0_SV/packages/w2vbert_speaker/artifacts/w2vbert_speaker_script.pt
Metadata: sample_rate=16000, embedding_dim=256


In [4]:
import librosa
import soundfile as sf

target_audio = (NOTEBOOK_DIR / "../../../../datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav").resolve()
if not target_audio.exists():
    fallback = (NOTEBOOK_DIR / "../../../datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav").resolve()
    if fallback.exists():
        target_audio = fallback
    else:
        raise FileNotFoundError(f"Audio file not found at {target_audio}")

signal, sr = sf.read(str(target_audio), dtype="float32")
if signal.ndim > 1:
    signal = signal.mean(axis=1)
if sr != sample_rate:
    signal = librosa.resample(signal, orig_sr=sr, target_sr=sample_rate)
    sr = sample_rate

waveform = torch.from_numpy(signal).unsqueeze(0).to(torch.float32)
print(f"Loaded {target_audio} at sample rate {sr} with waveform shape {waveform.shape}")

Loaded /Users/zb/NWG/datasets/voxceleb1test/wav/id10270/5r0dWxy17C8/00001.wav at sample rate 16000 with waveform shape torch.Size([1, 133761])


In [5]:
with torch.inference_mode():
    embedding = scripted_model(waveform)

vector = embedding.squeeze(0).cpu().numpy()
print(f"Embedding shape: {vector.shape}")
vector[:8]

Embedding shape: (256,)


array([ 2.2525663 ,  0.41569418,  0.71409535,  1.3925596 ,  1.1965709 ,
       -1.364686  , -0.46425554,  2.1635387 ], dtype=float32)